@@ -196,6 +196,8 @@ def visit_standard_method_call(self, node):
196
196
self .visit (node .receiver .args [1 ])
197
197
self .write (", 0.0)" )
198
198
else :
199
+ if self .funcname .startswith ("model_" ) or self .funcname .startswith ("init_" ):
200
+ node .receiver .cpp_struct_name = self .struct_name_for (node .receiver .name )
199
201
self .visit (z (node ))
200
202
else :
201
203
if not node .args :
@@ -213,9 +215,11 @@ def visit_method_call(self, node):
213
215
"%s.%s" % (self .visit (node .receiver ), self .write (node .message ))
214
216
215
217
def visit_local (self , node ):
216
- sn = self .struct_name_for (node .name )
217
- if sn :
218
- self .write (f"{ sn } ." )
218
+ #if node.name not in list(map(lambda p: p.name, self.params)):
219
+ if self .funcname .startswith ("model_" ) or self .funcname .startswith ("init_" ):
220
+ sn = self .struct_name_for (node .name )
221
+ if sn :
222
+ self .write (f"{ sn } ." )
219
223
self .write (node .name )
220
224
221
225
def visit_index (self , node ):
@@ -539,27 +543,27 @@ def visit_function_definition(self, node):
539
543
self .newline (node )
540
544
self .write (self .doc .outputs_doc )
541
545
self .newline (node )
542
- self .indentation += 1
543
- for arg in z : # self.add_features(node) :
544
- if "feat" in dir (arg ):
545
- if arg .feat in ("IN" , "INOUT" ):
546
- self .newline (node )
547
- if self .model and arg .name not in self .modparam :
548
- # self.visit_decl(arg.pseudo_type)
549
- # if arg.pseudo_type[0] in ["list", "array"]:
550
- # self.write("&")
551
- # self.write(f" {arg.name}")
552
- if node .name .startswith ("init_" ):
553
- if arg .name in self .exogenous :
554
- self .write (f" = { self .cpp_struct_names ['ex' ]} .get{ arg .name } ()" )
555
- elif arg .pseudo_type [0 ] == "list" :
556
- self .write (f" = std::vector<{ self .types [arg .pseudo_type [1 ]]} >()" )
557
- elif arg .pseudo_type [0 ] == "array" :
558
- x = arg .elts [0 ].value if "value" in dir (arg .elts [0 ]) else arg .elts [0 ].name
559
- if not x :
560
- self .write (f" = std::vector<{ self .types [arg .pseudo_type [1 ]]} >()" )
561
- else :
562
- self .write (f" = std::vector<{ self .types [arg .pseudo_type [1 ]]} >({ x } )" )
546
+ # self.indentation += 1
547
+ # for arg in z: # self.add_features(node) :
548
+ # if "feat" in dir(arg):
549
+ # if arg.feat in ("IN", "INOUT"):
550
+ # self.newline(node)
551
+ # if self.model and arg.name not in self.modparam:
552
+ # self.visit_decl(arg.pseudo_type)
553
+ # if arg.pseudo_type[0] in ["list", "array"]:
554
+ # self.write("&")
555
+ # self.write(f" {arg.name}")
556
+ # if node.name.startswith("init_"):
557
+ # if arg.name in self.exogenous:
558
+ # self.write(f" = {self.cpp_struct_names['ex']}.get{arg.name}()")
559
+ # elif arg.pseudo_type[0] == "list":
560
+ # self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>()")
561
+ # elif arg.pseudo_type[0] == "array":
562
+ # x = arg.elts[0].value if "value" in dir(arg.elts[0]) else arg.elts[0].name
563
+ # if not x:
564
+ # self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>()")
565
+ # else:
566
+ # self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>({x})")
563
567
#else:
564
568
# # make left hand side a reference to the result in case of lists and arrays
565
569
# if arg.name in self.states and not arg.name.endswith("_t1"):
@@ -573,7 +577,7 @@ def visit_function_definition(self, node):
573
577
# elif arg.name in self.exogenous:
574
578
# self.write(f" = {self.cpp_struct_names['ex']}.get{arg.name}()")
575
579
#self.write(";")
576
- self .indentation -= 1
580
+ # self.indentation -= 1
577
581
self .body (node .block )
578
582
self .newline (node )
579
583
self .visit_return (node )
@@ -672,15 +676,19 @@ def visit_str(self, node):
672
676
673
677
def visit_declaration (self , node ):
674
678
self .newline (node )
679
+ is_init_or_model_func = self .funcname .startswith ("model_" ) or self .funcname .startswith ("init_" )
675
680
for n in node .decl :
676
681
dn = dir (n )
677
682
self .newline (node )
678
683
if 'value' not in dn and not isinstance (n .pseudo_type , list ) and n .pseudo_type != "datetime" :
679
- self .write (self .types [n .pseudo_type ])
680
- self .write (f' { n .name } ;' )
684
+ if "feat" not in dn or ("feat" in dn and n .feat not in ("OUT" , "INOUT" )):
685
+ self .write (self .types [n .pseudo_type ])
686
+ self .write (f' { n .name } ;' )
681
687
elif 'elements' not in dn and n .type in ("list" , "array" ):
682
- if "feat" in dn and n .feat == "OUT" :
683
- pass
688
+ if "feat" in dn and n .feat in ("OUT" , "INOUT" ) and is_init_or_model_func :
689
+ if n .type == "array" and "elts" in dn and n .elts :
690
+ self .write (f"{ self .struct_name_for (n .name )} .{ n .name } = std::vector<{ self .types [n .pseudo_type [1 ]]} >" )
691
+ self .write (f"({ n .elts [0 ].name if 'name' in dir (n .elts [0 ]) else n .elts [0 ].value } );" )
684
692
else :
685
693
if n .type == "list" :
686
694
self .write (f"std::vector<{ self .types [n .pseudo_type [1 ]]} > { n .name } ;" )
@@ -691,29 +699,28 @@ def visit_declaration(self, node):
691
699
self .write (f"std::vector<{ self .types [n .pseudo_type [1 ]]} > { n .name } " )
692
700
self .write (f"({ n .elts [0 ].name if 'name' in dir (n .elts [0 ]) else n .elts [0 ].value } );" )
693
701
elif 'value' in dn and n .type in ("int" , "float" , "str" , "bool" ):
694
- if "feat" in dn and n .feat == "OUT" :
695
- pass
702
+ if "feat" in dn and n .feat in ( "OUT" , "INOUT" ) and is_init_or_model_func :
703
+ self . write ( f" { self . struct_name_for ( n . name ) } . { n . name } = " )
696
704
else :
697
- self .write (f"{ self .types [n .type ]} { n .name } " )
698
- self .write (" = " )
699
- if n .type == "local" :
700
- self .write (n .value )
701
- else :
702
- self .visit (n )
703
- self .write (";" )
705
+ self .write (f"{ self .types [n .type ]} { n .name } = " )
706
+ if n .type == "local" :
707
+ self .write (n .value )
708
+ else :
709
+ self .visit (n )
710
+ self .write (";" )
704
711
elif n .type == 'datetime' :
705
- self .newline (node )
706
- self .write ("DateTime " )
707
- self .write (n .name )
712
+ if "feat" in dn and n .feat in ("OUT" , "INOUT" ) and is_init_or_model_func :
713
+ self .write (f"{ self .struct_name_for (n .name )} .{ n .name } " )
714
+ else :
715
+ self .write (f"DateTime { n .name } " )
708
716
if "elts" in dir (n ):
709
717
self .write (" = " )
710
718
self .visit (n .elts )
711
719
self .write (";" )
712
720
elif 'elements' in dn and n .type in ("list" , "tuple" ):
713
- if "feat" in dn and n .feat in ("OUT" , "INOUT" ):
721
+ if "feat" in dn and n .feat in ("OUT" , "INOUT" ) and is_init_or_model_func :
714
722
if n .type == "list" :
715
- self .write (f"{ self .struct_name_for (n .name )} .{ n .name } " )
716
- self .write (" = " )
723
+ self .write (f"{ self .struct_name_for (n .name )} .{ n .name } = " )
717
724
self .write (u'{' )
718
725
self .comma_separated_list (n .elements )
719
726
self .write (u'};' )
@@ -769,12 +776,12 @@ def visit_tuple_decl(self, node):
769
776
def visit_float_decl (self , node , pa = None ):
770
777
self .write (self .types [node ])
771
778
if pa :
772
- dpa = dir ( pa )
779
+ dpa = pa . __dict__
773
780
if "name" in dpa :
774
781
self .write (f" { pa .name } " )
775
782
if "default_val" in dpa and pa .default_val :
776
783
self .write (f"{{{ pa .default_val } }}" )
777
- else :
784
+ elif dpa . get ( "feat" , None ) != "IN" :
778
785
self .write (f"{{{ 0.0 } }}" )
779
786
780
787
def visit_datetime_decl (self , node ):
@@ -783,12 +790,12 @@ def visit_datetime_decl(self, node):
783
790
def visit_int_decl (self , node , pa = None ):
784
791
self .write (self .types [node ])
785
792
if pa :
786
- dpa = dir ( pa )
793
+ dpa = pa . __dict__
787
794
if "name" in dpa :
788
795
self .write (f" { pa .name } " )
789
796
if "default_val" in dpa and pa .default_val :
790
797
self .write (f"{{{ pa .default_val } }}" )
791
- else :
798
+ elif dpa . get ( "feat" , None ) != "IN" :
792
799
self .write (f"{{{ 0 } }}" )
793
800
794
801
def visit_str_decl (self , node , pa = None ):
@@ -798,17 +805,17 @@ def visit_str_decl(self, node, pa=None):
798
805
if "name" in dpa :
799
806
self .write (f" { pa .name } " )
800
807
if "default_val" in dpa and pa .default_val :
801
- self .write (f"{{{ pa .default_val } }}" )
808
+ self .write (f"{{\" { pa .default_val } \" }}" )
802
809
803
810
def visit_bool_decl (self , node , pa = None ):
804
811
self .write (self .types [node ])
805
812
if pa :
806
- dpa = dir ( pa )
813
+ dpa = pa . __dict__
807
814
if "name" in dpa :
808
815
self .write (f" { pa .name } " )
809
816
if "default_val" in dpa and pa .default_val :
810
817
self .write (f"{{{ pa .default_val } }}" )
811
- else :
818
+ elif dpa . get ( "feat" , None ) != "IN" :
812
819
self .write (f"{{{ False } }}" )
813
820
814
821
def visit_array_decl (self , node , pa = None ):
0 commit comments