Skip to content

Commit 4ee260f

Browse files
committed
refactor: Cleanup and add function clarifications
1 parent 530890f commit 4ee260f

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

cidr_sorter.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ func (s cidrSorter) Less(i, j int) bool {
4242
return p1 < p2 // Smaller prefix (wider network) comes first
4343
}
4444

45+
// parseCIDR parses a CIDR notation string and returns the IP, prefix length, and whether it's IPv4.
4546
func parseCIDR(s string) (ip net.IP, prefix int, isV4 bool) {
4647
_, ipNet, err := net.ParseCIDR(s)
4748
if err != nil {

spf2ip.go

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,11 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
104104
currentDomainIPs := make(map[string]struct{})
105105

106106
spfString, err := r.getSPFRecord(ctx, domain)
107-
if err != nil {
108-
if !errors.Is(err, errIgnorableDNSErr) {
109-
log.Printf("Warning: Failed to get SPF record for %s: %v", domain, err)
110-
r.resolvedIPsCache[domain] = nil
107+
if err != nil && !errors.Is(err, errIgnorableDNSErr) {
108+
log.Printf("Warning: Failed to get SPF record for %s: %v", domain, err)
109+
r.resolvedIPsCache[domain] = nil
111110

112-
return nil, err
113-
}
111+
return nil, err
114112
}
115113

116114
if spfString == "" {
@@ -209,37 +207,37 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
209207
}
210208

211209
case "include":
212-
if value != "" {
213-
includedIPs, includeErr := r.processDomain(ctx, value, depth+1)
214-
if includeErr != nil {
215-
return nil, fmt.Errorf("include failed for %s (directive in %s): %w", value, domain, includeErr)
216-
} else {
217-
for ip := range includedIPs {
218-
currentDomainIPs[ip] = struct{}{}
219-
}
220-
}
221-
} else {
210+
if value == "" {
222211
log.Printf("Warning: 'include' modifier without domain in %s", domain)
223212
r.resolvedIPsCache[domain] = nil
224213

225214
return nil, fmt.Errorf("include without domain in %s", domain)
226215
}
227216

228-
case "redirect":
229-
if value != "" {
230-
r.debugLogPrintf("Debug: Redirecting from %s to %s. Discarding IPs found so far for %s.", domain, value, domain)
217+
includedIPs, includeErr := r.processDomain(ctx, value, depth+1)
218+
if includeErr != nil {
219+
return nil, fmt.Errorf("include failed for %s (directive in %s): %w", value, domain, includeErr)
220+
}
221+
222+
for ip := range includedIPs {
223+
currentDomainIPs[ip] = struct{}{}
224+
}
231225

232-
// The result of this domain's processing is now entirely determined by the redirect target.
233-
redirectedIPs, redirectErr := r.processDomain(ctx, value, depth+1)
234-
r.resolvedIPsCache[domain] = deepCopyMap(redirectedIPs) // Overwrite cache with redirected IPs
226+
case "redirect":
227+
if value == "" {
228+
log.Printf("Warning: 'redirect' modifier without domain in %s", domain)
229+
r.resolvedIPsCache[domain] = nil
235230

236-
return redirectedIPs, redirectErr
231+
return nil, fmt.Errorf("redirect without domain in %s", domain)
237232
}
238233

239-
log.Printf("Warning: 'redirect' modifier without domain in %s", domain)
240-
r.resolvedIPsCache[domain] = nil
234+
r.debugLogPrintf("Debug: Redirecting from %s to %s. Discarding IPs found so far for %s.", domain, value, domain)
235+
236+
// The result of this domain's processing is now entirely determined by the redirect target.
237+
redirectedIPs, redirectErr := r.processDomain(ctx, value, depth+1)
238+
r.resolvedIPsCache[domain] = deepCopyMap(redirectedIPs) // Overwrite cache with redirected IPs
241239

242-
return nil, fmt.Errorf("redirect without domain in %s", domain)
240+
return redirectedIPs, redirectErr
243241

244242
case "exists", "ptr", "all":
245243
// These mechanisms don't directly define IPs in the same way, ignore for IP extraction.
@@ -255,6 +253,7 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
255253
return currentDomainIPs, nil
256254
}
257255

256+
// lookupIPNetwork returns the appropriate network type for IP lookups based on the resolver's IP version.
258257
func (r *SPF2IPResolver) lookupIPNetwork() string {
259258
switch r.ipVersion {
260259
case ipv4:
@@ -266,6 +265,7 @@ func (r *SPF2IPResolver) lookupIPNetwork() string {
266265
}
267266
}
268267

268+
// parseSPFMechanismTargetAndMask extracts the target host and mask suffix from an SPF mechanism value.
269269
func parseSPFMechanismTargetAndMask(defaultDomain, mechanismValue string) (targetHost, maskSuffix string) {
270270
targetHost = defaultDomain
271271
maskSuffix = ""
@@ -353,7 +353,8 @@ func (r *SPF2IPResolver) addIPOrCIDRToSet(value string, targetSet map[string]str
353353
return fmt.Errorf("value '%s' is not a valid IP address or CIDR block", value)
354354
}
355355

356-
func (r *SPF2IPResolver) debugLogPrintf(format string, args ...interface{}) {
356+
// debugLogPrintf logs debug messages if debug logging is enabled.
357+
func (r *SPF2IPResolver) debugLogPrintf(format string, args ...any) {
357358
if r.debugLogging {
358359
log.Printf(format, args...)
359360
}

spf2ip_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func TestResolve(t *testing.T) {
3737
[]string{"v=spf1 ip4:1.2.3.0/28 include:first.included.com include:second.included.com a mx:mx.example.com -all"}, nil,
3838
).Times(1)
3939
s.netResolver.EXPECT().LookupTXT(gomock.Any(), "first.included.com").Return(
40-
[]string{"v=spf1 ip4:5.6.7.8 ip6:2001:db8::1 include:included.again.com -all"}, nil,
40+
[]string{"v=spf1 ip4:5.6.7.8 ip6:2001:db8::1 include:nxdomain.included.com -all"}, nil,
4141
).Times(1)
4242
s.netResolver.EXPECT().LookupTXT(gomock.Any(), "second.included.com").Return(
4343
[]string{"v=spf1 include:first.included.com ip4:1.0.0.0/24 -all"}, nil,
@@ -57,7 +57,7 @@ func TestResolve(t *testing.T) {
5757
s.netResolver.EXPECT().LookupIP(gomock.Any(), "ip4", "example.com").Return(
5858
[]net.IP{net.ParseIP("8.8.8.8")}, nil,
5959
).Times(1)
60-
s.netResolver.EXPECT().LookupTXT(gomock.Any(), "included.again.com").Return(
60+
s.netResolver.EXPECT().LookupTXT(gomock.Any(), "nxdomain.included.com").Return(
6161
nil, &net.DNSError{IsNotFound: true}, // Simulating an ignorable DNS error
6262
).Times(1)
6363

@@ -304,7 +304,7 @@ func TestResolve_ExceededMaxDepthErr(t *testing.T) {
304304
[]string{"v=spf1 include:included0.com -all"}, nil,
305305
).Times(1)
306306

307-
for i := 0; i < maxSPFIncludeDepth; i++ { // excluding the first include
307+
for i := range maxSPFIncludeDepth { // excluding the first include
308308
s.netResolver.EXPECT().LookupTXT(gomock.Any(), fmt.Sprintf("included%d.com", i)).Return(
309309
[]string{fmt.Sprintf("v=spf1 include:included%d.com -all", i+1)}, nil,
310310
).Times(1)

utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"maps"
55
)
66

7-
// Helper to copy a map, as maps are reference types.
7+
// deepCopyMap is a helper for copying a map, as maps are reference types.
88
func deepCopyMap(originalMap map[string]struct{}) map[string]struct{} {
99
if originalMap == nil {
1010
return nil

0 commit comments

Comments
 (0)