Skip to content

Commit 494647d

Browse files
authored
Merge pull request #175 from plgd-dev/feature/dtlsPeerCertificate
feature/dtls peer certificate
2 parents 9937654 + cf2b67e commit 494647d

22 files changed

+728
-28
lines changed

dtls/options.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ func (o OnNewClientConnOpt) apply(opts *serverOptions) {
188188
}
189189

190190
// WithOnNewClientConn server's notify about new client connection.
191+
//
192+
// Note: Calling `dtlsConn.Close()` is forbidden, and `dtlsConn` should be treated as a
193+
// "read-only" parameter, mainly used to get the peer certificate from the underlining connection
191194
func WithOnNewClientConn(onNewClientConn OnNewClientConnFunc) OnNewClientConnOpt {
192195
return OnNewClientConnOpt{
193196
onNewClientConn: onNewClientConn,

dtls/server.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"sync"
88
"time"
99

10+
"github.com/pion/dtls/v2"
1011
"github.com/plgd-dev/go-coap/v2/message"
1112
"github.com/plgd-dev/go-coap/v2/net/blockwise"
1213
kitSync "github.com/plgd-dev/kit/sync"
@@ -36,7 +37,11 @@ type GoPoolFunc = func(func()) error
3637

3738
type BlockwiseFactoryFunc = func(getSendedRequest func(token message.Token) (blockwise.Message, bool)) *blockwise.BlockWise
3839

39-
type OnNewClientConnFunc = func(cc *client.ClientConn)
40+
// OnNewClientConnFunc is the callback for new connections.
41+
//
42+
// Note: Calling `dtlsConn.Close()` is forbidden, and `dtlsConn` should be treated as a
43+
// "read-only" parameter, mainly used to get the peer certificate from the underlining connection
44+
type OnNewClientConnFunc = func(cc *client.ClientConn, dtlsConn *dtls.Conn)
4045

4146
type GetMIDFunc = func() uint16
4247

@@ -59,7 +64,7 @@ var defaultServerOptions = serverOptions{
5964
blockwiseEnable: true,
6065
blockwiseSZX: blockwise.SZX1024,
6166
blockwiseTransferTimeout: time.Second * 5,
62-
onNewClientConn: func(cc *client.ClientConn) {},
67+
onNewClientConn: func(cc *client.ClientConn, dtlsConn *dtls.Conn) {},
6368
heartBeat: time.Millisecond * 100,
6469
transmissionNStart: time.Second,
6570
transmissionAcknowledgeTimeout: time.Second * 2,
@@ -210,7 +215,8 @@ func (s *Server) Serve(l Listener) error {
210215
wg.Add(1)
211216
cc := s.createClientConn(coapNet.NewConn(rw, coapNet.WithHeartBeat(s.heartBeat)))
212217
if s.onNewClientConn != nil {
213-
s.onNewClientConn(cc)
218+
dtlsConn := rw.(*dtls.Conn)
219+
s.onNewClientConn(cc, dtlsConn)
214220
}
215221
go func() {
216222
defer wg.Done()

dtls/server_test.go

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
package dtls_test
22

33
import (
4+
"bytes"
45
"context"
6+
"crypto/tls"
7+
"crypto/x509"
58
"fmt"
9+
"math/big"
610
"sync"
711
"testing"
812
"time"
913

1014
piondtls "github.com/pion/dtls/v2"
1115
"github.com/plgd-dev/go-coap/v2/dtls"
16+
"github.com/plgd-dev/go-coap/v2/examples/dtls/pki"
17+
"github.com/plgd-dev/go-coap/v2/message"
18+
"github.com/plgd-dev/go-coap/v2/message/codes"
1219
coapNet "github.com/plgd-dev/go-coap/v2/net"
1320
"github.com/plgd-dev/go-coap/v2/udp/client"
21+
"github.com/plgd-dev/go-coap/v2/udp/message/pool"
1422
"github.com/stretchr/testify/require"
1523
)
1624

@@ -29,7 +37,7 @@ func TestServer_CleanUpConns(t *testing.T) {
2937

3038
var checkCloseWg sync.WaitGroup
3139
defer checkCloseWg.Wait()
32-
sd := dtls.NewServer(dtls.WithOnNewClientConn(func(cc *client.ClientConn) {
40+
sd := dtls.NewServer(dtls.WithOnNewClientConn(func(cc *client.ClientConn, dtlsConn *piondtls.Conn) {
3341
checkCloseWg.Add(1)
3442
cc.AddOnClose(func() {
3543
checkCloseWg.Done()
@@ -57,3 +65,98 @@ func TestServer_CleanUpConns(t *testing.T) {
5765
err = cc.Ping(ctx)
5866
require.NoError(t, err)
5967
}
68+
69+
func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clientConfig *piondtls.Config, clientSerial *big.Int, err error) {
70+
// root cert
71+
ca, rootBytes, _, caPriv, err := pki.GenerateCA()
72+
if err != nil {
73+
return
74+
}
75+
// server cert
76+
certBytes, keyBytes, err := pki.GenerateCertificate(ca, caPriv, "server@test.com")
77+
if err != nil {
78+
return
79+
}
80+
certificate, err := pki.LoadKeyAndCertificate(keyBytes, certBytes)
81+
if err != nil {
82+
return
83+
}
84+
// cert pool
85+
certPool, err := pki.LoadCertPool(rootBytes)
86+
if err != nil {
87+
return
88+
}
89+
90+
serverConfig = &piondtls.Config{
91+
Certificates: []tls.Certificate{*certificate},
92+
ExtendedMasterSecret: piondtls.RequireExtendedMasterSecret,
93+
ClientCAs: certPool,
94+
ClientAuth: piondtls.RequireAndVerifyClientCert,
95+
ConnectContextMaker: func() (context.Context, func()) {
96+
return context.WithTimeout(ctx, 30*time.Second)
97+
},
98+
}
99+
100+
// client cert
101+
certBytes, keyBytes, err = pki.GenerateCertificate(ca, caPriv, "client@test.com")
102+
if err != nil {
103+
return
104+
}
105+
certificate, err = pki.LoadKeyAndCertificate(keyBytes, certBytes)
106+
if err != nil {
107+
return
108+
}
109+
clientInfo, err := x509.ParseCertificate(certificate.Certificate[0])
110+
if err != nil {
111+
return
112+
}
113+
clientSerial = clientInfo.SerialNumber
114+
115+
clientConfig = &piondtls.Config{
116+
Certificates: []tls.Certificate{*certificate},
117+
ExtendedMasterSecret: piondtls.RequireExtendedMasterSecret,
118+
RootCAs: certPool,
119+
InsecureSkipVerify: true,
120+
}
121+
122+
return
123+
}
124+
125+
func TestServer_SetContextValueWithPKI(t *testing.T) {
126+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
127+
defer cancel()
128+
serverCgf, clientCgf, clientSerial, err := createDTLSConfig(ctx)
129+
require.NoError(t, err)
130+
131+
ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
132+
require.NoError(t, err)
133+
defer ld.Close()
134+
135+
onNewConn := func(cc *client.ClientConn, dtlsConn *piondtls.Conn) {
136+
// set connection context certificate
137+
clientCert, err := x509.ParseCertificate(dtlsConn.ConnectionState().PeerCertificates[0])
138+
require.NoError(t, err)
139+
cc.Session().SetContextValue("client-cert", clientCert)
140+
}
141+
handle := func(w *client.ResponseWriter, r *pool.Message) {
142+
// get certificate from connection context
143+
clientCert := r.Context().Value("client-cert").(*x509.Certificate)
144+
require.Equal(t, clientCert.SerialNumber, clientSerial)
145+
require.NotNil(t, clientCert)
146+
w.SetResponse(codes.Content, message.TextPlain, bytes.NewReader([]byte("done")))
147+
}
148+
149+
sd := dtls.NewServer(dtls.WithHandlerFunc(handle), dtls.WithOnNewClientConn(onNewConn))
150+
defer sd.Stop()
151+
go func() {
152+
err := sd.Serve(ld)
153+
require.NoError(t, err)
154+
}()
155+
156+
cc, err := dtls.Dial(ld.Addr().String(), clientCgf)
157+
require.NoError(t, err)
158+
defer cc.Close()
159+
160+
_, err = cc.Get(ctx, "/")
161+
require.NoError(t, err)
162+
}

dtls/session.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"net"
77
"sync"
8+
"sync/atomic"
89

910
coapNet "github.com/plgd-dev/go-coap/v2/net"
1011
"github.com/plgd-dev/go-coap/v2/udp/client"
@@ -22,7 +23,7 @@ type Session struct {
2223
onClose []EventFunc
2324

2425
cancel context.CancelFunc
25-
ctx context.Context
26+
ctx atomic.Value
2627
}
2728

2829
func NewSession(
@@ -32,17 +33,18 @@ func NewSession(
3233
closeSocket bool,
3334
) *Session {
3435
ctx, cancel := context.WithCancel(ctx)
35-
return &Session{
36-
ctx: ctx,
36+
s := &Session{
3737
cancel: cancel,
3838
connection: connection,
3939
maxMessageSize: maxMessageSize,
4040
closeSocket: closeSocket,
4141
}
42+
s.ctx.Store(&ctx)
43+
return s
4244
}
4345

4446
func (s *Session) Done() <-chan struct{} {
45-
return s.ctx.Done()
47+
return s.Context().Done()
4648
}
4749

4850
func (s *Session) AddOnClose(f EventFunc) {
@@ -75,7 +77,14 @@ func (s *Session) Close() error {
7577
}
7678

7779
func (s *Session) Context() context.Context {
78-
return s.ctx
80+
return *s.ctx.Load().(*context.Context)
81+
}
82+
83+
func (s *Session) SetContextValue(key interface{}, val interface{}) {
84+
s.mutex.Lock()
85+
defer s.mutex.Unlock()
86+
ctx := context.WithValue(s.Context(), key, val)
87+
s.ctx.Store(&ctx)
7988
}
8089

8190
func (s *Session) WriteMessage(req *pool.Message) error {
@@ -109,7 +118,7 @@ func (s *Session) Run(cc *client.ClientConn) (err error) {
109118
m := make([]byte, s.maxMessageSize)
110119
for {
111120
readBuf := m
112-
readLen, err := s.connection.ReadWithContext(s.ctx, readBuf)
121+
readLen, err := s.connection.ReadWithContext(s.Context(), readBuf)
113122
if err != nil {
114123
return err
115124
}

examples/dtls/pki/cert_gen.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package pki
2+
3+
import (
4+
"bytes"
5+
"crypto/ecdsa"
6+
"crypto/elliptic"
7+
"crypto/rand"
8+
"crypto/x509"
9+
"crypto/x509/pkix"
10+
"encoding/pem"
11+
"io"
12+
"math/big"
13+
"net"
14+
"time"
15+
)
16+
17+
var (
18+
algo = elliptic.P256()
19+
notBefore = time.Now()
20+
notAfter = notBefore.Add(time.Hour)
21+
subject = pkix.Name{
22+
Country: []string{"BR"},
23+
Province: []string{"Parana"},
24+
Locality: []string{"Curitiba"},
25+
Organization: []string{"Test"},
26+
CommonName: "test.com",
27+
}
28+
)
29+
30+
func sequentialBytes(n int) io.Reader {
31+
sequence := make([]byte, n)
32+
for i := 0; i < n; i++ {
33+
sequence[i] = byte(i)
34+
}
35+
return bytes.NewReader(sequence)
36+
}
37+
38+
// GenerateCA creates a deterministic certificate authority (for test purposes only)
39+
func GenerateCA() (ca *x509.Certificate, cert, key []byte, priv *ecdsa.PrivateKey, err error) {
40+
priv, err = ecdsa.GenerateKey(algo, sequentialBytes(64))
41+
if err != nil {
42+
return
43+
}
44+
45+
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
46+
serialNumber, err := rand.Int(sequentialBytes(128), serialNumberLimit)
47+
48+
ca = &x509.Certificate{
49+
NotBefore: notBefore,
50+
NotAfter: notAfter,
51+
SerialNumber: serialNumber,
52+
53+
Subject: subject,
54+
EmailAddresses: []string{"ca@test.com"},
55+
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
56+
57+
IsCA: true,
58+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
59+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
60+
BasicConstraintsValid: true,
61+
}
62+
63+
derBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &priv.PublicKey, priv)
64+
if err != nil {
65+
return
66+
}
67+
cert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
68+
69+
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
70+
if err != nil {
71+
return
72+
}
73+
key = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
74+
75+
return
76+
}
77+
78+
// GenerateCertificate creates a certificate
79+
func GenerateCertificate(ca *x509.Certificate, caPriv *ecdsa.PrivateKey, email string) (cert, key []byte, err error) {
80+
priv, err := ecdsa.GenerateKey(algo, rand.Reader)
81+
if err != nil {
82+
return
83+
}
84+
85+
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
86+
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
87+
if err != nil {
88+
return
89+
}
90+
91+
template := x509.Certificate{
92+
NotBefore: notBefore,
93+
NotAfter: notAfter,
94+
SerialNumber: serialNumber,
95+
96+
Subject: subject,
97+
EmailAddresses: []string{email},
98+
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
99+
100+
SubjectKeyId: []byte{1, 2, 3, 4, 6},
101+
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
102+
KeyUsage: x509.KeyUsageDigitalSignature,
103+
}
104+
105+
derBytes, err := x509.CreateCertificate(rand.Reader, &template, ca, &priv.PublicKey, caPriv)
106+
if err != nil {
107+
return
108+
}
109+
110+
cert = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
111+
112+
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
113+
if err != nil {
114+
return
115+
}
116+
key = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes})
117+
118+
return
119+
}

examples/dtls/pki/cert_gen_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package pki
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/require"
7+
)
8+
9+
func TestGenerateCA(t *testing.T) {
10+
ca, cert, key, caPriv, err := GenerateCA()
11+
require.NoError(t, err)
12+
require.Contains(t, string(cert), "-----BEGIN CERTIFICATE-----")
13+
require.Contains(t, string(key), "-----BEGIN EC PRIVATE KEY-----")
14+
15+
cert, key, err = GenerateCertificate(ca, caPriv, "cert@test.com")
16+
require.NoError(t, err)
17+
require.Contains(t, string(cert), "-----BEGIN CERTIFICATE-----")
18+
require.Contains(t, string(key), "-----BEGIN EC PRIVATE KEY-----")
19+
}

0 commit comments

Comments
 (0)