Skip to content

Commit 74388d5

Browse files
committed
Fix type issues with ndarray.tolist()
1 parent 77fb42c commit 74388d5

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

graphdatascience/graph/ogb_loader.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,12 @@ def _parse_homogeneous(self, dataset: HomogeneousOGBNDataset) -> tuple[list[pd.D
108108
"nodeId": list(range(node_count)),
109109
}
110110
if "node_feat" in graph and graph["node_feat"] is not None:
111-
node_dict["features"] = graph["node_feat"].tolist()
111+
node_dict["features"] = graph["node_feat"].tolist() # type: ignore
112112

113113
if len(dataset.labels[0]) == 1:
114114
node_dict["classLabel"] = [cl[0] for cl in dataset.labels]
115115
else:
116-
node_dict["classLabel"] = dataset.labels.tolist()
116+
node_dict["classLabel"] = dataset.labels.tolist() # type: ignore
117117

118118
split = dataset.get_idx_split()
119119
node_labels = ["Train" for _ in range(node_count)]
@@ -170,13 +170,13 @@ def _parse_heterogeneous(self, dataset: HeterogeneousOGBNDataset) -> tuple[list[
170170
}
171171

172172
if node_label in node_features:
173-
node_dict["features"] = node_features[node_label].tolist()
173+
node_dict["features"] = node_features[node_label].tolist() # type: ignore
174174

175175
if node_label in class_labels:
176176
if len(class_labels[node_label]) == 1:
177177
node_dict["classLabel"] = [cl[0] for cl in class_labels[node_label]]
178178
else:
179-
node_dict["classLabel"] = class_labels[node_label].tolist()
179+
node_dict["classLabel"] = class_labels[node_label].tolist() # type: ignore
180180

181181
node_id_offsets[node_label] = current_offset
182182
current_offset += node_count
@@ -243,7 +243,7 @@ def _parse_homogeneous(self, dataset: HomogeneousOGBLDataset) -> tuple[list[pd.D
243243
"labels": "N",
244244
}
245245
if "node_feat" in graph and graph["node_feat"] is not None:
246-
node_dict["features"] = graph["node_feat"].tolist()
246+
node_dict["features"] = graph["node_feat"].tolist() # type: ignore
247247
nodes = pd.DataFrame(node_dict)
248248

249249
self._logger.info("Preparing relationship data for transfer to server...")
@@ -283,7 +283,7 @@ def _parse_heterogeneous(self, dataset: HeterogeneousOGBLDataset) -> tuple[list[
283283
}
284284

285285
if node_label in node_features:
286-
node_dict["features"] = node_features[node_label].tolist()
286+
node_dict["features"] = node_features[node_label].tolist() # type: ignore
287287

288288
node_id_offsets[node_label] = current_offset
289289
current_offset += node_count

graphdatascience/query_runner/cypher_graph_constructor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ def run(self, node_df: DataFrame, relationship_df: DataFrame) -> None:
386386
)
387387

388388
def _node_query(self, node_df: DataFrame) -> tuple[str, list[list[Any]]]:
389-
node_list = node_df.values.tolist()
389+
# ignore type as tolist return depends on number of dimensions)
390+
node_list: list[list[Any]] = node_df.values.tolist() # type: ignore
390391
node_columns = list(node_df.columns)
391392
node_id_index = node_columns.index("nodeId")
392393

@@ -411,7 +412,7 @@ def _node_query(self, node_df: DataFrame) -> tuple[str, list[list[Any]]]:
411412
return f"UNWIND $nodes as node RETURN node[{node_id_index}] as id{label_query}{property_query}", node_list
412413

413414
def _relationship_query(self, rel_df: DataFrame) -> tuple[str, list[list[Any]]]:
414-
rel_list = rel_df.values.tolist()
415+
rel_list: list[list[Any]] = rel_df.values.tolist() # type: ignore
415416
rel_columns = list(rel_df.columns)
416417
source_id_index = rel_columns.index("sourceNodeId")
417418
target_id_index = rel_columns.index("targetNodeId")

0 commit comments

Comments
 (0)