Skip to content

Commit a0ab592

Browse files
committed
First attempt at introducing CORS support
1 parent 4e00b8d commit a0ab592

File tree

3 files changed

+74
-11
lines changed

3 files changed

+74
-11
lines changed

cmd/daemon.go

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ import (
2222
// DaemonCmd should be used to represent the 'daemon' command.
2323
type DaemonCmd struct {
2424
*cmd.BaseCmd
25-
Dev bool
26-
Addr string
27-
cfgLoader config.Loader
28-
ctxLoader configcontext.Loader
25+
Dev bool
26+
Addr string
27+
EnableCORS bool
28+
CORSOrigins []string
29+
cfgLoader config.Loader
30+
ctxLoader configcontext.Loader
2931
}
3032

3133
// NewDaemonCmd creates a newly configured (Cobra) command.
@@ -36,13 +38,14 @@ func NewDaemonCmd(baseCmd *cmd.BaseCmd, opt ...cmdopts.CmdOption) (*cobra.Comman
3638
}
3739

3840
c := &DaemonCmd{
39-
BaseCmd: baseCmd,
40-
cfgLoader: opts.ConfigLoader,
41-
ctxLoader: opts.ContextLoader,
41+
BaseCmd: baseCmd,
42+
cfgLoader: opts.ConfigLoader,
43+
ctxLoader: opts.ContextLoader,
44+
CORSOrigins: []string{},
4245
}
4346

4447
cobraCommand := &cobra.Command{
45-
Use: "daemon [--dev] [--addr]",
48+
Use: "daemon [--dev] [--addr] [--enable-cors] [--cors-origins]",
4649
Short: "Launches an `mcpd` daemon instance",
4750
Long: "Launches an `mcpd` daemon instance, which starts MCP servers and provides routing via HTTP API",
4851
RunE: c.run,
@@ -62,6 +65,21 @@ func NewDaemonCmd(baseCmd *cmd.BaseCmd, opt ...cmdopts.CmdOption) (*cobra.Comman
6265
"Address for the daemon to bind (not applicable in --dev mode)",
6366
)
6467

68+
// Add CORS flags
69+
cobraCommand.Flags().BoolVar(
70+
&c.EnableCORS,
71+
"enable-cors",
72+
false,
73+
"Enable Cross-Origin Resource Sharing (CORS) for browser clients",
74+
)
75+
76+
cobraCommand.Flags().StringSliceVar(
77+
&c.CORSOrigins,
78+
"cors-origins",
79+
[]string{"*"},
80+
"Comma-separated list of allowed CORS origins (default: * for all origins)",
81+
)
82+
6583
cobraCommand.MarkFlagsMutuallyExclusive("dev", "addr")
6684

6785
return cobraCommand, nil
@@ -93,7 +111,9 @@ func (c *DaemonCmd) run(_ *cobra.Command, _ []string) error {
93111
if err != nil {
94112
return fmt.Errorf("error configuring mcpd daemon options: %w", err)
95113
}
96-
d, err := daemon.NewDaemon(addr, opts)
114+
115+
// Pass CORS configuration to daemon creation
116+
d, err := daemon.NewDaemon(addr, opts, c.EnableCORS, c.CORSOrigins)
97117
if err != nil {
98118
return fmt.Errorf("failed to create mcpd daemon instance: %w", err)
99119
}
@@ -128,6 +148,11 @@ func (c *DaemonCmd) run(_ *cobra.Command, _ []string) error {
128148
banner += fmt.Sprintf(" Log file:\t%s => (%s)\n", flags.LogPath, flags.LogLevel)
129149
}
130150

151+
// Add CORS status to banner
152+
if c.EnableCORS {
153+
banner += fmt.Sprintf(" CORS enabled:\t%v (origins: %s)\n", c.EnableCORS, strings.Join(c.CORSOrigins, ", "))
154+
}
155+
131156
banner += "\nPress Ctrl+C to stop.\n\n"
132157
fmt.Print(banner)
133158
}

internal/daemon/daemon.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func NewDaemonOpts(logger hclog.Logger, cfgLoader config.Loader, ctxLoader confi
6464

6565
// NewDaemon creates a new Daemon instance with proper initialization.
6666
// Use this function instead of directly creating a Daemon struct.
67-
func NewDaemon(apiAddr string, opts *Opts) (*Daemon, error) {
67+
func NewDaemon(apiAddr string, opts *Opts, enableCORS bool, corsOrigins []string) (*Daemon, error) {
6868
if err := IsValidAddr(apiAddr); err != nil {
6969
return nil, fmt.Errorf("invalid API address '%s': %w", apiAddr, err)
7070
}
@@ -94,7 +94,7 @@ func NewDaemon(apiAddr string, opts *Opts) (*Daemon, error) {
9494

9595
healthTracker := NewHealthTracker(serverNames)
9696
clientManager := NewClientManager()
97-
apiServer, err := NewApiServer(opts.logger, clientManager, healthTracker, apiAddr)
97+
apiServer, err := NewApiServer(opts.logger, clientManager, healthTracker, apiAddr, enableCORS, corsOrigins)
9898
if err != nil {
9999
return nil, fmt.Errorf("failed to create daemon API server: %w", err)
100100
}

internal/daemon/server.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ import (
77
"net/http"
88
"net/url"
99
"reflect"
10+
"strings"
1011
"time"
1112

1213
"github.com/danielgtaylor/huma/v2"
1314
"github.com/danielgtaylor/huma/v2/adapters/humachi"
1415
"github.com/go-chi/chi/v5"
1516
"github.com/go-chi/chi/v5/middleware"
17+
"github.com/go-chi/cors"
1618
"github.com/hashicorp/go-hclog"
1719

1820
"github.com/mozilla-ai/mcpd/v2/internal/api"
@@ -26,13 +28,17 @@ type ApiServer struct {
2628
healthTracker contracts.MCPHealthMonitor
2729
logger hclog.Logger
2830
addr string
31+
enableCORS bool
32+
corsOrigins []string
2933
}
3034

3135
func NewApiServer(
3236
logger hclog.Logger,
3337
accessor contracts.MCPClientAccessor,
3438
monitor contracts.MCPHealthMonitor,
3539
addr string,
40+
enableCORS bool,
41+
corsOrigins []string,
3642
) (*ApiServer, error) {
3743
if logger == nil || reflect.ValueOf(logger).IsNil() {
3844
return nil, fmt.Errorf("logger cannot be nil")
@@ -52,13 +58,42 @@ func NewApiServer(
5258
clientManager: accessor,
5359
healthTracker: monitor,
5460
addr: addr,
61+
enableCORS: enableCORS,
62+
corsOrigins: corsOrigins,
5563
}, nil
5664
}
5765

5866
func (a *ApiServer) Start(ctx context.Context) error {
5967
// Create router.
6068
mux := chi.NewMux()
6169
mux.Use(middleware.StripSlashes)
70+
71+
// Add CORS middleware if enabled
72+
if a.enableCORS {
73+
a.logger.Info("Enabling CORS", "origins", a.corsOrigins)
74+
75+
corsOptions := cors.Options{
76+
AllowedOrigins: a.corsOrigins,
77+
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
78+
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token", "X-Requested-With"},
79+
ExposedHeaders: []string{"Link"},
80+
AllowCredentials: false,
81+
MaxAge: 300, // Maximum value not ignored by any of major browsers
82+
}
83+
84+
// Handle wildcard origins properly
85+
for i, origin := range corsOptions.AllowedOrigins {
86+
if origin == "*" {
87+
corsOptions.AllowedOrigins = []string{"*"}
88+
corsOptions.AllowCredentials = false
89+
break
90+
}
91+
corsOptions.AllowedOrigins[i] = strings.TrimSpace(origin)
92+
}
93+
94+
mux.Use(cors.Handler(corsOptions))
95+
}
96+
6297
config := huma.DefaultConfig("mcpd docs", cmd.Version())
6398
router := humachi.New(mux, config)
6499

@@ -85,6 +120,9 @@ func (a *ApiServer) Start(ctx context.Context) error {
85120
// Start the API.
86121
go func() {
87122
a.logger.Info("Starting API server", "address", a.addr, "prefix", apiPathPrefix)
123+
if a.enableCORS {
124+
a.logger.Info("CORS enabled", "origins", a.corsOrigins)
125+
}
88126
if err := srv.ListenAndServe(); err != nil && !stdErrors.Is(err, http.ErrServerClosed) {
89127
errCh <- err
90128
}

0 commit comments

Comments
 (0)