Skip to content

Commit ea18c45

Browse files
authored
fix: make all tools async (#34)
1 parent 222ff6e commit ea18c45

File tree

6 files changed

+94
-11
lines changed

6 files changed

+94
-11
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,6 @@ jobs:
2626
- name: Checks
2727
run: make lint-check
2828

29-
# Tests don't run without credentials yet.
30-
# - name: Test
31-
# run: make test
29+
# Note that some tests require credentials that aren't available in CI environment.
30+
- name: Test
31+
run: make test

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ readme = "README.md"
66
requires-python = ">=3.10"
77
dependencies = [
88
"acryl-datahub==1.2.0.1",
9+
"asyncer>=0.0.8",
910
"fastmcp==2.10.5",
1011
"jmespath~=1.0.1",
1112
]

src/mcp_server_datahub/mcp_server.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
11
import contextlib
22
import contextvars
3+
import functools
4+
import inspect
35
import pathlib
4-
from typing import Any, Dict, Iterator, List, Optional
6+
from typing import (
7+
Any,
8+
Awaitable,
9+
Callable,
10+
Dict,
11+
Iterator,
12+
List,
13+
Optional,
14+
ParamSpec,
15+
TypeVar,
16+
)
517

18+
import asyncer
619
import jmespath
720
from datahub.errors import ItemNotFoundError
821
from datahub.ingestion.graph.client import DataHubGraph
@@ -13,6 +26,23 @@
1326
from fastmcp import FastMCP
1427
from pydantic import BaseModel
1528

29+
_P = ParamSpec("_P")
30+
_R = TypeVar("_R")
31+
32+
33+
# See https://github.com/jlowin/fastmcp/issues/864#issuecomment-3103678258
34+
# for why we need to wrap sync functions with asyncify.
35+
def async_background(fn: Callable[_P, _R]) -> Callable[_P, Awaitable[_R]]:
36+
if inspect.iscoroutinefunction(fn):
37+
raise RuntimeError("async_background can only be used on non-async functions")
38+
39+
@functools.wraps(fn)
40+
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
41+
return await asyncer.asyncify(fn)(*args, **kwargs)
42+
43+
return wrapper
44+
45+
1646
mcp = FastMCP[None](name="datahub")
1747

1848

@@ -132,6 +162,7 @@ def _clean_get_entity_response(raw_response: dict) -> dict:
132162

133163

134164
@mcp.tool(description="Get an entity by its DataHub URN.")
165+
@async_background
135166
def get_entity(urn: str) -> dict:
136167
client = get_datahub_client()
137168

@@ -184,6 +215,7 @@ def get_entity(urn: str) -> dict:
184215
```
185216
"""
186217
)
218+
@async_background
187219
def search(
188220
query: str = "*",
189221
filters: Optional[Filter] = None,
@@ -215,6 +247,7 @@ def search(
215247

216248

217249
@mcp.tool(description="Use this tool to get the SQL queries associated with a dataset.")
250+
@async_background
218251
def get_dataset_queries(dataset_urn: str, start: int = 0, count: int = 10) -> dict:
219252
client = get_datahub_client()
220253

@@ -330,6 +363,7 @@ def get_lineage(
330363
Use this tool to get upstream or downstream lineage for any entity, including datasets, schemaFields, dashboards, charts, etc. \
331364
Set upstream to True for upstream lineage, False for downstream lineage."""
332365
)
366+
@async_background
333367
def get_lineage(urn: str, upstream: bool, max_hops: int = 1) -> dict:
334368
client = get_datahub_client()
335369
lineage_api = AssetLineageAPI(client._graph)

tests/test_async_setup.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import inspect
2+
import time
3+
4+
import anyio
5+
import fastmcp.tools.tool
6+
import pytest
7+
8+
from mcp_server_datahub.mcp_server import async_background, mcp
9+
10+
11+
@pytest.mark.anyio
12+
async def test_async_background() -> None:
13+
@async_background
14+
def my_sleep(sec: float) -> None:
15+
time.sleep(sec)
16+
17+
start_time = time.time()
18+
19+
async with anyio.create_task_group() as tg:
20+
tg.start_soon(my_sleep, 0.5)
21+
tg.start_soon(my_sleep, 0.6)
22+
tg.start_soon(my_sleep, 0.7)
23+
24+
end_time = time.time()
25+
duration = end_time - start_time
26+
# The calls should not be serialized, so the duration should be less than the sum of the durations.
27+
assert 0.5 <= duration < 1.8
28+
29+
30+
def test_all_tools_are_async() -> None:
31+
# If any tools are sync, the tool execution will block the main event loop.
32+
for tool in mcp._tool_manager._tools.values():
33+
assert isinstance(tool, fastmcp.tools.tool.FunctionTool)
34+
assert inspect.iscoroutinefunction(tool.fn)

tests/test_mcp_server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525

2626
@pytest.fixture(autouse=True, scope="session")
2727
def setup_client() -> Iterable[None]:
28-
with with_datahub_client(DataHubClient.from_env()):
28+
try:
29+
client = DataHubClient.from_env()
30+
except Exception as e:
31+
if "`datahub init`" in str(e):
32+
pytest.skip("No credentials available, skipping tests")
33+
raise
34+
with with_datahub_client(client):
2935
yield
3036

3137

@@ -95,9 +101,3 @@ async def test_search(mcp_client: Client) -> None:
95101
)
96102
assert res.is_error is False
97103
assert res.data is not None
98-
99-
100-
if __name__ == "__main__":
101-
import pytest
102-
103-
pytest.main()

uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)