|
6 | 6 | # Project Website: https://github.com/a-r-j/graphein
|
7 | 7 | # Code Repository: https://github.com/a-r-j/graphein
|
8 | 8 |
|
| 9 | +import collections |
9 | 10 | import os
|
10 | 11 | from typing import List, Optional, Union
|
11 | 12 |
|
@@ -218,13 +219,24 @@ def protein_to_pyg(
|
218 | 219 | if store_het:
|
219 | 220 | hetatms = df.loc[df.record_name == "HETATM"]
|
220 | 221 | all_hets = list(set(hetatms.residue_name))
|
221 |
| - het_coords = {} |
| 222 | + het_data = collections.defaultdict(dict) |
222 | 223 | for het in all_hets:
|
223 |
| - het_coords[het] = torch.tensor( |
| 224 | + het_data[het]["coords"] = torch.tensor( |
224 | 225 | hetatms.loc[hetatms.residue_name == het][
|
225 | 226 | ["x_coord", "y_coord", "z_coord"]
|
226 | 227 | ].values
|
227 | 228 | )
|
| 229 | + het_data[het]["atoms"] = hetatms.loc[hetatms.residue_name == het][ |
| 230 | + "atom_name" |
| 231 | + ].values |
| 232 | + het_data[het]["residue_number"] = torch.tensor( |
| 233 | + hetatms.loc[hetatms.residue_name == het][ |
| 234 | + "residue_number" |
| 235 | + ].values |
| 236 | + ) |
| 237 | + het_data[het]["element_symbol"] = hetatms.loc[ |
| 238 | + hetatms.residue_name == het |
| 239 | + ]["element_symbol"].values |
228 | 240 |
|
229 | 241 | df = df.loc[df.record_name == "ATOM"]
|
230 | 242 | if remove_nonstandard:
|
@@ -260,7 +272,7 @@ def protein_to_pyg(
|
260 | 272 | chains=protein_df_to_chain_tensor(df),
|
261 | 273 | )
|
262 | 274 | if store_het:
|
263 |
| - out.hetatms = [het_coords] |
| 275 | + out.hetatms = [het_data] |
264 | 276 |
|
265 | 277 | if store_bfactor:
|
266 | 278 | # group by residue_id and average b_factor per residue
|
|
0 commit comments