Skip to content

Commit a52ff03

Browse files
authored
fix: Add thread safeties to subscriptions list (#247)
1 parent b60b5c6 commit a52ff03

File tree

4 files changed

+197
-38
lines changed

4 files changed

+197
-38
lines changed

Source/HiveMQtt/Client/HiveMQClient.cs

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ public HiveMQClient(HiveMQClientOptions? options = null)
6666
/// <inheritdoc />
6767
public List<Subscription> Subscriptions { get; } = new();
6868

69+
private SemaphoreSlim SubscriptionsSemaphore { get; } = new(1, 1);
70+
6971
/// <inheritdoc />
7072
public bool IsConnected() => this.Connection.State == ConnectState.Connected;
7173

@@ -479,8 +481,16 @@ void TaskHandler(object? sender, OnSubAckReceivedEventArgs args)
479481
subscription.MessageReceivedHandler = handler.Value;
480482
}
481483
}
484+
}
482485

483-
this.Subscriptions.Add(subscription);
486+
try
487+
{
488+
await this.SubscriptionsSemaphore.WaitAsync().ConfigureAwait(false);
489+
this.Subscriptions.AddRange(subscribeResult.Subscriptions);
490+
}
491+
finally
492+
{
493+
_ = this.SubscriptionsSemaphore.Release();
484494
}
485495

486496
// Fire the corresponding event
@@ -508,9 +518,17 @@ public async Task<UnsubscribeResult> UnsubscribeAsync(string topic)
508518
/// <inheritdoc />
509519
public async Task<UnsubscribeResult> UnsubscribeAsync(Subscription subscription)
510520
{
511-
if (!this.Subscriptions.Contains(subscription))
521+
try
512522
{
513-
throw new HiveMQttClientException("No such subscription found. Make sure to take subscription(s) from HiveMQClient.Subscriptions[] or HiveMQClient.GetSubscriptionByTopic().");
523+
await this.SubscriptionsSemaphore.WaitAsync().ConfigureAwait(false);
524+
if (!this.Subscriptions.Contains(subscription))
525+
{
526+
throw new HiveMQttClientException("No such subscription found. Make sure to take subscription(s) from HiveMQClient.Subscriptions[] or HiveMQClient.GetSubscriptionByTopic().");
527+
}
528+
}
529+
finally
530+
{
531+
_ = this.SubscriptionsSemaphore.Release();
514532
}
515533

516534
var unsubOptions = new UnsubscribeOptionsBuilder()
@@ -523,11 +541,22 @@ public async Task<UnsubscribeResult> UnsubscribeAsync(Subscription subscription)
523541
/// <inheritdoc />
524542
public async Task<UnsubscribeResult> UnsubscribeAsync(List<Subscription> subscriptions)
525543
{
526-
for (var i = 0; i < subscriptions.Count; i++)
544+
HashSet<Subscription> currentSubscriptions;
545+
try
527546
{
528-
if (!this.Subscriptions.Contains(subscriptions[i]))
547+
await this.SubscriptionsSemaphore.WaitAsync().ConfigureAwait(false);
548+
currentSubscriptions = this.Subscriptions.ToHashSet();
549+
}
550+
finally
551+
{
552+
_ = this.SubscriptionsSemaphore.Release();
553+
}
554+
555+
foreach (var sub in subscriptions)
556+
{
557+
if (!currentSubscriptions.Contains(sub))
529558
{
530-
throw new HiveMQttClientException("No such subscription found. Make sure to take subscription(s) from HiveMQClient.Subscriptions[] or HiveMQClient.GetSubscriptionByTopic().");
559+
throw new HiveMQttClientException("No such subscription found. Make sure to take subscription(s) from HiveMQClient.Subscriptions[] or HiveMQClient.GetSubscriptionByTopic().");
531560
}
532561
}
533562

@@ -589,13 +618,28 @@ void TaskHandler(object? sender, OnUnsubAckReceivedEventArgs args)
589618
};
590619

591620
var counter = 0;
621+
var subscriptionsToRemove = new List<Subscription>();
592622
foreach (var reasonCode in unsubAck.ReasonCodes)
593623
{
594624
unsubscribeResult.Subscriptions[counter].UnsubscribeReasonCode = reasonCode;
595625
if (reasonCode == UnsubAckReasonCode.Success)
596626
{
597-
// Remove the subscription from the client
598-
this.Subscriptions.Remove(unsubscribeResult.Subscriptions[counter]);
627+
// Collect subscriptions which need to be removed
628+
subscriptionsToRemove.Add(unsubscribeResult.Subscriptions[counter]);
629+
}
630+
}
631+
632+
if (subscriptionsToRemove.Count != 0)
633+
{
634+
try
635+
{
636+
// remove subscriptions from client while locking them
637+
await this.SubscriptionsSemaphore.WaitAsync().ConfigureAwait(false);
638+
_ = this.Subscriptions.RemoveAll(subscriptionsToRemove.Contains);
639+
}
640+
finally
641+
{
642+
_ = this.SubscriptionsSemaphore.Release();
599643
}
600644
}
601645

Source/HiveMQtt/Client/HiveMQClientEvents.cs

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -266,26 +266,46 @@ internal virtual void OnMessageReceivedEventLauncher(PublishPacket packet)
266266
messageHandled = true;
267267
}
268268

269+
if (packet.Message.Topic is null)
270+
{
271+
return;
272+
}
273+
269274
// Per Subscription Event Handler
270-
foreach (var subscription in this.Subscriptions)
275+
// use ToList, so the iteration goes through a copy and changes at the list make not problems
276+
// otherwise it would be necessary to lock the Subscriptions with the semaphore of HiveMQClient
277+
List<Subscription> tempList;
278+
try
279+
{
280+
this.SubscriptionsSemaphore.Wait();
281+
tempList = this.Subscriptions.ToList();
282+
}
283+
finally
284+
{
285+
_ = this.SubscriptionsSemaphore.Release();
286+
}
287+
288+
var matchingSubscriptions = tempList.Where(sub =>
289+
sub.MessageReceivedHandler is not null &&
290+
MatchTopic(sub.TopicFilter.Topic, packet.Message.Topic));
291+
292+
foreach (var subscription in matchingSubscriptions)
271293
{
272-
if (packet.Message.Topic != null && MatchTopic(subscription.TopicFilter.Topic, packet.Message.Topic))
294+
// We have a per-subscription message handler.
295+
_ = Task.Run(() =>
273296
{
274-
if (subscription.MessageReceivedHandler != null && subscription.MessageReceivedHandler.GetInvocationList().Length > 0)
297+
try
275298
{
276-
// We have a per-subscription message handler.
277-
_ = Task.Run(() => subscription.MessageReceivedHandler?.Invoke(this, eventArgs)).ContinueWith(
278-
t =>
279-
{
280-
if (t.IsFaulted)
281-
{
282-
Logger.Error($"per-subscription MessageReceivedEventLauncher faulted ({packet.Message.Topic}): {t.Exception?.Message}");
283-
}
284-
}, TaskScheduler.Default);
285-
286-
messageHandled = true;
299+
subscription.MessageReceivedHandler?.Invoke(this, eventArgs);
287300
}
288-
}
301+
catch (Exception e)
302+
{
303+
Logger.Error(
304+
$"per-subscription MessageReceivedEventLauncher faulted ({packet.Message.Topic}): {e.Message}");
305+
}
306+
});
307+
308+
messageHandled = true;
289309
}
290310

291311
if (!messageHandled)

Source/HiveMQtt/Client/HiveMQClientUtil.cs

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,18 @@ public partial class HiveMQClient : IDisposable, IHiveMQClient
3131
/// <returns>A boolean indicating whether the subscription exists.</returns>
3232
internal bool SubscriptionExists(Subscription subscription)
3333
{
34-
if (this.Subscriptions.Contains(subscription))
34+
List<Subscription> tempList;
35+
try
3536
{
36-
return true;
37+
this.SubscriptionsSemaphore.Wait();
38+
tempList = this.Subscriptions.ToList();
3739
}
38-
39-
foreach (var candidate in this.Subscriptions)
40+
finally
4041
{
41-
if (candidate.TopicFilter.Topic == subscription.TopicFilter.Topic)
42-
{
43-
return true;
44-
}
42+
_ = this.SubscriptionsSemaphore.Release();
4543
}
4644

47-
return false;
45+
return tempList.Any(s => s.TopicFilter.Topic == subscription.TopicFilter.Topic);
4846
}
4947

5048
/// <summary>
@@ -54,15 +52,18 @@ internal bool SubscriptionExists(Subscription subscription)
5452
/// <returns>The subscription or null if not found.</returns>
5553
internal Subscription? GetSubscriptionByTopic(string topic)
5654
{
57-
foreach (var subscription in this.Subscriptions)
55+
List<Subscription> tempList;
56+
try
5857
{
59-
if (subscription.TopicFilter.Topic == topic)
60-
{
61-
return subscription;
62-
}
58+
this.SubscriptionsSemaphore.Wait();
59+
tempList = this.Subscriptions.ToList();
60+
}
61+
finally
62+
{
63+
_ = this.SubscriptionsSemaphore.Release();
6364
}
6465

65-
return null;
66+
return tempList.FirstOrDefault(s => s.TopicFilter.Topic == topic);
6667
}
6768

6869
/// <summary>
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
namespace HiveMQtt.Test.HiveMQClient;
2+
3+
using System.Collections.Concurrent;
4+
using Client;
5+
using Client.Options;
6+
using MQTT5.ReasonCodes;
7+
using MQTT5.Types;
8+
using Xunit;
9+
using Xunit.Abstractions;
10+
11+
public class ThreadSafeSubscribeUnsubscribeTest
12+
{
13+
private readonly ITestOutputHelper testOutputHelper;
14+
15+
public ThreadSafeSubscribeUnsubscribeTest(ITestOutputHelper testOutputHelper)
16+
{
17+
this.testOutputHelper = testOutputHelper;
18+
}
19+
20+
[Fact]
21+
public async Task SubscribeUnsubscribe_InManyThreadsAsync()
22+
{
23+
const int workerCount = 100;
24+
const int iterationsPerWorker = 100;
25+
const int topicsPerIteration = 10;
26+
const int publishesPerIteration = 10;
27+
const int totalExpectedSuccesses = workerCount * iterationsPerWorker;
28+
29+
var options = new HiveMQClientOptionsBuilder().WithClientId("ConcurrentSubscribeUnsubscribeAndPublish").Build();
30+
options.ResponseTimeoutInMs = 20000;
31+
var client = new HiveMQClient(options);
32+
Assert.NotNull(client);
33+
34+
var connectResult = await client.ConnectAsync().ConfigureAwait(false);
35+
Assert.True(connectResult.ReasonCode == ConnAckReasonCode.Success);
36+
Assert.True(client.IsConnected());
37+
38+
client.OnMessageReceived += (_, args) => { };
39+
_ = await client.SubscribeAsync("/test/#").ConfigureAwait(false);
40+
41+
var exceptionMessages = new ConcurrentBag<string>();
42+
var successCount = 0;
43+
var tasks = new List<Task>();
44+
45+
foreach (var workerId in Enumerable.Range(0, workerCount))
46+
{
47+
tasks.Add(Task.Run(async () =>
48+
{
49+
for (var i = 0; i < iterationsPerWorker; i++)
50+
{
51+
var topicPrefix = $"/test/topic/{workerId}/{i}";
52+
53+
var topicsToManage = Enumerable.Range(0, topicsPerIteration)
54+
.Select(j => $"{topicPrefix}/{(char)('a' + j)}")
55+
.ToList();
56+
57+
try
58+
{
59+
var topicFilters = topicsToManage.Select(topic => new TopicFilter(topic, QualityOfService.ExactlyOnceDelivery)).ToList();
60+
var subscribeOptions = new SubscribeOptions { TopicFilters = topicFilters };
61+
_ = await client.SubscribeAsync(subscribeOptions).ConfigureAwait(false);
62+
63+
var publishTasks = new List<Task>(publishesPerIteration * 3);
64+
for (var j = 0; j < publishesPerIteration; j++)
65+
{
66+
publishTasks.Add(client.PublishAsync(topicsToManage.First(), "Hello World"));
67+
publishTasks.Add(client.PublishAsync(topicsToManage.Last(), "Hello World"));
68+
publishTasks.Add(client.PublishAsync("/unknown/topic", "Hello World"));
69+
}
70+
71+
await Task.WhenAll(publishTasks).ConfigureAwait(false);
72+
73+
var subscriptions = topicsToManage.Select(topic => new Subscription(new TopicFilter(topic))).ToList();
74+
var unsubscribeOptions = new UnsubscribeOptions { Subscriptions = subscriptions };
75+
_ = await client.UnsubscribeAsync(unsubscribeOptions).ConfigureAwait(false);
76+
77+
_ = Interlocked.Increment(ref successCount);
78+
}
79+
catch (Exception e)
80+
{
81+
var errorMessage = $"Worker {workerId}, Iteration {i}: {e}";
82+
exceptionMessages.Add(errorMessage);
83+
this.testOutputHelper.WriteLine(errorMessage);
84+
}
85+
}
86+
}));
87+
}
88+
89+
await Task.WhenAll(tasks).ConfigureAwait(false);
90+
91+
Assert.Equal(totalExpectedSuccesses, successCount);
92+
Assert.Empty(exceptionMessages);
93+
}
94+
}

0 commit comments

Comments
 (0)