Skip to content

Commit 8f16fdb

Browse files
committed
add returning X-IPCountry header for explicitly allowed IP addresses (fixes #76)
1 parent ef98da2 commit 8f16fdb

File tree

2 files changed

+77
-6
lines changed

2 files changed

+77
-6
lines changed

geoblock.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,12 @@ func (a *GeoBlock) allowDenyIPAddress(requestIPAddr *net.IP, req *http.Request)
229229

230230
// check if the request IP address is explicitly allowed
231231
if ipInSlice(*requestIPAddr, a.allowedIPAddresses) {
232+
if a.addCountryHeader {
233+
ok, countryCode := a.cachedRequestIP(requestIPAddr, req)
234+
if ok && len(countryCode) > 0 {
235+
req.Header.Set(countryHeader, countryCode)
236+
}
237+
}
232238
if a.logAllowedRequests {
233239
a.infoLogger.Printf("%s: request allowed [%s] since the IP address is explicitly allowed", a.name, requestIPAddr)
234240
}
@@ -238,6 +244,12 @@ func (a *GeoBlock) allowDenyIPAddress(requestIPAddr *net.IP, req *http.Request)
238244
// check if the request IP address is contained within one of the explicitly allowed IP address ranges
239245
for _, ipRange := range a.allowedIPRanges {
240246
if ipRange.Contains(*requestIPAddr) {
247+
if a.addCountryHeader {
248+
ok, countryCode := a.cachedRequestIP(requestIPAddr, req)
249+
if ok && len(countryCode) > 0 {
250+
req.Header.Set(countryHeader, countryCode)
251+
}
252+
}
241253
if a.logLocalRequests {
242254
a.infoLogger.Printf("%s: request allowed [%s] since the IP address is explicitly allowed", a.name, requestIPAddr)
243255
}
@@ -281,7 +293,6 @@ func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Req
281293
// check if existing entry was made more than a month ago, if so update the entry
282294
if time.Since(entry.Timestamp).Hours() >= numberOfHoursInMonth && a.forceMonthlyUpdate {
283295
entry, err = a.createNewIPEntry(req, ipAddressString)
284-
285296
if err != nil {
286297
return false, ""
287298
}
@@ -303,6 +314,36 @@ func (a *GeoBlock) allowDenyCachedRequestIP(requestIPAddr *net.IP, req *http.Req
303314
return true, entry.Country
304315
}
305316

317+
func (a *GeoBlock) cachedRequestIP(requestIPAddr *net.IP, req *http.Request) (bool, string) {
318+
ipAddressString := requestIPAddr.String()
319+
cacheEntry, ok := a.database.Get(ipAddressString)
320+
321+
var entry ipEntry
322+
var err error
323+
if !ok {
324+
entry, err = a.createNewIPEntry(req, ipAddressString)
325+
if err != nil {
326+
return false, ""
327+
}
328+
} else {
329+
entry = cacheEntry.(ipEntry)
330+
}
331+
332+
if a.logAPIRequests {
333+
a.infoLogger.Println("Loaded from database: ", entry)
334+
}
335+
336+
// check if existing entry was made more than a month ago, if so update the entry
337+
if time.Since(entry.Timestamp).Hours() >= numberOfHoursInMonth && a.forceMonthlyUpdate {
338+
entry, err = a.createNewIPEntry(req, ipAddressString)
339+
if err != nil {
340+
return false, ""
341+
}
342+
}
343+
344+
return true, entry.Country
345+
}
346+
306347
func (a *GeoBlock) collectRemoteIP(req *http.Request) ([]*net.IP, error) {
307348
var ipList []*net.IP
308349

geoblock_test.go

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,6 @@ func TestInvalidApiResponse(t *testing.T) {
737737
var apiStub = httptest.NewServer(http.HandlerFunc(apiHandlerInvalid))
738738

739739
cfg := createTesterConfig()
740-
fmt.Println(apiStub.URL)
741740
cfg.API = apiStub.URL + "/{ip}"
742741
cfg.Countries = append(cfg.Countries, "CH")
743742

@@ -770,7 +769,6 @@ func TestApiResponseTimeoutAllowed(t *testing.T) {
770769
var apiStub = httptest.NewServer(http.HandlerFunc(apiTimeout))
771770

772771
cfg := createTesterConfig()
773-
fmt.Println(apiStub.URL)
774772
cfg.API = apiStub.URL + "/{ip}"
775773
cfg.Countries = append(cfg.Countries, "CH")
776774
cfg.APITimeoutMs = 5
@@ -805,7 +803,6 @@ func TestApiResponseTimeoutNotAllowed(t *testing.T) {
805803
var apiStub = httptest.NewServer(http.HandlerFunc(apiTimeout))
806804

807805
cfg := createTesterConfig()
808-
fmt.Println(apiStub.URL)
809806
cfg.API = apiStub.URL + "/{ip}"
810807
cfg.Countries = append(cfg.Countries, "CH")
811808
cfg.APITimeoutMs = 5
@@ -863,6 +860,41 @@ func TestExplicitlyAllowedIP(t *testing.T) {
863860
assertStatusCode(t, recorder.Result(), http.StatusOK)
864861
}
865862

863+
func TestExplicitlyAllowedIPWithIPCountryHeader(t *testing.T) {
864+
// set up our fake api server
865+
apiHandler := &CountryCodeHandler{ResponseCountryCode: "CA"}
866+
var apiStub = httptest.NewServer(apiHandler)
867+
868+
cfg := createTesterConfig()
869+
cfg.API = apiStub.URL + "/{ip}"
870+
cfg.Countries = append(cfg.Countries, "CH")
871+
cfg.AllowedIPAddresses = append(cfg.AllowedIPAddresses, caExampleIP)
872+
cfg.LogLocalRequests = true
873+
cfg.AddCountryHeader = true
874+
875+
ctx := context.Background()
876+
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})
877+
878+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
879+
if err != nil {
880+
t.Fatal(err)
881+
}
882+
883+
recorder := httptest.NewRecorder()
884+
885+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
886+
if err != nil {
887+
t.Fatal(err)
888+
}
889+
890+
req.Header.Add(xForwardedFor, caExampleIP)
891+
892+
handler.ServeHTTP(recorder, req)
893+
894+
assertStatusCode(t, recorder.Result(), http.StatusOK)
895+
assertRequestHeader(t, req, CountryHeader, "CA")
896+
}
897+
866898
func TestExplicitlyAllowedIPNoMatch(t *testing.T) {
867899
cfg := createTesterConfig()
868900
cfg.Countries = append(cfg.Countries, "CA")
@@ -1156,8 +1188,6 @@ func assertStatusCode(t *testing.T, req *http.Response, expected int) {
11561188
func assertRequestHeader(t *testing.T, req *http.Request, key string, expected string) {
11571189
t.Helper()
11581190

1159-
fmt.Println(req.Header.Get(key))
1160-
11611191
if received := req.Header.Get(key); received != expected {
11621192
t.Errorf("header value mismatch: %s: %s <> %s", key, expected, received)
11631193
}

0 commit comments

Comments
 (0)