Skip to content

Commit 2e39215

Browse files
author
Peter Wilhelmsson
committed
Respect connection context timout when retrieving routing table
1 parent 07fb3de commit 2e39215

File tree

7 files changed

+54
-47
lines changed

7 files changed

+54
-47
lines changed

neo4j/directrouter.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@
1919

2020
package neo4j
2121

22+
import "context"
23+
2224
// A router implementation that never routes
2325
type directRouter struct {
2426
address string
2527
}
2628

27-
func (r *directRouter) Readers(database string) ([]string, error) {
29+
func (r *directRouter) Readers(ctx context.Context, database string) ([]string, error) {
2830
return []string{r.address}, nil
2931
}
3032

31-
func (r *directRouter) Writers(database string) ([]string, error) {
33+
func (r *directRouter) Writers(ctx context.Context, database string) ([]string, error) {
3234
return []string{r.address}, nil
3335
}
3436

neo4j/driver.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package neo4j
2222

2323
import (
24+
"context"
2425
"fmt"
2526
"net/url"
2627
"strings"
@@ -223,8 +224,8 @@ func routingContextFromUrl(useRouting bool, u *url.URL) (map[string]string, erro
223224
}
224225

225226
type sessionRouter interface {
226-
Readers(database string) ([]string, error)
227-
Writers(database string) ([]string, error)
227+
Readers(ctx context.Context, database string) ([]string, error)
228+
Writers(ctx context.Context, database string) ([]string, error)
228229
Invalidate(database string)
229230
CleanUp()
230231
}

neo4j/internal/pool/pool.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,15 @@ func (p *Pool) Borrow(ctx context.Context, serverNames []string, wait bool) (db.
249249
var err error
250250
var conn db.Connection
251251
for _, s := range penalties {
252-
// Check if we have timed out
253-
if timeOut() {
254-
return nil, &PoolTimeout{servers: serverNames}
255-
}
256-
257252
conn, err = p.tryBorrow(s.name)
258253
if err == nil {
259254
return conn, nil
260255
}
256+
257+
// Check if we have timed out after failed borrow
258+
if timeOut() {
259+
return nil, &PoolTimeout{servers: serverNames}
260+
}
261261
}
262262

263263
// If there are no connections for any of the servers, there is no point in waiting for anything

neo4j/internal/router/router.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func New(rootRouter string, getRouters func() []string, routerContext map[string
7272
return r
7373
}
7474

75-
func (r *Router) getTable(database string) (*db.RoutingTable, error) {
75+
func (r *Router) getTable(ctx context.Context, database string) (*db.RoutingTable, error) {
7676
now := r.now()
7777

7878
r.dbRoutersMut.Lock()
@@ -92,20 +92,20 @@ func (r *Router) getTable(database string) (*db.RoutingTable, error) {
9292
if dbRouter != nil && len(dbRouter.table.Routers) > 0 {
9393
routers := dbRouter.table.Routers
9494
r.log.Infof(log.Router, r.logId, "Reading routing table for '%s' from previously known routers: %v", database, routers)
95-
table, err = readTable(context.Background(), r.pool, database, routers, r.routerContext)
95+
table, err = readTable(ctx, r.pool, database, routers, r.routerContext)
9696
}
9797

9898
// Try initial router if no routers or failed
99-
if table == nil || err != nil {
99+
if table == nil {
100100
r.log.Infof(log.Router, r.logId, "Reading routing table from initial router: %s", r.rootRouter)
101-
table, err = readTable(context.Background(), r.pool, database, []string{r.rootRouter}, r.routerContext)
101+
table, err = readTable(ctx, r.pool, database, []string{r.rootRouter}, r.routerContext)
102102
}
103103

104104
// Use hook to retrieve possibly different set of routers and retry
105-
if err != nil && r.getRouters != nil {
105+
if table == nil && r.getRouters != nil {
106106
routers := r.getRouters()
107107
r.log.Infof(log.Router, r.logId, "Reading routing table for '%s' from custom routers: %v", routers)
108-
table, err = readTable(context.Background(), r.pool, database, routers, r.routerContext)
108+
table, err = readTable(ctx, r.pool, database, routers, r.routerContext)
109109
}
110110

111111
if err != nil {
@@ -130,8 +130,8 @@ func (r *Router) getTable(database string) (*db.RoutingTable, error) {
130130
return table, nil
131131
}
132132

133-
func (r *Router) Readers(database string) ([]string, error) {
134-
table, err := r.getTable(database)
133+
func (r *Router) Readers(ctx context.Context, database string) ([]string, error) {
134+
table, err := r.getTable(ctx, database)
135135
if err != nil {
136136
return nil, err
137137
}
@@ -146,7 +146,7 @@ func (r *Router) Readers(database string) ([]string, error) {
146146
r.log.Infof(log.Router, r.logId, "Invalidating routing table, no readers")
147147
r.Invalidate(database)
148148
r.sleep(100 * time.Millisecond)
149-
table, err = r.getTable(database)
149+
table, err = r.getTable(ctx, database)
150150
if err != nil {
151151
return nil, err
152152
}
@@ -158,8 +158,8 @@ func (r *Router) Readers(database string) ([]string, error) {
158158
return table.Readers, nil
159159
}
160160

161-
func (r *Router) Writers(database string) ([]string, error) {
162-
table, err := r.getTable(database)
161+
func (r *Router) Writers(ctx context.Context, database string) ([]string, error) {
162+
table, err := r.getTable(ctx, database)
163163
if err != nil {
164164
return nil, err
165165
}
@@ -174,7 +174,7 @@ func (r *Router) Writers(database string) ([]string, error) {
174174
r.log.Infof(log.Router, r.logId, "Invalidating routing table, no writers")
175175
r.Invalidate(database)
176176
r.sleep(100 * time.Millisecond)
177-
table, err = r.getTable(database)
177+
table, err = r.getTable(ctx, database)
178178
if err != nil {
179179
return nil, err
180180
}

neo4j/internal/router/router_test.go

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,14 @@ func TestMultithreading(t *testing.T) {
6565
wg.Add(2)
6666
consumer := func() {
6767
for i := 0; i < 30; i++ {
68-
readers, err := router.Readers(dbName)
68+
readers, err := router.Readers(context.Background(), dbName)
6969
if len(readers) != 2 {
7070
t.Error("Wrong number of readers")
7171
}
7272
if err != nil {
7373
t.Error(err)
7474
}
75-
writers, err := router.Writers(dbName)
75+
writers, err := router.Writers(context.Background(), dbName)
7676
if len(writers) != 1 {
7777
t.Error("Wrong number of writers")
7878
}
@@ -114,25 +114,25 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) {
114114
dbName := "dbname"
115115

116116
// First access should trigger initial table read
117-
router.Readers(dbName)
117+
router.Readers(context.Background(), dbName)
118118
assertNum(t, numfetch, 1, "Should have fetched initial")
119119

120120
// Second access with time set to same should not trigger a read
121-
router.Readers(dbName)
121+
router.Readers(context.Background(), dbName)
122122
assertNum(t, numfetch, 1, "Should not have have fetched")
123123

124124
// Third access with time passed table due should trigger fetch
125125
n = n.Add(2 * time.Second)
126-
router.Readers(dbName)
126+
router.Readers(context.Background(), dbName)
127127
assertNum(t, numfetch, 2, "Should have have fetched")
128128

129129
// Just another one to make sure we're cached
130-
router.Readers(dbName)
130+
router.Readers(context.Background(), dbName)
131131
assertNum(t, numfetch, 2, "Should not have have fetched")
132132

133133
// Invalidate should force fetching
134134
router.Invalidate(dbName)
135-
router.Readers(dbName)
135+
router.Readers(context.Background(), dbName)
136136
assertNum(t, numfetch, 3, "Should have have fetched")
137137
}
138138

@@ -160,13 +160,13 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) {
160160
dbName := "dbname"
161161

162162
// First access should trigger initial table read from root router
163-
router.Readers(dbName)
163+
router.Readers(context.Background(), dbName)
164164
if borrows[0][0] != "rootRouter" {
165165
t.Errorf("Should have connected to root upon first router request")
166166
}
167167
// Next access should go to otherRouter
168168
n = n.Add(2 * time.Second)
169-
router.Readers(dbName)
169+
router.Readers(context.Background(), dbName)
170170
if borrows[1][0] != "otherRouter" {
171171
t.Errorf("Should have queried other router")
172172
}
@@ -191,7 +191,7 @@ func TestUsesRootRouterWhenPreviousRoutersFails(t *testing.T) {
191191
return &testutil.ConnFake{Table: &db.RoutingTable{TimeToLive: 1, Readers: []string{"aReader"}}}, nil
192192
}
193193
n = n.Add(2 * time.Second)
194-
readers, err := router.Readers(dbName)
194+
readers, err := router.Readers(context.Background(), dbName)
195195
if err != nil {
196196
t.Error(err)
197197
}
@@ -219,7 +219,7 @@ func TestUseGetRoutersHookWhenInitialRouterFails(t *testing.T) {
219219
dbName := "dbname"
220220

221221
// Trigger read of routing table
222-
router.Readers(dbName)
222+
router.Readers(context.Background(), dbName)
223223

224224
expected := []string{rootRouter}
225225
expected = append(expected, backupRouters...)
@@ -247,7 +247,7 @@ func TestWritersFailAfterNRetries(t *testing.T) {
247247
dbName := "dbname"
248248

249249
// Should trigger a lot of retries to get a writer until it finally fails
250-
writers, err := router.Writers(dbName)
250+
writers, err := router.Writers(context.Background(), dbName)
251251
if err == nil {
252252
t.Error("Should have failed")
253253
}
@@ -285,7 +285,7 @@ func TestWritersRetriesWhenNoWriters(t *testing.T) {
285285

286286
// Should trigger initial table read that contains no writers and a second table read
287287
// that gets the writers
288-
writers, err := router.Writers(dbName)
288+
writers, err := router.Writers(context.Background(), dbName)
289289
if err != nil {
290290
t.Errorf("Got error: %s", err)
291291
}
@@ -323,7 +323,7 @@ func TestReadersRetriesWhenNoReaders(t *testing.T) {
323323

324324
// Should trigger initial table read that contains no readers and a second table read
325325
// that gets the readers
326-
readers, err := router.Readers(dbName)
326+
readers, err := router.Readers(context.Background(), dbName)
327327
if err != nil {
328328
t.Errorf("Got error: %s", err)
329329
}
@@ -349,8 +349,8 @@ func TestCleanUp(t *testing.T) {
349349
router := New("router", func() []string { return []string{} }, nil, pool, logger, "routerid")
350350
router.now = func() time.Time { return now }
351351

352-
router.Readers("db1")
353-
router.Readers("db2")
352+
router.Readers(context.Background(), "db1")
353+
router.Readers(context.Background(), "db2")
354354

355355
// Should be a router for each requested database
356356
if len(router.dbRouters) != 2 {

neo4j/internal/testutil/routerfake.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
package testutil
2121

22+
import (
23+
"context"
24+
)
25+
2226
type RouterFake struct {
2327
Invalidated bool
2428
InvalidatedDb string
@@ -33,11 +37,11 @@ func (r *RouterFake) Invalidate(database string) {
3337
r.Invalidated = true
3438
}
3539

36-
func (r *RouterFake) Readers(database string) ([]string, error) {
40+
func (r *RouterFake) Readers(ctx context.Context, database string) ([]string, error) {
3741
return r.ReadersRet, r.Err
3842
}
3943

40-
func (r *RouterFake) Writers(database string) ([]string, error) {
44+
func (r *RouterFake) Writers(ctx context.Context, database string) ([]string, error) {
4145
return r.WritersRet, r.Err
4246
}
4347

neo4j/session.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,15 @@ func (s *session) WriteTransaction(
327327
return s.runRetriable(db.WriteMode, work, configurers...)
328328
}
329329

330-
func (s *session) getServers(mode db.AccessMode) ([]string, error) {
330+
func (s *session) getServers(ctx context.Context, mode db.AccessMode) ([]string, error) {
331331
if mode == db.ReadMode {
332-
return s.router.Readers(s.databaseName)
332+
return s.router.Readers(ctx, s.databaseName)
333333
} else {
334-
return s.router.Writers(s.databaseName)
334+
return s.router.Writers(ctx, s.databaseName)
335335
}
336336
}
337337

338338
func (s *session) getConnection(mode db.AccessMode) (db.Connection, error) {
339-
servers, err := s.getServers(mode)
340-
if err != nil {
341-
return nil, wrapError(err)
342-
}
343-
344339
var ctx context.Context
345340
if s.config.ConnectionAcquisitionTimeout > 0 {
346341
var cancel context.CancelFunc
@@ -351,6 +346,11 @@ func (s *session) getConnection(mode db.AccessMode) (db.Connection, error) {
351346
} else {
352347
ctx = context.Background()
353348
}
349+
servers, err := s.getServers(ctx, mode)
350+
if err != nil {
351+
return nil, wrapError(err)
352+
}
353+
354354
conn, err := s.pool.Borrow(ctx, servers, s.config.ConnectionAcquisitionTimeout != 0)
355355
if err != nil {
356356
return nil, wrapError(err)

0 commit comments

Comments
 (0)