Skip to content

Commit 06055a7

Browse files
author
Isaac Peterson
committed
updated i-15 scenario
1 parent 8e54862 commit 06055a7

File tree

4 files changed

+560525
-4188
lines changed

4 files changed

+560525
-4188
lines changed

matsimAI/run_gradient_flow_matching.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from torch.utils.tensorboard import SummaryWriter
77
import datetime
88
from tqdm import tqdm
9+
import xml.etree.ElementTree as ET
10+
from xml.dom import minidom
911

1012
def main(args):
1113

@@ -16,9 +18,33 @@ def main(args):
1618
network_name = Path(args.network_path).stem
1719
save_path = Path(output_path, f"{unique_time_string}_nclusters_{args.num_clusters}_{network_name}")
1820
tensorboard_path = Path(save_path, "logs")
21+
writer = SummaryWriter(tensorboard_path)
22+
23+
1924
os.makedirs(tensorboard_path)
25+
26+
tree = ET.parse(args.config_path)
27+
root = tree.getroot()
28+
for module in root.findall("module"):
29+
if module.attrib.get("name") == "counts":
30+
for param in module.findall("param"):
31+
if param.attrib.get("name") == "countsScaleFactor":
32+
param.attrib["value"] = str(args.percent_pop)
33+
else:
34+
ET.SubElement(module, "param", {"name": "countsScaleFactor", "value": str(args.percent_pop)})
2035

21-
writer = SummaryWriter(tensorboard_path)
36+
temp_path = "temp_output.xml"
37+
tree.write(temp_path, encoding="utf-8", xml_declaration=False)
38+
39+
with open(temp_path, "r", encoding="utf-8") as f:
40+
xml_content = f.read()
41+
42+
with open(args.config_path, "w", encoding="utf-8") as f:
43+
f.write('<?xml version="1.0"?>\n')
44+
f.write('<!DOCTYPE config SYSTEM "http://www.matsim.org/files/dtd/config_v2.dtd">\n')
45+
f.write(xml_content)
46+
47+
os.remove(temp_path)
2248

2349
if not os.path.exists(output_path):
2450
os.makedirs(output_path)
@@ -45,7 +71,8 @@ def main(args):
4571
parameters = [W]
4672

4773
TAM = TAM.reshape(-1, Z_2)
48-
TARGET = dataset.target_graph.edge_attr.to(device).to(torch.float32)
74+
percent_pop_float = float(args.percent_pop) / 100
75+
TARGET = dataset.target_graph.edge_attr.to(device).to(torch.float32) * percent_pop_float
4976

5077
optimizer = torch.optim.Adam(parameters, lr=0.001)
5178
pbar = tqdm(range(args.training_steps))
@@ -70,7 +97,7 @@ def main(args):
7097
if step % args.log_interval == 0:
7198
pbar.set_description(f"Loss: {loss.item()}")
7299
writer.add_scalar("Loss/mse", loss.item(), step)
73-
writer.add_scalar("Logs/mad", torch.abs(R[sensor_idxs] - TARGET[sensor_idxs]).sum() / target_size, step)
100+
writer.add_scalar("Logs/mad", torch.abs(R[sensor_idxs] - (TARGET[sensor_idxs])).sum() / target_size, step)
74101

75102
if step != 0 and \
76103
args.save_interval > 0 and \
@@ -95,7 +122,7 @@ def main(args):
95122
if args.best_plans_save_path is not None:
96123
dataset.save_plans_from_flow_res(
97124
best_model.reshape(args.num_clusters, args.num_clusters, 24),
98-
Path(args.best_plans_save_path, "best_plans.xml")
125+
Path(args.best_plans_save_path)
99126
)
100127

101128
if __name__ == "__main__":
@@ -104,6 +131,8 @@ def main(args):
104131
parser.add_argument("results_path", help="path to the output folder for the results of the algorithm")
105132
parser.add_argument("network_path", help="path to matsim xml network")
106133
parser.add_argument("counts_path", help="path to matsim xml counts")
134+
parser.add_argument("config_path", help="path to matsim xml config path")
135+
parser.add_argument("--percent_pop", type=str, default=100, help="The percentage of the population to match to between 0-100.")
107136
parser.add_argument("--num_clusters", type=int, required=True, help="number of clusters for the network")
108137
parser.add_argument("--training_steps", type=int, required=True, help="number of training iterations")
109138
parser.add_argument("--log_interval", type=int, required=True, help="tensorboard logging interval")

scenarios/i-15-scenario/i-15-config.xml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0" ?>
1+
<?xml version="1.0"?>
22
<!DOCTYPE config SYSTEM "http://www.matsim.org/files/dtd/config_v2.dtd">
33
<config>
44
<module name="global">
@@ -18,10 +18,10 @@
1818
<param name="inputCountsFile" value="i-15-counts.xml" />
1919
<param name="writeCountsInterval" value="1" />
2020
<param name="averageCountsOverIterations" value="1" />
21-
</module>
21+
<param name="countsScaleFactor" value="5" /></module>
2222

2323
<module name="controller">
24-
<param name="outputDirectory" value="./output" />
24+
<param name="outputDirectory" value="./i-15-output" />
2525
<param name="firstIteration" value="0" />
2626
<param name="lastIteration" value="100" />
2727
<param name="eventsFileFormat" value="xml" />
@@ -49,7 +49,7 @@
4949
<param name="marginalUtilityOfMoney" value="0" />
5050

5151
<parameterset type="modeParams">
52-
<param name="mode" value="car"/>
52+
<param name="mode" value="car" />
5353
<param name="marginalUtilityOfTraveling_util_hr" value="-6.0" />
5454
</parameterset>
5555

@@ -153,4 +153,4 @@
153153
</module>
154154

155155

156-
</config>
156+
</config>

0 commit comments

Comments
 (0)