Skip to content

Draft: Fix race condition during CSM message exchange #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion net/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func TestConnWriteWithContext(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tcpConn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
require.NoError(t, err)
c := NewConn(tcpConn)
defer func() {
Expand Down
19 changes: 12 additions & 7 deletions net/tlslistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ func TestTLSListenerAcceptWithContext(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)

c, err := tls.DialWithDialer(&net.Dialer{
Timeout: time.Millisecond * 400,
}, "tcp", listener.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
})
d := &tls.Dialer{
NetDialer: &net.Dialer{
Timeout: time.Millisecond * 400,
},
Config: &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
},
}
c, err := d.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
continue
}
Expand Down Expand Up @@ -186,7 +190,8 @@ func TestTLSListenerCheckForInfinitLoop(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)
func() {
conn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
conn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
return
}
Expand Down
16 changes: 15 additions & 1 deletion options/tcpOptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package options

import (
"crypto/tls"
"time"

tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client"
tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server"
Expand Down Expand Up @@ -34,11 +35,24 @@ func (o DisableTCPSignalMessageCSMOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.DisableTCPSignalMessageCSM = true
}

// WithDisableTCPSignalMessageCSM don't send CSM when client conn is created.
func WithDisableTCPSignalMessageCSM() DisableTCPSignalMessageCSMOpt {
return DisableTCPSignalMessageCSMOpt{}
}

type CSMExchangeTimeoutOpt struct {
timeout time.Duration
}

func (o CSMExchangeTimeoutOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.CSMExchangeTimeout = o.timeout
}

func WithCSMExchangeTimeout(timeout time.Duration) CSMExchangeTimeoutOpt {
return CSMExchangeTimeoutOpt{
timeout: timeout,
}
}

// TLSOpt tls configuration option.
type TLSOpt struct {
tlsCfg *tls.Config
Expand Down
39 changes: 35 additions & 4 deletions tcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/net/blockwise"
Expand All @@ -30,19 +31,23 @@ func Dial(target string, opts ...Option) (*client.Conn, error) {
var conn net.Conn
var err error
if cfg.TLSCfg != nil {
conn, err = tls.DialWithDialer(cfg.Dialer, cfg.Net, target, cfg.TLSCfg)
d := &tls.Dialer{
NetDialer: cfg.Dialer,
Config: cfg.TLSCfg,
}
conn, err = d.DialContext(cfg.Ctx, cfg.Net, target)
} else {
conn, err = cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target)
}
if err != nil {
return nil, err
}
opts = append(opts, options.WithCloseSocket())
return Client(conn, opts...), nil
return Client(conn, opts...)
}

// Client creates client over tcp/tcp-tls connection.
func Client(conn net.Conn, opts ...Option) *client.Conn {
func Client(conn net.Conn, opts ...Option) (*client.Conn, error) {
cfg := client.DefaultConfig
for _, o := range opts {
o.TCPClientApply(&cfg)
Expand Down Expand Up @@ -100,12 +105,38 @@ func Client(conn net.Conn, opts ...Option) *client.Conn {
return cc.Context().Err() == nil
})

var csmExchangeDone chan struct{}
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
csmExchangeDone = make(chan struct{})

cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
if code == codes.CSM {
close(csmExchangeDone)
}
})
}

go func() {
err := cc.Run()
if err != nil {
cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
}
}()

return cc
// if CSM messages are enabled, wait for the CSM messages to be exchanged
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
select {
case <-time.After(cfg.CSMExchangeTimeout):
err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
cfg.Errors(err)
cc.Close() // Close connection on timeout
return nil, err // or return cc with an error state
case <-csmExchangeDone:
// CSM exchange completed successfully
}
// Clear the handler after exchange is complete or timed out
cc.SetTCPSignalReceivedHandler(nil)
}
Comment on lines +108 to +139
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix potential panic on close, ensure handler cleanup on all paths, and avoid timer leaks

Issues:

  • Closing csmExchangeDone without guarding can panic if multiple CSM signals arrive (Line 114).
  • Handler is not cleared on the timeout return path; defer the cleanup to cover all exits (Lines 137-139).
  • time.After allocates a timer that can linger; prefer time.NewTimer and stop it (Lines 129-136).
  • Consider exiting early if the connection closes before the exchange completes.

Proposed fix:

-	var csmExchangeDone chan struct{}
-	if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
-		csmExchangeDone = make(chan struct{})
-
-		cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
-			if code == codes.CSM {
-				close(csmExchangeDone)
-			}
-		})
-	}
+	var csmExchangeDone chan struct{}
+	var csmOnce sync.Once
+	if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
+		csmExchangeDone = make(chan struct{})
+		cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
+			if code == codes.CSM {
+				csmOnce.Do(func() { close(csmExchangeDone) })
+			}
+		})
+		// Ensure handler is always cleared, regardless of outcome.
+		defer cc.SetTCPSignalReceivedHandler(nil)
+	}
@@
-	// if CSM messages are enabled, wait for the CSM messages to be exchanged
-	if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
-		select {
-		case <-time.After(cfg.CSMExchangeTimeout):
-			err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
-			cfg.Errors(err)
-			cc.Close()      // Close connection on timeout
-			return nil, err // or return cc with an error state
-		case <-csmExchangeDone:
-			// CSM exchange completed successfully
-		}
-		// Clear the handler after exchange is complete or timed out
-		cc.SetTCPSignalReceivedHandler(nil)
-	}
+	// If enabled, wait for the CSM exchange to complete or time out.
+	if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
+		timer := time.NewTimer(cfg.CSMExchangeTimeout)
+		defer timer.Stop()
+		select {
+		case <-timer.C:
+			err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
+			cfg.Errors(err)
+			_ = cc.Close()
+			return nil, err
+		case <-cc.Done():
+			return nil, fmt.Errorf("%v: connection closed before CSM exchange: %w", cc.RemoteAddr(), cc.Context().Err())
+		case <-csmExchangeDone:
+			// CSM exchange completed successfully
+		}
+	}

And add the import:

 import (
   "crypto/tls"
   "fmt"
   "net"
   "time"
+  "sync"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
var csmExchangeDone chan struct{}
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
csmExchangeDone = make(chan struct{})
cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
if code == codes.CSM {
close(csmExchangeDone)
}
})
}
go func() {
err := cc.Run()
if err != nil {
cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
}
}()
return cc
// if CSM messages are enabled, wait for the CSM messages to be exchanged
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
select {
case <-time.After(cfg.CSMExchangeTimeout):
err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
cfg.Errors(err)
cc.Close() // Close connection on timeout
return nil, err // or return cc with an error state
case <-csmExchangeDone:
// CSM exchange completed successfully
}
// Clear the handler after exchange is complete or timed out
cc.SetTCPSignalReceivedHandler(nil)
}
var csmExchangeDone chan struct{}
var csmOnce sync.Once
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
csmExchangeDone = make(chan struct{})
cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
if code == codes.CSM {
csmOnce.Do(func() { close(csmExchangeDone) })
}
})
// Ensure handler is always cleared, regardless of outcome.
defer cc.SetTCPSignalReceivedHandler(nil)
}
go func() {
err := cc.Run()
if err != nil {
cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
}
}()
// If enabled, wait for the CSM exchange to complete or time out.
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
timer := time.NewTimer(cfg.CSMExchangeTimeout)
defer timer.Stop()
select {
case <-timer.C:
err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
cfg.Errors(err)
_ = cc.Close()
return nil, err
case <-cc.Done():
return nil, fmt.Errorf("%v: connection closed before CSM exchange: %w", cc.RemoteAddr(), cc.Context().Err())
case <-csmExchangeDone:
// CSM exchange completed successfully
}
}


return cc, nil
}
1 change: 1 addition & 0 deletions tcp/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ type Config struct {
DisablePeerTCPSignalMessageCSMs bool
CloseSocket bool
DisableTCPSignalMessageCSM bool
CSMExchangeTimeout time.Duration
}
40 changes: 35 additions & 5 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/plgd-dev/go-coap/v3/message"
Expand All @@ -27,6 +28,8 @@ type InactivityMonitor interface {
CheckInactivity(now time.Time, cc *Conn)
}

type TCPSignalReceivedHandler func(codes.Code)

type (
HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message)
ErrorFunc = func(error)
Expand All @@ -51,6 +54,8 @@ type Conn struct {
blockwiseSZX blockwise.SZX
peerMaxMessageSize atomic.Uint32
disablePeerTCPSignalMessageCSMs bool
tcpSignalReceivedHandler TCPSignalReceivedHandler
handlerMutex sync.RWMutex
peerBlockWiseTranferEnabled atomic.Bool

receivedMessageReader *client.ReceivedMessageReader[*Conn]
Expand Down Expand Up @@ -267,6 +272,12 @@ func (cc *Conn) Run() (err error) {
return cc.session.Run(cc)
}

func (cc *Conn) SetTCPSignalReceivedHandler(handler TCPSignalReceivedHandler) {
cc.handlerMutex.Lock()
defer cc.handlerMutex.Unlock()
cc.tcpSignalReceivedHandler = handler
}

// AddOnClose calls function on close connection event.
func (cc *Conn) AddOnClose(f EventFunc) {
cc.session.AddOnClose(f)
Expand Down Expand Up @@ -370,6 +381,14 @@ func (cc *Conn) sendPong(token message.Token) error {
return cc.Session().WriteMessage(req)
}

func (cc *Conn) handleTCPSignalReceived(code codes.Code) {
cc.handlerMutex.RLock()
defer cc.handlerMutex.RUnlock()
if cc.tcpSignalReceivedHandler != nil {
cc.tcpSignalReceivedHandler(code)
}
}

func (cc *Conn) handleSignals(r *pool.Message) bool {
switch r.Code() {
case codes.CSM:
Expand All @@ -382,6 +401,9 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if r.HasOption(message.TCPBlockWiseTransfer) {
cc.peerBlockWiseTranferEnabled.Store(true)
}

// signal CSM message is received.
cc.handleTCPSignalReceived(codes.CSM)
return true
case codes.Ping:
// if r.HasOption(message.TCPCustody) {
Expand All @@ -390,21 +412,29 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if err := cc.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) {
cc.Session().errors(fmt.Errorf("cannot handle ping signal: %w", err))
}

cc.handleTCPSignalReceived(codes.Ping)
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Pong)
return true
case codes.Release:
// if r.HasOption(message.TCPAlternativeAddress) {
// TODO
// }

cc.handleTCPSignalReceived(codes.Release)
return true
case codes.Abort:
// if r.HasOption(message.TCPBadCSMOption) {
// TODO
// }
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Abort)
return true
}
return false
Expand Down
77 changes: 77 additions & 0 deletions tcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"github.com/plgd-dev/go-coap/v3/options/config"
"github.com/plgd-dev/go-coap/v3/pkg/runner/periodic"
"github.com/plgd-dev/go-coap/v3/tcp/client"
"github.com/plgd-dev/go-coap/v3/tcp/server"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
Expand Down Expand Up @@ -839,3 +840,79 @@
require.Error(t, err)
require.ErrorIs(t, err, context.DeadlineExceeded)
}

func TestConnWithCSMExchangeTimeout(t *testing.T) {

Check failure on line 844 in tcp/client_test.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)

Check failure on line 845 in tcp/client_test.go

View workflow job for this annotation

GitHub Actions / lint

File is not properly formatted (gofumpt)
type args struct {
clientOptions []Option
serverOptions []server.Option
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "client-server-no-csm",
args: args{
clientOptions: []Option{},
serverOptions: []server.Option{},
},
wantErr: false,
},
{
name: "client-server-csm-success",
args: args{
clientOptions: []Option{
options.WithCSMExchangeTimeout(time.Second * 3),
},
},
wantErr: false,
},
{
name: "client-server-csm-timeout",
args: args{
clientOptions: []Option{
options.WithCSMExchangeTimeout(time.Second * 3),
},
serverOptions: []server.Option{
options.WithDisableTCPSignalMessageCSM(),
},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

Check failure on line 887 in tcp/client_test.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)

l, err := coapNet.NewTCPListener("tcp", "")
require.NoError(t, err)
defer func() {
errC := l.Close()
require.NoError(t, errC)
}()
var wg sync.WaitGroup
defer wg.Wait()

s := NewServer(tt.args.serverOptions...)
defer s.Stop()
wg.Add(1)
go func() {
defer wg.Done()
errS := s.Serve(l)
assert.NoError(t, errS)
}()

client, err := Dial(l.Addr().String(),
tt.args.clientOptions...)
if tt.wantErr {
require.Nil(t, client)
require.Error(t, err)
} else {
require.NotNil(t, client)
require.NoError(t, err)
}
})
Comment on lines +905 to +914
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Prevent resource leaks and avoid package shadowing

  • The variable name client shadows the imported client package; prefer cc for consistency with the rest of this file.
  • On the success path the connection is never closed; this can leak goroutines and sockets. Close and wait for Done as in other tests.
-			client, err := Dial(l.Addr().String(),
-				tt.args.clientOptions...)
-			if tt.wantErr {
-				require.Nil(t, client)
-				require.Error(t, err)
-			} else {
-				require.NotNil(t, client)
-				require.NoError(t, err)
-			}
+			cc, err := Dial(l.Addr().String(), tt.args.clientOptions...)
+			if tt.wantErr {
+				require.Nil(t, cc)
+				require.Error(t, err)
+				return
+			}
+			require.NoError(t, err)
+			require.NotNil(t, cc)
+			defer func() {
+				_ = cc.Close()
+				<-cc.Done()
+			}()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
client, err := Dial(l.Addr().String(),
tt.args.clientOptions...)
if tt.wantErr {
require.Nil(t, client)
require.Error(t, err)
} else {
require.NotNil(t, client)
require.NoError(t, err)
}
})
cc, err := Dial(l.Addr().String(), tt.args.clientOptions...)
if tt.wantErr {
require.Nil(t, cc)
require.Error(t, err)
return
}
require.NoError(t, err)
require.NotNil(t, cc)
defer func() {
_ = cc.Close()
<-cc.Done()
}()
🤖 Prompt for AI Agents
In tcp/client_test.go around lines 907 to 916, the test shadows the imported
client package by naming the variable "client" and on the success path never
closes the connection which can leak goroutines/sockets; rename the local
variable to "cc" to match the rest of the file and, in the non-error branch,
close the connection and wait for its Done signal (same pattern used in other
tests) before returning from the subtest so resources are released.

}
}
3 changes: 2 additions & 1 deletion tcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ func TestServerKeepAliveMonitor(t *testing.T) {
assert.NoError(t, errS)
}()

cc, err := net.Dial("tcp", ld.Addr().String())
dialer := net.Dialer{}
cc, err := dialer.DialContext(context.Background(), "tcp", ld.Addr().String())
require.NoError(t, err)
defer func() {
_ = cc.Close()
Expand Down
Loading