Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion cmd/add.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"io"
"strings"
"time"

"github.com/spf13/cobra"

Expand All @@ -29,6 +30,10 @@ type AddCmd struct {
Source string
Format internalcmd.OutputFormat
AllowDeprecated bool
CacheDisabled bool
CacheRefresh bool
CacheDir string
CacheTTL string
cfgLoader config.Loader
packagePrinter output.Printer[config.ServerEntry]
registryBuilder registry.Builder
Expand Down Expand Up @@ -100,6 +105,39 @@ func NewAddCmd(baseCmd *internalcmd.BaseCmd, opt ...cmdopts.CmdOption) (*cobra.C
"Optional, allows server installations marked as deprecated to be added",
)

cobraCommand.Flags().BoolVar(
&c.CacheDisabled,
"no-cache",
false,
"Disable registry manifest caching",
)

cobraCommand.Flags().BoolVar(
&c.CacheRefresh,
"refresh-cache",
false,
"Force refresh of cached registry manifests",
)

defaultCacheDir, err := regopts.DefaultCacheDir()
if err != nil {
return nil, fmt.Errorf("error getting default cache directory: %w", err)
}

cobraCommand.Flags().StringVar(
&c.CacheDir,
"cache-dir",
defaultCacheDir,
"Directory for caching registry manifests",
)

cobraCommand.Flags().StringVar(
&c.CacheTTL,
"cache-ttl",
regopts.DefaultCacheTTL().String(),
"Time-to-live for cached registry manifests (e.g. 1h, 30m, 24h)",
)

return cobraCommand, nil
}

Expand Down Expand Up @@ -130,7 +168,17 @@ func (c *AddCmd) run(cmd *cobra.Command, args []string) error {
return handler.HandleError(err)
}

reg, err := c.registryBuilder.Build()
cacheTTL, err := time.ParseDuration(c.CacheTTL)
if err != nil {
return handler.HandleError(fmt.Errorf("invalid cache TTL: %w", err))
}

reg, err := c.registryBuilder.Build(
regopts.WithCaching(!c.CacheDisabled),
regopts.WithRefreshCache(c.CacheRefresh),
regopts.WithCacheDir(c.CacheDir),
regopts.WithCacheTTL(cacheTTL),
)
if err != nil {
return handler.HandleError(err)
}
Expand Down
224 changes: 205 additions & 19 deletions cmd/add_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"bytes"
"errors"
"io"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/mozilla-ai/mcpd/v2/internal/cmd"
Expand Down Expand Up @@ -72,7 +73,7 @@ type fakeBuilder struct {
err error
}

func (f *fakeBuilder) Build() (registry.PackageProvider, error) {
func (f *fakeBuilder) Build(_ ...options.BuildOption) (registry.PackageProvider, error) {
return f.reg, f.err
}

Expand Down Expand Up @@ -115,10 +116,10 @@ func TestAddCmd_Success(t *testing.T) {

err = cmdObj.Execute()
require.NoError(t, err)
assert.Contains(t, buf.String(), "✓ Added server")
assert.True(t, cfg.addCalled)
assert.Equal(t, "server1", cfg.entry.Name)
assert.Equal(t, "uvx::mcp-server-1@1.2.3", cfg.entry.Package)
require.Contains(t, buf.String(), "✓ Added server")
require.True(t, cfg.addCalled)
require.Equal(t, "server1", cfg.entry.Name)
require.Equal(t, "uvx::mcp-server-1@1.2.3", cfg.entry.Package)
}

func TestAddCmd_MissingArgs(t *testing.T) {
Expand All @@ -132,7 +133,7 @@ func TestAddCmd_MissingArgs(t *testing.T) {

err = cmdObj.Execute()
require.Error(t, err)
assert.Contains(t, err.Error(), "server name is required")
require.Contains(t, err.Error(), "server name is required")
}

func TestAddCmd_RegistryFails(t *testing.T) {
Expand All @@ -145,7 +146,7 @@ func TestAddCmd_RegistryFails(t *testing.T) {
cmdObj.SetArgs([]string{"server1"})
err = cmdObj.Execute()
require.Error(t, err)
assert.Contains(t, err.Error(), "registry error")
require.Contains(t, err.Error(), "registry error")
}

func TestAddCmd_BasicServerAdd(t *testing.T) {
Expand Down Expand Up @@ -187,14 +188,14 @@ func TestAddCmd_BasicServerAdd(t *testing.T) {

// Output assertions
outStr := o.String()
assert.Contains(t, outStr, "✓ Added server 'testserver'")
assert.Contains(t, outStr, "version: latest")
require.Contains(t, outStr, "✓ Added server 'testserver'")
require.Contains(t, outStr, "version: latest")

// Config assertions
require.True(t, cfg.addCalled)
assert.Equal(t, "testserver", cfg.entry.Name)
assert.Equal(t, "uvx::mcp-server-testserver@latest", cfg.entry.Package)
assert.ElementsMatch(t, []string{"tool1", "tool2", "tool3"}, cfg.entry.Tools)
require.Equal(t, "testserver", cfg.entry.Name)
require.Equal(t, "uvx::mcp-server-testserver@latest", cfg.entry.Package)
require.ElementsMatch(t, []string{"tool1", "tool2", "tool3"}, cfg.entry.Tools)
}

func TestAddCmd_ServerWithArguments(t *testing.T) {
Expand Down Expand Up @@ -396,11 +397,11 @@ func TestAddCmd_ServerWithArguments(t *testing.T) {

// Verify config was called with correct arguments
require.True(t, cfg.addCalled)
assert.Equal(t, tc.pkg.ID, cfg.entry.Name)
assert.ElementsMatch(t, tc.expectedRequiredEnvs, cfg.entry.RequiredEnvVars)
assert.ElementsMatch(t, tc.expectedRequiredPositionals, cfg.entry.RequiredPositionalArgs)
assert.ElementsMatch(t, tc.expectedRequiredValues, cfg.entry.RequiredValueArgs)
assert.ElementsMatch(t, tc.expectedRequiredBools, cfg.entry.RequiredBoolArgs)
require.Equal(t, tc.pkg.ID, cfg.entry.Name)
require.ElementsMatch(t, tc.expectedRequiredEnvs, cfg.entry.RequiredEnvVars)
require.ElementsMatch(t, tc.expectedRequiredPositionals, cfg.entry.RequiredPositionalArgs)
require.ElementsMatch(t, tc.expectedRequiredValues, cfg.entry.RequiredValueArgs)
require.ElementsMatch(t, tc.expectedRequiredBools, cfg.entry.RequiredBoolArgs)
})
}
}
Expand Down Expand Up @@ -476,7 +477,7 @@ func TestSelectRuntime(t *testing.T) {
require.EqualError(t, err, "no supported runtimes found")
} else {
require.NoError(t, err)
assert.Equal(t, tc.expectedRuntime, got)
require.Equal(t, tc.expectedRuntime, got)
}
})
}
Expand Down Expand Up @@ -761,3 +762,188 @@ func TestParseServerEntry(t *testing.T) {
})
}
}

func TestAddCmd_CacheTTL(t *testing.T) {
t.Parallel()

tests := []struct {
name string
ttl string
expectedError string
}{
{
name: "valid cache TTL",
ttl: "1h",
},
{
name: "invalid cache TTL",
ttl: "invalid",
expectedError: "invalid cache TTL: time: invalid duration \"invalid\"",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

pkg := packages.Server{
ID: "testserver",
Name: "testserver",
Tools: []packages.Tool{
{Name: "tool1"},
},
Installations: map[runtime.Runtime]packages.Installation{
runtime.UVX: {
Runtime: "uvx",
Package: "mcp-server-testserver",
Version: "latest",
Recommended: true,
},
},
}

cfg := &fakeConfig{}
cmdObj, err := NewAddCmd(
&cmd.BaseCmd{},
cmdopts.WithConfigLoader(&fakeLoader{cfg: cfg}),
cmdopts.WithRegistryBuilder(&fakeBuilder{reg: &fakeRegistry{pkg: pkg}}),
)
require.NoError(t, err)

cmdObj.SetOut(io.Discard)
cmdObj.SetErr(io.Discard)
cmdObj.SetArgs([]string{"testserver", "--cache-ttl", tc.ttl})

err = cmdObj.Execute()
if tc.expectedError != "" {
require.Error(t, err)
require.EqualError(t, err, tc.expectedError)
} else {
require.NoError(t, err)
require.True(t, cfg.addCalled)
require.Equal(t, "testserver", cfg.entry.Name)
}
})
}
}

func TestAddCmd_CacheFlagsWithTempDir(t *testing.T) {
t.Parallel()

tests := []struct {
name string
setupCmd func(t *testing.T, tempDir string) []string
}{
{
name: "custom cache directory",
setupCmd: func(t *testing.T, tempDir string) []string {
return []string{"testserver", "--cache-dir", tempDir}
},
},
{
name: "both custom cache flags",
setupCmd: func(t *testing.T, tempDir string) []string {
return []string{"testserver", "--cache-dir", tempDir, "--cache-ttl", "30m"}
},
},
{
name: "cache disabled with custom settings",
setupCmd: func(t *testing.T, tempDir string) []string {
return []string{"testserver", "--no-cache", "--cache-dir", tempDir, "--cache-ttl", "2h"}
},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

tempDir := t.TempDir()
args := tc.setupCmd(t, tempDir)

pkg := packages.Server{
ID: "testserver",
Name: "testserver",
Tools: []packages.Tool{
{Name: "tool1"},
},
Installations: map[runtime.Runtime]packages.Installation{
runtime.UVX: {
Runtime: "uvx",
Package: "mcp-server-testserver",
Version: "latest",
Recommended: true,
},
},
}

cfg := &fakeConfig{}
cmdObj, err := NewAddCmd(
&cmd.BaseCmd{},
cmdopts.WithConfigLoader(&fakeLoader{cfg: cfg}),
cmdopts.WithRegistryBuilder(&fakeBuilder{reg: &fakeRegistry{pkg: pkg}}),
)
require.NoError(t, err)

cmdObj.SetOut(io.Discard)
cmdObj.SetErr(io.Discard)
cmdObj.SetArgs(args)

err = cmdObj.Execute()
require.NoError(t, err)
require.True(t, cfg.addCalled)
require.Equal(t, "testserver", cfg.entry.Name)

// tempDir is available here for any cache directory verification
})
}
}

func TestAddCmd_NoCacheDirectoryCreatedWhenDisabled(t *testing.T) {
t.Parallel()

tempDir := t.TempDir()
cacheSubDir := filepath.Join(tempDir, "should-not-be-created")

// Verify the cache directory doesn't exist initially.
_, err := os.Stat(cacheSubDir)
require.True(t, os.IsNotExist(err), "Cache directory should not exist initially")

pkg := packages.Server{
ID: "testserver",
Name: "testserver",
Tools: []packages.Tool{
{Name: "tool1"},
},
Installations: map[runtime.Runtime]packages.Installation{
runtime.UVX: {
Runtime: "uvx",
Package: "mcp-server-testserver",
Version: "latest",
Recommended: true,
},
},
}

cfg := &fakeConfig{}
cmdObj, err := NewAddCmd(
&cmd.BaseCmd{},
cmdopts.WithConfigLoader(&fakeLoader{cfg: cfg}),
cmdopts.WithRegistryBuilder(&fakeBuilder{reg: &fakeRegistry{pkg: pkg}}),
)
require.NoError(t, err)

cmdObj.SetOut(io.Discard)
cmdObj.SetErr(io.Discard)
// Use --no-cache with custom cache directory - directory should NOT be created.
cmdObj.SetArgs([]string{"testserver", "--no-cache", "--cache-dir", cacheSubDir})

err = cmdObj.Execute()
require.NoError(t, err)
require.True(t, cfg.addCalled)
require.Equal(t, "testserver", cfg.entry.Name)

// Verify the cache directory was never created.
_, err = os.Stat(cacheSubDir)
require.True(t, os.IsNotExist(err), "Cache directory should not be created when --no-cache is used")
}
Loading