From 9d874813030948a5c4898fb1ca61aa818bbec6b0 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 4 Aug 2025 15:01:41 +0000 Subject: [PATCH] Allow enum datatypes to be set to their enum values --- pyproject.toml | 8 +++----- src/fastcs/datatypes.py | 11 +++++++++++ tests/test_attribute.py | 19 +++++++++++++++++++ 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6cc7cca6..d625dc06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=64", "setuptools_scm[toml]>=8"] +requires = ["setuptools>=70.1", "setuptools_scm[toml]>=8"] build-backend = "setuptools.build_meta" [project] @@ -20,7 +20,7 @@ dependencies = [ "pytango", "softioc>=4.5.0", "strawberry-graphql", - "p4p" + "p4p", ] dynamic = ["version"] license.file = "LICENSE" @@ -53,9 +53,7 @@ dev = [ "httpx", "tickit~=0.4.3", ] -demo = [ - "tickit~=0.4.3", -] +demo = ["tickit~=0.4.3"] [project.scripts] fastcs-demo = "fastcs.demo.__main__:main" diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 9580c357..fda53a39 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -139,6 +139,17 @@ def __post_init__(self): def index_of(self, value: T_Enum) -> int: return self.members.index(value) + def validate(self, value: T) -> T: + enum_vals = [key.value for key in self.dtype] + + if value not in enum_vals and not issubclass(type(value), self.dtype): + raise ValueError( + f"Value '{value}' is not a member of {self.dtype} or of " + f"type {self.dtype}" + ) + + return value + @cached_property def members(self) -> list[T_Enum]: return list(self.enum_cls) diff --git a/tests/test_attribute.py b/tests/test_attribute.py index 2a8e788c..9e85c3ee 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -1,3 +1,4 @@ +import enum from functools import partial import numpy as np @@ -106,3 +107,21 @@ async def test_handler_initialise(mocker: MockerFixture): def test_validate(datatype, init_args, value): with pytest.raises(ValueError): datatype(**init_args).validate(value) + + +class MyEnum(enum.Enum): + TEST = "Test" + + +class MyOtherEnum(enum.Enum): + TEST = "Test" + + +def test_enum_validate(): + enum_datatype = Enum(MyEnum) + enum_datatype.validate(MyEnum.TEST) + enum_datatype.validate("Test") + with pytest.raises(ValueError): + enum_datatype.validate("BadTest") + with pytest.raises(ValueError): + enum_datatype.validate(MyOtherEnum.TEST)