@@ -737,7 +737,6 @@ func TestInvalidApiResponse(t *testing.T) {
737
737
var apiStub = httptest .NewServer (http .HandlerFunc (apiHandlerInvalid ))
738
738
739
739
cfg := createTesterConfig ()
740
- fmt .Println (apiStub .URL )
741
740
cfg .API = apiStub .URL + "/{ip}"
742
741
cfg .Countries = append (cfg .Countries , "CH" )
743
742
@@ -770,7 +769,6 @@ func TestApiResponseTimeoutAllowed(t *testing.T) {
770
769
var apiStub = httptest .NewServer (http .HandlerFunc (apiTimeout ))
771
770
772
771
cfg := createTesterConfig ()
773
- fmt .Println (apiStub .URL )
774
772
cfg .API = apiStub .URL + "/{ip}"
775
773
cfg .Countries = append (cfg .Countries , "CH" )
776
774
cfg .APITimeoutMs = 5
@@ -805,7 +803,6 @@ func TestApiResponseTimeoutNotAllowed(t *testing.T) {
805
803
var apiStub = httptest .NewServer (http .HandlerFunc (apiTimeout ))
806
804
807
805
cfg := createTesterConfig ()
808
- fmt .Println (apiStub .URL )
809
806
cfg .API = apiStub .URL + "/{ip}"
810
807
cfg .Countries = append (cfg .Countries , "CH" )
811
808
cfg .APITimeoutMs = 5
@@ -863,6 +860,41 @@ func TestExplicitlyAllowedIP(t *testing.T) {
863
860
assertStatusCode (t , recorder .Result (), http .StatusOK )
864
861
}
865
862
863
+ func TestExplicitlyAllowedIPWithIPCountryHeader (t * testing.T ) {
864
+ // set up our fake api server
865
+ apiHandler := & CountryCodeHandler {ResponseCountryCode : "CA" }
866
+ var apiStub = httptest .NewServer (apiHandler )
867
+
868
+ cfg := createTesterConfig ()
869
+ cfg .API = apiStub .URL + "/{ip}"
870
+ cfg .Countries = append (cfg .Countries , "CH" )
871
+ cfg .AllowedIPAddresses = append (cfg .AllowedIPAddresses , caExampleIP )
872
+ cfg .LogLocalRequests = true
873
+ cfg .AddCountryHeader = true
874
+
875
+ ctx := context .Background ()
876
+ next := http .HandlerFunc (func (_ http.ResponseWriter , _ * http.Request ) {})
877
+
878
+ handler , err := geoblock .New (ctx , next , cfg , "GeoBlock" )
879
+ if err != nil {
880
+ t .Fatal (err )
881
+ }
882
+
883
+ recorder := httptest .NewRecorder ()
884
+
885
+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , "http://localhost" , nil )
886
+ if err != nil {
887
+ t .Fatal (err )
888
+ }
889
+
890
+ req .Header .Add (xForwardedFor , caExampleIP )
891
+
892
+ handler .ServeHTTP (recorder , req )
893
+
894
+ assertStatusCode (t , recorder .Result (), http .StatusOK )
895
+ assertRequestHeader (t , req , CountryHeader , "CA" )
896
+ }
897
+
866
898
func TestExplicitlyAllowedIPNoMatch (t * testing.T ) {
867
899
cfg := createTesterConfig ()
868
900
cfg .Countries = append (cfg .Countries , "CA" )
@@ -1156,8 +1188,6 @@ func assertStatusCode(t *testing.T, req *http.Response, expected int) {
1156
1188
func assertRequestHeader (t * testing.T , req * http.Request , key string , expected string ) {
1157
1189
t .Helper ()
1158
1190
1159
- fmt .Println (req .Header .Get (key ))
1160
-
1161
1191
if received := req .Header .Get (key ); received != expected {
1162
1192
t .Errorf ("header value mismatch: %s: %s <> %s" , key , expected , received )
1163
1193
}
0 commit comments