Skip to content

Commit bd8e59c

Browse files
authored
[4.4] Fix panic when calling driver.Close concurrently (#627)
* Fix panic when calling driver.Close concurrently Doing so might well still yield unexpected or undesired results, but at least it shouldn't cause a panic. * TestKit: close resources on disconnect & catch backend panics * fixup! TestKit: close resources on disconnect & catch backend panics * Refactor to not change a public interface
1 parent 752180e commit bd8e59c

File tree

19 files changed

+181
-76
lines changed

19 files changed

+181
-76
lines changed

neo4j/driver.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ type Driver interface {
6060
// or error describing the problem.
6161
VerifyConnectivity() error
6262
// Close the driver and all underlying connections
63+
// This function may not be called while the driver is in use (i.e., concurrently).
6364
Close() error
6465
}
6566

neo4j/internal/bolt/bolt3.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ func (b *bolt3) ServerName() string {
119119
return b.serverName
120120
}
121121

122+
func (b *bolt3) ConnId() string {
123+
return b.connId
124+
}
125+
122126
func (b *bolt3) ServerVersion() string {
123127
return b.serverVersion
124128
}

neo4j/internal/bolt/bolt4.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ func (b *bolt4) ServerName() string {
142142
return b.serverName
143143
}
144144

145+
func (b *bolt4) ConnId() string {
146+
return b.connId
147+
}
148+
145149
func (b *bolt4) ServerVersion() string {
146150
return b.serverVersion
147151
}

neo4j/internal/bolt/connect.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
"io"
2727
"net"
2828

29-
"github.com/neo4j/neo4j-go-driver/v4/neo4j/db"
29+
idb "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/db"
3030
"github.com/neo4j/neo4j-go-driver/v4/neo4j/log"
3131
)
3232

@@ -46,7 +46,7 @@ var versions = [4]protocolVersion{
4646

4747
// Connect initiates the negotiation of the Bolt protocol version.
4848
// Returns the instance of bolt protocol implementing the low-level Connection interface.
49-
func Connect(serverName string, conn net.Conn, auth map[string]interface{}, userAgent string, routingContext map[string]string, logger log.Logger, boltLog log.BoltLogger) (db.Connection, error) {
49+
func Connect(serverName string, conn net.Conn, auth map[string]interface{}, userAgent string, routingContext map[string]string, logger log.Logger, boltLog log.BoltLogger) (idb.Connection, error) {
5050
// Perform Bolt handshake to negotiate version
5151
// Send handshake to server
5252
handshake := []byte{

neo4j/internal/connector/connector.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ import (
2828
"net"
2929
"time"
3030

31-
"github.com/neo4j/neo4j-go-driver/v4/neo4j/db"
3231
"github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/bolt"
32+
idb "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/db"
3333
"github.com/neo4j/neo4j-go-driver/v4/neo4j/log"
3434
)
3535

@@ -63,7 +63,7 @@ func (e *TlsError) Error() string {
6363
return e.inner.Error()
6464
}
6565

66-
func (c Connector) Connect(address string, boltLogger log.BoltLogger) (db.Connection, error) {
66+
func (c Connector) Connect(address string, boltLogger log.BoltLogger) (idb.Connection, error) {
6767
if c.SupplyConnection == nil {
6868
c.SupplyConnection = c.createConnection
6969
}

neo4j/internal/db/connection.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
*/
19+
20+
// Package db defines generic database functionality.
21+
package db
22+
23+
import (
24+
"github.com/neo4j/neo4j-go-driver/v4/neo4j/db"
25+
)
26+
27+
// Connection defines an abstract database server connection.
28+
type Connection interface {
29+
db.Connection
30+
// ConnId returns the connection id as assigned by the server ("" if not available)
31+
ConnId() string
32+
}

neo4j/internal/pool/pool.go

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ import (
3131
"sync"
3232
"time"
3333

34-
"github.com/neo4j/neo4j-go-driver/v4/neo4j/db"
34+
idb "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/db"
3535
"github.com/neo4j/neo4j-go-driver/v4/neo4j/log"
3636
)
3737

38-
type Connect func(string, log.BoltLogger) (db.Connection, error)
38+
type Connect func(string, log.BoltLogger) (idb.Connection, error)
3939

4040
type qitem struct {
4141
servers []string
4242
wakeup chan bool
43-
conn db.Connection
43+
conn idb.Connection
4444
}
4545

4646
type Pool struct {
@@ -89,12 +89,22 @@ func (p *Pool) Close() {
8989
p.queueMut.Unlock()
9090
// Go through each server and close all connections to it
9191
p.serversMut.Lock()
92-
for n, s := range p.servers {
93-
s.closeAll()
94-
delete(p.servers, n)
92+
pendingConnections := 0
93+
for _, s := range p.servers {
94+
s.startClosing()
95+
pendingConnections += s.size()
9596
}
9697
p.serversMut.Unlock()
97-
p.log.Infof(log.Pool, p.logId, "Closed")
98+
if pendingConnections == 0 {
99+
p.log.Infof(log.Pool, p.logId, "Closed")
100+
} else {
101+
p.log.Warnf(
102+
log.Pool,
103+
p.logId,
104+
"Called close with %d in-flight connections (will be closed when work is done).",
105+
pendingConnections,
106+
)
107+
}
98108
}
99109

100110
func (p *Pool) anyExistingConnectionsOnServers(serverNames []string) bool {
@@ -145,7 +155,7 @@ func (p *Pool) CleanUp() {
145155
}
146156
}
147157

148-
func (p *Pool) tryBorrow(serverName string, boltLogger log.BoltLogger) (db.Connection, error) {
158+
func (p *Pool) tryBorrow(serverName string, boltLogger log.BoltLogger) (idb.Connection, error) {
149159
// For now, lock complete servers map to avoid over connecting but with the downside
150160
// that long connect times will block connects to other servers as well. To fix this
151161
// we would need to add a pending connect to the server and lock per server.
@@ -205,7 +215,7 @@ func (p *Pool) getPenaltiesForServers(serverNames []string) []serverPenalty {
205215
return penalties
206216
}
207217

208-
func (p *Pool) tryAnyIdle(serverNames []string) db.Connection {
218+
func (p *Pool) tryAnyIdle(serverNames []string) idb.Connection {
209219
p.serversMut.Lock()
210220
defer p.serversMut.Unlock()
211221
for _, serverName := range serverNames {
@@ -224,7 +234,7 @@ func (p *Pool) tryAnyIdle(serverNames []string) db.Connection {
224234
// Borrow tries to borrow an existing database connection or tries to create a new one
225235
// if none exists. The wait flag indicates if the caller wants to wait for a connection
226236
// to be returned if there aren't any idle connection available.
227-
func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, boltLogger log.BoltLogger) (db.Connection, error) {
237+
func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, boltLogger log.BoltLogger) (idb.Connection, error) {
228238
timeOut := func() bool {
229239
select {
230240
case <-ctx.Done():
@@ -248,7 +258,7 @@ func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, bolt
248258
})
249259

250260
var err error
251-
var conn db.Connection
261+
var conn idb.Connection
252262
for _, s := range penalties {
253263
conn, err = p.tryBorrow(s.name, boltLogger)
254264
if err == nil {
@@ -313,7 +323,7 @@ func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool, bolt
313323
}
314324
}
315325

316-
func (p *Pool) unreg(serverName string, c db.Connection, now time.Time) {
326+
func (p *Pool) unreg(serverName string, c idb.Connection, now time.Time) {
317327
p.serversMut.Lock()
318328
defer p.serversMut.Unlock()
319329

@@ -345,18 +355,24 @@ func (p *Pool) removeIdleOlderThanOnServer(serverName string, now time.Time, max
345355
server.removeIdleOlderThan(now, maxAge)
346356
}
347357

348-
func (p *Pool) Return(c db.Connection) {
358+
func (p *Pool) Return(c idb.Connection) {
349359
if p.closed {
350360
p.log.Warnf(log.Pool, p.logId, "Trying to return connection to closed pool")
351-
return
352361
}
353362

354363
c.SetBoltLogger(nil)
355364

356365
// Get the name of the server that the connection belongs to.
357366
serverName := c.ServerName()
358367
isAlive := c.IsAlive()
359-
p.log.Debugf(log.Pool, p.logId, "Returning connection to %s {alive:%t}", serverName, isAlive)
368+
p.log.Debugf(
369+
log.Pool,
370+
p.logId,
371+
"Returning connection %s to %s {alive:%t}",
372+
c.ConnId(),
373+
serverName,
374+
isAlive,
375+
)
360376

361377
// If the connection is dead, remove all other idle connections on the same server that older
362378
// or of the same age as the dead connection, otherwise perform normal cleanup of old connections
@@ -413,6 +429,9 @@ func (p *Pool) Return(c db.Connection) {
413429
server := p.servers[serverName]
414430
if server != nil { // Strange when server not found
415431
server.returnBusy(c)
432+
if server.closing && server.size() == 0 {
433+
delete(p.servers, serverName)
434+
}
416435
} else {
417436
p.log.Warnf(log.Pool, p.logId, "Server %s not found", serverName)
418437
}

neo4j/internal/pool/pool_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727
"testing"
2828
"time"
2929

30-
"github.com/neo4j/neo4j-go-driver/v4/neo4j/db"
30+
idb "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/db"
3131
"github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/testutil"
3232
"github.com/neo4j/neo4j-go-driver/v4/neo4j/log"
3333
)
@@ -39,12 +39,12 @@ func TestPoolBorrowReturn(ot *testing.T) {
3939
maxAge := 1 * time.Second
4040
birthdate := time.Now()
4141

42-
succeedingConnect := func(s string, _ log.BoltLogger) (db.Connection, error) {
42+
succeedingConnect := func(s string, _ log.BoltLogger) (idb.Connection, error) {
4343
return &testutil.ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
4444
}
4545

4646
failingError := errors.New("whatever")
47-
failingConnect := func(s string, _ log.BoltLogger) (db.Connection, error) {
47+
failingConnect := func(s string, _ log.BoltLogger) (idb.Connection, error) {
4848
return nil, failingError
4949
}
5050

@@ -198,7 +198,7 @@ func TestPoolResourceUsage(ot *testing.T) {
198198
maxAge := 1 * time.Second
199199
birthdate := time.Now()
200200

201-
succeedingConnect := func(s string, _ log.BoltLogger) (db.Connection, error) {
201+
succeedingConnect := func(s string, _ log.BoltLogger) (idb.Connection, error) {
202202
return &testutil.ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
203203
}
204204

@@ -305,12 +305,12 @@ func TestPoolResourceUsage(ot *testing.T) {
305305
func TestPoolCleanup(ot *testing.T) {
306306
birthdate := time.Now()
307307
maxLife := 1 * time.Second
308-
succeedingConnect := func(s string, _ log.BoltLogger) (db.Connection, error) {
308+
succeedingConnect := func(s string, _ log.BoltLogger) (idb.Connection, error) {
309309
return &testutil.ConnFake{Name: s, Alive: true, Birth: birthdate}, nil
310310
}
311311

312312
// Borrows a connection in server A and another in server B
313-
borrowConnections := func(t *testing.T, p *Pool) (db.Connection, db.Connection) {
313+
borrowConnections := func(t *testing.T, p *Pool) (idb.Connection, idb.Connection) {
314314
c1, err := p.Borrow(context.Background(), []string{"A"}, true, nil)
315315
assertConnection(t, c1, err)
316316
c2, err := p.Borrow(context.Background(), []string{"B"}, true, nil)
@@ -352,7 +352,7 @@ func TestPoolCleanup(ot *testing.T) {
352352
})
353353

354354
ot.Run("Should not remove servers with only idle connections but with recent connect failures ", func(t *testing.T) {
355-
failingConnect := func(s string, _ log.BoltLogger) (db.Connection, error) {
355+
failingConnect := func(s string, _ log.BoltLogger) (idb.Connection, error) {
356356
return nil, errors.New("an error")
357357
}
358358
p := New(0, maxLife, failingConnect, logger, "poolid")

neo4j/internal/pool/server.go

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import (
2424
"sync/atomic"
2525
"time"
2626

27-
"github.com/neo4j/neo4j-go-driver/v4/neo4j/db"
27+
idb "github.com/neo4j/neo4j-go-driver/v4/neo4j/internal/db"
2828
)
2929

3030
// Represents a server with a number of connections that either is in use (borrowed) or
@@ -35,14 +35,15 @@ type server struct {
3535
busy list.List
3636
failedConnectAt time.Time
3737
roundRobin uint32
38+
closing bool
3839
}
3940

4041
var sharedRoundRobin uint32
4142

4243
const rememberFailedConnectDuration = 3 * time.Minute
4344

4445
// Returns a idle connection if any
45-
func (s *server) getIdle() db.Connection {
46+
func (s *server) getIdle() idb.Connection {
4647
// Remove from idle list and add to busy list
4748
e := s.idle.Front()
4849
if e != nil {
@@ -51,7 +52,7 @@ func (s *server) getIdle() db.Connection {
5152
// Update round-robin counter every time we give away a connection and keep track
5253
// of our own round-robin index
5354
s.roundRobin = atomic.AddUint32(&sharedRoundRobin, 1)
54-
return c.(db.Connection)
55+
return c.(idb.Connection)
5556
}
5657
return nil
5758
}
@@ -102,9 +103,13 @@ func (s *server) calculatePenalty(now time.Time) uint32 {
102103
}
103104

104105
// Returns a busy connection, makes it idle
105-
func (s *server) returnBusy(c db.Connection) {
106+
func (s *server) returnBusy(c idb.Connection) {
106107
s.unregisterBusy(c)
107-
s.idle.PushFront(c)
108+
if s.closing {
109+
c.Close()
110+
} else {
111+
s.idle.PushFront(c)
112+
}
108113
}
109114

110115
// Number of idle connections
@@ -113,16 +118,16 @@ func (s server) numIdle() int {
113118
}
114119

115120
// Adds a connection to busy list
116-
func (s *server) registerBusy(c db.Connection) {
121+
func (s *server) registerBusy(c idb.Connection) {
117122
// Update round-robin to indicate when this server was last used.
118123
s.roundRobin = atomic.AddUint32(&sharedRoundRobin, 1)
119124
s.busy.PushFront(c)
120125
}
121126

122-
func (s *server) unregisterBusy(c db.Connection) {
127+
func (s *server) unregisterBusy(c idb.Connection) {
123128
found := false
124129
for e := s.busy.Front(); e != nil && !found; e = e.Next() {
125-
x := e.Value.(db.Connection)
130+
x := e.Value.(idb.Connection)
126131
found = x == c
127132
if found {
128133
s.busy.Remove(e)
@@ -139,7 +144,7 @@ func (s *server) removeIdleOlderThan(now time.Time, maxAge time.Duration) {
139144
e := s.idle.Front()
140145
for e != nil {
141146
n := e.Next()
142-
c := e.Value.(db.Connection)
147+
c := e.Value.(idb.Connection)
143148

144149
age := now.Sub(c.Birthdate())
145150
if age >= maxAge {
@@ -151,16 +156,15 @@ func (s *server) removeIdleOlderThan(now time.Time, maxAge time.Duration) {
151156
}
152157
}
153158

154-
func closeAndEmptyConnections(l list.List) {
159+
func closeAndEmptyConnections(l *list.List) {
155160
for e := l.Front(); e != nil; e = e.Next() {
156-
c := e.Value.(db.Connection)
161+
c := e.Value.(idb.Connection)
157162
c.Close()
158163
}
159164
l.Init()
160165
}
161166

162-
func (s *server) closeAll() {
163-
closeAndEmptyConnections(s.idle)
164-
// Closing the busy connections could mean here that we do close from another thread.
165-
closeAndEmptyConnections(s.busy)
167+
func (s *server) startClosing() {
168+
s.closing = true
169+
closeAndEmptyConnections(&s.idle)
166170
}

0 commit comments

Comments
 (0)