Skip to content

Commit cf2b67e

Browse files
committed
store context.Context to atomic.Value
1 parent 657685a commit cf2b67e

File tree

7 files changed

+41
-24
lines changed

7 files changed

+41
-24
lines changed

dtls/server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clien
123123
}
124124

125125
func TestServer_SetContextValueWithPKI(t *testing.T) {
126-
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
126+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
127127
defer cancel()
128128
serverCgf, clientCgf, clientSerial, err := createDTLSConfig(ctx)
129129
require.NoError(t, err)

dtls/session.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net"
77
"sync"
8+
"sync/atomic"
89

910
coapNet "github.com/plgd-dev/go-coap/v2/net"
1011
"github.com/plgd-dev/go-coap/v2/udp/client"
@@ -22,7 +23,7 @@ type Session struct {
2223
onClose []EventFunc
2324

2425
cancel context.CancelFunc
25-
ctx context.Context
26+
ctx atomic.Value
2627
}
2728

2829
func NewSession(
@@ -32,17 +33,18 @@ func NewSession(
3233
closeSocket bool,
3334
) *Session {
3435
ctx, cancel := context.WithCancel(ctx)
35-
return &Session{
36-
ctx: ctx,
36+
s := &Session{
3737
cancel: cancel,
3838
connection: connection,
3939
maxMessageSize: maxMessageSize,
4040
closeSocket: closeSocket,
4141
}
42+
s.ctx.Store(&ctx)
43+
return s
4244
}
4345

4446
func (s *Session) Done() <-chan struct{} {
45-
return s.ctx.Done()
47+
return s.Context().Done()
4648
}
4749

4850
func (s *Session) AddOnClose(f EventFunc) {
@@ -75,13 +77,14 @@ func (s *Session) Close() error {
7577
}
7678

7779
func (s *Session) Context() context.Context {
78-
return s.ctx
80+
return *s.ctx.Load().(*context.Context)
7981
}
8082

8183
func (s *Session) SetContextValue(key interface{}, val interface{}) {
8284
s.mutex.Lock()
8385
defer s.mutex.Unlock()
84-
s.ctx = context.WithValue(s.ctx, key, val)
86+
ctx := context.WithValue(s.Context(), key, val)
87+
s.ctx.Store(&ctx)
8588
}
8689

8790
func (s *Session) WriteMessage(req *pool.Message) error {
@@ -115,7 +118,7 @@ func (s *Session) Run(cc *client.ClientConn) (err error) {
115118
m := make([]byte, s.maxMessageSize)
116119
for {
117120
readBuf := m
118-
readLen, err := s.connection.ReadWithContext(s.ctx, readBuf)
121+
readLen, err := s.connection.ReadWithContext(s.Context(), readBuf)
119122
if err != nil {
120123
return err
121124
}

mux/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Client interface {
2323

2424
RemoteAddr() net.Addr
2525
Context() context.Context
26+
SetContextValue(key interface{}, val interface{})
2627
WriteMessage(req *message.Message) error
2728
Do(req *message.Message) (*message.Message, error)
2829
Close() error

tcp/client.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ func (c *ClientTCP) Context() context.Context {
7272
return c.cc.Context()
7373
}
7474

75+
func (c *ClientTCP) SetContextValue(key interface{}, val interface{}) {
76+
c.cc.Session().SetContextValue(key, val)
77+
}
78+
7579
func (c *ClientTCP) WriteMessage(req *message.Message) error {
7680
r, err := pool.ConvertFrom(req)
7781
if err != nil {

tcp/session.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ type Session struct {
4141
onClose []EventFunc
4242

4343
cancel context.CancelFunc
44-
ctx context.Context
44+
ctx atomic.Value
4545

4646
errSendCSM error
4747
}
@@ -65,7 +65,6 @@ func NewSession(
6565
}
6666

6767
s := &Session{
68-
ctx: ctx,
6968
cancel: cancel,
7069
connection: connection,
7170
handler: handler,
@@ -80,6 +79,8 @@ func NewSession(
8079
disableTCPSignalMessageCSM: disableTCPSignalMessageCSM,
8180
closeSocket: closeSocket,
8281
}
82+
s.ctx.Store(&ctx)
83+
8384
if !disableTCPSignalMessageCSM {
8485
err := s.sendCSM()
8586
if err != nil {
@@ -93,11 +94,12 @@ func NewSession(
9394
func (s *Session) SetContextValue(key interface{}, val interface{}) {
9495
s.mutex.Lock()
9596
defer s.mutex.Unlock()
96-
s.ctx = context.WithValue(s.ctx, key, val)
97+
ctx := context.WithValue(s.Context(), key, val)
98+
s.ctx.Store(&ctx)
9799
}
98100

99101
func (s *Session) Done() <-chan struct{} {
100-
return s.ctx.Done()
102+
return s.Context().Done()
101103
}
102104

103105
func (s *Session) AddOnClose(f EventFunc) {
@@ -134,7 +136,7 @@ func (s *Session) Sequence() uint64 {
134136
}
135137

136138
func (s *Session) Context() context.Context {
137-
return s.ctx
139+
return *s.ctx.Load().(*context.Context)
138140
}
139141

140142
func (s *Session) PeerMaxMessageSize() uint32 {
@@ -245,15 +247,15 @@ func (s *Session) processBuffer(buffer *bytes.Buffer, cc *ClientConn) error {
245247
if n != hdr.TotalLen {
246248
return fmt.Errorf("invalid data: %w", err)
247249
}
248-
req := pool.AcquireMessage(s.ctx)
250+
req := pool.AcquireMessage(s.Context())
249251
_, err = req.Unmarshal(msgRaw)
250252
if err != nil {
251253
pool.ReleaseMessage(req)
252254
return fmt.Errorf("cannot unmarshal with header: %w", err)
253255
}
254256
req.SetSequence(s.Sequence())
255257
s.goPool(func() {
256-
origResp := pool.AcquireMessage(s.ctx)
258+
origResp := pool.AcquireMessage(s.Context())
257259
origResp.SetToken(req.Token())
258260
w := NewResponseWriter(origResp, cc, req.Options())
259261
s.Handle(w, req)
@@ -286,7 +288,7 @@ func (s *Session) sendCSM() error {
286288
if err != nil {
287289
return fmt.Errorf("cannot get token: %w", err)
288290
}
289-
req := pool.AcquireMessage(s.ctx)
291+
req := pool.AcquireMessage(s.Context())
290292
defer pool.ReleaseMessage(req)
291293
req.SetCode(codes.CSM)
292294
req.SetToken(token)
@@ -319,7 +321,7 @@ func (s *Session) Run(cc *ClientConn) (err error) {
319321
if err != nil {
320322
return err
321323
}
322-
readLen, err := s.connection.ReadWithContext(s.ctx, readBuf)
324+
readLen, err := s.connection.ReadWithContext(s.Context(), readBuf)
323325
if err != nil {
324326
return err
325327
}

udp/client/client.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,10 @@ func (c *Client) Context() context.Context {
7272
return c.cc.Context()
7373
}
7474

75+
func (c *Client) SetContextValue(key interface{}, val interface{}) {
76+
c.cc.Session().SetContextValue(key, val)
77+
}
78+
7579
func (c *Client) WriteMessage(req *message.Message) error {
7680
r, err := pool.ConvertFrom(req)
7781
if err != nil {

udp/session.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net"
77
"sync"
8+
"sync/atomic"
89

910
coapNet "github.com/plgd-dev/go-coap/v2/net"
1011
"github.com/plgd-dev/go-coap/v2/udp/client"
@@ -23,7 +24,7 @@ type Session struct {
2324
onClose []EventFunc
2425

2526
cancel context.CancelFunc
26-
ctx context.Context
27+
ctx atomic.Value
2728
}
2829

2930
func NewSession(
@@ -34,24 +35,26 @@ func NewSession(
3435
closeSocket bool,
3536
) *Session {
3637
ctx, cancel := context.WithCancel(ctx)
37-
return &Session{
38-
ctx: ctx,
38+
s := &Session{
3939
cancel: cancel,
4040
connection: connection,
4141
raddr: raddr,
4242
maxMessageSize: maxMessageSize,
4343
closeSocket: closeSocket,
4444
}
45+
s.ctx.Store(&ctx)
46+
return s
4547
}
4648

4749
func (s *Session) SetContextValue(key interface{}, val interface{}) {
4850
s.mutex.Lock()
4951
defer s.mutex.Unlock()
50-
s.ctx = context.WithValue(s.ctx, key, val)
52+
ctx := context.WithValue(s.Context(), key, val)
53+
s.ctx.Store(&ctx)
5154
}
5255

5356
func (s *Session) Done() <-chan struct{} {
54-
return s.ctx.Done()
57+
return s.Context().Done()
5558
}
5659

5760
func (s *Session) AddOnClose(f EventFunc) {
@@ -84,7 +87,7 @@ func (s *Session) Close() error {
8487
}
8588

8689
func (s *Session) Context() context.Context {
87-
return s.ctx
90+
return *s.ctx.Load().(*context.Context)
8891
}
8992

9093
func (s *Session) WriteMessage(req *pool.Message) error {
@@ -109,7 +112,7 @@ func (s *Session) Run(cc *client.ClientConn) (err error) {
109112
m := make([]byte, s.maxMessageSize)
110113
for {
111114
buf := m
112-
n, _, err := s.connection.ReadWithContext(s.ctx, buf)
115+
n, _, err := s.connection.ReadWithContext(s.Context(), buf)
113116
if err != nil {
114117
return err
115118
}

0 commit comments

Comments
 (0)