Skip to content

Commit 3acad13

Browse files
committed
add returning X-IPCountry header for explicitly allowed IP addresses (fixes #76)
1 parent e927a5c commit 3acad13

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

geoblock.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,12 @@ func (a *GeoBlock) allowDenyIPAddress(requestIPAddr *net.IP, req *http.Request)
229229

230230
// check if the request IP address is explicitly allowed
231231
if ipInSlice(*requestIPAddr, a.allowedIPAddresses) {
232+
if a.addCountryHeader {
233+
ok, countryCode := a.cachedRequestIP(requestIPAddr, req)
234+
if ok && len(countryCode) > 0 {
235+
req.Header.Set(countryHeader, countryCode)
236+
}
237+
}
232238
if a.logAllowedRequests {
233239
a.infoLogger.Printf("%s: request allowed [%s] since the IP address is explicitly allowed", a.name, requestIPAddr)
234240
}
@@ -238,6 +244,12 @@ func (a *GeoBlock) allowDenyIPAddress(requestIPAddr *net.IP, req *http.Request)
238244
// check if the request IP address is contained within one of the explicitly allowed IP address ranges
239245
for _, ipRange := range a.allowedIPRanges {
240246
if ipRange.Contains(*requestIPAddr) {
247+
if a.addCountryHeader {
248+
ok, countryCode := a.cachedRequestIP(requestIPAddr, req)
249+
if ok && len(countryCode) > 0 {
250+
req.Header.Set(countryHeader, countryCode)
251+
}
252+
}
241253
if a.logLocalRequests {
242254
a.infoLogger.Printf("%s: request allowed [%s] since the IP address is explicitly allowed", a.name, requestIPAddr)
243255
}
@@ -281,7 +293,6 @@ func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Req
281293
// check if existing entry was made more than a month ago, if so update the entry
282294
if time.Since(entry.Timestamp).Hours() >= numberOfHoursInMonth && a.forceMonthlyUpdate {
283295
entry, err = a.createNewIPEntry(req, ipAddressString)
284-
285296
if err != nil {
286297
return false, ""
287298
}
@@ -303,6 +314,36 @@ func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Req
303314
return true, entry.Country
304315
}
305316

317+
func (a *GeoBlock) cachedRequestIP(requestIPAddr *net.IP, req *http.Request) (bool, string) {
318+
ipAddressString := requestIPAddr.String()
319+
cacheEntry, ok := a.database.Get(ipAddressString)
320+
321+
var entry ipEntry
322+
var err error
323+
if !ok {
324+
entry, err = a.createNewIPEntry(req, ipAddressString)
325+
if err != nil {
326+
return false, ""
327+
}
328+
} else {
329+
entry = cacheEntry.(ipEntry)
330+
}
331+
332+
if a.logAPIRequests {
333+
a.infoLogger.Println("Loaded from database: ", entry)
334+
}
335+
336+
// check if existing entry was made more than a month ago, if so update the entry
337+
if time.Since(entry.Timestamp).Hours() >= numberOfHoursInMonth && a.forceMonthlyUpdate {
338+
entry, err = a.createNewIPEntry(req, ipAddressString)
339+
if err != nil {
340+
return false, ""
341+
}
342+
}
343+
344+
return true, entry.Country
345+
}
346+
306347
func (a *GeoBlock) collectRemoteIP(req *http.Request) ([]*net.IP, error) {
307348
var ipList []*net.IP
308349

geoblock_test.go

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,6 @@ func TestInvalidApiResponse(t *testing.T) {
656656
var apiStub = httptest.NewServer(http.HandlerFunc(apiHandlerInvalid))
657657

658658
cfg := createTesterConfig()
659-
fmt.Println(apiStub.URL)
660659
cfg.API = apiStub.URL + "/{ip}"
661660
cfg.Countries = append(cfg.Countries, "CH")
662661

@@ -689,7 +688,6 @@ func TestApiResponseTimeoutAllowed(t *testing.T) {
689688
var apiStub = httptest.NewServer(http.HandlerFunc(apiTimeout))
690689

691690
cfg := createTesterConfig()
692-
fmt.Println(apiStub.URL)
693691
cfg.API = apiStub.URL + "/{ip}"
694692
cfg.Countries = append(cfg.Countries, "CH")
695693
cfg.APITimeoutMs = 5
@@ -724,7 +722,6 @@ func TestApiResponseTimeoutNotAllowed(t *testing.T) {
724722
var apiStub = httptest.NewServer(http.HandlerFunc(apiTimeout))
725723

726724
cfg := createTesterConfig()
727-
fmt.Println(apiStub.URL)
728725
cfg.API = apiStub.URL + "/{ip}"
729726
cfg.Countries = append(cfg.Countries, "CH")
730727
cfg.APITimeoutMs = 5
@@ -782,6 +779,41 @@ func TestExplicitlyAllowedIP(t *testing.T) {
782779
assertStatusCode(t, recorder.Result(), http.StatusOK)
783780
}
784781

782+
func TestExplicitlyAllowedIPWithIPCountryHeader(t *testing.T) {
783+
// set up our fake api server
784+
apiHandler := &CountryCodeHandler{ResponseCountryCode: "CA"}
785+
var apiStub = httptest.NewServer(apiHandler)
786+
787+
cfg := createTesterConfig()
788+
cfg.API = apiStub.URL + "/{ip}"
789+
cfg.Countries = append(cfg.Countries, "CH")
790+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, caExampleIP)
791+
cfg.LogLocalRequests = true
792+
cfg.AddCountryHeader = true
793+
794+
ctx := context.Background()
795+
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})
796+
797+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
798+
if err != nil {
799+
t.Fatal(err)
800+
}
801+
802+
recorder := httptest.NewRecorder()
803+
804+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
805+
if err != nil {
806+
t.Fatal(err)
807+
}
808+
809+
req.Header.Add(xForwardedFor, caExampleIP)
810+
811+
handler.ServeHTTP(recorder, req)
812+
813+
assertStatusCode(t, recorder.Result(), http.StatusOK)
814+
assertRequestHeader(t, req, CountryHeader, "CA")
815+
}
816+
785817
func TestExplicitlyAllowedIPNoMatch(t *testing.T) {
786818
cfg := createTesterConfig()
787819
cfg.Countries = append(cfg.Countries, "CA")
@@ -1075,8 +1107,6 @@ func assertStatusCode(t *testing.T, req *http.Response, expected int) {
10751107
func assertRequestHeader(t *testing.T, req *http.Request, key string, expected string) {
10761108
t.Helper()
10771109

1078-
fmt.Println(req.Header.Get(key))
1079-
10801110
if received := req.Header.Get(key); received != expected {
10811111
t.Errorf("header value mismatch: %s: %s <> %s", key, expected, received)
10821112
}

0 commit comments

Comments
 (0)