Skip to content

Commit fa39565

Browse files
Peter Wilhelmsson2hdddg
authored andcommitted
UDS support
Support for connecting to neo4j server via Unix domain socket.
1 parent 8ee6022 commit fa39565

File tree

4 files changed

+103
-48
lines changed

4 files changed

+103
-48
lines changed

neo4j/directrouter.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ package neo4j
2121

2222
// A router implementation that never routes
2323
type directRouter struct {
24-
server string
24+
address string
2525
}
2626

2727
func (r *directRouter) Readers(database string) ([]string, error) {
28-
return []string{r.server}, nil
28+
return []string{r.address}, nil
2929
}
3030

3131
func (r *directRouter) Writers(database string) ([]string, error) {
32-
return []string{r.server}, nil
32+
return []string{r.address}, nil
3333
}
3434

3535
func (r *directRouter) Invalidate(database string) {

neo4j/driver.go

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,27 @@ func NewDriver(target string, auth AuthToken, configurers ...func(*Config)) (Dri
8282
if err != nil {
8383
return nil, err
8484
}
85-
if parsed.Port() == "" {
86-
parsed.Host = parsed.Host + ":7687"
87-
}
8885

8986
d := driver{target: parsed}
9087

9188
routing := true
89+
d.connector.Network = "tcp"
90+
address := parsed.Host
9291
switch parsed.Scheme {
9392
case "bolt":
9493
routing = false
9594
d.connector.SkipEncryption = true
95+
case "bolt+unix":
96+
// bolt+unix://<path to socket>
97+
routing = false
98+
d.connector.SkipEncryption = true
99+
d.connector.Network = "unix"
100+
if parsed.Host != "" {
101+
return nil, &UsageError{
102+
Message: fmt.Sprintf("Host part should be empty for scheme %s", parsed.Scheme),
103+
}
104+
}
105+
address = parsed.Path
96106
case "bolt+s":
97107
routing = false
98108
case "bolt+ssc":
@@ -105,10 +115,15 @@ func NewDriver(target string, auth AuthToken, configurers ...func(*Config)) (Dri
105115
case "neo4j+s":
106116
default:
107117
return nil, &UsageError{
108-
Message: fmt.Sprintf("URL scheme %s is not supported", parsed.Scheme),
118+
Message: fmt.Sprintf("URI scheme %s is not supported", parsed.Scheme),
109119
}
110120
}
111121

122+
if parsed.Host != "" && parsed.Port() == "" {
123+
address += ":7687"
124+
parsed.Host = address
125+
}
126+
112127
if !routing && len(parsed.RawQuery) > 0 {
113128
return nil, &UsageError{
114129
Message: fmt.Sprintf("Routing context is not supported for URL scheme %s", parsed.Scheme),
@@ -150,7 +165,7 @@ func NewDriver(target string, auth AuthToken, configurers ...func(*Config)) (Dri
150165
d.pool = pool.New(d.config.MaxConnectionPoolSize, d.config.MaxConnectionLifetime, d.connector.Connect, d.log, d.logId)
151166

152167
if !routing {
153-
d.router = &directRouter{server: parsed.Host}
168+
d.router = &directRouter{address: address}
154169
} else {
155170
var routersResolver func() []string
156171
addressResolverHook := d.config.AddressResolver
@@ -165,10 +180,10 @@ func NewDriver(target string, auth AuthToken, configurers ...func(*Config)) (Dri
165180
}
166181
}
167182
// Let the router use the same logid as the driver to simplify log reading.
168-
d.router = router.New(parsed.Host, routersResolver, routingContext, d.pool, d.log, d.logId)
183+
d.router = router.New(address, routersResolver, routingContext, d.pool, d.log, d.logId)
169184
}
170185

171-
d.log.Infof(log.Driver, d.logId, "Created { target: %s }", target)
186+
d.log.Infof(log.Driver, d.logId, "Created { target: %s }", address)
172187
return &d, nil
173188
}
174189

neo4j/driver_test.go

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ func assertNoRouter(t *testing.T, d Driver) {
3434
t.Error("Expected no router")
3535
}
3636
}
37+
func assertNoRouterAddress(t *testing.T, d Driver, address string) {
38+
t.Helper()
39+
direct := d.(*driver).router.(*directRouter)
40+
if direct.address != address {
41+
t.Errorf("Address mismatch %s vs %s", address, direct.address)
42+
}
43+
}
3744

3845
func assertRouter(t *testing.T, d Driver) {
3946
t.Helper()
@@ -52,46 +59,80 @@ func assertRouterContext(t *testing.T, d Driver, context map[string]string) {
5259
}
5360
}
5461

55-
var uriSchemeTests = []struct {
56-
name string
57-
scheme string
58-
testing string
59-
router bool
60-
}{
61-
{"bolt://", "bolt", "bolt://localhost:7687", false},
62-
{"bolt+s://", "bolt", "bolt://localhost:7687", false},
63-
{"bolt+ssc://", "bolt", "bolt://localhost:7687", false},
64-
{"neo4j://", "neo4j", "neo4j://localhost:7687", true},
65-
{"neo4j+s://", "neo4j", "neo4j://localhost:7687", true},
66-
{"neo4j+ssc://", "neo4j", "neo4j://localhost:7687", true},
62+
func assertSkipEncryption(t *testing.T, d Driver, skipEncryption bool) {
63+
t.Helper()
64+
c := d.(*driver).connector
65+
if c.SkipEncryption != skipEncryption {
66+
t.Errorf("SkipEncryption mismatch, %t vs %t", skipEncryption, c.SkipEncryption)
67+
}
68+
}
69+
70+
func assertSkipVerify(t *testing.T, d Driver, skipVerify bool) {
71+
t.Helper()
72+
c := d.(*driver).connector
73+
if c.SkipVerify != skipVerify {
74+
t.Errorf("SkipVerify mismatch, %t vs %t", skipVerify, c.SkipVerify)
75+
}
6776
}
6877

69-
func TestDriverURISchemesX(t *testing.T) {
78+
func assertNetwork(t *testing.T, d Driver, network string) {
79+
t.Helper()
80+
c := d.(*driver).connector
81+
if c.Network != network {
82+
t.Errorf("Network mismatch, %s vs %s", network, c.Network)
83+
}
84+
}
85+
86+
func TestDriverURISchemes(t *testing.T) {
87+
uriSchemeTests := []struct {
88+
scheme string
89+
testing string
90+
router bool
91+
skipEncryption bool
92+
skipVerify bool
93+
network string
94+
address string
95+
}{
96+
{"bolt", "bolt://localhost:7687", false, true, false, "tcp", "localhost:7687"},
97+
{"bolt+s", "bolt+s://localhost:7687", false, false, false, "tcp", "localhost:7687"},
98+
{"bolt+ssc", "bolt+ssc://localhost:7687", false, false, true, "tcp", "localhost:7687"},
99+
{"bolt+unix", "bolt+unix:///tmp/a.socket", false, true, false, "unix", "/tmp/a.socket"},
100+
{"neo4j", "neo4j://localhost:7687", true, true, false, "tcp", ""},
101+
{"neo4j+s", "neo4j+s://localhost:7687", true, false, false, "tcp", ""},
102+
{"neo4j+ssc", "neo4j+ssc://localhost:7687", true, false, true, "tcp", ""},
103+
}
104+
70105
for _, tt := range uriSchemeTests {
71-
t.Run(tt.name, func(t *testing.T) {
106+
t.Run(tt.scheme, func(t *testing.T) {
72107
driver, err := NewDriver(tt.testing, NoAuth())
73108

74109
AssertNoError(t, err)
75110
AssertStringEqual(t, driver.Target().Scheme, tt.scheme)
76111
if !tt.router {
77112
assertNoRouter(t, driver)
113+
assertNoRouterAddress(t, driver, tt.address)
78114
} else {
79115
assertRouter(t, driver)
80116
}
117+
assertSkipEncryption(t, driver, tt.skipEncryption)
118+
if !tt.skipEncryption {
119+
assertSkipVerify(t, driver, tt.skipVerify)
120+
}
121+
assertNetwork(t, driver, tt.network)
81122
})
82123
}
83124
}
84125

85-
var invalidURISchemeTests = []struct {
86-
name string
87-
scheme string
88-
testing string
89-
}{
90-
{"bolt+routing://", "bolt+routing", "bolt+routing://localhost:7687"},
91-
{"invalid://", "invalid", "invalid://localhost:7687"},
92-
}
126+
func TestDriverInvalidURISchemes(t *testing.T) {
127+
invalidURISchemeTests := []struct {
128+
name string
129+
scheme string
130+
testing string
131+
}{
132+
{"bolt+routing://", "bolt+routing", "bolt+routing://localhost:7687"},
133+
{"invalid://", "invalid", "invalid://localhost:7687"},
134+
}
93135

94-
func TestDriverInvalidURISchemesX(t *testing.T) {
95136
for _, tt := range invalidURISchemeTests {
96137
t.Run(tt.name, func(t *testing.T) {
97138
_, err := NewDriver(tt.testing, NoAuth())
@@ -127,7 +168,6 @@ func TestDriverURIRoutingContext(t *testing.T) {
127168
}
128169

129170
func TestDriverDefaultPort(t *testing.T) {
130-
131171
t.Run("neo4j://localhost should default to port 7687", func(t1 *testing.T) {
132172
driver, err := NewDriver("neo4j://localhost", NoAuth())
133173
driverTarget := driver.Target()
@@ -139,7 +179,6 @@ func TestDriverDefaultPort(t *testing.T) {
139179
}
140180

141181
func TestNewDriverAndClose(t *testing.T) {
142-
143182
driver, err := NewDriver("bolt://localhost:7687", NoAuth())
144183
AssertNoError(t, err)
145184

@@ -172,19 +211,19 @@ func TestNewDriverAndClose(t *testing.T) {
172211
}
173212
}
174213

175-
var driverSessionCreationTests = []struct {
176-
name string
177-
testing string
178-
mode AccessMode
179-
bookmarks []string
180-
}{
181-
{"case one", "bolt://localhost:7687", AccessModeWrite, []string(nil)},
182-
{"case two", "bolt://localhost:7687", AccessModeRead, []string(nil)},
183-
{"case three", "bolt://localhost:7687", AccessModeWrite, []string{"B1", "B2", "B3"}},
184-
{"case four", "bolt://localhost:7687", AccessModeRead, []string{"B1", "B2", "B3", "B4"}},
185-
}
214+
func TestDriverSessionCreation(t *testing.T) {
215+
driverSessionCreationTests := []struct {
216+
name string
217+
testing string
218+
mode AccessMode
219+
bookmarks []string
220+
}{
221+
{"Write", "bolt://localhost:7687", AccessModeWrite, []string(nil)},
222+
{"Read", "bolt://localhost:7687", AccessModeRead, []string(nil)},
223+
{"Write+bookmarks", "bolt://localhost:7687", AccessModeWrite, []string{"B1", "B2", "B3"}},
224+
{"Read+bookmarks", "bolt://localhost:7687", AccessModeRead, []string{"B1", "B2", "B3", "B4"}},
225+
}
186226

187-
func TestDriverSessionCreationX(t *testing.T) {
188227
for _, tt := range driverSessionCreationTests {
189228
t.Run(tt.name, func(t *testing.T) {
190229
driver, err := NewDriver(tt.testing, NoAuth())

neo4j/internal/connector/connector.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ type Connector struct {
4343
Log log.Logger
4444
UserAgent string
4545
RoutingContext map[string]string
46+
Network string
4647
}
4748

4849
type ConnectError struct {
@@ -67,7 +68,7 @@ func (c Connector) Connect(address string) (db.Connection, error) {
6768
dialer.KeepAlive = -1 * time.Second // Turns keep-alive off
6869
}
6970

70-
conn, err := dialer.Dial("tcp", address)
71+
conn, err := dialer.Dial(c.Network, address)
7172
if err != nil {
7273
return nil, &ConnectError{inner: err}
7374
}

0 commit comments

Comments
 (0)