Skip to content

Commit d954794

Browse files
authored
Add case-insensitive handling for server and tool names (#148)
* Normalize server names in ClientManager for consistent lookups * Normalize tool names in ClientManager and API handlers * Add comprehensive tests for case-insensitive behavior * Update documentation to reflect normalization
1 parent 48fd523 commit d954794

File tree

5 files changed

+357
-3
lines changed

5 files changed

+357
-3
lines changed

internal/api/servers.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/mozilla-ai/mcpd/v2/internal/contracts"
1414
"github.com/mozilla-ai/mcpd/v2/internal/errors"
15+
"github.com/mozilla-ai/mcpd/v2/internal/filter"
1516
)
1617

1718
// ServersResponse represents the wrapped API response for a list of servers.
@@ -117,7 +118,7 @@ func handleServerTools(accessor contracts.MCPClientAccessor, name string) (*Tool
117118
// Only return data on allowed tools.
118119
tools := make([]Tool, 0, len(result.Tools))
119120
for _, tool := range result.Tools {
120-
if slices.Contains(allowedTools, tool.Name) {
121+
if slices.Contains(allowedTools, filter.NormalizeString(tool.Name)) {
121122
data, err := DomainTool(tool).ToAPIType()
122123
if err != nil {
123124
return nil, err
@@ -153,7 +154,9 @@ func handleServerToolCall(
153154
return nil, fmt.Errorf("%w: %s", errors.ErrToolsNotFound, server)
154155
}
155156

156-
if !slices.Contains(allowedTools, tool) {
157+
// Normalize the tool name before comparing.
158+
normalizedToolName := filter.NormalizeString(tool)
159+
if !slices.Contains(allowedTools, normalizedToolName) {
157160
return nil, fmt.Errorf("%w: %s/%s", errors.ErrToolForbidden, server, tool)
158161
}
159162

internal/api/servers_test.go

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
package api
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/mark3labs/mcp-go/client"
8+
"github.com/mark3labs/mcp-go/mcp"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/mozilla-ai/mcpd/v2/internal/errors"
13+
)
14+
15+
// mockMCPClientAccessor implements the MCPClientAccessor interface for testing.
16+
type mockMCPClientAccessor struct {
17+
clients map[string]client.MCPClient
18+
tools map[string][]string
19+
}
20+
21+
func newMockMCPClientAccessor() *mockMCPClientAccessor {
22+
return &mockMCPClientAccessor{
23+
clients: make(map[string]client.MCPClient),
24+
tools: make(map[string][]string),
25+
}
26+
}
27+
28+
func (m *mockMCPClientAccessor) Add(name string, c client.MCPClient, tools []string) {
29+
m.clients[name] = c
30+
m.tools[name] = tools
31+
}
32+
33+
func (m *mockMCPClientAccessor) Client(name string) (client.MCPClient, bool) {
34+
c, ok := m.clients[name]
35+
return c, ok
36+
}
37+
38+
func (m *mockMCPClientAccessor) Tools(name string) ([]string, bool) {
39+
tools, ok := m.tools[name]
40+
return tools, ok
41+
}
42+
43+
func (m *mockMCPClientAccessor) List() []string {
44+
names := make([]string, 0, len(m.clients))
45+
for name := range m.clients {
46+
names = append(names, name)
47+
}
48+
return names
49+
}
50+
51+
func (m *mockMCPClientAccessor) Remove(name string) {
52+
delete(m.clients, name)
53+
delete(m.tools, name)
54+
}
55+
56+
// mockMCPClient implements the client.MCPClient interface for testing.
57+
type mockMCPClient struct {
58+
listToolsResult *mcp.ListToolsResult
59+
listToolsError error
60+
callToolResult *mcp.CallToolResult
61+
callToolError error
62+
}
63+
64+
func (m *mockMCPClient) Initialize(_ context.Context, _ mcp.InitializeRequest) (*mcp.InitializeResult, error) {
65+
return nil, nil
66+
}
67+
68+
func (m *mockMCPClient) Ping(_ context.Context) error {
69+
return nil
70+
}
71+
72+
func (m *mockMCPClient) ListResourcesByPage(
73+
_ context.Context,
74+
_ mcp.ListResourcesRequest,
75+
) (*mcp.ListResourcesResult, error) {
76+
return nil, nil
77+
}
78+
79+
func (m *mockMCPClient) ListResources(
80+
_ context.Context,
81+
_ mcp.ListResourcesRequest,
82+
) (*mcp.ListResourcesResult, error) {
83+
return nil, nil
84+
}
85+
86+
func (m *mockMCPClient) ListResourceTemplatesByPage(
87+
_ context.Context,
88+
_ mcp.ListResourceTemplatesRequest,
89+
) (*mcp.ListResourceTemplatesResult, error) {
90+
return nil, nil
91+
}
92+
93+
func (m *mockMCPClient) ListResourceTemplates(
94+
_ context.Context,
95+
_ mcp.ListResourceTemplatesRequest,
96+
) (*mcp.ListResourceTemplatesResult, error) {
97+
return nil, nil
98+
}
99+
100+
func (m *mockMCPClient) ReadResource(
101+
_ context.Context,
102+
_ mcp.ReadResourceRequest,
103+
) (*mcp.ReadResourceResult, error) {
104+
return nil, nil
105+
}
106+
107+
func (m *mockMCPClient) Subscribe(_ context.Context, _ mcp.SubscribeRequest) error {
108+
return nil
109+
}
110+
111+
func (m *mockMCPClient) Unsubscribe(_ context.Context, _ mcp.UnsubscribeRequest) error {
112+
return nil
113+
}
114+
115+
func (m *mockMCPClient) ListPromptsByPage(
116+
_ context.Context,
117+
_ mcp.ListPromptsRequest,
118+
) (*mcp.ListPromptsResult, error) {
119+
return nil, nil
120+
}
121+
122+
func (m *mockMCPClient) ListPrompts(
123+
_ context.Context,
124+
_ mcp.ListPromptsRequest,
125+
) (*mcp.ListPromptsResult, error) {
126+
return nil, nil
127+
}
128+
129+
func (m *mockMCPClient) GetPrompt(_ context.Context, _ mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
130+
return nil, nil
131+
}
132+
133+
func (m *mockMCPClient) ListToolsByPage(
134+
_ context.Context,
135+
_ mcp.ListToolsRequest,
136+
) (*mcp.ListToolsResult, error) {
137+
return m.listToolsResult, m.listToolsError
138+
}
139+
140+
func (m *mockMCPClient) ListTools(_ context.Context, _ mcp.ListToolsRequest) (*mcp.ListToolsResult, error) {
141+
return m.listToolsResult, m.listToolsError
142+
}
143+
144+
func (m *mockMCPClient) CallTool(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
145+
return m.callToolResult, m.callToolError
146+
}
147+
148+
func (m *mockMCPClient) SetLevel(_ context.Context, _ mcp.SetLevelRequest) error {
149+
return nil
150+
}
151+
152+
func (m *mockMCPClient) Complete(_ context.Context, _ mcp.CompleteRequest) (*mcp.CompleteResult, error) {
153+
return nil, nil
154+
}
155+
156+
func (m *mockMCPClient) Close() error {
157+
return nil
158+
}
159+
160+
func (m *mockMCPClient) OnNotification(_ func(notification mcp.JSONRPCNotification)) {}
161+
162+
func TestHandleServerTools_CaseInsensitiveFiltering(t *testing.T) {
163+
t.Parallel()
164+
165+
accessor := newMockMCPClientAccessor()
166+
167+
// Mock client returns tools with mixed case.
168+
mockClient := &mockMCPClient{
169+
listToolsResult: &mcp.ListToolsResult{
170+
Tools: []mcp.Tool{
171+
{Name: "GetTime", Description: "Gets current time"},
172+
{Name: "SET_ALARM", Description: "Sets an alarm"},
173+
{Name: "list_events", Description: "Lists events"},
174+
},
175+
},
176+
}
177+
178+
// Server has allowed tools in mixed case, but they should be normalized for comparison.
179+
allowedTools := []string{"gettime", "set_alarm"}
180+
accessor.Add("testserver", mockClient, allowedTools)
181+
182+
result, err := handleServerTools(accessor, "testserver")
183+
require.NoError(t, err)
184+
require.NotNil(t, result)
185+
186+
// Should return 2 tools that match (case-insensitive).
187+
assert.Len(t, result.Body.Tools, 2)
188+
189+
toolNames := make([]string, len(result.Body.Tools))
190+
for i, tool := range result.Body.Tools {
191+
toolNames[i] = tool.Name
192+
}
193+
194+
// Verify the correct tools are returned.
195+
assert.Contains(t, toolNames, "GetTime")
196+
assert.Contains(t, toolNames, "SET_ALARM")
197+
assert.NotContains(t, toolNames, "list_events")
198+
}
199+
200+
func TestHandleServerToolCall_ToolNameNormalization(t *testing.T) {
201+
t.Parallel()
202+
203+
accessor := newMockMCPClientAccessor()
204+
205+
// Mock client that will be called.
206+
mockClient := &mockMCPClient{
207+
callToolResult: &mcp.CallToolResult{
208+
Content: []mcp.Content{
209+
mcp.TextContent{Text: "Tool executed successfully"},
210+
},
211+
},
212+
}
213+
214+
// Allowed tools are stored in normalized form.
215+
allowedTools := []string{"gettime", "setalarm"}
216+
accessor.Add("testserver", mockClient, allowedTools)
217+
218+
// Call with mixed case tool name - should be normalized and match.
219+
result, err := handleServerToolCall(accessor, "testserver", " GetTime ", map[string]any{})
220+
require.NoError(t, err)
221+
require.NotNil(t, result)
222+
223+
assert.Equal(t, "Tool executed successfully", result.Body)
224+
}
225+
226+
func TestHandleServerToolCall_ToolNotAllowed(t *testing.T) {
227+
t.Parallel()
228+
229+
accessor := newMockMCPClientAccessor()
230+
231+
mockClient := &mockMCPClient{}
232+
allowedTools := []string{"gettime"}
233+
accessor.Add("testserver", mockClient, allowedTools)
234+
235+
// Try to call a tool that's not in the allowed list.
236+
result, err := handleServerToolCall(accessor, "testserver", "forbidden_tool", map[string]any{})
237+
require.Error(t, err)
238+
require.Nil(t, result)
239+
240+
assert.ErrorIs(t, err, errors.ErrToolForbidden)
241+
}
242+
243+
func TestHandleServerToolCall_ServerNotFound(t *testing.T) {
244+
t.Parallel()
245+
246+
accessor := newMockMCPClientAccessor()
247+
248+
result, err := handleServerToolCall(accessor, "nonexistent", "tool", map[string]any{})
249+
require.Error(t, err)
250+
require.Nil(t, result)
251+
252+
assert.ErrorIs(t, err, errors.ErrServerNotFound)
253+
}
254+
255+
func TestHandleServerTools_ServerNotFound(t *testing.T) {
256+
t.Parallel()
257+
258+
accessor := newMockMCPClientAccessor()
259+
260+
result, err := handleServerTools(accessor, "nonexistent")
261+
require.Error(t, err)
262+
require.Nil(t, result)
263+
264+
assert.ErrorIs(t, err, errors.ErrServerNotFound)
265+
}
266+
267+
func TestHandleServerTools_NoTools(t *testing.T) {
268+
t.Parallel()
269+
270+
accessor := newMockMCPClientAccessor()
271+
mockClient := &mockMCPClient{}
272+
273+
// Add server with no tools.
274+
accessor.Add("testserver", mockClient, []string{})
275+
276+
result, err := handleServerTools(accessor, "testserver")
277+
require.Error(t, err)
278+
require.Nil(t, result)
279+
280+
assert.ErrorIs(t, err, errors.ErrToolsNotFound)
281+
}

internal/context/context_test.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,12 @@ DATABASE_URL = "${UNDEFINED_DB_URL}"`
590590
require.True(t, exists, "test-server should exist")
591591

592592
// All undefined variables should expand to empty strings
593-
require.Equal(t, []string{"--token", "", "--config="}, server.Args, "Undefined vars in args should expand to empty strings")
593+
require.Equal(
594+
t,
595+
[]string{"--token", "", "--config="},
596+
server.Args,
597+
"Undefined vars in args should expand to empty strings",
598+
)
594599
require.Equal(t, map[string]string{
595600
"API_KEY": "",
596601
"DATABASE_URL": "",

internal/daemon/client_manager.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"sync"
55

66
"github.com/mark3labs/mcp-go/client"
7+
8+
"github.com/mozilla-ai/mcpd/v2/internal/filter"
79
)
810

911
// ClientManager holds active client connections and their associated tool lists.
@@ -24,28 +26,35 @@ func NewClientManager() *ClientManager {
2426
}
2527

2628
// Add registers a client and its tools by server name.
29+
// The server name and tool names are normalized (lowercase, trimmed) for consistent lookups.
2730
// This method is safe for concurrent use.
2831
func (cm *ClientManager) Add(name string, c client.MCPClient, tools []string) {
32+
name = filter.NormalizeString(name)
33+
tools = filter.NormalizeSlice(tools)
2934
cm.mu.Lock()
3035
defer cm.mu.Unlock()
3136
cm.clients[name] = c
3237
cm.serverTools[name] = tools
3338
}
3439

3540
// Client returns the client for the given server name.
41+
// The server name is normalized for case-insensitive lookup.
3642
// It returns a boolean to indicate whether the client was found.
3743
// This method is safe for concurrent use.
3844
func (cm *ClientManager) Client(name string) (client.MCPClient, bool) {
45+
name = filter.NormalizeString(name)
3946
cm.mu.RLock()
4047
defer cm.mu.RUnlock()
4148
c, ok := cm.clients[name]
4249
return c, ok
4350
}
4451

4552
// Tools returns the tools for the given server name.
53+
// The server name is normalized for case-insensitive lookup.
4654
// It returns a boolean to indicate whether the tools were found.
4755
// This method is safe for concurrent use.
4856
func (cm *ClientManager) Tools(name string) ([]string, bool) {
57+
name = filter.NormalizeString(name)
4958
cm.mu.RLock()
5059
defer cm.mu.RUnlock()
5160
t, ok := cm.serverTools[name]
@@ -65,8 +74,10 @@ func (cm *ClientManager) List() []string {
6574
}
6675

6776
// Remove deletes the client and its tools by server name.
77+
// The server name is normalized for case-insensitive lookup.
6878
// This method is safe for concurrent use.
6979
func (cm *ClientManager) Remove(name string) {
80+
name = filter.NormalizeString(name)
7081
cm.mu.Lock()
7182
defer cm.mu.Unlock()
7283
delete(cm.clients, name)

0 commit comments

Comments
 (0)