Skip to content

Commit 723b659

Browse files
feat(ServiceDiscovery): Only return reachable address records
1 parent cd9b00c commit 723b659

File tree

3 files changed

+197
-1
lines changed

3 files changed

+197
-1
lines changed

src/IPAddressExtensions.cs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Net;
5+
using System.Net.NetworkInformation;
6+
using System.Net.Sockets;
7+
using System.Text;
8+
9+
namespace Makaretu.Dns
10+
{
11+
/// <summary>
12+
/// Extensions for <see cref="IPAddress"/>.
13+
/// </summary>
14+
public static class IPAddressExtensions
15+
{
16+
/// <summary>
17+
/// Gets the subnet mask associated with the IP address.
18+
/// </summary>
19+
/// <param name="address">
20+
/// An IP Addresses.
21+
/// </param>
22+
/// <returns>
23+
/// The subnet mask; ror example "127.0.0.1" returns "255.0.0.0".
24+
/// Or <b>null</b> When <paramref name="address"/> does not belong to
25+
/// the localhost.
26+
/// s</returns>
27+
public static IPAddress GetSubnetMask(this IPAddress address)
28+
{
29+
return NetworkInterface.GetAllNetworkInterfaces()
30+
.SelectMany(nic => nic.GetIPProperties().UnicastAddresses)
31+
.Where(a => a.Address.Equals(address))
32+
.Select(a => a.IPv4Mask)
33+
.FirstOrDefault();
34+
}
35+
36+
/// <summary>
37+
/// Determines if the local IP address can be used by the
38+
/// remote address.
39+
/// </summary>
40+
/// <param name="local"></param>
41+
/// <param name="remote"></param>
42+
/// <returns>
43+
/// <b>true</b> if <paramref name="local"/> can be used by <paramref name="remote"/>;
44+
/// otherwise, <b>false</b>.
45+
/// </returns>
46+
public static bool IsReachable(this IPAddress local, IPAddress remote)
47+
{
48+
// Loopback addresses are only reachable when the remote is
49+
// the same host.
50+
if (local.Equals(IPAddress.Loopback) || local.Equals(IPAddress.IPv6Loopback))
51+
{
52+
return MulticastService.GetIPAddresses().Contains(remote);
53+
}
54+
55+
// IPv4 addresses are reachable when on the same subnet.
56+
if (local.AddressFamily == AddressFamily.InterNetwork && remote.AddressFamily == AddressFamily.InterNetwork)
57+
{
58+
var mask = local.GetSubnetMask();
59+
if (mask != null)
60+
{
61+
var network = IPNetwork.Parse(local, mask);
62+
return network.Contains(remote);
63+
}
64+
}
65+
66+
// IPv6 link local addresses are reachabe when using the same scope id.
67+
if (local.AddressFamily == AddressFamily.InterNetworkV6 && remote.AddressFamily == AddressFamily.InterNetworkV6)
68+
{
69+
if (local.IsIPv6LinkLocal || remote.IsIPv6LinkLocal)
70+
{
71+
return local.Equals(remote);
72+
}
73+
}
74+
75+
// Can not determine reachability, assume that network routing can do it.
76+
return true;
77+
}
78+
}
79+
}

src/ServiceDiscovery.cs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ void OnQuery(object sender, MessageEventArgs e)
143143
if (log.IsDebugEnabled)
144144
log.Debug($"got query for {request.Questions[0].Name} {request.Questions[0].Type}");
145145
var response = NameServer.ResolveAsync(request).Result;
146+
146147
if (response.Status == MessageStatus.NoError)
147148
{
148149
// Many bonjour browsers don't like DNS-SD response
@@ -152,14 +153,25 @@ void OnQuery(object sender, MessageEventArgs e)
152153
response.AdditionalRecords.Clear();
153154
}
154155

156+
// Only return address records that the querier can reach.
157+
response.Answers.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
158+
response.AuthorityRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
159+
response.AdditionalRecords.RemoveAll(rr => IsUnreachable(rr, e.RemoteEndPoint));
160+
155161
Mdns.SendAnswer(response);
156162
if (log.IsDebugEnabled)
157163
log.Debug($"sent answer {response.Answers[0]}");
158164
//Console.WriteLine($"Response time {(DateTime.Now - request.CreationTime).TotalMilliseconds}ms");
159165
}
160166
}
161167

162-
#region IDisposable Support
168+
bool IsUnreachable(ResourceRecord rr, IPEndPoint sender)
169+
{
170+
var arecord = rr as AddressRecord;
171+
return !arecord?.Address.IsReachable(sender.Address) ?? false;
172+
}
173+
174+
#region IDisposable Support
163175

164176
/// <inheritdoc />
165177
protected virtual void Dispose(bool disposing)

test/IPAddressExtensionsTest.cs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
using Makaretu.Dns;
2+
using Microsoft.VisualStudio.TestTools.UnitTesting;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using System.Net;
7+
using System.Net.NetworkInformation;
8+
using System.Net.Sockets;
9+
10+
namespace Makaretu.Dns
11+
{
12+
13+
[TestClass]
14+
public class IPAddressExtensionsTest
15+
{
16+
[TestMethod]
17+
public void SubnetMask_Ipv4Loopback()
18+
{
19+
var mask = IPAddress.Loopback.GetSubnetMask();
20+
Assert.AreEqual(IPAddress.Parse("255.0.0.0"), mask);
21+
var network = IPNetwork.Parse(IPAddress.Loopback, mask);
22+
Console.Write(network.ToString());
23+
}
24+
25+
[TestMethod]
26+
public void SubnetMask_Ipv6Loopback()
27+
{
28+
var mask = IPAddress.IPv6Loopback.GetSubnetMask();
29+
Assert.AreEqual(IPAddress.Parse("0.0.0.0"), mask);
30+
}
31+
32+
[TestMethod]
33+
public void SubmetMask_NotLocalhost()
34+
{
35+
var mask = IPAddress.Parse("1.1.1.1").GetSubnetMask();
36+
Assert.IsNull(mask);
37+
}
38+
39+
[TestMethod]
40+
public void SubnetMask_All()
41+
{
42+
foreach (var a in MulticastService.GetIPAddresses())
43+
{
44+
var network = IPNetwork.Parse(a, a.GetSubnetMask());
45+
46+
Console.WriteLine($"{a} mask {a.GetSubnetMask()} {network}");
47+
48+
Assert.IsTrue(network.Contains(a), $"{a} is not reachable");
49+
}
50+
}
51+
52+
[TestMethod]
53+
public void LinkLocal()
54+
{
55+
foreach (var a in MulticastService.GetIPAddresses())
56+
{
57+
Console.WriteLine($"{a} ll={a.IsIPv6LinkLocal} ss={a.IsIPv6SiteLocal}");
58+
}
59+
}
60+
61+
[TestMethod]
62+
public void Reachable_Loopback_From_Localhost()
63+
{
64+
var me = IPAddress.Loopback;
65+
foreach (var a in MulticastService.GetIPAddresses())
66+
{
67+
Assert.IsTrue(me.IsReachable(a), $"{a}");
68+
}
69+
Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
70+
Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111")));
71+
72+
me = IPAddress.IPv6Loopback;
73+
foreach (var a in MulticastService.GetIPAddresses())
74+
{
75+
Assert.IsTrue(me.IsReachable(a), $"{a}");
76+
}
77+
Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
78+
Assert.IsFalse(me.IsReachable(IPAddress.Parse("2606:4700:4700::1111")));
79+
}
80+
81+
[TestMethod]
82+
public void Reachable_Ipv4()
83+
{
84+
var me = MulticastService.GetIPAddresses()
85+
.First(a => a.AddressFamily == AddressFamily.InterNetwork && !IPAddress.IsLoopback(a));
86+
Assert.IsTrue(me.IsReachable(me));
87+
Assert.IsFalse(me.IsReachable(IPAddress.Parse("1.1.1.1")));
88+
89+
var nat = IPAddress.Parse("165.84.19.151"); // NAT PCP assigned address
90+
Assert.IsTrue(nat.IsReachable(IPAddress.Parse("1.1.1.1")));
91+
}
92+
93+
[TestMethod]
94+
public void Reachable_Ipv6_LinkLocal()
95+
{
96+
var me1 = IPAddress.Parse("fe80::1:2:3:4%1");
97+
var me2 = IPAddress.Parse("fe80::1:2:3:4%2");
98+
var me5 = IPAddress.Parse("fe80::1:2:3:5%1");
99+
Assert.IsTrue(me1.IsReachable(me1));
100+
Assert.IsTrue(me2.IsReachable(me2));
101+
Assert.IsFalse(me1.IsReachable(me2));
102+
Assert.IsFalse(me1.IsReachable(me5));
103+
}
104+
}
105+
}

0 commit comments

Comments
 (0)