Skip to content

Commit 3608bf7

Browse files
committed
fixes to cpp2 generator to generate at least correct code for soil temp models (as cpp)
1 parent d4be413 commit 3608bf7

File tree

2 files changed

+65
-54
lines changed

2 files changed

+65
-54
lines changed

src/pycropml/transpiler/generators/cpp2Generator.py

Lines changed: 58 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def visit_standard_method_call(self, node):
196196
self.visit(node.receiver.args[1])
197197
self.write(", 0.0)")
198198
else:
199+
if self.funcname.startswith("model_") or self.funcname.startswith("init_"):
200+
node.receiver.cpp_struct_name = self.struct_name_for(node.receiver.name)
199201
self.visit(z(node))
200202
else:
201203
if not node.args:
@@ -213,9 +215,11 @@ def visit_method_call(self, node):
213215
"%s.%s" % (self.visit(node.receiver), self.write(node.message))
214216

215217
def visit_local(self, node):
216-
sn = self.struct_name_for(node.name)
217-
if sn:
218-
self.write(f"{sn}.")
218+
#if node.name not in list(map(lambda p: p.name, self.params)):
219+
if self.funcname.startswith("model_") or self.funcname.startswith("init_"):
220+
sn = self.struct_name_for(node.name)
221+
if sn:
222+
self.write(f"{sn}.")
219223
self.write(node.name)
220224

221225
def visit_index(self, node):
@@ -539,27 +543,27 @@ def visit_function_definition(self, node):
539543
self.newline(node)
540544
self.write(self.doc.outputs_doc)
541545
self.newline(node)
542-
self.indentation += 1
543-
for arg in z: # self.add_features(node) :
544-
if "feat" in dir(arg):
545-
if arg.feat in ("IN", "INOUT"):
546-
self.newline(node)
547-
if self.model and arg.name not in self.modparam:
548-
#self.visit_decl(arg.pseudo_type)
549-
#if arg.pseudo_type[0] in ["list", "array"]:
550-
# self.write("&")
551-
#self.write(f" {arg.name}")
552-
if node.name.startswith("init_"):
553-
if arg.name in self.exogenous:
554-
self.write(f" = {self.cpp_struct_names['ex']}.get{arg.name}()")
555-
elif arg.pseudo_type[0] == "list":
556-
self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>()")
557-
elif arg.pseudo_type[0] == "array":
558-
x = arg.elts[0].value if "value" in dir(arg.elts[0]) else arg.elts[0].name
559-
if not x:
560-
self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>()")
561-
else:
562-
self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>({x})")
546+
#self.indentation += 1
547+
#for arg in z: # self.add_features(node) :
548+
# if "feat" in dir(arg):
549+
# if arg.feat in ("IN", "INOUT"):
550+
# self.newline(node)
551+
# if self.model and arg.name not in self.modparam:
552+
# self.visit_decl(arg.pseudo_type)
553+
# if arg.pseudo_type[0] in ["list", "array"]:
554+
# self.write("&")
555+
# self.write(f" {arg.name}")
556+
# if node.name.startswith("init_"):
557+
# if arg.name in self.exogenous:
558+
# self.write(f" = {self.cpp_struct_names['ex']}.get{arg.name}()")
559+
# elif arg.pseudo_type[0] == "list":
560+
# self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>()")
561+
# elif arg.pseudo_type[0] == "array":
562+
# x = arg.elts[0].value if "value" in dir(arg.elts[0]) else arg.elts[0].name
563+
# if not x:
564+
# self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>()")
565+
# else:
566+
# self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>({x})")
563567
#else:
564568
# # make left hand side a reference to the result in case of lists and arrays
565569
# if arg.name in self.states and not arg.name.endswith("_t1"):
@@ -573,7 +577,7 @@ def visit_function_definition(self, node):
573577
# elif arg.name in self.exogenous:
574578
# self.write(f" = {self.cpp_struct_names['ex']}.get{arg.name}()")
575579
#self.write(";")
576-
self.indentation -= 1
580+
#self.indentation -= 1
577581
self.body(node.block)
578582
self.newline(node)
579583
self.visit_return(node)
@@ -672,15 +676,19 @@ def visit_str(self, node):
672676

673677
def visit_declaration(self, node):
674678
self.newline(node)
679+
is_init_or_model_func = self.funcname.startswith("model_") or self.funcname.startswith("init_")
675680
for n in node.decl:
676681
dn = dir(n)
677682
self.newline(node)
678683
if 'value' not in dn and not isinstance(n.pseudo_type, list) and n.pseudo_type != "datetime":
679-
self.write(self.types[n.pseudo_type])
680-
self.write(f' {n.name};')
684+
if "feat" not in dn or ("feat" in dn and n.feat not in ("OUT", "INOUT")):
685+
self.write(self.types[n.pseudo_type])
686+
self.write(f' {n.name};')
681687
elif 'elements' not in dn and n.type in ("list", "array"):
682-
if "feat" in dn and n.feat == "OUT":
683-
pass
688+
if "feat" in dn and n.feat in ("OUT", "INOUT") and is_init_or_model_func:
689+
if n.type == "array" and "elts" in dn and n.elts:
690+
self.write(f"{self.struct_name_for(n.name)}.{n.name} = std::vector<{self.types[n.pseudo_type[1]]}>")
691+
self.write(f"({n.elts[0].name if 'name' in dir(n.elts[0]) else n.elts[0].value});")
684692
else:
685693
if n.type == "list":
686694
self.write(f"std::vector<{self.types[n.pseudo_type[1]]}> {n.name};")
@@ -691,29 +699,28 @@ def visit_declaration(self, node):
691699
self.write(f"std::vector<{self.types[n.pseudo_type[1]]}> {n.name}")
692700
self.write(f"({n.elts[0].name if 'name' in dir(n.elts[0]) else n.elts[0].value});")
693701
elif 'value' in dn and n.type in ("int", "float", "str", "bool"):
694-
if "feat" in dn and n.feat == "OUT":
695-
pass
702+
if "feat" in dn and n.feat in ("OUT", "INOUT") and is_init_or_model_func:
703+
self.write(f"{self.struct_name_for(n.name)}.{n.name} = ")
696704
else:
697-
self.write(f"{self.types[n.type]} {n.name}")
698-
self.write(" = ")
699-
if n.type == "local":
700-
self.write(n.value)
701-
else:
702-
self.visit(n)
703-
self.write(";")
705+
self.write(f"{self.types[n.type]} {n.name} = ")
706+
if n.type == "local":
707+
self.write(n.value)
708+
else:
709+
self.visit(n)
710+
self.write(";")
704711
elif n.type == 'datetime':
705-
self.newline(node)
706-
self.write("DateTime ")
707-
self.write(n.name)
712+
if "feat" in dn and n.feat in ("OUT", "INOUT") and is_init_or_model_func:
713+
self.write(f"{self.struct_name_for(n.name)}.{n.name}")
714+
else:
715+
self.write(f"DateTime {n.name}")
708716
if "elts" in dir(n):
709717
self.write(" = ")
710718
self.visit(n.elts)
711719
self.write(";")
712720
elif 'elements' in dn and n.type in ("list", "tuple"):
713-
if "feat" in dn and n.feat in ("OUT", "INOUT"):
721+
if "feat" in dn and n.feat in ("OUT", "INOUT") and is_init_or_model_func:
714722
if n.type == "list":
715-
self.write(f"{self.struct_name_for(n.name)}.{n.name}")
716-
self.write(" = ")
723+
self.write(f"{self.struct_name_for(n.name)}.{n.name} = ")
717724
self.write(u'{')
718725
self.comma_separated_list(n.elements)
719726
self.write(u'};')
@@ -769,12 +776,12 @@ def visit_tuple_decl(self, node):
769776
def visit_float_decl(self, node, pa=None):
770777
self.write(self.types[node])
771778
if pa:
772-
dpa = dir(pa)
779+
dpa = pa.__dict__
773780
if "name" in dpa:
774781
self.write(f" {pa.name}")
775782
if "default_val" in dpa and pa.default_val:
776783
self.write(f"{{{pa.default_val}}}")
777-
else:
784+
elif dpa.get("feat", None) != "IN":
778785
self.write(f"{{{0.0}}}")
779786

780787
def visit_datetime_decl(self, node):
@@ -783,12 +790,12 @@ def visit_datetime_decl(self, node):
783790
def visit_int_decl(self, node, pa=None):
784791
self.write(self.types[node])
785792
if pa:
786-
dpa = dir(pa)
793+
dpa = pa.__dict__
787794
if "name" in dpa:
788795
self.write(f" {pa.name}")
789796
if "default_val" in dpa and pa.default_val:
790797
self.write(f"{{{pa.default_val}}}")
791-
else:
798+
elif dpa.get("feat", None) != "IN":
792799
self.write(f"{{{0}}}")
793800

794801
def visit_str_decl(self, node, pa=None):
@@ -798,17 +805,17 @@ def visit_str_decl(self, node, pa=None):
798805
if "name" in dpa:
799806
self.write(f" {pa.name}")
800807
if "default_val" in dpa and pa.default_val:
801-
self.write(f"{{{pa.default_val}}}")
808+
self.write(f"{{\"{pa.default_val}\"}}")
802809

803810
def visit_bool_decl(self, node, pa=None):
804811
self.write(self.types[node])
805812
if pa:
806-
dpa = dir(pa)
813+
dpa = pa.__dict__
807814
if "name" in dpa:
808815
self.write(f" {pa.name}")
809816
if "default_val" in dpa and pa.default_val:
810817
self.write(f"{{{pa.default_val}}}")
811-
else:
818+
elif dpa.get("feat", None) != "IN":
812819
self.write(f"{{{False}}}")
813820

814821
def visit_array_decl(self, node, pa=None):

src/pycropml/transpiler/rules/cpp2Rules.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ def translateLog(node):
1616

1717
def translateSum(node):
1818
if "name" in dir(node.receiver):
19-
print(node.receiver.y)
19+
#print(node.receiver.y)
20+
if "cpp_struct_name" in dir(node.receiver) and node.receiver.cpp_struct_name is not None:
21+
name = f"{node.receiver.cpp_struct_name}.{node.receiver.name}"
22+
else:
23+
name = node.receiver.name
2024
return Node("call", function="accumulate",
21-
args=[Node("local", name=f"{node.receiver.name}.begin()"),
22-
Node("local", name=f"{node.receiver.name}.end()"),
25+
args=[Node("local", name=f"{name}.begin()"),
26+
Node("local", name=f"{name}.end()"),
2327
Node("int", value="0" if node.receiver.pseudo_type[1] == "int" else "0.0")],
2428
pseudo_type=node.pseudo_type)
2529

0 commit comments

Comments
 (0)