diff --git a/net/conn_test.go b/net/conn_test.go index 93e7d476..85ab479e 100644 --- a/net/conn_test.go +++ b/net/conn_test.go @@ -65,7 +65,8 @@ func TestConnWriteWithContext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tcpConn, err := net.Dial("tcp", listener.Addr().String()) + dialer := net.Dialer{} + tcpConn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String()) require.NoError(t, err) c := NewConn(tcpConn) defer func() { diff --git a/net/tlslistener_test.go b/net/tlslistener_test.go index c36ea990..bd25554c 100644 --- a/net/tlslistener_test.go +++ b/net/tlslistener_test.go @@ -100,12 +100,16 @@ func TestTLSListenerAcceptWithContext(t *testing.T) { cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) assert.NoError(t, err) - c, err := tls.DialWithDialer(&net.Dialer{ - Timeout: time.Millisecond * 400, - }, "tcp", listener.Addr().String(), &tls.Config{ - InsecureSkipVerify: true, - Certificates: []tls.Certificate{cert}, - }) + d := &tls.Dialer{ + NetDialer: &net.Dialer{ + Timeout: time.Millisecond * 400, + }, + Config: &tls.Config{ + InsecureSkipVerify: true, + Certificates: []tls.Certificate{cert}, + }, + } + c, err := d.DialContext(context.Background(), "tcp", listener.Addr().String()) if err != nil { continue } @@ -186,7 +190,8 @@ func TestTLSListenerCheckForInfinitLoop(t *testing.T) { cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock) assert.NoError(t, err) func() { - conn, err := net.Dial("tcp", listener.Addr().String()) + dialer := net.Dialer{} + conn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String()) if err != nil { return } diff --git a/options/tcpOptions.go b/options/tcpOptions.go index ec2f9fb6..9e8cf5a9 100644 --- a/options/tcpOptions.go +++ b/options/tcpOptions.go @@ -2,6 +2,7 @@ package options import ( "crypto/tls" + "time" tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client" tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server" @@ -34,11 +35,24 @@ func (o DisableTCPSignalMessageCSMOpt) TCPClientApply(cfg *tcpClient.Config) { cfg.DisableTCPSignalMessageCSM = true } -// WithDisableTCPSignalMessageCSM don't send CSM when client conn is created. func WithDisableTCPSignalMessageCSM() DisableTCPSignalMessageCSMOpt { return DisableTCPSignalMessageCSMOpt{} } +type CSMExchangeTimeoutOpt struct { + timeout time.Duration +} + +func (o CSMExchangeTimeoutOpt) TCPClientApply(cfg *tcpClient.Config) { + cfg.CSMExchangeTimeout = o.timeout +} + +func WithCSMExchangeTimeout(timeout time.Duration) CSMExchangeTimeoutOpt { + return CSMExchangeTimeoutOpt{ + timeout: timeout, + } +} + // TLSOpt tls configuration option. type TLSOpt struct { tlsCfg *tls.Config diff --git a/tcp/client.go b/tcp/client.go index 54349433..69bbc56c 100644 --- a/tcp/client.go +++ b/tcp/client.go @@ -7,6 +7,7 @@ import ( "time" "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/codes" "github.com/plgd-dev/go-coap/v3/message/pool" coapNet "github.com/plgd-dev/go-coap/v3/net" "github.com/plgd-dev/go-coap/v3/net/blockwise" @@ -30,7 +31,11 @@ func Dial(target string, opts ...Option) (*client.Conn, error) { var conn net.Conn var err error if cfg.TLSCfg != nil { - conn, err = tls.DialWithDialer(cfg.Dialer, cfg.Net, target, cfg.TLSCfg) + d := &tls.Dialer{ + NetDialer: cfg.Dialer, + Config: cfg.TLSCfg, + } + conn, err = d.DialContext(cfg.Ctx, cfg.Net, target) } else { conn, err = cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target) } @@ -38,11 +43,11 @@ func Dial(target string, opts ...Option) (*client.Conn, error) { return nil, err } opts = append(opts, options.WithCloseSocket()) - return Client(conn, opts...), nil + return Client(conn, opts...) } // Client creates client over tcp/tcp-tls connection. -func Client(conn net.Conn, opts ...Option) *client.Conn { +func Client(conn net.Conn, opts ...Option) (*client.Conn, error) { cfg := client.DefaultConfig for _, o := range opts { o.TCPClientApply(&cfg) @@ -100,6 +105,17 @@ func Client(conn net.Conn, opts ...Option) *client.Conn { return cc.Context().Err() == nil }) + var csmExchangeDone chan struct{} + if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs { + csmExchangeDone = make(chan struct{}) + + cc.SetTCPSignalReceivedHandler(func(code codes.Code) { + if code == codes.CSM { + close(csmExchangeDone) + } + }) + } + go func() { err := cc.Run() if err != nil { @@ -107,5 +123,20 @@ func Client(conn net.Conn, opts ...Option) *client.Conn { } }() - return cc + // if CSM messages are enabled, wait for the CSM messages to be exchanged + if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs { + select { + case <-time.After(cfg.CSMExchangeTimeout): + err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr()) + cfg.Errors(err) + cc.Close() // Close connection on timeout + return nil, err // or return cc with an error state + case <-csmExchangeDone: + // CSM exchange completed successfully + } + // Clear the handler after exchange is complete or timed out + cc.SetTCPSignalReceivedHandler(nil) + } + + return cc, nil } diff --git a/tcp/client/config.go b/tcp/client/config.go index b3c71857..0eda7c9e 100644 --- a/tcp/client/config.go +++ b/tcp/client/config.go @@ -50,4 +50,5 @@ type Config struct { DisablePeerTCPSignalMessageCSMs bool CloseSocket bool DisableTCPSignalMessageCSM bool + CSMExchangeTimeout time.Duration } diff --git a/tcp/client/conn.go b/tcp/client/conn.go index 018ead4d..1f8a0342 100644 --- a/tcp/client/conn.go +++ b/tcp/client/conn.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "sync" "time" "github.com/plgd-dev/go-coap/v3/message" @@ -27,6 +28,8 @@ type InactivityMonitor interface { CheckInactivity(now time.Time, cc *Conn) } +type TCPSignalReceivedHandler func(codes.Code) + type ( HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message) ErrorFunc = func(error) @@ -51,6 +54,8 @@ type Conn struct { blockwiseSZX blockwise.SZX peerMaxMessageSize atomic.Uint32 disablePeerTCPSignalMessageCSMs bool + tcpSignalReceivedHandler TCPSignalReceivedHandler + handlerMutex sync.RWMutex peerBlockWiseTranferEnabled atomic.Bool receivedMessageReader *client.ReceivedMessageReader[*Conn] @@ -267,6 +272,12 @@ func (cc *Conn) Run() (err error) { return cc.session.Run(cc) } +func (cc *Conn) SetTCPSignalReceivedHandler(handler TCPSignalReceivedHandler) { + cc.handlerMutex.Lock() + defer cc.handlerMutex.Unlock() + cc.tcpSignalReceivedHandler = handler +} + // AddOnClose calls function on close connection event. func (cc *Conn) AddOnClose(f EventFunc) { cc.session.AddOnClose(f) @@ -370,6 +381,14 @@ func (cc *Conn) sendPong(token message.Token) error { return cc.Session().WriteMessage(req) } +func (cc *Conn) handleTCPSignalReceived(code codes.Code) { + cc.handlerMutex.RLock() + defer cc.handlerMutex.RUnlock() + if cc.tcpSignalReceivedHandler != nil { + cc.tcpSignalReceivedHandler(code) + } +} + func (cc *Conn) handleSignals(r *pool.Message) bool { switch r.Code() { case codes.CSM: @@ -382,6 +401,9 @@ func (cc *Conn) handleSignals(r *pool.Message) bool { if r.HasOption(message.TCPBlockWiseTransfer) { cc.peerBlockWiseTranferEnabled.Store(true) } + + // signal CSM message is received. + cc.handleTCPSignalReceived(codes.CSM) return true case codes.Ping: // if r.HasOption(message.TCPCustody) { @@ -390,21 +412,29 @@ func (cc *Conn) handleSignals(r *pool.Message) bool { if err := cc.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) { cc.Session().errors(fmt.Errorf("cannot handle ping signal: %w", err)) } + + cc.handleTCPSignalReceived(codes.Ping) + return true + case codes.Pong: + if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok { + cc.processReceivedMessage(r, cc, h) + } + + cc.handleTCPSignalReceived(codes.Pong) return true case codes.Release: // if r.HasOption(message.TCPAlternativeAddress) { // TODO // } + + cc.handleTCPSignalReceived(codes.Release) return true case codes.Abort: // if r.HasOption(message.TCPBadCSMOption) { // TODO // } - return true - case codes.Pong: - if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok { - cc.processReceivedMessage(r, cc, h) - } + + cc.handleTCPSignalReceived(codes.Abort) return true } return false diff --git a/tcp/client_test.go b/tcp/client_test.go index 5c4a1198..4cddc71e 100644 --- a/tcp/client_test.go +++ b/tcp/client_test.go @@ -18,6 +18,7 @@ import ( "github.com/plgd-dev/go-coap/v3/options/config" "github.com/plgd-dev/go-coap/v3/pkg/runner/periodic" "github.com/plgd-dev/go-coap/v3/tcp/client" + "github.com/plgd-dev/go-coap/v3/tcp/server" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" @@ -839,3 +840,77 @@ func TestConnRequestMonitorDropRequest(t *testing.T) { require.Error(t, err) require.ErrorIs(t, err, context.DeadlineExceeded) } + +func TestConnWithCSMExchangeTimeout(t *testing.T) { + type args struct { + clientOptions []Option + serverOptions []server.Option + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "client-server-no-csm", + args: args{ + clientOptions: []Option{}, + serverOptions: []server.Option{}, + }, + wantErr: false, + }, + { + name: "client-server-csm-success", + args: args{ + clientOptions: []Option{ + options.WithCSMExchangeTimeout(time.Second * 3), + }, + }, + wantErr: false, + }, + { + name: "client-server-csm-timeout", + args: args{ + clientOptions: []Option{ + options.WithCSMExchangeTimeout(time.Second * 3), + }, + serverOptions: []server.Option{ + options.WithDisableTCPSignalMessageCSM(), + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l, err := coapNet.NewTCPListener("tcp", "") + require.NoError(t, err) + defer func() { + errC := l.Close() + require.NoError(t, errC) + }() + var wg sync.WaitGroup + defer wg.Wait() + + s := NewServer(tt.args.serverOptions...) + defer s.Stop() + wg.Add(1) + go func() { + defer wg.Done() + errS := s.Serve(l) + assert.NoError(t, errS) + }() + + client, err := Dial(l.Addr().String(), + tt.args.clientOptions...) + if tt.wantErr { + require.Nil(t, client) + require.Error(t, err) + } else { + require.NotNil(t, client) + require.NoError(t, err) + } + }) + } +} diff --git a/tcp/server_test.go b/tcp/server_test.go index 124bd992..a7ed20e0 100644 --- a/tcp/server_test.go +++ b/tcp/server_test.go @@ -301,7 +301,8 @@ func TestServerKeepAliveMonitor(t *testing.T) { assert.NoError(t, errS) }() - cc, err := net.Dial("tcp", ld.Addr().String()) + dialer := net.Dialer{} + cc, err := dialer.DialContext(context.Background(), "tcp", ld.Addr().String()) require.NoError(t, err) defer func() { _ = cc.Close()