diff --git a/graphdatascience/model/v2/model.py b/graphdatascience/model/v2/model.py new file mode 100644 index 000000000..7f34fbeab --- /dev/null +++ b/graphdatascience/model/v2/model.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from abc import ABC +from typing import Optional + +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.model.v2.model_details import ModelDetails + + +# Compared to v1 Model offering typed parameters for predict endpoints +class Model(ABC): + def __init__(self, name: str, model_api: ModelApi): + self._name = name + self._model_api = model_api + + # TODO estimate mode, predict modes on here? + # implement Cypher and Arrow info_provider and stuff + + def name(self) -> str: + """ + Get the name of the model. + + Returns: + The name of the model. + + """ + return self._name + + def details(self) -> ModelDetails: + return self._model_api.get(self._name) + + def exists(self) -> bool: + """ + Check whether the model exists. + + Returns: + True if the model exists, False otherwise. + + """ + return self._model_api.exists(self._name) + + def drop(self, failIfMissing: bool = False) -> Optional[ModelDetails]: + """ + Drop the model. + + Args: + failIfMissing: If True, an error is thrown if the model does not exist. If False, no error is thrown. + + Returns: + The result of the drop operation. + + """ + return self._model_api.drop(self._name, failIfMissing) + + def __str__(self) -> str: + return f"{self.__class__.__name__}(name={self.name()}, type={self.details().type})" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.details().model_dump()})" diff --git a/graphdatascience/model/v2/model_api.py b/graphdatascience/model/v2/model_api.py new file mode 100644 index 000000000..d0bbc36c4 --- /dev/null +++ b/graphdatascience/model/v2/model_api.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from graphdatascience.model.v2.model_details import ModelDetails + + +class ModelApi(ABC): + """ + Abstract base class defining the API for model operations. + This class is intended to be subclassed by specific model implementations. + """ + + @abstractmethod + def exists(self, model: str) -> bool: + """ + Check if a specific model exists. + + Args: + model: The name of the model. + + Returns: + True if the model exists, False otherwise. + """ + pass + + @abstractmethod + def get(self, model: str) -> ModelDetails: + """ + Get the details of a specific model. + + Args: + model: The name of the model. + + Returns: + The details of the model. + """ + pass + + @abstractmethod + def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]: + """ + Drop a specific model. + + Args: + model: The name of the model. + fail_if_missing: If True, an error is thrown if the model does not exist. If False, no error is thrown. + + Returns: + The result of the drop operation. + """ + pass diff --git a/graphdatascience/model/v2/model_details.py b/graphdatascience/model/v2/model_details.py new file mode 100644 index 000000000..356fb13c3 --- /dev/null +++ b/graphdatascience/model/v2/model_details.py @@ -0,0 +1,20 @@ +import datetime +from typing import Any + +from pydantic import BaseModel, Field +from pydantic.alias_generators import to_camel + + +class ModelDetails(BaseModel, alias_generator=to_camel): + name: str = Field(alias="modelName") + type: str = Field(alias="modelType") + train_config: dict[str, Any] + graph_schema: dict[str, Any] + loaded: bool + stored: bool + published: bool + model_info: dict[str, Any] # TODO better typing in actual model? + creation_time: datetime.datetime + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) diff --git a/graphdatascience/procedure_surface/api/graphsage_predict_endpoints.py b/graphdatascience/procedure_surface/api/graphsage_predict_endpoints.py new file mode 100644 index 000000000..87b188819 --- /dev/null +++ b/graphdatascience/procedure_surface/api/graphsage_predict_endpoints.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from pandas import DataFrame + +from graphdatascience.procedure_surface.api.base_result import BaseResult +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult + +from ...graph.graph_object import Graph + + +class GraphSagePredictEndpoints(ABC): + """ + Abstract base class defining the API for the GraphSage algorithm. + """ + + @abstractmethod + def stream(self, G: Graph, **config: Any) -> DataFrame: + pass + + @abstractmethod + def write(self, G: Graph, **config: Any) -> GraphSageWriteResult: + pass + + @abstractmethod + def mutate(self, G: Graph, **config: Any) -> GraphSageMutateResult: + pass + + @abstractmethod + def estimate(self, G: Graph, **config: Any) -> EstimationResult: + pass + + +class GraphSageMutateResult(BaseResult): + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + mutate_millis: int + configuration: dict[str, Any] + + +class GraphSageWriteResult(BaseResult): + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + write_millis: int + configuration: dict[str, Any] diff --git a/graphdatascience/procedure_surface/api/graphsage_train_endpoints.py b/graphdatascience/procedure_surface/api/graphsage_train_endpoints.py new file mode 100644 index 000000000..fd63a2718 --- /dev/null +++ b/graphdatascience/procedure_surface/api/graphsage_train_endpoints.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from graphdatascience.procedure_surface.api.base_result import BaseResult +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 + +from ...graph.graph_object import Graph + + +class GraphSageTrainEndpoints(ABC): + """ + Abstract base class defining the API for the GraphSage algorithm. + """ + + @abstractmethod + def train( + self, + G: Graph, + model_name: str, + feature_properties: List[str], + activation_function: Optional[Any] = None, + negative_sample_weight: Optional[int] = None, + embedding_dimension: Optional[int] = None, + tolerance: Optional[float] = None, + learning_rate: Optional[float] = None, + max_iterations: Optional[int] = None, + sample_sizes: Optional[List[int]] = None, + aggregator: Optional[Any] = None, + penalty_l2: Optional[float] = None, + search_depth: Optional[int] = None, + epochs: Optional[int] = None, + projected_feature_dimension: Optional[int] = None, + batch_sampling_ratio: Optional[float] = None, + store_model_to_disk: Optional[bool] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + relationship_weight_property: Optional[str] = None, + random_seed: Optional[Any] = None, + ) -> tuple[GraphSageModelV2, GraphSageTrainResult]: + """ + Trains a GraphSage model on the given graph. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + model_name : str + Name under which the model will be stored + feature_properties : List[str] + The names of the node properties to use as input features + activation_function : Optional[Any], default=None + The activation function to apply after each layer + negative_sample_weight : Optional[int], default=None + Weight of negative samples in the loss function + embedding_dimension : Optional[int], default=None + The dimension of the generated embeddings + tolerance : Optional[float], default=None + Tolerance for early stopping based on loss improvement + learning_rate : Optional[float], default=None + Learning rate for the training optimization + max_iterations : Optional[int], default=None + Maximum number of training iterations + sample_sizes : Optional[List[int]], default=None + Number of neighbors to sample at each layer + aggregator : Optional[Any], default=None + The aggregator function for neighborhood aggregation + penalty_l2 : Optional[float], default=None + L2 regularization penalty + search_depth : Optional[int], default=None + Maximum search depth for neighbor sampling + epochs : Optional[int], default=None + Number of training epochs + projected_feature_dimension : Optional[int], default=None + Dimension to project input features to before training + batch_sampling_ratio : Optional[float], default=None + Ratio of nodes to sample for each training batch + store_model_to_disk : Optional[bool], default=None + Whether to persist the model to disk + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run + username : Optional[str] = None + The username to attribute the procedure run to + log_progress : Optional[bool], default=None + Whether to log progress + sudo : Optional[bool], default=None + Override memory estimation limits + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + batch_size : Optional[int], default=None + Batch size for training + relationship_weight_property : Optional[str], default=None + The property name that contains weight + random_seed : Optional[Any], default=None + Random seed for reproducible results + + Returns + ------- + GraphSageModelV2 + Trained model + """ + + +class GraphSageTrainResult(BaseResult): + model_info: dict[str, Any] + configuration: dict[str, Any] + train_millis: int diff --git a/graphdatascience/procedure_surface/api/model/graphsage_model.py b/graphdatascience/procedure_surface/api/model/graphsage_model.py new file mode 100644 index 000000000..4b1b532f4 --- /dev/null +++ b/graphdatascience/procedure_surface/api/model/graphsage_model.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import Optional + +from pandas import DataFrame + +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.graphsage_predict_endpoints import ( + GraphSageMutateResult, + GraphSagePredictEndpoints, + GraphSageWriteResult, +) + +from ....graph.graph_object import Graph +from ....graph.graph_type_check import graph_type_check +from ....model.v2.model import Model + + +class GraphSageModelV2(Model): + """ + Represents a GraphSAGE model in the model catalog. + Construct this using :func:`gds.graphSage.train()`. + """ + + def __init__(self, name: str, model_api: ModelApi, predict_endpoints: GraphSagePredictEndpoints) -> None: + super().__init__(name, model_api) + self._predict_endpoints = predict_endpoints + + @graph_type_check + def predict_write( + self, + G: Graph, + write_property: str, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + write_concurrency: Optional[int] = None, + write_to_result_store: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> GraphSageWriteResult: + """ + Generate embeddings for the given graph and write the results to the database. + + Args: + G: The graph to generate embeddings for. + write_property: The property to write the embeddings to. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + write_concurrency: The concurrency for writing. + write_to_result_store: Whether to write to the result store. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The result of the write operation. + + """ + return self._predict_endpoints.write( + G, + modelName=self.name(), + writeProperty=write_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + writeConcurrency=write_concurrency, + writeToResultStore=write_to_result_store, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) + + def predict_stream( + self, + G: Graph, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> DataFrame: + """ + Generate embeddings for the given graph and stream the results. + + Args: + G: The graph to generate embeddings for. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The streaming results as a DataFrame. + + """ + return self._predict_endpoints.stream( + G, + modelName=self.name(), + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) + + def predict_mutate( + self, + G: Graph, + mutate_property: str, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> GraphSageMutateResult: + """ + Generate embeddings for the given graph and mutate the graph with the results. + + Args: + G: The graph to generate embeddings for. + mutate_property: The property to mutate with the embeddings. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The result of the mutate operation. + + """ + return self._predict_endpoints.mutate( + G, + modelName=self.name(), + mutateProperty=mutate_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) + + @graph_type_check + def predict_estimate( + self, + G: Graph, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> EstimationResult: + """ + Estimate the memory needed to generate embeddings for the given graph and write the results to the database. + + Args: + G: The graph to generate embeddings for. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The memory needed to generate embeddings for the given graph and write the results to the database. + + """ + return self._predict_endpoints.estimate( + G, + modelName=self.name(), + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) diff --git a/graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py new file mode 100644 index 000000000..13fa1127e --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py @@ -0,0 +1,63 @@ +from typing import Any + +from pandas import DataFrame + +from graphdatascience.graph.graph_object import Graph +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.graphsage_predict_endpoints import ( + GraphSageMutateResult, + GraphSagePredictEndpoints, + GraphSageWriteResult, +) + +from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from .model_api_arrow import ModelApiArrow +from .node_property_endpoints import NodePropertyEndpoints + + +class GraphSagePredictArrowEndpoints(GraphSagePredictEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + self._node_property_endpoints = NodePropertyEndpoints(arrow_client) + self._model_api = ModelApiArrow(arrow_client) + + def stream(self, G: Graph, **config: Any) -> DataFrame: + config = self._node_property_endpoints.create_base_config(G, **config) + + return self._node_property_endpoints.run_job_and_stream("v2/embeddings.graphSage", G, config) + + def write(self, G: Graph, **config: Any) -> GraphSageWriteResult: + config = self._node_property_endpoints.create_base_config(G, **config) + + raw_result = self._node_property_endpoints.run_job_and_write( + "v2/embeddings.graphSage", + G, + config, + config.get("writeConcurrency"), + config.get("concurrency"), + ) + + return GraphSageWriteResult(**raw_result) + + def mutate(self, G: Graph, **config: Any) -> GraphSageMutateResult: + config = self._node_property_endpoints.create_base_config(G, **config) + + mutateProperty = config.pop("mutateProperty", "") + + raw_result = self._node_property_endpoints.run_job_and_mutate( + "v2/embeddings.graphSage", + G, + config, + mutateProperty, + ) + + return GraphSageMutateResult(**raw_result) + + def estimate(self, G: Graph, **config: Any) -> EstimationResult: + config = self._node_property_endpoints.create_estimate_config(**config) + + return self._node_property_endpoints.estimate( + "v2/embeddings.graphSage.estimate", + G, + config, + ) diff --git a/graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py new file mode 100644 index 000000000..e3ef22ff9 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py @@ -0,0 +1,89 @@ +from typing import Any, List, Optional + +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.arrow.graphsage_predict_arrow_endpoints import GraphSagePredictArrowEndpoints + +from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from ...graph.graph_object import Graph +from ..api.graphsage_train_endpoints import ( + GraphSageTrainEndpoints, + GraphSageTrainResult, +) +from .model_api_arrow import ModelApiArrow +from .node_property_endpoints import NodePropertyEndpoints + + +class GraphSageTrainArrowEndpoints(GraphSageTrainEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + self._node_property_endpoints = NodePropertyEndpoints(arrow_client) + self._model_api = ModelApiArrow(arrow_client) + + def train( + self, + G: Graph, + model_name: str, + feature_properties: List[str], + activation_function: Optional[Any] = None, + negative_sample_weight: Optional[int] = None, + embedding_dimension: Optional[int] = None, + tolerance: Optional[float] = None, + learning_rate: Optional[float] = None, + max_iterations: Optional[int] = None, + sample_sizes: Optional[List[int]] = None, + aggregator: Optional[Any] = None, + penalty_l2: Optional[float] = None, + search_depth: Optional[int] = None, + epochs: Optional[int] = None, + projected_feature_dimension: Optional[int] = None, + batch_sampling_ratio: Optional[float] = None, + store_model_to_disk: Optional[bool] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + relationship_weight_property: Optional[str] = None, + random_seed: Optional[Any] = None, + ) -> tuple[GraphSageModelV2, GraphSageTrainResult]: + config = self._node_property_endpoints.create_base_config( + G, + model_name=model_name, + feature_properties=feature_properties, + activation_function=activation_function, + negative_sample_weight=negative_sample_weight, + embedding_dimension=embedding_dimension, + tolerance=tolerance, + learning_rate=learning_rate, + max_iterations=max_iterations, + sample_sizes=sample_sizes, + aggregator=aggregator, + penalty_l2=penalty_l2, + search_depth=search_depth, + epochs=epochs, + projected_feature_dimension=projected_feature_dimension, + batch_sampling_ratio=batch_sampling_ratio, + store_model_to_disk=store_model_to_disk, + relationship_types=relationship_types, + node_labels=node_labels, + username=username, + log_progress=log_progress, + sudo=sudo, + concurrency=concurrency, + job_id=job_id, + batch_size=batch_size, + relationship_weight_property=relationship_weight_property, + random_seed=random_seed, + ) + + result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.graphSage.train", G, config) + + model = GraphSageModelV2( + model_name, self._model_api, predict_endpoints=GraphSagePredictArrowEndpoints(self._arrow_client) + ) + train_result = GraphSageTrainResult(**result) + + return model, train_result diff --git a/graphdatascience/procedure_surface/arrow/model_api_arrow.py b/graphdatascience/procedure_surface/arrow/model_api_arrow.py new file mode 100644 index 000000000..7fa730c72 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/model_api_arrow.py @@ -0,0 +1,57 @@ +import datetime +import json +import re +from typing import Any, Optional + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.model.v2.model_details import ModelDetails + + +class ModelApiArrow(ModelApi): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client: AuthenticatedArrowClient = arrow_client + super().__init__() + + def exists(self, model: str) -> bool: + raw_result = self._arrow_client.do_action_with_retry( + "v2/model.exists", payload=json.dumps({"modelName": model}).encode("utf-8") + ) + result = deserialize(raw_result) + + if not result: + return False + + return True + + def get(self, model: str) -> ModelDetails: + raw_result = self._arrow_client.do_action_with_retry( + "v2/model.get", payload=json.dumps({"modelName": model}).encode("utf-8") + ) + result = deserialize(raw_result) + + if not result: + raise ValueError(f"There is no '{model}' in the model catalog") + + return self._parse_model_details(result[0]) + + def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]: + raw_result = self._arrow_client.do_action_with_retry( + "v2/model.drop", payload=json.dumps({"modelName": model, "failIfMissing": fail_if_missing}).encode("utf-8") + ) + result = deserialize(raw_result) + + if not result: + return None + + return self._parse_model_details(result[0]) + + def _parse_model_details(self, input: dict[str, Any]) -> ModelDetails: + creation_time = input.pop("creationTime") + if creation_time and isinstance(creation_time, str): + # Trim microseconds from 9 digits to 6 digits + trimmed = re.sub(r"\.(\d{6})\d+", r".\1", creation_time) + input["creationTime"] = datetime.datetime.strptime(trimmed, "%Y-%m-%dT%H:%M:%S.%fZ[%Z]") + + return ModelDetails(**input) diff --git a/graphdatascience/procedure_surface/cypher/graphsage_predict_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graphsage_predict_cypher_endpoints.py new file mode 100644 index 000000000..9c43d2a8f --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/graphsage_predict_cypher_endpoints.py @@ -0,0 +1,58 @@ +from typing import Any + +from pandas import DataFrame + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.graphsage_predict_endpoints import ( + GraphSageMutateResult, + GraphSagePredictEndpoints, + GraphSageWriteResult, +) +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter + +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner + + +class GraphSagePredictCypherEndpoints(GraphSagePredictEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def stream(self, G: Graph, **config: Any) -> DataFrame: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.beta.graphSage.stream", params=params) + + def write(self, G: Graph, **config: Any) -> GraphSageWriteResult: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + raw_result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.write", params=params) + + return GraphSageWriteResult(**raw_result.iloc[0].to_dict()) + + def mutate(self, G: Graph, **config: Any) -> GraphSageMutateResult: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + raw_result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.mutate", params=params) + + return GraphSageMutateResult(**raw_result.iloc[0].to_dict()) + + def estimate(self, G: Graph, **config: Any) -> EstimationResult: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + raw_result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.stream.estimate", params=params) + + return EstimationResult(**raw_result.iloc[0].to_dict()) diff --git a/graphdatascience/procedure_surface/cypher/graphsage_train_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graphsage_train_cypher_endpoints.py new file mode 100644 index 000000000..4364e3cee --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/graphsage_train_cypher_endpoints.py @@ -0,0 +1,89 @@ +from typing import Any, List, Optional + +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.cypher.graphsage_predict_cypher_endpoints import GraphSagePredictCypherEndpoints +from graphdatascience.procedure_surface.cypher.model_api_cypher import ModelApiCypher + +from ...call_parameters import CallParameters +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner +from ..api.graphsage_train_endpoints import ( + GraphSageTrainEndpoints, + GraphSageTrainResult, +) +from ..utils.config_converter import ConfigConverter + + +class GraphSageTrainCypherEndpoints(GraphSageTrainEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def train( + self, + G: Graph, + model_name: str, + feature_properties: List[str], + activation_function: Optional[Any] = None, + negative_sample_weight: Optional[int] = None, + embedding_dimension: Optional[int] = None, + tolerance: Optional[float] = None, + learning_rate: Optional[float] = None, + max_iterations: Optional[int] = None, + sample_sizes: Optional[List[int]] = None, + aggregator: Optional[Any] = None, + penalty_l2: Optional[float] = None, + search_depth: Optional[int] = None, + epochs: Optional[int] = None, + projected_feature_dimension: Optional[int] = None, + batch_sampling_ratio: Optional[float] = None, + store_model_to_disk: Optional[bool] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + relationship_weight_property: Optional[str] = None, + random_seed: Optional[Any] = None, + ) -> tuple[GraphSageModelV2, GraphSageTrainResult]: + config = ConfigConverter.convert_to_gds_config( + model_name=model_name, + feature_properties=feature_properties, + activation_function=activation_function, + negative_sample_weight=negative_sample_weight, + embedding_dimension=embedding_dimension, + tolerance=tolerance, + learning_rate=learning_rate, + max_iterations=max_iterations, + sample_sizes=sample_sizes, + aggregator=aggregator, + penalty_l2=penalty_l2, + search_depth=search_depth, + epochs=epochs, + projected_feature_dimension=projected_feature_dimension, + batch_sampling_ratio=batch_sampling_ratio, + store_model_to_disk=store_model_to_disk, + relationship_types=relationship_types, + node_labels=node_labels, + username=username, + log_progress=log_progress, + sudo=sudo, + concurrency=concurrency, + job_id=job_id, + batch_size=batch_size, + relationship_weight_property=relationship_weight_property, + random_seed=random_seed, + ) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.train", params=params).iloc[0] + + return GraphSageModelV2( + name=model_name, + model_api=ModelApiCypher(self._query_runner), + predict_endpoints=GraphSagePredictCypherEndpoints(self._query_runner), + ), GraphSageTrainResult(**result.to_dict()) diff --git a/graphdatascience/procedure_surface/cypher/model_api_cypher.py b/graphdatascience/procedure_surface/cypher/model_api_cypher.py new file mode 100644 index 000000000..01772829f --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/model_api_cypher.py @@ -0,0 +1,49 @@ +from typing import Any, Optional + +import neo4j + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.model.v2.model_details import ModelDetails +from graphdatascience.query_runner.query_runner import QueryRunner + + +class ModelApiCypher(ModelApi): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + super().__init__() + + def exists(self, model: str) -> bool: + params = CallParameters(name=model) + + result = self._query_runner.call_procedure("gds.model.exists", params=params, custom_error=False) + if result.empty: + return False + + return result.iloc[0]["exists"] # type: ignore + + def get(self, model: str) -> ModelDetails: + params = CallParameters(name=model) + + result = self._query_runner.call_procedure("gds.model.list", params=params, custom_error=False) + if result.empty: + raise ValueError(f"There is no '{model}' in the model catalog") + + return self._to_model_details(result.iloc[0].to_dict()) + + def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]: + params = CallParameters(model_name=model, fail_if_missing=fail_if_missing) + + result = self._query_runner.call_procedure("gds.model.drop", params=params, custom_error=False) + + if result.empty: + return None + + return self._to_model_details(result.iloc[0].to_dict()) + + def _to_model_details(self, result: dict[str, Any]) -> ModelDetails: + creation_time = result.get("creationTime", None) + if creation_time and isinstance(creation_time, neo4j.time.DateTime): + result["creationTime"] = creation_time.to_native() + + return ModelDetails(**result) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py new file mode 100644 index 000000000..bf31e91e2 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py @@ -0,0 +1,56 @@ +import json +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.arrow.graphsage_train_arrow_endpoints import GraphSageTrainArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + CREATE + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (d: Node {feature: 4.0}), + (a)-[:REL]->(b), + (b)-[:REL]->(c), + (c)-[:REL]->(d), + (d)-[:REL]->(a) + """ + + yield create_graph(arrow_client, "g", gdl) + arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8")) + + +@pytest.fixture +def graphsage_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[GraphSageTrainArrowEndpoints, None, None]: + yield GraphSageTrainArrowEndpoints(arrow_client) + + +def test_graphsage_train(graphsage_endpoints: GraphSageTrainArrowEndpoints, sample_graph: Graph) -> None: + """Test GraphSage train operation.""" + model, result = graphsage_endpoints.train( + G=sample_graph, + model_name="testGraphSageModel", + feature_properties=["feature"], + embedding_dimension=1, + epochs=1, # Use minimal epochs for faster testing + max_iterations=1, # Use minimal iterations for faster testing + ) + + # Check the result + assert result.train_millis >= 0 + assert result.configuration is not None + assert result.model_info is not None + + # Check the model + assert model.name() == "testGraphSageModel" + assert model.exists() + + # Clean up the model + model.drop() diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_predict_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_predict_arrow_endpoints.py new file mode 100644 index 000000000..618a9cfaa --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_predict_arrow_endpoints.py @@ -0,0 +1,74 @@ +import json +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.arrow.graphsage_train_arrow_endpoints import GraphSageTrainArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (d: Node {feature: 4.0}), + (a)-[:REL]->(b), + (b)-[:REL]->(c), + (c)-[:REL]->(d), + (d)-[:REL]->(a) + """ + + yield create_graph(arrow_client, "g", gdl) + arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8")) + + +@pytest.fixture +def gs_model(arrow_client: AuthenticatedArrowClient, sample_graph: Graph) -> Generator[GraphSageModelV2, None, None]: + model, _ = GraphSageTrainArrowEndpoints(arrow_client).train( + G=sample_graph, + model_name="gs-model", + feature_properties=["feature"], + embedding_dimension=1, + sample_sizes=[1], + max_iterations=1, + search_depth=1, + ) + + yield model + + arrow_client.do_action_with_retry("v2/model.drop", json.dumps({"modelName": model.name()}).encode("utf-8")) + + +def test_stream(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_stream(sample_graph, concurrency=4) + + assert set(result.columns) == {"nodeId", "embedding"} + assert len(result) == 4 + + +def test_mutate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_mutate(sample_graph, concurrency=4, mutate_property="embedding") + + assert result.node_properties_written == 4 + + +def test_write(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + with pytest.raises(Exception, match="Write back client is not initialized"): + gs_model.predict_write(sample_graph, write_property="embedding", concurrency=4, write_concurrency=2) + + +def test_estimate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_estimate(sample_graph, concurrency=4) + + assert result.node_count == 4 + assert result.relationship_count == 4 + assert "KiB" in result.required_memory + assert result.bytes_min > 0 + assert result.bytes_max > 0 + assert result.heap_percentage_min > 0 + assert result.heap_percentage_max > 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py new file mode 100644 index 000000000..98021454a --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py @@ -0,0 +1,85 @@ +import json +from typing import Generator + +import pytest +from pyarrow.flight import FlightServerError + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.arrow.graphsage_train_arrow_endpoints import GraphSageTrainArrowEndpoints +from graphdatascience.procedure_surface.arrow.model_api_arrow import ModelApiArrow +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + (a: Node {age: 1}) + (b: Node {age: 2}) + (c: Node {age: 3}) + (d: Node {age: 4}) + (e: Node {age: 5}) + (f: Node {age: 6}) + (a)-[:REL]->(b) + (b)-[:REL]->(c) + (c)-[:REL]->(a) + (d)-[:REL]->(e) + (e)-[:REL]->(f) + (f)-[:REL]->(d) + """ + + yield create_graph(arrow_client, "model_api_g", gdl) + arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "model_api_g"}).encode("utf-8")) + + +@pytest.fixture +def gs_model(arrow_client: AuthenticatedArrowClient, sample_graph: Graph) -> Generator[str, None, None]: + model, _ = GraphSageTrainArrowEndpoints(arrow_client).train( + G=sample_graph, + model_name="gs-model", + feature_properties=["age"], + embedding_dimension=1, + sample_sizes=[1], + max_iterations=1, + search_depth=1, + ) + + yield model.name() + + arrow_client.do_action_with_retry("v2/model.drop", json.dumps({"modelName": model.name()}).encode("utf-8")) + + +@pytest.fixture +def model_api(arrow_client: AuthenticatedArrowClient) -> Generator[ModelApiArrow, None, None]: + yield ModelApiArrow(arrow_client) + + +def test_model_get(gs_model: str, model_api: ModelApiArrow) -> None: + model = model_api.get(gs_model) + + assert model.name == gs_model + assert model.type == "graphSage" + + with pytest.raises(ValueError, match="There is no 'nonexistent-model' in the model catalog"): + model_api.get("nonexistent-model") + + +def test_model_exists(gs_model: str, model_api: ModelApiArrow) -> None: + assert model_api.exists(gs_model) + assert not model_api.exists("nonexistent-model") + + +def test_model_delete(gs_model: str, model_api: ModelApiArrow) -> None: + model_details = model_api.drop(gs_model, fail_if_missing=False) + + assert model_details is not None + assert model_details.name == gs_model + + # Check that the model no longer exists + assert not model_api.exists(gs_model) + + # Attempt to drop a non-existing model + assert model_api.drop("nonexistent-model", fail_if_missing=False) is None + + with pytest.raises(FlightServerError, match="Model with name `nonexistent-model` does not exist"): + model_api.drop("nonexistent-model", fail_if_missing=True) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py new file mode 100644 index 000000000..264f3f08a --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py @@ -0,0 +1,54 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph, QueryRunner +from graphdatascience.procedure_surface.cypher.graphsage_train_cypher_endpoints import GraphSageTrainCypherEndpoints + + +@pytest.fixture +def sample_graph_with_features(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (a)-[:REL]->(c), + (b)-[:REL]->(c) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {sourceNodeProperties: properties(n), targetNodeProperties: properties(m)}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + query_runner.run_cypher("CALL gds.graph.drop('g')") + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def graphsage_endpoints(query_runner: QueryRunner) -> Generator[GraphSageTrainCypherEndpoints, None, None]: + yield GraphSageTrainCypherEndpoints(query_runner) + + +def test_graphsage_train(graphsage_endpoints: GraphSageTrainCypherEndpoints, sample_graph_with_features: Graph) -> None: + """Test GraphSage train operation.""" + model, train_result = graphsage_endpoints.train( + G=sample_graph_with_features, + model_name="testModel", + feature_properties=["feature"], + embedding_dimension=1, + epochs=1, # Use minimal epochs for faster testing + max_iterations=1, # Use minimal iterations for faster testing + ) + + assert train_result.train_millis >= 0 + assert train_result.model_info is not None + assert train_result.configuration is not None + assert model.name() == "testModel" diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_predict_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_predict_cypher_endpoints.py new file mode 100644 index 000000000..bd94c2d13 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_predict_cypher_endpoints.py @@ -0,0 +1,82 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.cypher.graphsage_train_cypher_endpoints import GraphSageTrainCypherEndpoints +from graphdatascience.query_runner.query_runner import QueryRunner + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (a)-[:REL]->(c), + (b)-[:REL]->(c) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {sourceNodeProperties: properties(n), targetNodeProperties: properties(m)}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + query_runner.run_cypher("CALL gds.graph.drop('g')") + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def gs_model(query_runner: QueryRunner, sample_graph: Graph) -> Generator[GraphSageModelV2, None, None]: + model, _ = GraphSageTrainCypherEndpoints(query_runner).train( + G=sample_graph, + model_name="gs-model", + feature_properties=["feature"], + embedding_dimension=1, + sample_sizes=[1], + max_iterations=1, + search_depth=1, + ) + + yield model + + query_runner.run_cypher("CALL gds.model.drop('gs-model')") + + +def test_stream(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_stream(sample_graph, concurrency=4) + + assert set(result.columns) == {"nodeId", "embedding"} + assert len(result) == 3 + + +def test_mutate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_mutate(sample_graph, concurrency=4, mutate_property="embedding") + + assert result.node_properties_written == 3 + + +def test_write(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_write(sample_graph, write_property="embedding", concurrency=4, write_concurrency=2) + + assert result.node_properties_written == 3 + + +def test_estimate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_estimate(sample_graph, concurrency=4) + + assert result.node_count == 3 + assert result.relationship_count == 2 + assert "KiB" in result.required_memory + assert result.bytes_min > 0 + assert result.bytes_max > 0 + assert result.heap_percentage_min > 0 + assert result.heap_percentage_max > 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_model_api_cypher.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_model_api_cypher.py new file mode 100644 index 000000000..428b0d24f --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_model_api_cypher.py @@ -0,0 +1,83 @@ +from typing import Generator + +import pytest +from neo4j.exceptions import Neo4jError + +from graphdatascience import Graph, QueryRunner +from graphdatascience.procedure_surface.cypher.model_api_cypher import ModelApiCypher + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {age: 1}), + (b: Node {age: 2}), + (c: Node {age: 3}), + (a)-[:REL]->(c), + (b)-[:REL]->(c) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {sourceNodeProperties: properties(n), targetNodeProperties: properties(m)}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + query_runner.run_cypher("CALL gds.graph.drop('g')") + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def gs_model(query_runner: QueryRunner, sample_graph: Graph) -> Generator[str, None, None]: + train_result = query_runner.run_cypher( + "CALL gds.beta.graphSage.train($graph, {modelName: 'gs-model', featureProperties:['age'], embeddingDimension: 1, sampleSizes: [1], maxIterations: 1, searchDepth: 1})", + {"graph": sample_graph.name()}, + ) + + model_name = train_result.iloc[0]["modelInfo"]["modelName"] + + yield model_name # type: ignore + + query_runner.run_cypher("CALL gds.model.drop($name, false)", {"name": model_name}) + + +@pytest.fixture +def model_api(query_runner: QueryRunner) -> Generator[ModelApiCypher, None, None]: + yield ModelApiCypher(query_runner) + + +def test_model_get(gs_model: str, model_api: ModelApiCypher) -> None: + model = model_api.get(gs_model) + + assert model.name == gs_model + assert model.type == "graphSage" + + with pytest.raises(ValueError, match="There is no 'nonexistent-model' in the model catalog"): + model_api.get("nonexistent-model") + + +def test_model_exists(gs_model: str, model_api: ModelApiCypher) -> None: + assert model_api.exists(gs_model) + assert not model_api.exists("nonexistent-model") + + +def test_model_delete(gs_model: str, model_api: ModelApiCypher) -> None: + model_details = model_api.drop(gs_model, fail_if_missing=False) + + assert model_details is not None + assert model_details.name == gs_model + + # Check that the model no longer exists + assert not model_api.exists(gs_model) + + # Attempt to drop a non-existing model + assert model_api.drop("nonexistent-model", fail_if_missing=False) is None + + with pytest.raises(Neo4jError, match="Model with name `nonexistent-model` does not exist"): + model_api.drop("nonexistent-model", fail_if_missing=True)