25
25
)
26
26
27
27
type SPF2IPResolver struct {
28
- ipVersion int
29
28
netResolver NetResolver
30
29
debugLogging bool
31
-
32
- // Map of visited domains in the current path for loop detection: domain -> struct{}
33
- domainsVisitedInCurrentPath map [string ]struct {}
34
- // Map of resolved IPs for each domain: domain -> IPs -> struct{}
35
- resolvedIPsCache map [string ]map [string ]struct {}
36
30
}
37
31
38
32
//go:generate mockgen -package spf2ip -source spf2ip.go -destination netresolver_mock.go
@@ -54,11 +48,9 @@ func (r *SPF2IPResolver) Resolve(ctx context.Context, domain string, ipVersion i
54
48
return nil , fmt .Errorf ("%w: %d" , ErrInvalidIPVersion , ipVersion )
55
49
}
56
50
57
- r .ipVersion = ipVersion
58
- r .domainsVisitedInCurrentPath = make (map [string ]struct {})
59
- r .resolvedIPsCache = make (map [string ]map [string ]struct {})
60
-
61
- finalIPs , err := r .processDomain (ctx , domain , 0 )
51
+ finalIPs , err := r .processDomain (
52
+ ctx , ipVersion , make (map [string ]struct {}), make (map [string ]map [string ]struct {}), domain , 0 ,
53
+ )
62
54
if err != nil {
63
55
return nil , err
64
56
}
@@ -77,27 +69,34 @@ func (r *SPF2IPResolver) Resolve(ctx context.Context, domain string, ipVersion i
77
69
return result , nil
78
70
}
79
71
80
- func (r * SPF2IPResolver ) processDomain (ctx context.Context , domain string , depth int ) (map [string ]struct {}, error ) {
72
+ func (r * SPF2IPResolver ) processDomain (
73
+ ctx context.Context ,
74
+ ipVersion int ,
75
+ domainsVisitedInCurrentPath map [string ]struct {},
76
+ resolvedIPsCache map [string ]map [string ]struct {},
77
+ domain string ,
78
+ depth int ,
79
+ ) (map [string ]struct {}, error ) {
81
80
if depth > maxSPFIncludeDepth {
82
81
return nil , fmt .Errorf ("%w: %s (depth %d)" , ErrExceededMaxDepth , domain , depth )
83
82
}
84
83
85
84
domain = strings .ToLower (strings .TrimSpace (domain ))
86
85
87
86
// Check if this domain's SPF is already resolved and cached.
88
- if cachedIPs , found := r . resolvedIPsCache [domain ]; found {
87
+ if cachedIPs , found := resolvedIPsCache [domain ]; found {
89
88
r .debugLogPrintf ("Debug: Using cached result for domain: %s" , domain )
90
89
return deepCopyMap (cachedIPs ), nil
91
90
}
92
91
93
92
// Check for loops in the current resolution path.
94
- if _ , visited := r . domainsVisitedInCurrentPath [domain ]; visited {
93
+ if _ , visited := domainsVisitedInCurrentPath [domain ]; visited {
95
94
r .debugLogPrintf ("Debug: Loop detected for domain %s" , domain )
96
95
return nil , fmt .Errorf ("%w: %s" , ErrLoopDetected , domain )
97
96
}
98
97
99
- r . domainsVisitedInCurrentPath [domain ] = struct {}{}
100
- defer delete (r . domainsVisitedInCurrentPath , domain )
98
+ domainsVisitedInCurrentPath [domain ] = struct {}{}
99
+ defer delete (domainsVisitedInCurrentPath , domain )
101
100
102
101
r .debugLogPrintf ("Debug: Processing domain: %s (depth %d)" , domain , depth )
103
102
@@ -106,14 +105,14 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
106
105
spfString , err := r .getSPFRecord (ctx , domain )
107
106
if err != nil && ! errors .Is (err , errIgnorableDNSErr ) {
108
107
r .debugLogPrintf ("Debug: Failed to get SPF record for %s: %v" , domain , err )
109
- r . resolvedIPsCache [domain ] = nil
108
+ resolvedIPsCache [domain ] = nil
110
109
111
110
return nil , fmt .Errorf ("spf2ip: failed to get SPF record for %s: %w" , domain , err )
112
111
}
113
112
114
113
if spfString == "" {
115
114
r .debugLogPrintf ("Debug: No SPF record found for %s, treating as empty" , domain )
116
- r . resolvedIPsCache [domain ] = currentDomainIPs
115
+ resolvedIPsCache [domain ] = currentDomainIPs
117
116
118
117
return currentDomainIPs , nil
119
118
}
@@ -140,23 +139,23 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
140
139
141
140
switch mechanism {
142
141
case "ip4" :
143
- if r . ipVersion == ipv4 {
144
- if err := r .addIPOrCIDRToSet (value , currentDomainIPs ); err != nil {
142
+ if ipVersion == ipv4 {
143
+ if err := r .addIPOrCIDRToSet (ipVersion , value , currentDomainIPs ); err != nil {
145
144
return nil , fmt .Errorf ("spf2ip: failed to add IP/CIDR for ip4 mechanism in %s: %w" , domain , err )
146
145
}
147
146
}
148
147
149
148
case "ip6" :
150
- if r . ipVersion == ipv6 {
151
- if err := r .addIPOrCIDRToSet (value , currentDomainIPs ); err != nil {
149
+ if ipVersion == ipv6 {
150
+ if err := r .addIPOrCIDRToSet (ipVersion , value , currentDomainIPs ); err != nil {
152
151
return nil , fmt .Errorf ("spf2ip: failed to add IP/CIDR for ip6 mechanism in %s: %w" , domain , err )
153
152
}
154
153
}
155
154
156
155
case "a" :
157
156
targetHost , maskSuffix := parseSPFMechanismTargetAndMask (domain , value )
158
157
159
- ips , err := r .netResolver .LookupIP (ctx , r . lookupIPNetwork (), targetHost )
158
+ ips , err := r .netResolver .LookupIP (ctx , lookupIPNetwork (ipVersion ), targetHost )
160
159
if err != nil {
161
160
if isDNSErrIgnorable (err ) {
162
161
r .debugLogPrintf ("Debug: Ignorable DNS error for A/AAAA lookup of %s (directive in %s): %v" , targetHost , domain , err )
@@ -167,7 +166,7 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
167
166
}
168
167
169
168
for _ , ip := range ips {
170
- if err := r .addIPOrCIDRToSet (ip .String ()+ maskSuffix , currentDomainIPs ); err != nil {
169
+ if err := r .addIPOrCIDRToSet (ipVersion , ip .String ()+ maskSuffix , currentDomainIPs ); err != nil {
171
170
return nil , fmt .Errorf ("spf2ip: failed to add IP/CIDR for A mechanism in %s: %w" , domain , err )
172
171
}
173
172
}
@@ -188,7 +187,7 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
188
187
for _ , mx := range mxs {
189
188
mxHost := strings .TrimSuffix (mx .Host , "." )
190
189
191
- ips , err := r .netResolver .LookupIP (ctx , r . lookupIPNetwork (), mxHost )
190
+ ips , err := r .netResolver .LookupIP (ctx , lookupIPNetwork (ipVersion ), mxHost )
192
191
if err != nil {
193
192
if ! isDNSErrIgnorable (err ) {
194
193
return nil , fmt .Errorf ("A/AAAA lookup failed for MX host %s (directive in %s): %w" , mxHost , domain , err )
@@ -200,7 +199,7 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
200
199
}
201
200
202
201
for _ , ip := range ips {
203
- if err := r .addIPOrCIDRToSet (ip .String ()+ maskSuffix , currentDomainIPs ); err != nil {
202
+ if err := r .addIPOrCIDRToSet (ipVersion , ip .String ()+ maskSuffix , currentDomainIPs ); err != nil {
204
203
return nil , fmt .Errorf ("spf2ip: failed to add IP/CIDR for MX mechanism in %s: %w" , domain , err )
205
204
}
206
205
}
@@ -209,12 +208,12 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
209
208
case "include" :
210
209
if value == "" {
211
210
r .debugLogPrintf ("Debug: 'include' modifier without domain in %s" , domain )
212
- r . resolvedIPsCache [domain ] = nil
211
+ resolvedIPsCache [domain ] = nil
213
212
214
213
return nil , fmt .Errorf ("spf2ip: include without domain in %s" , domain )
215
214
}
216
215
217
- includedIPs , includeErr := r .processDomain (ctx , value , depth + 1 )
216
+ includedIPs , includeErr := r .processDomain (ctx , ipVersion , domainsVisitedInCurrentPath , resolvedIPsCache , value , depth + 1 )
218
217
if includeErr != nil {
219
218
return nil , fmt .Errorf ("spf2ip: include failed for %s (directive in %s): %w" , value , domain , includeErr )
220
219
}
@@ -226,16 +225,16 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
226
225
case "redirect" :
227
226
if value == "" {
228
227
r .debugLogPrintf ("Debug: 'redirect' modifier without domain in %s" , domain )
229
- r . resolvedIPsCache [domain ] = nil
228
+ resolvedIPsCache [domain ] = nil
230
229
231
230
return nil , fmt .Errorf ("spf2ip: redirect without domain in %s" , domain )
232
231
}
233
232
234
233
r .debugLogPrintf ("Debug: Redirecting from %s to %s. Discarding IPs found so far for %s." , domain , value , domain )
235
234
236
235
// The result of this domain's processing is now entirely determined by the redirect target.
237
- redirectedIPs , redirectErr := r .processDomain (ctx , value , depth + 1 )
238
- r . resolvedIPsCache [domain ] = deepCopyMap (redirectedIPs ) // Overwrite cache with redirected IPs
236
+ redirectedIPs , redirectErr := r .processDomain (ctx , ipVersion , domainsVisitedInCurrentPath , resolvedIPsCache , value , depth + 1 )
237
+ resolvedIPsCache [domain ] = deepCopyMap (redirectedIPs ) // Overwrite cache with redirected IPs
239
238
240
239
return redirectedIPs , redirectErr
241
240
@@ -248,14 +247,14 @@ func (r *SPF2IPResolver) processDomain(ctx context.Context, domain string, depth
248
247
}
249
248
}
250
249
251
- r . resolvedIPsCache [domain ] = deepCopyMap (currentDomainIPs )
250
+ resolvedIPsCache [domain ] = deepCopyMap (currentDomainIPs )
252
251
253
252
return currentDomainIPs , nil
254
253
}
255
254
256
255
// lookupIPNetwork returns the appropriate network type for IP lookups based on the resolver's IP version.
257
- func ( r * SPF2IPResolver ) lookupIPNetwork () string {
258
- switch r . ipVersion {
256
+ func lookupIPNetwork (ipVersion int ) string {
257
+ switch ipVersion {
259
258
case ipv4 :
260
259
return "ip4"
261
260
case ipv6 :
@@ -319,35 +318,35 @@ func (r *SPF2IPResolver) getSPFRecord(ctx context.Context, domain string) (strin
319
318
}
320
319
321
320
// addIPOrCIDRToSet adds an IP or CIDR block to the target set, ensuring that plain IPs are converted to CIDR notation.
322
- func (r * SPF2IPResolver ) addIPOrCIDRToSet (value string , targetSet map [string ]struct {}) error {
321
+ func (r * SPF2IPResolver ) addIPOrCIDRToSet (ipVersion int , value string , targetSet map [string ]struct {}) error {
323
322
value = strings .TrimSpace (value )
324
323
325
324
// Try CIDR first
326
325
if ip , ipNet , err := net .ParseCIDR (value ); err == nil {
327
- if (r . ipVersion == ipv4 && ip .To4 () != nil ) || (r . ipVersion == ipv6 && ip .To4 () == nil && ip .To16 () != nil ) {
326
+ if (ipVersion == ipv4 && ip .To4 () != nil ) || (ipVersion == ipv6 && ip .To4 () == nil && ip .To16 () != nil ) {
328
327
targetSet [ipNet .String ()] = struct {}{}
329
328
return nil
330
329
}
331
330
332
- return fmt .Errorf ("spf2ip: CIDR '%s' is not of the required IP version (v%d)" , value , r . ipVersion )
331
+ return fmt .Errorf ("spf2ip: CIDR '%s' is not of the required IP version (v%d)" , value , ipVersion )
333
332
}
334
333
335
334
// Try plain IP
336
335
ip := net .ParseIP (value )
337
336
if ip != nil {
338
- if r . ipVersion == ipv4 && ip .To4 () != nil {
337
+ if ipVersion == ipv4 && ip .To4 () != nil {
339
338
// Convert IPv4 to CIDR notation
340
339
targetSet [ip .To4 ().String ()+ "/32" ] = struct {}{}
341
340
return nil
342
341
}
343
342
344
- if r . ipVersion == ipv6 && ip .To4 () == nil && ip .To16 () != nil {
343
+ if ipVersion == ipv6 && ip .To4 () == nil && ip .To16 () != nil {
345
344
// Convert IPv6 to CIDR notation
346
345
targetSet [ip .String ()+ "/128" ] = struct {}{}
347
346
return nil
348
347
}
349
348
350
- return fmt .Errorf ("spf2ip: IP address '%s' is not of the required IP version (v%d)" , value , r . ipVersion )
349
+ return fmt .Errorf ("spf2ip: IP address '%s' is not of the required IP version (v%d)" , value , ipVersion )
351
350
}
352
351
353
352
return fmt .Errorf ("spf2ip: value '%s' is not a valid IP address or CIDR block" , value )
0 commit comments