Skip to content

Commit 068c138

Browse files
committed
feat: locks can never expire
1 parent 5ae4e94 commit 068c138

File tree

8 files changed

+98
-39
lines changed

8 files changed

+98
-39
lines changed

mutex.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ func (m *mutex) lockId() string {
1818
return m.ns + ":" + m.id
1919
}
2020

21-
func (m *mutex) GetLockId() string {
21+
func (m *mutex) GetId() string {
2222
return m.lockId()
2323
}
2424

25-
func (m *mutex) GetLockOwner() string {
25+
func (m *mutex) GetOwner() string {
2626
return m.owner
2727
}
2828

@@ -71,7 +71,7 @@ func newMutex(provider Provider, id string, opts ...Option) Mutex {
7171

7272
// String implements print interface.
7373
func (m *mutex) String() string {
74-
return "Mutex(" + m.provider.Name() + ":" + m.GetLockId() + ")"
74+
return "Mutex(" + m.provider.Name() + ":" + m.GetId() + ")"
7575
}
7676

7777
// Lock locks the named resourc

mutex_test.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,23 @@ import (
88
"github.com/stretchr/testify/assert"
99
)
1010

11-
func runBasicLockTests(t *testing.T, provider Provider) {
11+
func runLockTestsWithoutLifetime(t *testing.T, provider Provider) {
12+
factory := New(provider, WithNamespace("deadlock"))
13+
m1 := factory.New("build-images")
14+
m2 := factory.New("build-images")
15+
m3 := factory.New("start-containers")
16+
17+
assert.NoError(t, m1.Lock())
18+
assert.ErrorIs(t, m1.Lock(), ErrAlreadyLocked)
19+
assert.ErrorIs(t, m2.Lock(), ErrAlreadyLocked)
20+
assert.NoError(t, m3.Lock())
21+
22+
assert.NoError(t, m1.Unlock())
23+
assert.ErrorIs(t, m2.Unlock(), ErrNotLocked)
24+
assert.NoError(t, m3.Unlock())
25+
}
26+
27+
func runLockTestsWithLifetime(t *testing.T, provider Provider) {
1228
factory := New(provider, WithLockLifetime(1*time.Second))
1329
m := factory.New("johndoe", WithNamespace("questions"))
1430
expectedMutexDisplayName := fmt.Sprintf("Mutex(%s:questions:johndoe)", provider.Name())
@@ -51,7 +67,7 @@ func testLockContention(t *testing.T, m Mutex) {
5167
func testUnlockAfterOwnerChange(t *testing.T, m1, m2 Mutex) {
5268
assert.NoError(t, m1.Lock())
5369
assert.ErrorIs(t, m2.Lock(), ErrAlreadyLocked)
54-
time.Sleep(10 * time.Millisecond) // m1 expired (released by system)
70+
time.Sleep(50 * time.Millisecond) // m1 expired (released by system)
5571
assert.NoError(t, m2.Lock()) // m2 can obtain the lock, since m1 is expired
5672
assert.ErrorIs(t, m1.Unlock(), ErrNotLocked)
5773
}

mysql.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ const (
1515

1616
mysqlLockSQL = `INSERT INTO %s (id, owner, expire_at) VALUES (?, ?, ?)
1717
ON DUPLICATE KEY UPDATE
18-
owner = IF(expire_at < ?, VALUES(owner), owner),
19-
expire_at = IF(expire_at < ?, VALUES(expire_at), expire_at);`
18+
owner = IF(expire_at > 0 AND expire_at < ?, VALUES(owner), owner),
19+
expire_at = IF(expire_at > 0 AND expire_at < ?, VALUES(expire_at), expire_at);`
2020

21-
mysqlUnlockSQL = `DELETE FROM %s WHERE id = ? AND owner = ? AND expire_at >= ?;`
21+
mysqlUnlockSQL = `DELETE FROM %s WHERE id = ? AND owner = ? AND (expire_at = 0 OR expire_at >= ?);`
2222
)
2323

2424
type mysqlProvider struct {
@@ -65,11 +65,10 @@ func (p *mysqlProvider) init() error {
6565

6666
func (p *mysqlProvider) Lock(lock NamedLock) error {
6767
now := time.Now()
68-
expireAt := now.Add(lock.GetLifetime())
6968
rs, err := p.lockStmt.Exec(
70-
lock.GetLockId(),
71-
lock.GetLockOwner(),
72-
expireAt.UnixNano(),
69+
lock.GetId(),
70+
lock.GetOwner(),
71+
computeExpireAt(now, lock.GetLifetime()),
7372
now.UnixNano(),
7473
now.UnixNano(),
7574
)
@@ -88,8 +87,8 @@ func (p *mysqlProvider) Lock(lock NamedLock) error {
8887

8988
func (p *mysqlProvider) Unlock(lock NamedLock) error {
9089
rs, err := p.unlockStmt.Exec(
91-
lock.GetLockId(),
92-
lock.GetLockOwner(),
90+
lock.GetId(),
91+
lock.GetOwner(),
9392
time.Now().UnixNano(),
9493
)
9594
if err != nil {
@@ -108,3 +107,10 @@ func (p *mysqlProvider) Unlock(lock NamedLock) error {
108107
func formatSQL(sqlTemplate, tableName string) string {
109108
return fmt.Sprintf(sqlTemplate, tableName)
110109
}
110+
111+
func computeExpireAt(now time.Time, lifetime time.Duration) int64 {
112+
if lifetime == 0 {
113+
return 0 // never expire
114+
}
115+
return now.Add(lifetime).UnixNano()
116+
}

postgres.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ const (
1515

1616
pgLockSQL = `INSERT INTO %s AS t (id, owner, expire_at) VALUES ($1, $2, $3)
1717
ON CONFLICT (id) DO UPDATE
18-
SET owner = $2, expire_at = $3 WHERE t.id = $1 AND t.expire_at < $4;`
18+
SET owner = $2, expire_at = $3
19+
WHERE t.id = $1 AND t.expire_at > 0 AND t.expire_at < $4;`
1920

20-
pgUnlockSQL = `DELETE FROM %s WHERE id = $1 AND owner = $2 AND expire_at >= $3;`
21+
pgUnlockSQL = `DELETE FROM %s WHERE id = $1 AND owner = $2 AND (expire_at = 0 OR expire_at >= $3);`
2122
)
2223

2324
type postgreSQLProvider mysqlProvider
@@ -58,11 +59,10 @@ func (p *postgreSQLProvider) init() error {
5859

5960
func (p *postgreSQLProvider) Lock(lock NamedLock) error {
6061
now := time.Now()
61-
expireAt := now.Add(lock.GetLifetime())
6262
rs, err := p.lockStmt.Exec(
63-
lock.GetLockId(),
64-
lock.GetLockOwner(),
65-
expireAt.UnixNano(),
63+
lock.GetId(),
64+
lock.GetOwner(),
65+
computeExpireAt(now, lock.GetLifetime()),
6666
now.UnixNano(),
6767
)
6868
if err != nil {
@@ -80,8 +80,8 @@ func (p *postgreSQLProvider) Lock(lock NamedLock) error {
8080

8181
func (p *postgreSQLProvider) Unlock(lock NamedLock) error {
8282
rs, err := p.unlockStmt.Exec(
83-
lock.GetLockId(),
84-
lock.GetLockOwner(),
83+
lock.GetId(),
84+
lock.GetOwner(),
8585
time.Now().UnixNano(),
8686
)
8787
if err != nil {

provider.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ package distlock
33
import "time"
44

55
type NamedLock interface {
6-
GetLockId() string
7-
GetLockOwner() string
6+
GetId() string
7+
GetOwner() string
88
GetLifetime() time.Duration
99
}
1010

redis.go

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,28 @@ func (p *redisProvider) Lock(lock NamedLock) error {
4141
conn := p.pool.Get()
4242
defer conn.Close()
4343

44-
// SET key value PX milliseconds NX
45-
// PX: Set the specified expire time, in milliseconds.
46-
// NX: Only set the key if it does not already exist.
47-
reply, err := conn.Do(
48-
"SET", lock.GetLockId(), lock.GetLockOwner(),
49-
"PX", lock.GetLifetime().Nanoseconds()/int64(time.Millisecond),
50-
"NX",
44+
var (
45+
reply interface{}
46+
err error
5147
)
48+
49+
lifetime := lock.GetLifetime()
50+
if lifetime > 0 {
51+
// SET key value PX milliseconds NX
52+
// PX: Set the specified expire time, in milliseconds.
53+
// NX: Only set the key if it does not already exist.
54+
reply, err = conn.Do(
55+
"SET", lock.GetId(), lock.GetOwner(),
56+
"PX", lock.GetLifetime().Nanoseconds()/int64(time.Millisecond),
57+
"NX",
58+
)
59+
} else { // never expire
60+
reply, err = conn.Do(
61+
"SET", lock.GetId(), lock.GetOwner(),
62+
"NX",
63+
)
64+
}
65+
5266
if err != nil {
5367
return fmt.Errorf("redis SET: %w", err)
5468
}
@@ -64,7 +78,7 @@ func (p *redisProvider) Unlock(lock NamedLock) error {
6478
defer conn.Close()
6579

6680
command := redis.NewScript(1, unlockScript)
67-
ret, err := redis.Int(command.Do(conn, lock.GetLockId(), lock.GetLockOwner()))
81+
ret, err := redis.Int(command.Do(conn, lock.GetId(), lock.GetOwner()))
6882
if err != nil {
6983
return fmt.Errorf("redis EVAL: %w", err)
7084
}

redis_test.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,17 @@ var (
2525
}
2626
)
2727

28+
func cleanupRedis() {
29+
conn := redisPool.Get()
30+
defer conn.Close()
31+
32+
conn.Do("FLUSHDB")
33+
}
34+
2835
func TestRedisProvider(t *testing.T) {
36+
cleanupRedis()
37+
2938
provider, _ := NewRedisProvider(redisPool)
30-
runBasicLockTests(t, provider)
39+
runLockTestsWithLifetime(t, provider)
40+
runLockTestsWithoutLifetime(t, provider)
3141
}

sql_test.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,29 @@ import (
88
_ "github.com/lib/pq"
99
)
1010

11+
const TestTableName = "ggicci_distlock_test"
12+
13+
func cleanupMySQL(db *sql.DB) {
14+
_, _ = db.Exec(formatSQL("DROP TABLE IF EXISTS %s", TestTableName))
15+
}
16+
17+
func cleanupPostgreSQL(db *sql.DB) {
18+
_, _ = db.Exec(formatSQL("DROP TABLE IF EXISTS %s", TestTableName))
19+
}
20+
1121
func TestMySQLProvider(t *testing.T) {
1222
db, err := sql.Open("mysql", "root@tcp(localhost:3306)/test")
1323
if err != nil {
1424
t.Fatal(err)
1525
}
26+
cleanupMySQL(db)
1627

17-
provider, err := NewMySQLProvider(db, "distlocks")
28+
provider, err := NewMySQLProvider(db, TestTableName)
1829
if err != nil {
1930
t.Fatalf("could not create provider: %s", err)
2031
}
21-
22-
runBasicLockTests(t, provider)
32+
runLockTestsWithoutLifetime(t, provider)
33+
runLockTestsWithLifetime(t, provider)
2334
}
2435

2536
func TestPostgreSQLProvider(t *testing.T) {
@@ -31,10 +42,12 @@ func TestPostgreSQLProvider(t *testing.T) {
3142
t.Fatal(err)
3243
}
3344

34-
provider, err := NewPostgreSQLProvider(db, "distlocks")
45+
cleanupPostgreSQL(db)
46+
47+
provider, err := NewPostgreSQLProvider(db, TestTableName)
3548
if err != nil {
3649
t.Fatalf("could not create provider: %s", err)
3750
}
38-
39-
runBasicLockTests(t, provider)
51+
runLockTestsWithoutLifetime(t, provider)
52+
runLockTestsWithLifetime(t, provider)
4053
}

0 commit comments

Comments
 (0)