Skip to content

Commit d20ea40

Browse files
committed
add debug flag + tool telemetry
1 parent 4d770f4 commit d20ea40

File tree

1 file changed

+42
-1
lines changed

1 file changed

+42
-1
lines changed

src/mcp_server_datahub/__main__.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,62 @@
11
import importlib.metadata
2+
import logging
23

34
import click
5+
import mcp.types as mt
46
from datahub.ingestion.graph.client import get_default_graph
57
from datahub.ingestion.graph.config import ClientMode
68
from datahub.sdk.main_client import DataHubClient
79
from datahub.telemetry import telemetry
10+
from datahub.utilities.perf_timer import PerfTimer
11+
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
12+
from fastmcp.server.middleware.logging import LoggingMiddleware
813
from typing_extensions import Literal
914

1015
from mcp_server_datahub.mcp_server import mcp, with_datahub_client
1116

17+
logging.basicConfig(level=logging.INFO)
18+
19+
20+
class TelemetryMiddleware(Middleware):
21+
"""Middleware that logs tool calls."""
22+
23+
async def on_call_tool(
24+
self,
25+
context: MiddlewareContext[mt.CallToolRequestParams],
26+
call_next: CallNext[mt.CallToolRequestParams, mt.CallToolResult],
27+
) -> mt.CallToolResult:
28+
with PerfTimer() as timer:
29+
result = await call_next(context)
30+
31+
telemetry.telemetry_instance.ping(
32+
"mcp-server-tool-call",
33+
{
34+
"tool": context.message.name,
35+
"source": context.source,
36+
"type": context.type,
37+
"method": context.method,
38+
"duration_seconds": timer.elapsed_seconds(),
39+
},
40+
)
41+
42+
return result
43+
1244

1345
@click.command()
1446
@click.option(
1547
"--transport",
1648
type=click.Choice(["stdio", "sse", "http"]),
1749
default="stdio",
1850
)
51+
@click.option(
52+
"--debug",
53+
is_flag=True,
54+
default=False,
55+
)
1956
@telemetry.with_telemetry(
2057
capture_kwargs=["transport"],
2158
)
22-
def main(transport: Literal["stdio", "sse", "http"]) -> None:
59+
def main(transport: Literal["stdio", "sse", "http"], debug: bool) -> None:
2360
# Because we want to override the datahub_component, we can't use DataHubClient.from_env()
2461
# and need to use the DataHubClient constructor directly.
2562
mcp_version = importlib.metadata.version("mcp-server-datahub")
@@ -29,6 +66,10 @@ def main(transport: Literal["stdio", "sse", "http"]) -> None:
2966
)
3067
client = DataHubClient(graph=graph)
3168

69+
if debug:
70+
mcp.add_middleware(LoggingMiddleware(include_payloads=True))
71+
mcp.add_middleware(TelemetryMiddleware())
72+
3273
with with_datahub_client(client):
3374
if transport == "http":
3475
mcp.run(transport=transport, show_banner=False, stateless_http=True)

0 commit comments

Comments
 (0)