Skip to content

Commit 2e75043

Browse files
author
Arian Jamasb
committed
add additional hetatm info to parser
1 parent 53a76be commit 2e75043

File tree

1 file changed

+15
-3
lines changed
  • graphein/protein/tensor

1 file changed

+15
-3
lines changed

graphein/protein/tensor/io.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Project Website: https://github.com/a-r-j/graphein
77
# Code Repository: https://github.com/a-r-j/graphein
88

9+
import collections
910
import os
1011
from typing import List, Optional, Union
1112

@@ -218,13 +219,24 @@ def protein_to_pyg(
218219
if store_het:
219220
hetatms = df.loc[df.record_name == "HETATM"]
220221
all_hets = list(set(hetatms.residue_name))
221-
het_coords = {}
222+
het_data = collections.defaultdict(dict)
222223
for het in all_hets:
223-
het_coords[het] = torch.tensor(
224+
het_data[het]["coords"] = torch.tensor(
224225
hetatms.loc[hetatms.residue_name == het][
225226
["x_coord", "y_coord", "z_coord"]
226227
].values
227228
)
229+
het_data[het]["atoms"] = hetatms.loc[hetatms.residue_name == het][
230+
"atom_name"
231+
].values
232+
het_data[het]["residue_number"] = torch.tensor(
233+
hetatms.loc[hetatms.residue_name == het][
234+
"residue_number"
235+
].values
236+
)
237+
het_data[het]["element_symbol"] = hetatms.loc[
238+
hetatms.residue_name == het
239+
]["element_symbol"].values
228240

229241
df = df.loc[df.record_name == "ATOM"]
230242
if remove_nonstandard:
@@ -260,7 +272,7 @@ def protein_to_pyg(
260272
chains=protein_df_to_chain_tensor(df),
261273
)
262274
if store_het:
263-
out.hetatms = [het_coords]
275+
out.hetatms = [het_data]
264276

265277
if store_bfactor:
266278
# group by residue_id and average b_factor per residue

0 commit comments

Comments
 (0)