Skip to content

Commit 576dcb4

Browse files
committed
Allow nil results
1 parent 027891f commit 576dcb4

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

neo4j/session.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func (s *session) runRetriable(
259259
DatabaseName: s.databaseName,
260260
}
261261
for state.Continue() {
262-
if workResult := s.tryRun(&state, mode, &config, work); workResult != nil {
262+
if workResult, successfullyCompleted := s.tryRun(&state, mode, &config, work); successfullyCompleted {
263263
return workResult, nil
264264
}
265265
}
@@ -292,12 +292,12 @@ func (s *session) WriteTransaction(
292292
return s.runRetriable(db.WriteMode, work, configurers...)
293293
}
294294

295-
func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *TransactionConfig, work TransactionWork) interface{} {
295+
func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *TransactionConfig, work TransactionWork) (interface{}, bool) {
296296
// Establish new connection
297297
conn, err := s.getConnection(mode)
298298
if err != nil {
299299
state.OnFailure(conn, err, false)
300-
return nil
300+
return nil, false
301301
}
302302
defer s.pool.Return(conn)
303303
txHandle, err := conn.TxBegin(db.TxConfig{
@@ -308,7 +308,7 @@ func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *Transac
308308
})
309309
if err != nil {
310310
state.OnFailure(conn, err, false)
311-
return nil
311+
return nil, false
312312
}
313313

314314
tx := retryableTransaction{conn: conn, fetchSize: s.fetchSize, txHandle: txHandle}
@@ -321,17 +321,17 @@ func (s *session) tryRun(state *retry.State, mode db.AccessMode, config *Transac
321321
// but instead rely on pool invoking reset on the connection, that
322322
// will do an implicit rollback.
323323
state.OnFailure(conn, err, false)
324-
return nil
324+
return nil, false
325325
}
326326

327327
err = conn.TxCommit(txHandle)
328328
if err != nil {
329329
state.OnFailure(conn, err, true)
330-
return nil
330+
return nil, false
331331
}
332332

333333
s.retrieveBookmarks(conn)
334-
return x
334+
return x, true
335335
}
336336

337337
func (s *session) getServers(ctx context.Context, mode db.AccessMode) ([]string, error) {

neo4j/session_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,22 @@ func TestSession(st *testing.T) {
152152
}()
153153
_, _ = newSession.WriteTransaction(transactionFunction)
154154
})
155+
156+
rt.Run("tx function panic returns conn to pool and bubbles up", func(t *testing.T) {
157+
_, pool, newSession := createSession()
158+
pool.BorrowConn = &ConnFake{Alive: true}
159+
transactionFunction := func(Transaction) (interface{}, error) {
160+
return nil, nil
161+
}
162+
163+
result, err := newSession.WriteTransaction(transactionFunction)
164+
if result != nil {
165+
t.Errorf("expected nil result")
166+
}
167+
if err != nil {
168+
t.Errorf("expected nil error")
169+
}
170+
})
155171
})
156172

157173
st.Run("Bookmarking", func(bt *testing.T) {

0 commit comments

Comments
 (0)