Skip to content

Commit 40e79b3

Browse files
committed
Fixes for function memoization
1 parent f18baf6 commit 40e79b3

File tree

3 files changed

+183
-40
lines changed

3 files changed

+183
-40
lines changed

src/main/java/net/tascalate/concurrent/core/FunctionMemoization.java

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
/**
2+
* Copyright 2015-2024 Valery Silaev (http://vsilaev.com)
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package net.tascalate.concurrent.core;
217

318
import java.lang.ref.Reference;
@@ -7,7 +22,7 @@
722
import java.util.function.Function;
823

924
class FunctionMemoization<K, V> implements Function<K, V> {
10-
private final ConcurrentMap<K, Object> producerMutexes = new ConcurrentHashMap<>();
25+
private final KeyedLocks<K> producerMutexes = new KeyedLocks<>();
1126
private final ConcurrentMap<Object, Object> valueMap = new ConcurrentHashMap<>();
1227

1328
private final Function<? super K, ? extends V> fn;
@@ -45,50 +60,36 @@ public V apply(K key) {
4560
}
4661
}
4762

48-
Object mutex = getOrCreateMutex(key);
49-
synchronized (mutex) {
50-
try {
51-
// Double-check after getting mutex
52-
valueRef = valueMap.get(lookupKeyRef);
53-
value = valueRef == null ? null : valueRefType.dereference(valueRef);
54-
if (value == null) {
55-
value = fn.apply(key);
56-
valueMap.put(
57-
keyRefType.createKeyReference(key, queue),
58-
valueRefType.createValueReference(value)
59-
);
60-
}
61-
} finally {
62-
producerMutexes.remove(key, mutex);
63+
try (KeyedLocks.Lock lock = producerMutexes.acquire(key)) {
64+
// Double-check after getting mutex
65+
valueRef = valueMap.get(lookupKeyRef);
66+
value = valueRef == null ? null : valueRefType.dereference(valueRef);
67+
if (value == null) {
68+
value = fn.apply(key);
69+
valueMap.put(
70+
keyRefType.createKeyReference(key, queue),
71+
valueRefType.createValueReference(value)
72+
);
6373
}
74+
} catch (InterruptedException ex) {
75+
throw new RuntimeException(ex);
6476
}
65-
6677
return value;
6778
}
6879

6980
public V forget(K key) {
70-
Object mutex = getOrCreateMutex(key);
71-
synchronized (mutex) {
72-
try {
73-
Object valueRef = valueMap.remove(keyRefType.createLookupKey(key));
74-
return valueRef == null ? null : valueRefType.dereference(valueRef);
75-
} finally {
76-
producerMutexes.remove(key, mutex);
77-
}
78-
}
79-
}
80-
81-
private Object getOrCreateMutex(K key) {
82-
Object createdMutex = new byte[0];
83-
Object existingMutex = producerMutexes.putIfAbsent(key, createdMutex);
84-
if (existingMutex != null) {
85-
return existingMutex;
86-
} else {
87-
return createdMutex;
81+
try (KeyedLocks.Lock lock = producerMutexes.acquire(key)) {
82+
Object valueRef = valueMap.remove(keyRefType.createLookupKey(key));
83+
return valueRef == null ? null : valueRefType.dereference(valueRef);
84+
} catch (InterruptedException ex) {
85+
throw new RuntimeException(ex);
8886
}
8987
}
9088

9189
private void expungeStaleEntries() {
90+
if (null == queue) {
91+
return;
92+
}
9293
for (Reference<? extends K> ref; (ref = queue.poll()) != null;) {
9394
@SuppressWarnings("unchecked")
9495
Reference<K> keyRef = (Reference<K>) ref;
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/**
2+
* Copyright 2015-2024 Valery Silaev (http://vsilaev.com)
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package net.tascalate.concurrent.core;
17+
18+
import java.util.concurrent.ConcurrentHashMap;
19+
import java.util.concurrent.ConcurrentMap;
20+
import java.util.concurrent.CountDownLatch;
21+
22+
final class KeyedLocks<K> {
23+
private final ConcurrentMap<K, Lock> locksByKey = new ConcurrentHashMap<>();
24+
25+
public Lock acquire(K key) throws InterruptedException {
26+
@SuppressWarnings("resource")
27+
Lock ourLock = Lock.acquire(() -> locksByKey.remove(key));
28+
while (true) {
29+
Lock theirLock = locksByKey.putIfAbsent(key, ourLock);
30+
if (theirLock == null) {
31+
// No other locks, we are the owner
32+
return ourLock;
33+
}
34+
if (theirLock.tryAcquire(false)) {
35+
// Reentrant call
36+
return theirLock;
37+
}
38+
// Wait for other lock release and re-try
39+
theirLock.await();
40+
}
41+
}
42+
43+
final static class Lock implements AutoCloseable {
44+
private final CountDownLatch mutex;
45+
private final long threadId;
46+
private final Runnable cleanup;
47+
private int lockedCount = 1;
48+
49+
private Lock(long threadId, Runnable cleanup) {
50+
this.threadId = threadId;
51+
this.cleanup = cleanup;
52+
this.mutex = new CountDownLatch(1);
53+
}
54+
55+
static Lock acquire(Runnable cleanup) {
56+
return new Lock(currentThreadId(), cleanup);
57+
}
58+
59+
boolean sameThread(long currentThreadId, boolean throwError) {
60+
if (currentThreadId != threadId) {
61+
if (throwError) {
62+
return invalidThreadContext("The lock modified from the thread " + currentThreadId + " but was accuried in the thread " + threadId);
63+
} else {
64+
return false;
65+
}
66+
} else {
67+
return true;
68+
}
69+
70+
}
71+
72+
void await() throws InterruptedException {
73+
mutex.await();
74+
}
75+
76+
boolean tryAcquire(boolean throwError) {
77+
long currentThreadId = currentThreadId();
78+
if (threadId != currentThreadId) {
79+
if (throwError) {
80+
return invalidThreadContext("Trying to re-acquire lock from the thread " + currentThreadId + " but it was accuried in the thread " + threadId);
81+
} else {
82+
return false;
83+
}
84+
}
85+
lockedCount++;
86+
return true;
87+
}
88+
89+
boolean tryRelease(boolean throwError) {
90+
long currentThreadId = currentThreadId();
91+
if (threadId != currentThreadId) {
92+
if (throwError) {
93+
return invalidThreadContext("Trying to release lock from the thread " + currentThreadId + " but it was accuried in the thread " + threadId);
94+
} else {
95+
return false;
96+
}
97+
}
98+
if (lockedCount < 1) {
99+
return false;
100+
} else if (--lockedCount == 0) {
101+
cleanup.run();
102+
mutex.countDown();
103+
return true;
104+
} else {
105+
return true;
106+
}
107+
}
108+
109+
public boolean release() {
110+
return tryRelease(false);
111+
}
112+
113+
@Override
114+
public void close() {
115+
tryRelease(true);
116+
}
117+
118+
private static long currentThreadId() {
119+
return Thread.currentThread().getId();
120+
}
121+
122+
private static boolean invalidThreadContext(String message) {
123+
throw new IllegalStateException(message);
124+
}
125+
}
126+
}

src/main/java/net/tascalate/concurrent/core/ReferenceType.java

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
1+
/**
2+
* Copyright 2015-2024 Valery Silaev (http://vsilaev.com)
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
116
package net.tascalate.concurrent.core;
217

318
import java.lang.ref.Reference;
419
import java.lang.ref.ReferenceQueue;
520
import java.lang.ref.SoftReference;
621
import java.lang.ref.WeakReference;
22+
import java.util.Objects;
723

824
enum ReferenceType {
925
HARD() {
@@ -84,9 +100,9 @@ static final class LookupKey<K> {
84100
@Override
85101
public boolean equals(Object other) {
86102
if (other instanceof LookupKey<?>) {
87-
return ((LookupKey<?>) other).key == key;
103+
return key.equals(((LookupKey<?>)other).key);
88104
} else {
89-
return ((Reference<?>) other).get() == key;
105+
return key.equals(((Reference<?>)other).get());
90106
}
91107
}
92108

@@ -113,7 +129,7 @@ public int hashCode() {
113129
@Override
114130
public boolean equals(Object other) {
115131
if (other instanceof WeakKey<?>) {
116-
return ((WeakKey<?>) other).get() == get();
132+
return Objects.equals(((WeakKey<?>) other).get(), get());
117133
} else {
118134
return other.equals(this);
119135
}
@@ -142,7 +158,7 @@ public int hashCode() {
142158
@Override
143159
public boolean equals(Object other) {
144160
if (other instanceof SoftKey<?>) {
145-
return ((SoftKey<?>) other).get() == get();
161+
return Objects.equals(((SoftKey<?>) other).get(), get());
146162
} else {
147163
return other.equals(this);
148164
}
@@ -153,4 +169,4 @@ public String toString() {
153169
return String.valueOf(get());
154170
}
155171
}
156-
}
172+
}

0 commit comments

Comments
 (0)