Skip to content

Draft: Fix race condition during CSM message exchange #611

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion options/tcpOptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package options

import (
"crypto/tls"
"time"

tcpClient "github.com/plgd-dev/go-coap/v3/tcp/client"
tcpServer "github.com/plgd-dev/go-coap/v3/tcp/server"
Expand Down Expand Up @@ -34,11 +35,24 @@ func (o DisableTCPSignalMessageCSMOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.DisableTCPSignalMessageCSM = true
}

// WithDisableTCPSignalMessageCSM don't send CSM when client conn is created.
func WithDisableTCPSignalMessageCSM() DisableTCPSignalMessageCSMOpt {
return DisableTCPSignalMessageCSMOpt{}
}

type CSMExchangeTimeoutOpt struct {
timeout time.Duration
}

func (o CSMExchangeTimeoutOpt) TCPClientApply(cfg *tcpClient.Config) {
cfg.CSMExchangeTimeout = o.timeout
}

func WithCSMExchangeTimeout(timeout time.Duration) CSMExchangeTimeoutOpt {
return CSMExchangeTimeoutOpt{
timeout: timeout,
}
}

// TLSOpt tls configuration option.
type TLSOpt struct {
tlsCfg *tls.Config
Expand Down
21 changes: 21 additions & 0 deletions tcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"time"

"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
"github.com/plgd-dev/go-coap/v3/message/pool"
coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/net/blockwise"
Expand All @@ -30,7 +31,7 @@
var conn net.Conn
var err error
if cfg.TLSCfg != nil {
conn, err = tls.DialWithDialer(cfg.Dialer, cfg.Net, target, cfg.TLSCfg)

Check failure on line 34 in tcp/client.go

View workflow job for this annotation

GitHub Actions / lint

crypto/tls.DialWithDialer must not be called. use (*crypto/tls.Dialer).DialContext with NetDialer (noctx)
} else {
conn, err = cfg.Dialer.DialContext(cfg.Ctx, cfg.Net, target)
}
Expand Down Expand Up @@ -100,12 +101,32 @@
return cc.Context().Err() == nil
})

var csmExchangeDone chan struct{}
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
csmExchangeDone = make(chan struct{})

cc.SetTCPSignalReceivedHandler(func(code codes.Code) {
if code == codes.CSM {
close(csmExchangeDone)
}
})
}

go func() {
err := cc.Run()
if err != nil {
cfg.Errors(fmt.Errorf("%v: %w", cc.RemoteAddr(), err))
}
}()

// if CSM messages are enabled, wait for the CSM messages to be exchanged
if cfg.CSMExchangeTimeout != 0 && !cfg.DisablePeerTCPSignalMessageCSMs {
select {
case <-time.After(cfg.CSMExchangeTimeout):
cfg.Errors(fmt.Errorf("%v: timeout waiting for tcp signal csm exchange", cc.RemoteAddr()))
case <-csmExchangeDone:
}
}

return cc
}
1 change: 1 addition & 0 deletions tcp/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ type Config struct {
DisablePeerTCPSignalMessageCSMs bool
CloseSocket bool
DisableTCPSignalMessageCSM bool
CSMExchangeTimeout time.Duration
}
40 changes: 35 additions & 5 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"errors"
"fmt"
"net"
"sync"
"time"

"github.com/plgd-dev/go-coap/v3/message"
Expand All @@ -27,6 +28,8 @@
CheckInactivity(now time.Time, cc *Conn)
}

type TCPSignalReceivedHandler func(codes.Code)

type (
HandlerFunc = func(*responsewriter.ResponseWriter[*Conn], *pool.Message)
ErrorFunc = func(error)
Expand All @@ -51,6 +54,8 @@
blockwiseSZX blockwise.SZX
peerMaxMessageSize atomic.Uint32
disablePeerTCPSignalMessageCSMs bool
tcpSignalReceivedHandler TCPSignalReceivedHandler
handlerMutex sync.RWMutex
peerBlockWiseTranferEnabled atomic.Bool

receivedMessageReader *client.ReceivedMessageReader[*Conn]
Expand Down Expand Up @@ -267,6 +272,12 @@
return cc.session.Run(cc)
}

func (cc *Conn) SetTCPSignalReceivedHandler(handler TCPSignalReceivedHandler) {
cc.handlerMutex.Lock()
defer cc.handlerMutex.Unlock()
cc.tcpSignalReceivedHandler = handler
}

// AddOnClose calls function on close connection event.
func (cc *Conn) AddOnClose(f EventFunc) {
cc.session.AddOnClose(f)
Expand Down Expand Up @@ -370,6 +381,14 @@
return cc.Session().WriteMessage(req)
}

func (cc *Conn) handleTcpSignalReceived(code codes.Code) {

Check failure on line 384 in tcp/client/conn.go

View workflow job for this annotation

GitHub Actions / lint

var-naming: method handleTcpSignalReceived should be handleTCPSignalReceived (revive)
cc.handlerMutex.RLock()
defer cc.handlerMutex.RUnlock()
if cc.tcpSignalReceivedHandler != nil {
cc.tcpSignalReceivedHandler(code)
}
}

func (cc *Conn) handleSignals(r *pool.Message) bool {
switch r.Code() {
case codes.CSM:
Expand All @@ -382,6 +401,9 @@
if r.HasOption(message.TCPBlockWiseTransfer) {
cc.peerBlockWiseTranferEnabled.Store(true)
}

// signal CSM message is received.
cc.handleTcpSignalReceived(codes.CSM)
return true
case codes.Ping:
// if r.HasOption(message.TCPCustody) {
Expand All @@ -390,21 +412,29 @@
if err := cc.sendPong(r.Token()); err != nil && !coapNet.IsConnectionBrokenError(err) {
cc.Session().errors(fmt.Errorf("cannot handle ping signal: %w", err))
}

cc.handleTcpSignalReceived(codes.Ping)
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTcpSignalReceived(codes.Pong)
return true
case codes.Release:
// if r.HasOption(message.TCPAlternativeAddress) {
// TODO
// }

cc.handleTcpSignalReceived(codes.Release)
return true
case codes.Abort:
// if r.HasOption(message.TCPBadCSMOption) {
// TODO
// }
return true
case codes.Pong:
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
cc.processReceivedMessage(r, cc, h)
}

cc.handleTcpSignalReceived(codes.Abort)
return true
}
return false
Expand Down
Loading