1
1
import importlib .metadata
2
+ import logging
2
3
3
4
import click
5
+ import mcp .types as mt
4
6
from datahub .ingestion .graph .client import get_default_graph
5
7
from datahub .ingestion .graph .config import ClientMode
6
8
from datahub .sdk .main_client import DataHubClient
7
9
from datahub .telemetry import telemetry
10
+ from datahub .utilities .perf_timer import PerfTimer
11
+ from fastmcp .server .middleware import CallNext , Middleware , MiddlewareContext
12
+ from fastmcp .server .middleware .logging import LoggingMiddleware
8
13
from typing_extensions import Literal
9
14
10
15
from mcp_server_datahub .mcp_server import mcp , with_datahub_client
11
16
17
+ logging .basicConfig (level = logging .INFO )
18
+
19
+
20
+ class TelemetryMiddleware (Middleware ):
21
+ """Middleware that logs tool calls."""
22
+
23
+ async def on_call_tool (
24
+ self ,
25
+ context : MiddlewareContext [mt .CallToolRequestParams ],
26
+ call_next : CallNext [mt .CallToolRequestParams , mt .CallToolResult ],
27
+ ) -> mt .CallToolResult :
28
+ with PerfTimer () as timer :
29
+ result = await call_next (context )
30
+
31
+ telemetry .telemetry_instance .ping (
32
+ "mcp-server-tool-call" ,
33
+ {
34
+ "tool" : context .message .name ,
35
+ "source" : context .source ,
36
+ "type" : context .type ,
37
+ "method" : context .method ,
38
+ "duration_seconds" : timer .elapsed_seconds (),
39
+ },
40
+ )
41
+
42
+ return result
43
+
12
44
13
45
@click .command ()
14
46
@click .option (
15
47
"--transport" ,
16
48
type = click .Choice (["stdio" , "sse" , "http" ]),
17
49
default = "stdio" ,
18
50
)
51
+ @click .option (
52
+ "--debug" ,
53
+ is_flag = True ,
54
+ default = False ,
55
+ )
19
56
@telemetry .with_telemetry (
20
57
capture_kwargs = ["transport" ],
21
58
)
22
- def main (transport : Literal ["stdio" , "sse" , "http" ]) -> None :
59
+ def main (transport : Literal ["stdio" , "sse" , "http" ], debug : bool ) -> None :
23
60
# Because we want to override the datahub_component, we can't use DataHubClient.from_env()
24
61
# and need to use the DataHubClient constructor directly.
25
62
mcp_version = importlib .metadata .version ("mcp-server-datahub" )
@@ -29,6 +66,10 @@ def main(transport: Literal["stdio", "sse", "http"]) -> None:
29
66
)
30
67
client = DataHubClient (graph = graph )
31
68
69
+ if debug :
70
+ mcp .add_middleware (LoggingMiddleware (include_payloads = True ))
71
+ mcp .add_middleware (TelemetryMiddleware ())
72
+
32
73
with with_datahub_client (client ):
33
74
if transport == "http" :
34
75
mcp .run (transport = transport , show_banner = False , stateless_http = True )
0 commit comments