6
6
from torch .utils .tensorboard import SummaryWriter
7
7
import datetime
8
8
from tqdm import tqdm
9
+ import xml .etree .ElementTree as ET
10
+ from xml .dom import minidom
9
11
10
12
def main (args ):
11
13
@@ -16,9 +18,33 @@ def main(args):
16
18
network_name = Path (args .network_path ).stem
17
19
save_path = Path (output_path , f"{ unique_time_string } _nclusters_{ args .num_clusters } _{ network_name } " )
18
20
tensorboard_path = Path (save_path , "logs" )
21
+ writer = SummaryWriter (tensorboard_path )
22
+
23
+
19
24
os .makedirs (tensorboard_path )
25
+
26
+ tree = ET .parse (args .config_path )
27
+ root = tree .getroot ()
28
+ for module in root .findall ("module" ):
29
+ if module .attrib .get ("name" ) == "counts" :
30
+ for param in module .findall ("param" ):
31
+ if param .attrib .get ("name" ) == "countsScaleFactor" :
32
+ param .attrib ["value" ] = str (args .percent_pop )
33
+ else :
34
+ ET .SubElement (module , "param" , {"name" : "countsScaleFactor" , "value" : str (args .percent_pop )})
20
35
21
- writer = SummaryWriter (tensorboard_path )
36
+ temp_path = "temp_output.xml"
37
+ tree .write (temp_path , encoding = "utf-8" , xml_declaration = False )
38
+
39
+ with open (temp_path , "r" , encoding = "utf-8" ) as f :
40
+ xml_content = f .read ()
41
+
42
+ with open (args .config_path , "w" , encoding = "utf-8" ) as f :
43
+ f .write ('<?xml version="1.0"?>\n ' )
44
+ f .write ('<!DOCTYPE config SYSTEM "http://www.matsim.org/files/dtd/config_v2.dtd">\n ' )
45
+ f .write (xml_content )
46
+
47
+ os .remove (temp_path )
22
48
23
49
if not os .path .exists (output_path ):
24
50
os .makedirs (output_path )
@@ -45,7 +71,8 @@ def main(args):
45
71
parameters = [W ]
46
72
47
73
TAM = TAM .reshape (- 1 , Z_2 )
48
- TARGET = dataset .target_graph .edge_attr .to (device ).to (torch .float32 )
74
+ percent_pop_float = float (args .percent_pop ) / 100
75
+ TARGET = dataset .target_graph .edge_attr .to (device ).to (torch .float32 ) * percent_pop_float
49
76
50
77
optimizer = torch .optim .Adam (parameters , lr = 0.001 )
51
78
pbar = tqdm (range (args .training_steps ))
@@ -70,7 +97,7 @@ def main(args):
70
97
if step % args .log_interval == 0 :
71
98
pbar .set_description (f"Loss: { loss .item ()} " )
72
99
writer .add_scalar ("Loss/mse" , loss .item (), step )
73
- writer .add_scalar ("Logs/mad" , torch .abs (R [sensor_idxs ] - TARGET [sensor_idxs ]).sum () / target_size , step )
100
+ writer .add_scalar ("Logs/mad" , torch .abs (R [sensor_idxs ] - ( TARGET [sensor_idxs ]) ).sum () / target_size , step )
74
101
75
102
if step != 0 and \
76
103
args .save_interval > 0 and \
@@ -95,7 +122,7 @@ def main(args):
95
122
if args .best_plans_save_path is not None :
96
123
dataset .save_plans_from_flow_res (
97
124
best_model .reshape (args .num_clusters , args .num_clusters , 24 ),
98
- Path (args .best_plans_save_path , "best_plans.xml" )
125
+ Path (args .best_plans_save_path )
99
126
)
100
127
101
128
if __name__ == "__main__" :
@@ -104,6 +131,8 @@ def main(args):
104
131
parser .add_argument ("results_path" , help = "path to the output folder for the results of the algorithm" )
105
132
parser .add_argument ("network_path" , help = "path to matsim xml network" )
106
133
parser .add_argument ("counts_path" , help = "path to matsim xml counts" )
134
+ parser .add_argument ("config_path" , help = "path to matsim xml config path" )
135
+ parser .add_argument ("--percent_pop" , type = str , default = 100 , help = "The percentage of the population to match to between 0-100." )
107
136
parser .add_argument ("--num_clusters" , type = int , required = True , help = "number of clusters for the network" )
108
137
parser .add_argument ("--training_steps" , type = int , required = True , help = "number of training iterations" )
109
138
parser .add_argument ("--log_interval" , type = int , required = True , help = "tensorboard logging interval" )
0 commit comments