Skip to content

Commit 0668996

Browse files
yongjiajunnabekens-hennge
committed
refactor: Make SPF2IPResolver thread-safe
Co-Authored-By: Tanabe Ken-ichi <nabeken@tknetworks.org> Co-Authored-By: s-hennge <156737143+s-hennge@users.noreply.github.com>
1 parent cda9a9b commit 0668996

File tree

2 files changed

+40
-42
lines changed

2 files changed

+40
-42
lines changed

spf2ip.go

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,8 @@ var (
2525
)
2626

2727
type SPF2IPResolver struct {
28-
ipVersion int
2928
netResolver NetResolver
3029
debugLogging bool
31-
32-
// Map of visited domains in the current path for loop detection: domain -> struct{}
33-
domainsVisitedInCurrentPath map[string]struct{}
34-
// Map of resolved IPs for each domain: domain -> IPs -> struct{}
35-
resolvedIPsCache map[string]map[string]struct{}
3630
}
3731

3832
//go:generate mockgen -package spf2ip -source spf2ip.go -destination netresolver_mock.go
@@ -54,11 +48,9 @@ func (r *SPF2IPResolver) Resolve(ctx context.Context, domain string, ipVersion i
5448
return nil, fmt.Errorf("%w: %d", ErrInvalidIPVersion, ipVersion)
5549
}
5650

57-
r.ipVersion = ipVersion
58-
r.domainsVisitedInCurrentPath = make(map[string]struct{})
59-
r.resolvedIPsCache = make(map[string]map[string]struct{})
60-
61-
finalIPs, err := r.processDomain(ctx, domain, 0)
51+
finalIPs, err := r.processDomain(
52+
ctx, ipVersion, make(map[string]struct{}), make(map[string]map[string]struct{}), domain, 0,
53+
)
6254
if err != nil {
6355
return nil, err
6456
}
@@ -77,27 +69,34 @@ func (r *SPF2IPResolver) Resolve(ctx context.Context, domain string, ipVersion i
7769
return result, nil
7870
}
7971

80-
func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth int) (map[string]struct{}, error) {
72+
func (r *SPF2IPResolver) processDomain(
73+
ctx context.Context,
74+
ipVersion int,
75+
domainsVisitedInCurrentPath map[string]struct{},
76+
resolvedIPsCache map[string]map[string]struct{},
77+
domain string,
78+
depth int,
79+
) (map[string]struct{}, error) {
8180
if depth > maxSPFIncludeDepth {
8281
return nil, fmt.Errorf("%w: %s (depth %d)", ErrExceededMaxDepth, domain, depth)
8382
}
8483

8584
domain = strings.ToLower(strings.TrimSpace(domain))
8685

8786
// Check if this domain's SPF is already resolved and cached.
88-
if cachedIPs, found := r.resolvedIPsCache[domain]; found {
87+
if cachedIPs, found := resolvedIPsCache[domain]; found {
8988
r.debugLogPrintf("Debug: Using cached result for domain: %s", domain)
9089
return deepCopyMap(cachedIPs), nil
9190
}
9291

9392
// Check for loops in the current resolution path.
94-
if _, visited := r.domainsVisitedInCurrentPath[domain]; visited {
93+
if _, visited := domainsVisitedInCurrentPath[domain]; visited {
9594
r.debugLogPrintf("Debug: Loop detected for domain %s", domain)
9695
return nil, fmt.Errorf("%w: %s", ErrLoopDetected, domain)
9796
}
9897

99-
r.domainsVisitedInCurrentPath[domain] = struct{}{}
100-
defer delete(r.domainsVisitedInCurrentPath, domain)
98+
domainsVisitedInCurrentPath[domain] = struct{}{}
99+
defer delete(domainsVisitedInCurrentPath, domain)
101100

102101
r.debugLogPrintf("Debug: Processing domain: %s (depth %d)", domain, depth)
103102

@@ -106,14 +105,14 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
106105
spfString, err := r.getSPFRecord(ctx, domain)
107106
if err != nil && !errors.Is(err, errIgnorableDNSErr) {
108107
r.debugLogPrintf("Debug: Failed to get SPF record for %s: %v", domain, err)
109-
r.resolvedIPsCache[domain] = nil
108+
resolvedIPsCache[domain] = nil
110109

111110
return nil, fmt.Errorf("spf2ip: failed to get SPF record for %s: %w", domain, err)
112111
}
113112

114113
if spfString == "" {
115114
r.debugLogPrintf("Debug: No SPF record found for %s, treating as empty", domain)
116-
r.resolvedIPsCache[domain] = currentDomainIPs
115+
resolvedIPsCache[domain] = currentDomainIPs
117116

118117
return currentDomainIPs, nil
119118
}
@@ -140,23 +139,23 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
140139

141140
switch mechanism {
142141
case "ip4":
143-
if r.ipVersion == ipv4 {
144-
if err := r.addIPOrCIDRToSet(value, currentDomainIPs); err != nil {
142+
if ipVersion == ipv4 {
143+
if err := r.addIPOrCIDRToSet(ipVersion, value, currentDomainIPs); err != nil {
145144
return nil, fmt.Errorf("spf2ip: failed to add IP/CIDR for ip4 mechanism in %s: %w", domain, err)
146145
}
147146
}
148147

149148
case "ip6":
150-
if r.ipVersion == ipv6 {
151-
if err := r.addIPOrCIDRToSet(value, currentDomainIPs); err != nil {
149+
if ipVersion == ipv6 {
150+
if err := r.addIPOrCIDRToSet(ipVersion, value, currentDomainIPs); err != nil {
152151
return nil, fmt.Errorf("spf2ip: failed to add IP/CIDR for ip6 mechanism in %s: %w", domain, err)
153152
}
154153
}
155154

156155
case "a":
157156
targetHost, maskSuffix := parseSPFMechanismTargetAndMask(domain, value)
158157

159-
ips, err := r.netResolver.LookupIP(ctx, r.lookupIPNetwork(), targetHost)
158+
ips, err := r.netResolver.LookupIP(ctx, lookupIPNetwork(ipVersion), targetHost)
160159
if err != nil {
161160
if isDNSErrIgnorable(err) {
162161
r.debugLogPrintf("Debug: Ignorable DNS error for A/AAAA lookup of %s (directive in %s): %v", targetHost, domain, err)
@@ -167,7 +166,7 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
167166
}
168167

169168
for _, ip := range ips {
170-
if err := r.addIPOrCIDRToSet(ip.String()+maskSuffix, currentDomainIPs); err != nil {
169+
if err := r.addIPOrCIDRToSet(ipVersion, ip.String()+maskSuffix, currentDomainIPs); err != nil {
171170
return nil, fmt.Errorf("spf2ip: failed to add IP/CIDR for A mechanism in %s: %w", domain, err)
172171
}
173172
}
@@ -188,7 +187,7 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
188187
for _, mx := range mxs {
189188
mxHost := strings.TrimSuffix(mx.Host, ".")
190189

191-
ips, err := r.netResolver.LookupIP(ctx, r.lookupIPNetwork(), mxHost)
190+
ips, err := r.netResolver.LookupIP(ctx, lookupIPNetwork(ipVersion), mxHost)
192191
if err != nil {
193192
if !isDNSErrIgnorable(err) {
194193
return nil, fmt.Errorf("A/AAAA lookup failed for MX host %s (directive in %s): %w", mxHost, domain, err)
@@ -200,7 +199,7 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
200199
}
201200

202201
for _, ip := range ips {
203-
if err := r.addIPOrCIDRToSet(ip.String()+maskSuffix, currentDomainIPs); err != nil {
202+
if err := r.addIPOrCIDRToSet(ipVersion, ip.String()+maskSuffix, currentDomainIPs); err != nil {
204203
return nil, fmt.Errorf("spf2ip: failed to add IP/CIDR for MX mechanism in %s: %w", domain, err)
205204
}
206205
}
@@ -209,12 +208,12 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
209208
case "include":
210209
if value == "" {
211210
r.debugLogPrintf("Debug: 'include' modifier without domain in %s", domain)
212-
r.resolvedIPsCache[domain] = nil
211+
resolvedIPsCache[domain] = nil
213212

214213
return nil, fmt.Errorf("spf2ip: include without domain in %s", domain)
215214
}
216215

217-
includedIPs, includeErr := r.processDomain(ctx, value, depth+1)
216+
includedIPs, includeErr := r.processDomain(ctx, ipVersion, domainsVisitedInCurrentPath, resolvedIPsCache, value, depth+1)
218217
if includeErr != nil {
219218
return nil, fmt.Errorf("spf2ip: include failed for %s (directive in %s): %w", value, domain, includeErr)
220219
}
@@ -226,16 +225,16 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
226225
case "redirect":
227226
if value == "" {
228227
r.debugLogPrintf("Debug: 'redirect' modifier without domain in %s", domain)
229-
r.resolvedIPsCache[domain] = nil
228+
resolvedIPsCache[domain] = nil
230229

231230
return nil, fmt.Errorf("spf2ip: redirect without domain in %s", domain)
232231
}
233232

234233
r.debugLogPrintf("Debug: Redirecting from %s to %s. Discarding IPs found so far for %s.", domain, value, domain)
235234

236235
// The result of this domain's processing is now entirely determined by the redirect target.
237-
redirectedIPs, redirectErr := r.processDomain(ctx, value, depth+1)
238-
r.resolvedIPsCache[domain] = deepCopyMap(redirectedIPs) // Overwrite cache with redirected IPs
236+
redirectedIPs, redirectErr := r.processDomain(ctx, ipVersion, domainsVisitedInCurrentPath, resolvedIPsCache, value, depth+1)
237+
resolvedIPsCache[domain] = deepCopyMap(redirectedIPs) // Overwrite cache with redirected IPs
239238

240239
return redirectedIPs, redirectErr
241240

@@ -248,14 +247,14 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
248247
}
249248
}
250249

251-
r.resolvedIPsCache[domain] = deepCopyMap(currentDomainIPs)
250+
resolvedIPsCache[domain] = deepCopyMap(currentDomainIPs)
252251

253252
return currentDomainIPs, nil
254253
}
255254

256255
// lookupIPNetwork returns the appropriate network type for IP lookups based on the resolver's IP version.
257-
func (r *SPF2IPResolver) lookupIPNetwork() string {
258-
switch r.ipVersion {
256+
func lookupIPNetwork(ipVersion int) string {
257+
switch ipVersion {
259258
case ipv4:
260259
return "ip4"
261260
case ipv6:
@@ -319,35 +318,35 @@ func (r *SPF2IPResolver) getSPFRecord(ctx context.Context, domain string) (strin
319318
}
320319

321320
// addIPOrCIDRToSet adds an IP or CIDR block to the target set, ensuring that plain IPs are converted to CIDR notation.
322-
func (r *SPF2IPResolver) addIPOrCIDRToSet(value string, targetSet map[string]struct{}) error {
321+
func (r *SPF2IPResolver) addIPOrCIDRToSet(ipVersion int, value string, targetSet map[string]struct{}) error {
323322
value = strings.TrimSpace(value)
324323

325324
// Try CIDR first
326325
if ip, ipNet, err := net.ParseCIDR(value); err == nil {
327-
if (r.ipVersion == ipv4 && ip.To4() != nil) || (r.ipVersion == ipv6 && ip.To4() == nil && ip.To16() != nil) {
326+
if (ipVersion == ipv4 && ip.To4() != nil) || (ipVersion == ipv6 && ip.To4() == nil && ip.To16() != nil) {
328327
targetSet[ipNet.String()] = struct{}{}
329328
return nil
330329
}
331330

332-
return fmt.Errorf("spf2ip: CIDR '%s' is not of the required IP version (v%d)", value, r.ipVersion)
331+
return fmt.Errorf("spf2ip: CIDR '%s' is not of the required IP version (v%d)", value, ipVersion)
333332
}
334333

335334
// Try plain IP
336335
ip := net.ParseIP(value)
337336
if ip != nil {
338-
if r.ipVersion == ipv4 && ip.To4() != nil {
337+
if ipVersion == ipv4 && ip.To4() != nil {
339338
// Convert IPv4 to CIDR notation
340339
targetSet[ip.To4().String()+"/32"] = struct{}{}
341340
return nil
342341
}
343342

344-
if r.ipVersion == ipv6 && ip.To4() == nil && ip.To16() != nil {
343+
if ipVersion == ipv6 && ip.To4() == nil && ip.To16() != nil {
345344
// Convert IPv6 to CIDR notation
346345
targetSet[ip.String()+"/128"] = struct{}{}
347346
return nil
348347
}
349348

350-
return fmt.Errorf("spf2ip: IP address '%s' is not of the required IP version (v%d)", value, r.ipVersion)
349+
return fmt.Errorf("spf2ip: IP address '%s' is not of the required IP version (v%d)", value, ipVersion)
351350
}
352351

353352
return fmt.Errorf("spf2ip: value '%s' is not a valid IP address or CIDR block", value)

spf2ip_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,7 @@ func TestAddIPOrCIDRToSet(t *testing.T) {
583583
t.Run(description, func(t *testing.T) {
584584
ipsMap := make(map[string]struct{})
585585

586-
s.spf2IPResolver.ipVersion = tc.ipVersion
587-
err := s.spf2IPResolver.addIPOrCIDRToSet(tc.value, ipsMap)
586+
err := s.spf2IPResolver.addIPOrCIDRToSet(tc.ipVersion, tc.value, ipsMap)
588587

589588
assert.Equal(t, tc.expectErr, err != nil)
590589

0 commit comments

Comments
 (0)