Skip to content

Commit 60d5d32

Browse files
Peter Wilhelmsson2hdddg
authored andcommitted
Use routing per database from API level
1 parent 7dea0ef commit 60d5d32

File tree

7 files changed

+73
-57
lines changed

7 files changed

+73
-57
lines changed

neo4j/directrouter.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ type directRouter struct {
2424
server string
2525
}
2626

27-
func (r *directRouter) Readers() ([]string, error) {
27+
func (r *directRouter) Readers(database string) ([]string, error) {
2828
return []string{r.server}, nil
2929
}
3030

31-
func (r *directRouter) Writers() ([]string, error) {
31+
func (r *directRouter) Writers(database string) ([]string, error) {
3232
return []string{r.server}, nil
3333
}
3434

35-
func (r *directRouter) Invalidate() {
35+
func (r *directRouter) Invalidate(database string) {
3636
}

neo4j/driver.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ func routingContextFromUrl(u *url.URL) (map[string]string, error) {
174174
}
175175

176176
type sessionRouter interface {
177-
Readers() ([]string, error)
178-
Writers() ([]string, error)
179-
Invalidate()
177+
Readers(database string) ([]string, error)
178+
Writers(database string) ([]string, error)
179+
Invalidate(database string)
180180
}
181181

182182
type driver struct {

neo4j/internal/router/readtable.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828

2929
// Tries to read routing table from any of the specified routers using new or existing connection
3030
// from the supplied pool.
31-
func readTable(ctx context.Context, pool Pool, routers []string, routerContext map[string]string) (*db.RoutingTable, error) {
31+
func readTable(ctx context.Context, pool Pool, database string, routers []string, routerContext map[string]string) (*db.RoutingTable, error) {
3232
// Preserve last error to be returned, set a default for case of no routers
3333
var err error = &ReadRoutingTableError{}
3434

@@ -55,7 +55,7 @@ func readTable(ctx context.Context, pool Pool, routers []string, routerContext m
5555
}
5656

5757
var table *db.RoutingTable
58-
table, err = discovery.GetRoutingTable(db.DefaultDatabase, routerContext)
58+
table, err = discovery.GetRoutingTable(database, routerContext)
5959
if err == nil {
6060
return table, nil
6161
}

neo4j/internal/router/readtable_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func TestReadTableTable(ot *testing.T) {
146146
ot.Run(c.name, func(t *testing.T) {
147147
ctx, cancel := context.WithCancel(context.Background())
148148
c.pool.cancel = cancel
149-
table, err := readTable(ctx, c.pool, c.routers, nil)
149+
table, err := readTable(ctx, c.pool, "dbname", c.routers, nil)
150150
c.assert(t, table, err)
151151
if c.numReturns != len(c.pool.returned) {
152152
t.Errorf("Expected %d returned connections but %d was returned", c.numReturns, len(c.pool.returned))

neo4j/internal/router/router.go

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,17 @@ import (
3434

3535
const missingWriterRetries = 100
3636

37+
type databaseRouter struct {
38+
dueUnix int64
39+
table *db.RoutingTable
40+
}
41+
3742
// Thread safe
3843
type Router struct {
3944
routerContext map[string]string
4045
pool Pool
41-
table *db.RoutingTable
42-
dueUnix int64
43-
tableMut sync.Mutex
46+
dbRouters map[string]*databaseRouter
47+
dbRoutersMut sync.Mutex
4448
now func() time.Time
4549
sleep func(time.Duration)
4650
rootRouter string
@@ -63,6 +67,7 @@ func New(rootRouter string, getRouters func() []string, routerContext map[string
6367
getRouters: getRouters,
6468
routerContext: routerContext,
6569
pool: pool,
70+
dbRouters: make(map[string]*databaseRouter),
6671
now: time.Now,
6772
sleep: time.Sleep,
6873
log: logger,
@@ -72,54 +77,57 @@ func New(rootRouter string, getRouters func() []string, routerContext map[string
7277
return r
7378
}
7479

75-
func (r *Router) getTable() (*db.RoutingTable, error) {
76-
r.tableMut.Lock()
77-
defer r.tableMut.Unlock()
78-
80+
func (r *Router) getTable(database string) (*db.RoutingTable, error) {
7981
now := r.now()
8082

81-
if r.table != nil && now.Unix() < r.dueUnix {
82-
return r.table, nil
83+
r.dbRoutersMut.Lock()
84+
defer r.dbRoutersMut.Unlock()
85+
86+
dbRouter := r.dbRouters[database]
87+
if dbRouter != nil && now.Unix() < dbRouter.dueUnix {
88+
return dbRouter.table, nil
8389
}
8490

8591
var routers []string
86-
if r.table != nil {
87-
routers = r.table.Routers
92+
if dbRouter != nil {
93+
routers = dbRouter.table.Routers
8894
}
8995
if len(routers) == 0 {
9096
routers = []string{r.rootRouter}
9197
}
9298

93-
r.log.Infof(r.logId, "Reading routing table from any of %v", routers)
94-
table, err := readTable(context.Background(), r.pool, routers, r.routerContext)
99+
r.log.Infof(r.logId, "Reading routing table for '%s' from any of %v", database, routers)
100+
table, err := readTable(context.Background(), r.pool, database, routers, r.routerContext)
95101
if err != nil {
96102
// Use hook to retrieve possibly different set of routers and retry
97103
if r.getRouters != nil {
98104
routers = r.getRouters()
99-
table, err = readTable(context.Background(), r.pool, routers, r.routerContext)
105+
table, err = readTable(context.Background(), r.pool, database, routers, r.routerContext)
100106
}
101107
if err != nil {
102108
r.log.Error(r.logId, err)
103109
return nil, err
104110
}
105111
}
106-
r.table = table
107-
r.dueUnix = now.Add(time.Duration(table.TimeToLive) * time.Second).Unix()
108-
r.log.Debugf(r.logId, "New routing table, TTL %d", table.TimeToLive)
112+
r.dbRouters[database] = &databaseRouter{
113+
table: table,
114+
dueUnix: now.Add(time.Duration(table.TimeToLive) * time.Second).Unix(),
115+
}
116+
r.log.Debugf(r.logId, "New routing table for '%s', TTL %d", database, table.TimeToLive)
109117

110118
return table, nil
111119
}
112120

113-
func (r *Router) Readers() ([]string, error) {
114-
table, err := r.getTable()
121+
func (r *Router) Readers(database string) ([]string, error) {
122+
table, err := r.getTable(database)
115123
if err != nil {
116124
return nil, err
117125
}
118126
return table.Readers, nil
119127
}
120128

121-
func (r *Router) Writers() ([]string, error) {
122-
table, err := r.getTable()
129+
func (r *Router) Writers(database string) ([]string, error) {
130+
table, err := r.getTable(database)
123131
if err != nil {
124132
return nil, err
125133
}
@@ -133,11 +141,8 @@ func (r *Router) Writers() ([]string, error) {
133141
}
134142
r.log.Debugf(r.logId, "Invalidating routing table, no writers")
135143
r.sleep(100 * time.Millisecond)
136-
r.tableMut.Lock()
137-
// Reset due time to keep list of routers
138-
r.dueUnix = 0
139-
r.tableMut.Unlock()
140-
table, err = r.getTable()
144+
r.Invalidate(database)
145+
table, err = r.getTable(database)
141146
if err != nil {
142147
return nil, err
143148
}
@@ -153,11 +158,14 @@ func (r *Router) Context() map[string]string {
153158
return r.routerContext
154159
}
155160

156-
func (r *Router) Invalidate() {
157-
r.tableMut.Lock()
158-
defer r.tableMut.Unlock()
161+
func (r *Router) Invalidate(database string) {
162+
r.log.Infof(r.logId, "Invalidating routing table for '%s'", database)
163+
r.dbRoutersMut.Lock()
164+
defer r.dbRoutersMut.Unlock()
159165
// Reset due time to the 70s, this will make next access refresh the routing table using
160166
// last set of routers instead of the original one.
161-
r.dueUnix = 0
162-
r.log.Infof(r.logId, "Invalidating routing table")
167+
dbRouter := r.dbRouters[database]
168+
if dbRouter != nil {
169+
dbRouter.dueUnix = 0
170+
}
163171
}

neo4j/internal/router/router_test.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ var logger = &log.ConsoleLogger{Errors: true, Infos: true, Warns: true}
3737
// Verifies that concurrent access works as expected relying on the race detector to
3838
// report supicious behavior.
3939
func TestMultithreading(t *testing.T) {
40-
wg := sync.WaitGroup{}
41-
wg.Add(2)
4240

4341
// Setup a router that needs to read the routing table essentially on every access to
4442
// stress threading a bit more.
@@ -52,28 +50,34 @@ func TestMultithreading(t *testing.T) {
5250
}
5351
n := time.Now()
5452
router := New("router", func() []string { return []string{} }, nil, pool, logger)
53+
mut := sync.Mutex{}
5554
router.now = func() time.Time {
55+
// Need to lock here to make race detector happy
56+
mut.Lock()
57+
defer mut.Unlock()
5658
n = n.Add(time.Duration(table.TimeToLive) * time.Second * 2)
5759
return n
5860
}
5961

62+
dbName := "dbname"
63+
wg := sync.WaitGroup{}
64+
wg.Add(2)
6065
consumer := func() {
6166
for i := 0; i < 30; i++ {
62-
readers, err := router.Readers()
67+
readers, err := router.Readers(dbName)
6368
if len(readers) != 2 {
6469
t.Error("Wrong number of readers")
6570
}
6671
if err != nil {
6772
t.Error(err)
6873
}
69-
writers, err := router.Writers()
74+
writers, err := router.Writers(dbName)
7075
if len(writers) != 1 {
7176
t.Error("Wrong number of writers")
7277
}
7378
if err != nil {
7479
t.Error(err)
7580
}
76-
7781
}
7882
wg.Done()
7983
}
@@ -106,27 +110,28 @@ func TestRespectsTimeToLiveAndInvalidate(t *testing.T) {
106110
router.now = func() time.Time {
107111
return n
108112
}
113+
dbName := "dbname"
109114

110115
// First access should trigger initial table read
111-
router.Readers()
116+
router.Readers(dbName)
112117
assertNum(t, numfetch, 1, "Should have fetched initial")
113118

114119
// Second access with time set to same should not trigger a read
115-
router.Readers()
120+
router.Readers(dbName)
116121
assertNum(t, numfetch, 1, "Should not have have fetched")
117122

118123
// Third access with time passed table due should trigger fetch
119124
n = n.Add(2 * time.Second)
120-
router.Readers()
125+
router.Readers(dbName)
121126
assertNum(t, numfetch, 2, "Should have have fetched")
122127

123128
// Just another one to make sure we're cached
124-
router.Readers()
129+
router.Readers(dbName)
125130
assertNum(t, numfetch, 2, "Should not have have fetched")
126131

127132
// Invalidate should force fetching
128-
router.Invalidate()
129-
router.Readers()
133+
router.Invalidate(dbName)
134+
router.Readers(dbName)
130135
assertNum(t, numfetch, 3, "Should have have fetched")
131136
}
132137

@@ -143,9 +148,10 @@ func TestUseGetRoutersHookWhenInitialRouterFails(t *testing.T) {
143148
rootRouter := "rootRouter"
144149
backupRouters := []string{"bup1", "bup2"}
145150
router := New(rootRouter, func() []string { return backupRouters }, nil, pool, logger)
151+
dbName := "dbname"
146152

147153
// Trigger read of routing table
148-
router.Readers()
154+
router.Readers(dbName)
149155

150156
expected := []string{rootRouter}
151157
expected = append(expected, backupRouters...)
@@ -170,9 +176,10 @@ func TestWritersFailAfterNRetries(t *testing.T) {
170176
router.sleep = func(time.Duration) {
171177
numsleep++
172178
}
179+
dbName := "dbname"
173180

174181
// Should trigger a lot of retries to get a writer until it finally fails
175-
writers, err := router.Writers()
182+
writers, err := router.Writers(dbName)
176183
if err == nil {
177184
t.Error("Should have failed")
178185
}
@@ -206,10 +213,11 @@ func TestWritersRetriesWhenNoWriters(t *testing.T) {
206213
router.sleep = func(time.Duration) {
207214
numsleep++
208215
}
216+
dbName := "dbname"
209217

210218
// Should trigger initial table read that contains no writers and a second table read
211219
// that gets the writers
212-
writers, err := router.Writers()
220+
writers, err := router.Writers(dbName)
213221
if err != nil {
214222
t.Errorf("Got error: %s", err)
215223
}

neo4j/session.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ func (s *session) runRetriable(
325325
switch {
326326
case e.IsRetriableCluster():
327327
// Force routing tables to be updated before trying again
328-
s.router.Invalidate()
328+
s.router.Invalidate(s.databaseName)
329329
maxClusterErrors--
330330
if maxClusterErrors < 0 {
331331
s.log.Errorf(s.logId, "Retriable transaction failed due to encountering too many cluster errors")
@@ -362,9 +362,9 @@ func (s *session) borrowConn(mode db.AccessMode) error {
362362
var servers []string
363363
var err error
364364
if mode == db.ReadMode {
365-
servers, err = s.router.Readers()
365+
servers, err = s.router.Readers(s.databaseName)
366366
} else {
367-
servers, err = s.router.Writers()
367+
servers, err = s.router.Writers(s.databaseName)
368368
}
369369
if err != nil {
370370
return err

0 commit comments

Comments
 (0)