Skip to content

Commit d33ba07

Browse files
CalCravenchrisjonesBSUpre-commit-ci[bot]
authored
fix label handling during flatten (#1208)
* fix label handling during flatten * Change the reset_labels method for compound.py to label the container lists with the format 'all-{name}s' for clarity * Add Ruff to pre-commit hooks (#1207) * update CI and precommit files * add ruff changes * remove gmso lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change error type in test * raise the error that is created * remove duplicate windows 3.12 test * fix precommit errors * fix import error * fix CI error --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix labeling of windows compounds * fix references in monomers tests --------- Co-authored-by: Chris Jones <50423140+chrisjonesBSU@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0800b14 commit d33ba07

File tree

3 files changed

+79
-56
lines changed

3 files changed

+79
-56
lines changed

mbuild/compound.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -699,12 +699,13 @@ def add(
699699

700700
if label.endswith("[$]"):
701701
label = label[:-3]
702-
if label not in self.labels:
703-
self.labels[label] = []
702+
all_label = "all-" + label + "s"
703+
if all_label not in self.labels:
704+
self.labels[all_label] = []
704705
label_pattern = label + "[{}]"
705706

706-
count = len(self.labels[label])
707-
self.labels[label].append(new_child)
707+
count = len(self.labels[all_label])
708+
self.labels[all_label].append(new_child)
708709
label = label_pattern.format(count)
709710

710711
if not replace and label in self.labels:
@@ -825,7 +826,21 @@ def _check_if_empty(child):
825826
self.reset_labels()
826827

827828
def reset_labels(self):
828-
"""Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports."""
829+
"""Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports.
830+
831+
Notes
832+
-----
833+
Will renumber the labels in a given Compound. Duplicated labels are named in the format "{name}[$]", where the $ stands in for the 0-indexed
834+
number in the Compound hierarchy with given "name".
835+
836+
i.e. self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"]
837+
and
838+
i.e. self.labels.keys() = ["CH2[1]", "CH2[3]", "CH2[5]"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"]
839+
840+
Additonally, if it doesn't exist, duplicated labels that are numbered as above with the "[$]" will also be put into a list index.
841+
self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] as shown above, but also
842+
have a label of self.labels["all-CH2s"], which is a list of all CH2 children in the Compound.
843+
"""
829844
new_labels = OrderedDict()
830845
hoisted_children = {
831846
key: val
@@ -856,16 +871,16 @@ def reset_labels(self):
856871
if "port" in label:
857872
label = "port[$]"
858873
else:
859-
label = "{0}[$]".format(child.name)
860-
874+
label = f"{child.name}[$]"
861875
if label.endswith("[$]"):
862876
label = label[:-3]
863-
if label not in new_labels:
864-
new_labels[label] = []
877+
all_label = "all-" + label + "s"
878+
if all_label not in new_labels:
879+
new_labels[all_label] = []
865880
label_pattern = label + "[{}]"
866881

867-
count = len(new_labels[label])
868-
new_labels[label].append(child)
882+
count = len(new_labels[all_label])
883+
new_labels[all_label].append(child)
869884
label = label_pattern.format(count)
870885
new_labels[label] = child
871886
self.labels = new_labels
@@ -1880,6 +1895,9 @@ def flatten(self, inplace=True):
18801895
for neighbor in nx.neighbors(bond_graph, particle):
18811896
new_bonds.append((particle, neighbor))
18821897

1898+
# Remove all labels which refer to children in the hierarchy
1899+
self.labels.clear()
1900+
18831901
# Remove all the children
18841902
if inplace:
18851903
for child in children_list:
@@ -1896,6 +1914,7 @@ def flatten(self, inplace=True):
18961914
comp = clone(self)
18971915
comp.flatten(inplace=True)
18981916
return comp
1917+
self.reset_labels()
18991918

19001919
def update_coordinates(self, filename, update_port_locations=True):
19011920
"""Update the coordinates of this Compound from a file.

mbuild/tests/test_compound.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def test_add_by_list(self, h2o):
604604
temp_comp.add(comp_list, label=label_list)
605605
a = [k for k, v in temp_comp.labels.items()]
606606
assert a == [
607-
"water",
607+
"all-waters",
608608
"water[0]",
609609
"water[1]",
610610
"water[2]",
@@ -783,42 +783,14 @@ def test_remove(self, ethane):
783783

784784
# Test to reset labels after hydrogens
785785
ethane6 = mb.clone(ethane)
786-
ethane6.flatten()
787786
hydrogens = ethane6.particles_by_name("H")
788-
ethane6.remove(hydrogens)
787+
ethane6.remove(hydrogens, reset_labels=True)
789788
assert list(ethane6.labels.keys()) == [
790789
"methyl1",
791790
"methyl2",
792-
"C",
793-
"C[0]",
794-
"H",
795-
"C[1]",
796-
"port",
797-
"port[1]",
798-
"port[3]",
799-
"port[5]",
800-
"port[7]",
801-
"port[9]",
802-
"port[11]",
803-
]
804-
805-
ethane7 = mb.clone(ethane)
806-
ethane7.flatten()
807-
hydrogens = ethane7.particles_by_name("H")
808-
ethane7.remove(hydrogens, reset_labels=True)
809-
810-
assert list(ethane7.labels.keys()) == [
811-
"C",
812-
"C[0]",
813-
"C[1]",
814-
"port",
815-
"port[0]",
816-
"port[1]",
817-
"port[2]",
818-
"port[3]",
819-
"port[4]",
820-
"port[5]",
821791
]
792+
assert ethane6.available_ports() == []
793+
assert len(ethane6.all_ports()) == 6
822794

823795
def test_remove_many(self, ethane):
824796
ethane.remove([ethane.children[0], ethane.children[1]])
@@ -1041,6 +1013,31 @@ def test_flatten_box_of_eth(self, ethane):
10411013
box_of_eth.flatten()
10421014
assert len(box_of_eth.children) == box_of_eth.n_particles == 8 * 2
10431015
assert box_of_eth.n_bonds == 7 * 2
1016+
assert list(box_of_eth.labels.keys()) == [
1017+
"all-Cs",
1018+
"C[0]",
1019+
"all-Hs",
1020+
"H[0]",
1021+
"H[1]",
1022+
"H[2]",
1023+
"C[1]",
1024+
"H[3]",
1025+
"H[4]",
1026+
"H[5]",
1027+
"C[2]",
1028+
"H[6]",
1029+
"H[7]",
1030+
"H[8]",
1031+
"C[3]",
1032+
"H[9]",
1033+
"H[10]",
1034+
"H[11]",
1035+
]
1036+
1037+
def test_flatten_then_fill_box(self, benzene):
1038+
benzene.flatten(inplace=True)
1039+
benzene_box = mb.packing.fill_box(compound=benzene, n_compounds=2, density=0.3)
1040+
assert next(iter(benzene_box.particles())).root.bond_graph
10441041

10451042
def test_flatten_with_port(self, ethane):
10461043
ethane.remove(ethane[2])
@@ -1726,7 +1723,7 @@ def test_energy_minimize_shift_com(self, octane):
17261723
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
17271724
)
17281725
def test_energy_minimize_shift_anchor(self, octane):
1729-
anchor_compound = octane.labels["chain"].labels["CH3"][0]
1726+
anchor_compound = octane.labels["chain"].labels["CH3[0]"]
17301727
pos_old = anchor_compound.pos
17311728
octane.energy_minimize(anchor=anchor_compound)
17321729
# check to see if COM of the anchor Compound
@@ -1738,9 +1735,9 @@ def test_energy_minimize_shift_anchor(self, octane):
17381735
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
17391736
)
17401737
def test_energy_minimize_fix_compounds(self, octane):
1741-
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
1742-
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
1743-
carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
1738+
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
1739+
methyl_end1 = octane.labels["chain"].labels["CH3[0]"]
1740+
carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
17441741
not_in_compound = mb.Compound(name="H")
17451742

17461743
# fix the whole molecule and make sure positions are close
@@ -1827,9 +1824,9 @@ def test_energy_minimize_fix_compounds(self, octane):
18271824
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
18281825
)
18291826
def test_energy_minimize_ignore_compounds(self, octane):
1830-
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
1831-
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
1832-
carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
1827+
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
1828+
methyl_end1 = octane.labels["chain"].labels["CH3[1]"]
1829+
carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
18331830
not_in_compound = mb.Compound(name="H")
18341831

18351832
# fix the whole molecule and make sure positions are close
@@ -1859,12 +1856,12 @@ def test_energy_minimize_ignore_compounds(self, octane):
18591856
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
18601857
)
18611858
def test_energy_minimize_distance_constraints(self, octane):
1862-
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
1863-
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
1859+
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
1860+
methyl_end1 = octane.labels["chain"].labels["CH3[1]"]
18641861

1865-
carbon_end0 = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
1866-
carbon_end1 = octane.labels["chain"].labels["CH3"][1].labels["C"][0]
1867-
h_end0 = octane.labels["chain"].labels["CH3"][0].labels["H"][0]
1862+
carbon_end0 = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
1863+
carbon_end1 = octane.labels["chain"].labels["CH3[1]"].labels["C[0]"]
1864+
h_end0 = octane.labels["chain"].labels["CH3[0]"].labels["H[0]"]
18681865

18691866
not_in_compound = mb.Compound(name="H")
18701867

@@ -2539,3 +2536,10 @@ def test_catalog_bondgraph_types(self, benzene):
25392536
catalog_bondgraph_type(compound.children[1][0], compound.bond_graph)
25402537
== "particle_graph"
25412538
)
2539+
2540+
def test_reset_labels(self):
2541+
ethane = mb.load("CC", smiles=True)
2542+
Hs = ethane.particles_by_name("H")
2543+
ethane.remove(Hs, reset_labels=True)
2544+
ports = set(f"port[{i}]" for i in range(6))
2545+
assert ports.issubset(set(ethane.labels.keys()))

mbuild/tests/test_json_formats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_label_consistency(self):
9999
parent.add(CH3())
100100
compound_to_json(parent, "parent.json", include_ports=True)
101101
parent_copy = compound_from_json("parent.json")
102-
assert len(parent_copy["CH2"]) == len(parent["CH2"])
102+
assert len(parent_copy["all-CH2s"]) == len(parent["all-CH2s"])
103103
assert parent_copy.labels.keys() == parent.labels.keys()
104104
for child, child_copy in zip(parent.successors(), parent_copy.successors()):
105105
assert child.labels.keys() == child_copy.labels.keys()

0 commit comments

Comments
 (0)