@@ -656,7 +656,6 @@ func TestInvalidApiResponse(t *testing.T) {
656
656
var apiStub = httptest .NewServer (http .HandlerFunc (apiHandlerInvalid ))
657
657
658
658
cfg := createTesterConfig ()
659
- fmt .Println (apiStub .URL )
660
659
cfg .API = apiStub .URL + "/{ip}"
661
660
cfg .Countries = append (cfg .Countries , "CH" )
662
661
@@ -689,7 +688,6 @@ func TestApiResponseTimeoutAllowed(t *testing.T) {
689
688
var apiStub = httptest .NewServer (http .HandlerFunc (apiTimeout ))
690
689
691
690
cfg := createTesterConfig ()
692
- fmt .Println (apiStub .URL )
693
691
cfg .API = apiStub .URL + "/{ip}"
694
692
cfg .Countries = append (cfg .Countries , "CH" )
695
693
cfg .APITimeoutMs = 5
@@ -724,7 +722,6 @@ func TestApiResponseTimeoutNotAllowed(t *testing.T) {
724
722
var apiStub = httptest .NewServer (http .HandlerFunc (apiTimeout ))
725
723
726
724
cfg := createTesterConfig ()
727
- fmt .Println (apiStub .URL )
728
725
cfg .API = apiStub .URL + "/{ip}"
729
726
cfg .Countries = append (cfg .Countries , "CH" )
730
727
cfg .APITimeoutMs = 5
@@ -782,6 +779,41 @@ func TestExplicitlyAllowedIP(t *testing.T) {
782
779
assertStatusCode (t , recorder .Result (), http .StatusOK )
783
780
}
784
781
782
+ func TestExplicitlyAllowedIPWithIPCountryHeader (t * testing.T ) {
783
+ // set up our fake api server
784
+ apiHandler := & CountryCodeHandler {ResponseCountryCode : "CA" }
785
+ var apiStub = httptest .NewServer (apiHandler )
786
+
787
+ cfg := createTesterConfig ()
788
+ cfg .API = apiStub .URL + "/{ip}"
789
+ cfg .Countries = append (cfg .Countries , "CH" )
790
+ cfg .AllowedIPAddresses = append (cfg .AllowedIPAddresses , caExampleIP )
791
+ cfg .LogLocalRequests = true
792
+ cfg .AddCountryHeader = true
793
+
794
+ ctx := context .Background ()
795
+ next := http .HandlerFunc (func (_ http.ResponseWriter , _ * http.Request ) {})
796
+
797
+ handler , err := geoblock .New (ctx , next , cfg , "GeoBlock" )
798
+ if err != nil {
799
+ t .Fatal (err )
800
+ }
801
+
802
+ recorder := httptest .NewRecorder ()
803
+
804
+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , "http://localhost" , nil )
805
+ if err != nil {
806
+ t .Fatal (err )
807
+ }
808
+
809
+ req .Header .Add (xForwardedFor , caExampleIP )
810
+
811
+ handler .ServeHTTP (recorder , req )
812
+
813
+ assertStatusCode (t , recorder .Result (), http .StatusOK )
814
+ assertRequestHeader (t , req , CountryHeader , "CA" )
815
+ }
816
+
785
817
func TestExplicitlyAllowedIPNoMatch (t * testing.T ) {
786
818
cfg := createTesterConfig ()
787
819
cfg .Countries = append (cfg .Countries , "CA" )
@@ -1075,8 +1107,6 @@ func assertStatusCode(t *testing.T, req *http.Response, expected int) {
1075
1107
func assertRequestHeader (t * testing.T , req * http.Request , key string , expected string ) {
1076
1108
t .Helper ()
1077
1109
1078
- fmt .Println (req .Header .Get (key ))
1079
-
1080
1110
if received := req .Header .Get (key ); received != expected {
1081
1111
t .Errorf ("header value mismatch: %s: %s <> %s" , key , expected , received )
1082
1112
}
0 commit comments