Skip to content

Commit d1073c8

Browse files
authored
Merge pull request #273 from 2hdddg/user-impersonation
Feature: user impersonation Requires Neo4j 4.4.
2 parents c461ea8 + edbe0cf commit d1073c8

22 files changed

+623
-177
lines changed

neo4j/db/connection.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ type Command struct {
4545
}
4646

4747
type TxConfig struct {
48-
Mode AccessMode
49-
Bookmarks []string
50-
Timeout time.Duration
51-
Meta map[string]interface{}
48+
Mode AccessMode
49+
Bookmarks []string
50+
Timeout time.Duration
51+
ImpersonatedUser string
52+
Meta map[string]interface{}
5253
}
5354

5455
// Connection defines an abstract database server connection.
@@ -96,16 +97,21 @@ type Connection interface {
9697
// Gets routing table for specified database name or the default database if
9798
// database equals DefaultDatabase. If the underlying connection does not support
9899
// multiple databases, DefaultDatabase should be used as database.
99-
GetRoutingTable(context map[string]string, bookmarks []string, database string) (*RoutingTable, error)
100+
// If user impersonation is used (impersonatedUser != "") and default database is used
101+
// the database name in the returned routing table will contain the actual name of the
102+
// configured default database for the impersonated user. If no impersonation is used
103+
// database name in routing table will be set to the name of the requested database.
104+
GetRoutingTable(context map[string]string, bookmarks []string, database, impersonatedUser string) (*RoutingTable, error)
100105
// Sets Bolt message logger on already initialized connections
101106
SetBoltLogger(boltLogger log.BoltLogger)
102107
}
103108

104109
type RoutingTable struct {
105-
TimeToLive int
106-
Routers []string
107-
Readers []string
108-
Writers []string
110+
TimeToLive int
111+
DatabaseName string
112+
Routers []string
113+
Readers []string
114+
Writers []string
109115
}
110116

111117
// Marker for using the default database instance.

neo4j/db/errors.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ func (e *Neo4jError) IsRetriableCluster() bool {
9898
return false
9999
}
100100

101-
type RoutingNotSupportedError struct {
102-
Server string
101+
type FeatureNotSupportedError struct {
102+
Server string
103+
Feature string
104+
Reason string
103105
}
104106

105-
func (e *RoutingNotSupportedError) Error() string {
106-
return fmt.Sprintf("%s does not support routing", e.Server)
107+
func (e *FeatureNotSupportedError) Error() string {
108+
return fmt.Sprintf("Server %s does not support: %s (%s)", e.Server, e.Feature, e.Reason)
107109
}
108110

109111
type UnsupportedTypeError struct {

neo4j/directrouter.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package neo4j
2121

2222
import (
2323
"context"
24+
"github.com/neo4j/neo4j-go-driver/v4/neo4j/db"
2425
"github.com/neo4j/neo4j-go-driver/v4/neo4j/log"
2526
)
2627

@@ -37,6 +38,10 @@ func (r *directRouter) Writers(ctx context.Context, bookmarks []string, database
3738
return []string{r.address}, nil
3839
}
3940

41+
func (r *directRouter) GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, boltLogger log.BoltLogger) (string, error) {
42+
return db.DefaultDatabase, nil
43+
}
44+
4045
func (r *directRouter) Invalidate(database string) {
4146
}
4247

neo4j/driver.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,13 @@ func routingContextFromUrl(useRouting bool, u *url.URL) (map[string]string, erro
224224
}
225225

226226
type sessionRouter interface {
227+
// Returns list of servers that can serve reads on the requested database.
227228
Readers(ctx context.Context, bookmarks []string, database string, boltLogger log.BoltLogger) ([]string, error)
229+
// Returns list of servers that can serve writes on the requested database.
228230
Writers(ctx context.Context, bookmarks []string, database string, boltLogger log.BoltLogger) ([]string, error)
231+
// Returns name of default database for specified user. The correct database name is needed when
232+
// requesting readers or writers.
233+
GetNameOfDefaultDatabase(ctx context.Context, bookmarks []string, user string, boltLogger log.BoltLogger) (string, error)
229234
Invalidate(database string)
230235
CleanUp()
231236
}
@@ -253,15 +258,18 @@ func (d *driver) Session(accessMode AccessMode, bookmarks ...string) (Session, e
253258
Message: "Trying to create session on closed driver",
254259
}
255260
}
261+
sessConfig := SessionConfig{
262+
AccessMode: accessMode,
263+
Bookmarks: bookmarks,
264+
DatabaseName: db.DefaultDatabase,
265+
}
256266
return newSession(
257-
d.config, d.router,
258-
d.pool, db.AccessMode(accessMode), bookmarks, db.DefaultDatabase, 0, d.log, nil), nil
267+
d.config, sessConfig, d.router, d.pool, d.log), nil
259268
}
260269

261270
func (d *driver) NewSession(config SessionConfig) Session {
262-
databaseName := db.DefaultDatabase
263-
if config.DatabaseName != "" {
264-
databaseName = config.DatabaseName
271+
if config.DatabaseName == "" {
272+
config.DatabaseName = db.DefaultDatabase
265273
}
266274

267275
d.mut.Lock()
@@ -270,10 +278,7 @@ func (d *driver) NewSession(config SessionConfig) Session {
270278
return &sessionWithError{
271279
err: &UsageError{Message: "Trying to create session on closed driver"}}
272280
}
273-
return newSession(
274-
d.config, d.router,
275-
d.pool, db.AccessMode(config.AccessMode), config.Bookmarks, databaseName, config.FetchSize,
276-
d.log, config.BoltLogger)
281+
return newSession(d.config, config, d.router, d.pool, d.log)
277282
}
278283

279284
func (d *driver) VerifyConnectivity() error {

neo4j/error.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func IsTransactionExecutionLimit(err error) bool {
113113
// TokenExpiredError represent errors caused by the driver not being able to connect to Neo4j services,
114114
// or lost connections.
115115
type TokenExpiredError struct {
116-
Code string
116+
Code string
117117
Message string
118118
}
119119

@@ -129,8 +129,9 @@ func wrapError(err error) error {
129129
return &ConnectivityError{inner: err}
130130
}
131131
switch e := err.(type) {
132-
case *db.UnsupportedTypeError:
133-
// Usage of a type not supported by database network protocol
132+
case *db.UnsupportedTypeError, *db.FeatureNotSupportedError:
133+
// Usage of a type not supported by database network protocol or feature
134+
// not supported by current version or edition.
134135
return &UsageError{Message: err.Error()}
135136
case *connector.TlsError, *connector.ConnectError:
136137
return &ConnectivityError{inner: err}

neo4j/internal/bolt/bolt3.go

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func NewBolt3(serverName string, conn net.Conn, log log.Logger, boltLog log.Bolt
9797
boltLogger: boltLog,
9898
},
9999
connReadTimeout: -1,
100-
logger: log,
100+
logger: log,
101101
},
102102
birthDate: time.Now(),
103103
log: log,
@@ -220,6 +220,9 @@ func (b *bolt3) TxBegin(txConfig db.TxConfig) (db.TxHandle, error) {
220220
if err := b.assertState(bolt3_ready); err != nil {
221221
return 0, err
222222
}
223+
if err := b.checkImpersonation(txConfig.ImpersonatedUser); err != nil {
224+
return 0, err
225+
}
223226

224227
tx := &internalTx3{
225228
mode: txConfig.Mode,
@@ -469,6 +472,9 @@ func (b *bolt3) Run(runCommand db.Command, txConfig db.TxConfig) (db.StreamHandl
469472
if err := b.assertState(bolt3_streaming, bolt3_ready); err != nil {
470473
return nil, err
471474
}
475+
if err := b.checkImpersonation(txConfig.ImpersonatedUser); err != nil {
476+
return nil, err
477+
}
472478

473479
tx := internalTx3{
474480
mode: txConfig.Mode,
@@ -698,13 +704,22 @@ func (b *bolt3) Reset() {
698704
}
699705
}
700706

701-
func (b *bolt3) GetRoutingTable(context map[string]string, bookmarks []string, database string) (*db.RoutingTable, error) {
707+
func (b *bolt3) checkImpersonation(impersonatedUser string) error {
708+
if impersonatedUser != "" {
709+
return &db.FeatureNotSupportedError{Server: b.serverName, Feature: "user impersonation", Reason: "requires least server v4.4"}
710+
}
711+
return nil
712+
}
713+
714+
func (b *bolt3) GetRoutingTable(context map[string]string, bookmarks []string, database, impersonatedUser string) (*db.RoutingTable, error) {
702715
if err := b.assertState(bolt3_ready); err != nil {
703716
return nil, err
704717
}
705-
706718
if database != db.DefaultDatabase {
707-
return nil, errors.New("Bolt 3 does not support routing to a specifiec database name")
719+
return nil, &db.FeatureNotSupportedError{Server: b.serverName, Feature: "route to database", Reason: "requires at least server v4"}
720+
}
721+
if err := b.checkImpersonation(impersonatedUser); err != nil {
722+
return nil, err
708723
}
709724

710725
// Only available when Neo4j is setup with clustering
@@ -718,7 +733,7 @@ func (b *bolt3) GetRoutingTable(context map[string]string, bookmarks []string, d
718733
// Give a better error
719734
dbError, isDbError := err.(*db.Neo4jError)
720735
if isDbError && dbError.Code == "Neo.ClientError.Procedure.ProcedureNotFound" {
721-
return nil, &db.RoutingNotSupportedError{Server: b.serverName}
736+
return nil, &db.FeatureNotSupportedError{Server: b.serverName, Feature: "routing", Reason: "requires cluster setup"}
722737
}
723738
return nil, err
724739
}
@@ -737,6 +752,8 @@ func (b *bolt3) GetRoutingTable(context map[string]string, bookmarks []string, d
737752
if table == nil {
738753
return nil, errors.New("Unable to parse routing table")
739754
}
755+
// Just because
756+
table.DatabaseName = db.DefaultDatabase
740757

741758
return table, nil
742759
}

neo4j/internal/bolt/bolt4.go

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ const (
4545
const bolt4_fetchsize = 1000
4646

4747
type internalTx4 struct {
48-
mode db.AccessMode
49-
bookmarks []string
50-
timeout time.Duration
51-
txMeta map[string]interface{}
52-
databaseName string
48+
mode db.AccessMode
49+
bookmarks []string
50+
timeout time.Duration
51+
txMeta map[string]interface{}
52+
databaseName string
53+
impersonatedUser string
5354
}
5455

5556
func (i *internalTx4) toMeta() map[string]interface{} {
@@ -70,6 +71,9 @@ func (i *internalTx4) toMeta() map[string]interface{} {
7071
if i.databaseName != db.DefaultDatabase {
7172
meta["db"] = i.databaseName
7273
}
74+
if i.impersonatedUser != "" {
75+
meta["imp_user"] = i.impersonatedUser
76+
}
7377
return meta
7478
}
7579

@@ -110,7 +114,7 @@ func NewBolt4(serverName string, conn net.Conn, log log.Logger, boltLog log.Bolt
110114
boltLogger: boltLog,
111115
},
112116
connReadTimeout: -1,
113-
logger: log,
117+
logger: log,
114118
},
115119
}
116120
b.out = outgoing{
@@ -263,6 +267,13 @@ func (b *bolt4) connect(minor int, auth map[string]interface{}, userAgent string
263267
return nil
264268
}
265269

270+
func (b *bolt4) checkImpersonationAndVersion(impersonatedUser string) error {
271+
if impersonatedUser != "" && b.minor < 4 {
272+
return &db.FeatureNotSupportedError{Server: b.serverName, Feature: "user impersonation", Reason: "requires at least server v4.4"}
273+
}
274+
return nil
275+
}
276+
266277
func (b *bolt4) TxBegin(txConfig db.TxConfig) (db.TxHandle, error) {
267278
// Ok, to begin transaction while streaming auto-commit, just empty the stream and continue.
268279
if b.state == bolt4_streaming {
@@ -277,12 +288,17 @@ func (b *bolt4) TxBegin(txConfig db.TxConfig) (db.TxHandle, error) {
277288
return 0, err
278289
}
279290

291+
if err := b.checkImpersonationAndVersion(txConfig.ImpersonatedUser); err != nil {
292+
return 0, err
293+
}
294+
280295
tx := internalTx4{
281-
mode: txConfig.Mode,
282-
bookmarks: txConfig.Bookmarks,
283-
timeout: txConfig.Timeout,
284-
txMeta: txConfig.Meta,
285-
databaseName: b.databaseName,
296+
mode: txConfig.Mode,
297+
bookmarks: txConfig.Bookmarks,
298+
timeout: txConfig.Timeout,
299+
txMeta: txConfig.Meta,
300+
databaseName: b.databaseName,
301+
impersonatedUser: txConfig.ImpersonatedUser,
286302
}
287303

288304
// If there are bookmarks, begin the transaction immediately for backwards compatible
@@ -608,12 +624,17 @@ func (b *bolt4) Run(cmd db.Command, txConfig db.TxConfig) (db.StreamHandle, erro
608624
return nil, err
609625
}
610626

627+
if err := b.checkImpersonationAndVersion(txConfig.ImpersonatedUser); err != nil {
628+
return 0, err
629+
}
630+
611631
tx := internalTx4{
612-
mode: txConfig.Mode,
613-
bookmarks: txConfig.Bookmarks,
614-
timeout: txConfig.Timeout,
615-
txMeta: txConfig.Meta,
616-
databaseName: b.databaseName,
632+
mode: txConfig.Mode,
633+
bookmarks: txConfig.Bookmarks,
634+
timeout: txConfig.Timeout,
635+
txMeta: txConfig.Meta,
636+
databaseName: b.databaseName,
637+
impersonatedUser: txConfig.ImpersonatedUser,
617638
}
618639
stream, err := b.run(cmd.Cypher, cmd.Params, cmd.FetchSize, &tx)
619640
if err != nil {
@@ -873,19 +894,42 @@ func (b *bolt4) Reset() {
873894
}
874895
}
875896

876-
func (b *bolt4) GetRoutingTable(context map[string]string, bookmarks []string, database string) (*db.RoutingTable, error) {
897+
func (b *bolt4) GetRoutingTable(context map[string]string, bookmarks []string, database, impersonatedUser string) (*db.RoutingTable, error) {
877898
if err := b.assertState(bolt4_ready); err != nil {
878899
return nil, err
879900
}
880901

881902
b.log.Infof(log.Bolt4, b.logId, "Retrieving routing table")
903+
if b.minor > 3 {
904+
extras := map[string]interface{}{}
905+
if database != db.DefaultDatabase {
906+
extras["db"] = database
907+
}
908+
if impersonatedUser != "" {
909+
extras["imp_user"] = impersonatedUser
910+
}
911+
b.out.appendRoute(context, bookmarks, extras)
912+
b.out.send(b.conn)
913+
succ := b.receiveSuccess()
914+
if b.err != nil {
915+
return nil, b.err
916+
}
917+
return succ.routingTable, nil
918+
}
919+
920+
if err := b.checkImpersonationAndVersion(impersonatedUser); err != nil {
921+
return nil, err
922+
}
923+
882924
if b.minor > 2 {
883-
b.out.appendRoute(context, bookmarks, database)
925+
b.out.appendRouteToV43(context, bookmarks, database)
884926
b.out.send(b.conn)
885927
succ := b.receiveSuccess()
886928
if b.err != nil {
887929
return nil, b.err
888930
}
931+
// On this version we will not receive the database name
932+
succ.routingTable.DatabaseName = database
889933
return succ.routingTable, nil
890934
}
891935
return b.callGetRoutingTable(context, bookmarks, database)
@@ -927,6 +971,8 @@ func (b *bolt4) callGetRoutingTable(context map[string]string, bookmarks []strin
927971
if table == nil {
928972
return nil, errors.New("Unable to parse routing table")
929973
}
974+
// On this version we will not recive the database name
975+
table.DatabaseName = database
930976
return table, nil
931977
}
932978

0 commit comments

Comments
 (0)