@@ -3,6 +3,7 @@ package geoblock_test
3
3
import (
4
4
"context"
5
5
"fmt"
6
+ "io"
6
7
"net/http"
7
8
"net/http/httptest"
8
9
"os"
@@ -106,7 +107,7 @@ func TestAllowedCountry(t *testing.T) {
106
107
cfg .Countries = append (cfg .Countries , "CH" )
107
108
108
109
ctx := context .Background ()
109
- next := http .HandlerFunc (func (_ http.ResponseWriter , _ * http.Request ) {})
110
+ next := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) { _ , _ = w . Write ([] byte ( "Allowed request" )) })
110
111
111
112
handler , err := geoblock .New (ctx , next , cfg , "GeoBlock" )
112
113
if err != nil {
@@ -124,7 +125,19 @@ func TestAllowedCountry(t *testing.T) {
124
125
125
126
handler .ServeHTTP (recorder , req )
126
127
127
- assertStatusCode (t , recorder .Result (), http .StatusOK )
128
+ recorderResult := recorder .Result ()
129
+
130
+ assertStatusCode (t , recorderResult , http .StatusOK )
131
+
132
+ body , err := io .ReadAll (recorderResult .Body )
133
+ if err != nil {
134
+ t .Fatal (err )
135
+ }
136
+
137
+ expectedBody := "Allowed request"
138
+ if string (body ) != expectedBody {
139
+ t .Fatalf ("expected body %q, got %q" , expectedBody , string (body ))
140
+ }
128
141
}
129
142
130
143
func TestMultipleAllowedCountry (t * testing.T ) {
@@ -385,7 +398,7 @@ func TestDeniedCountry(t *testing.T) {
385
398
cfg .Countries = append (cfg .Countries , "CH" )
386
399
387
400
ctx := context .Background ()
388
- next := http .HandlerFunc (func (_ http.ResponseWriter , _ * http.Request ) {})
401
+ next := http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) { _ , _ = w . Write ([] byte ( "Allowed request" )) })
389
402
390
403
handler , err := geoblock .New (ctx , next , cfg , "GeoBlock" )
391
404
if err != nil {
@@ -403,7 +416,19 @@ func TestDeniedCountry(t *testing.T) {
403
416
404
417
handler .ServeHTTP (recorder , req )
405
418
406
- assertStatusCode (t , recorder .Result (), http .StatusForbidden )
419
+ recorderResult := recorder .Result ()
420
+
421
+ assertStatusCode (t , recorderResult , http .StatusForbidden )
422
+
423
+ body , err := io .ReadAll (recorderResult .Body )
424
+ if err != nil {
425
+ t .Fatal (err )
426
+ }
427
+
428
+ expectedBody := ""
429
+ if string (body ) != expectedBody {
430
+ t .Fatalf ("expected body %q, got %q" , expectedBody , string (body ))
431
+ }
407
432
}
408
433
409
434
func TestDeniedCountryWithRedirect (t * testing.T ) {
0 commit comments