diff --git a/README.md b/README.md index 35d16ac..004009b 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ conda env create -f environment.yml Additional dependencies are needed to compile the [protobufs](./protobufs/): ```bash -conda install -c conda-forge protobuf +conda install -c conda-forge libprotobuf-static pip install --pre betterproto[compiler] ``` diff --git a/pytensor_federated/__init__.py b/pytensor_federated/__init__.py index 21bdc4b..b169566 100644 --- a/pytensor_federated/__init__.py +++ b/pytensor_federated/__init__.py @@ -19,4 +19,4 @@ from .service import ArraysToArraysService, ArraysToArraysServiceClient from .signatures import ComputeFunc, LogpFunc, LogpGradFunc -__version__ = "1.0.1" +__version__ = "1.0.2" diff --git a/pytensor_federated/npproto/__init__.py b/pytensor_federated/npproto/__init__.py index b3bf697..23fbb79 100644 --- a/pytensor_federated/npproto/__init__.py +++ b/pytensor_federated/npproto/__init__.py @@ -12,8 +12,8 @@ @dataclass(eq=False, repr=False) class Ndarray(betterproto.Message): """ - Represents a NumPy array of arbitrary shape or dtype. Note that the array - must support the buffer protocol. + Represents a NumPy array of arbitrary shape or dtype. + Note that the array must support the buffer protocol. """ data: bytes = betterproto.bytes_field(1) diff --git a/pytensor_federated/rpc.py b/pytensor_federated/rpc.py index abb2b5b..4ce182c 100644 --- a/pytensor_federated/rpc.py +++ b/pytensor_federated/rpc.py @@ -91,14 +91,12 @@ async def evaluate( async def evaluate_stream( self, - input_arrays_iterator: Union[ - AsyncIterable["InputArrays"], Iterable["InputArrays"] - ], + input_arrays_iterator: Union[AsyncIterable[InputArrays], Iterable[InputArrays]], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, metadata: Optional["MetadataLike"] = None - ) -> AsyncIterator["OutputArrays"]: + ) -> AsyncIterator[OutputArrays]: async for response in self._stream_stream( "/ArraysToArraysService/EvaluateStream", input_arrays_iterator, @@ -129,12 +127,13 @@ async def get_load( class ArraysToArraysServiceBase(ServiceBase): + async def evaluate(self, input_arrays: "InputArrays") -> "OutputArrays": raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) async def evaluate_stream( - self, input_arrays_iterator: AsyncIterator["InputArrays"] - ) -> AsyncIterator["OutputArrays"]: + self, input_arrays_iterator: AsyncIterator[InputArrays] + ) -> AsyncIterator[OutputArrays]: raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) yield OutputArrays() diff --git a/requirements.txt b/requirements.txt index 8fe4dd9..1891d9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -betterproto==2.0.0b6 +betterproto==2.0.0b7 black isort nest-asyncio