Skip to content

Draft: Fix race condition during CSM message exchange #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion net/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func TestConnWriteWithContext(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tcpConn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
require.NoError(t, err)
c := NewConn(tcpConn)
defer func() {
Expand Down
19 changes: 12 additions & 7 deletions net/tlslistener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ func TestTLSListenerAcceptWithContext(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)

c, err := tls.DialWithDialer(&net.Dialer{
Timeout: time.Millisecond * 400,
}, "tcp", listener.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
})
d := &tls.Dialer{
NetDialer: &net.Dialer{
Timeout: time.Millisecond * 400,
},
Config: &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{cert},
},
}
c, err := d.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
continue
}
Expand Down Expand Up @@ -186,7 +190,8 @@ func TestTLSListenerCheckForInfinitLoop(t *testing.T) {
cert, err := tls.X509KeyPair(CertPEMBlock, KeyPEMBlock)
assert.NoError(t, err)
func() {
conn, err := net.Dial("tcp", listener.Addr().String())
dialer := net.Dialer{}
conn, err := dialer.DialContext(context.Background(), "tcp", listener.Addr().String())
if err != nil {
return
}
Expand Down
16 changes: 15 additions & 1 deletion options/tcpOptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package options

import (
"crypto/tls"
"time"

tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client"
tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server"
Expand Down Expand Up @@ -34,11 +35,24 @@ func (o DisableTCPSignalMessageCSMOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.DisableTCPSignalMessageCSM = true
}

// WithDisableTCPSignalMessageCSM don't send CSM when client conn is created.
func WithDisableTCPSignalMessageCSM() DisableTCPSignalMessageCSMOpt {
return DisableTCPSignalMessageCSMOpt{}
}

type CSMExchangeTimeoutOpt struct {
timeout time.Duration
}

func (o CSMExchangeTimeoutOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.CSMExchangeTimeout = o.timeout
}

func WithCSMExchangeTimeout(timeout time.Duration) CSMExchangeTimeoutOpt {
return CSMExchangeTimeoutOpt{
timeout: timeout,
}
}

// TLSOpt tls configuration option.
type TLSOpt struct {
tlsCfg *tls.Config
Expand Down
34 changes: 33 additions & 1 deletion tcp/client.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package tcp

import (
"context"
"crypto/tls"
"fmt"
"net"
"time"

"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/net/blockwise"
Expand All @@ -30,7 +32,11 @@ func Dial(target string, opts ...Option) (*client.Conn, error) {
var conn net.Conn
var err error
if cfg.TLSCfg != nil {
conn, err = tls.DialWithDialer(cfg.Dialer, cfg.Net, target, cfg.TLSCfg)
d := &tls.Dialer{
NetDialer: cfg.Dialer,
Config: cfg.TLSCfg,
}
conn, err = d.DialContext(context.Background(), cfg.Net, target)
} else {
conn, err = cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target)
}
Expand Down Expand Up @@ -100,12 +106,38 @@ func Client(conn net.Conn, opts ...Option) *client.Conn {
return cc.Context().Err() == nil
})

var csmExchangeDone chan struct{}
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
csmExchangeDone = make(chan struct{})

cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
if code == codes.CSM {
close(csmExchangeDone)
}
})
}

go func() {
err := cc.Run()
if err != nil {
cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
}
}()

// if CSM messages are enabled, wait for the CSM messages to be exchanged
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
select {
case <-time.After(cfg.CSMExchangeTimeout):
err := fmt.Errorf("%v: timeout waiting for CSM exchange with peer", cc.RemoteAddr())
cfg.Errors(err)
cc.Close() // Close connection on timeout
return nil // or return cc with an error state
case <-csmExchangeDone:
// CSM exchange completed successfully
}
// Clear the handler after exchange is complete or timed out
cc.SetTCPSignalReceivedHandler(nil)
}

return cc
}
1 change: 1 addition & 0 deletions tcp/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ type Config struct {
DisablePeerTCPSignalMessageCSMs bool
CloseSocket bool
DisableTCPSignalMessageCSM bool
CSMExchangeTimeout time.Duration
}
40 changes: 35 additions & 5 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/plgd-dev/go-coap/v3/message"
Expand All @@ -27,6 +28,8 @@ type InactivityMonitor interface {
CheckInactivity(now time.Time, cc *Conn)
}

type TCPSignalReceivedHandler func(codes.Code)

type (
HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message)
ErrorFunc = func(error)
Expand All @@ -51,6 +54,8 @@ type Conn struct {
blockwiseSZX blockwise.SZX
peerMaxMessageSize atomic.Uint32
disablePeerTCPSignalMessageCSMs bool
tcpSignalReceivedHandler TCPSignalReceivedHandler
handlerMutex sync.RWMutex
peerBlockWiseTranferEnabled atomic.Bool

receivedMessageReader *client.ReceivedMessageReader[*Conn]
Expand Down Expand Up @@ -267,6 +272,12 @@ func (cc *Conn) Run() (err error) {
return cc.session.Run(cc)
}

func (cc *Conn) SetTCPSignalReceivedHandler(handler TCPSignalReceivedHandler) {
cc.handlerMutex.Lock()
defer cc.handlerMutex.Unlock()
cc.tcpSignalReceivedHandler = handler
}

// AddOnClose calls function on close connection event.
func (cc *Conn) AddOnClose(f EventFunc) {
cc.session.AddOnClose(f)
Expand Down Expand Up @@ -370,6 +381,14 @@ func (cc *Conn) sendPong(token message.Token) error {
return cc.Session().WriteMessage(req)
}

func (cc *Conn) handleTCPSignalReceived(code codes.Code) {
cc.handlerMutex.RLock()
defer cc.handlerMutex.RUnlock()
if cc.tcpSignalReceivedHandler != nil {
cc.tcpSignalReceivedHandler(code)
}
}

func (cc *Conn) handleSignals(r *pool.Message) bool {
switch r.Code() {
case codes.CSM:
Expand All @@ -382,6 +401,9 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if r.HasOption(message.TCPBlockWiseTransfer) {
cc.peerBlockWiseTranferEnabled.Store(true)
}

// signal CSM message is received.
cc.handleTCPSignalReceived(codes.CSM)
return true
case codes.Ping:
// if r.HasOption(message.TCPCustody) {
Expand All @@ -390,21 +412,29 @@ func (cc *Conn) handleSignals(r *pool.Message) bool {
if err := cc.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) {
cc.Session().errors(fmt.Errorf("cannot handle ping signal: %w", err))
}

cc.handleTCPSignalReceived(codes.Ping)
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Pong)
return true
case codes.Release:
// if r.HasOption(message.TCPAlternativeAddress) {
// TODO
// }

cc.handleTCPSignalReceived(codes.Release)
return true
case codes.Abort:
// if r.HasOption(message.TCPBadCSMOption) {
// TODO
// }
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTCPSignalReceived(codes.Abort)
return true
}
return false
Expand Down
3 changes: 2 additions & 1 deletion tcp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ func TestServerKeepAliveMonitor(t *testing.T) {
assert.NoError(t, errS)
}()

cc, err := net.Dial("tcp", ld.Addr().String())
dialer := net.Dialer{}
cc, err := dialer.DialContext(context.Background(), "tcp", ld.Addr().String())
require.NoError(t, err)
defer func() {
_ = cc.Close()
Expand Down
Loading