Skip to content

Commit 32e3a1f

Browse files
author
Pascal Minder
committed
Fix serving http content if request should be denied
1 parent ff6385e commit 32e3a1f

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

geoblock.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,10 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
200200
if len(a.redirectURLIfDenied) != 0 {
201201
rw.Header().Set("Location", a.redirectURLIfDenied)
202202
rw.WriteHeader(http.StatusFound)
203-
a.next.ServeHTTP(rw, req)
203+
return
204204
} else {
205205
rw.WriteHeader(a.httpStatusCodeDeniedRequest)
206-
a.next.ServeHTTP(rw, req)
206+
return
207207
}
208208
}
209209
}

geoblock_test.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package geoblock_test
33
import (
44
"context"
55
"fmt"
6+
"io"
67
"net/http"
78
"net/http/httptest"
89
"os"
@@ -106,7 +107,7 @@ func TestAllowedCountry(t *testing.T) {
106107
cfg.Countries = append(cfg.Countries, "CH")
107108

108109
ctx := context.Background()
109-
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})
110+
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("Allowed request")) })
110111

111112
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
112113
if err != nil {
@@ -124,7 +125,19 @@ func TestAllowedCountry(t *testing.T) {
124125

125126
handler.ServeHTTP(recorder, req)
126127

127-
assertStatusCode(t, recorder.Result(), http.StatusOK)
128+
recorderResult := recorder.Result()
129+
130+
assertStatusCode(t, recorderResult, http.StatusOK)
131+
132+
body, err := io.ReadAll(recorderResult.Body)
133+
if err != nil {
134+
t.Fatal(err)
135+
}
136+
137+
expectedBody := "Allowed request"
138+
if string(body) != expectedBody {
139+
t.Fatalf("expected body %q, got %q", expectedBody, string(body))
140+
}
128141
}
129142

130143
func TestMultipleAllowedCountry(t *testing.T) {
@@ -385,7 +398,7 @@ func TestDeniedCountry(t *testing.T) {
385398
cfg.Countries = append(cfg.Countries, "CH")
386399

387400
ctx := context.Background()
388-
next := http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {})
401+
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("Allowed request")) })
389402

390403
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
391404
if err != nil {
@@ -403,7 +416,19 @@ func TestDeniedCountry(t *testing.T) {
403416

404417
handler.ServeHTTP(recorder, req)
405418

406-
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
419+
recorderResult := recorder.Result()
420+
421+
assertStatusCode(t, recorderResult, http.StatusForbidden)
422+
423+
body, err := io.ReadAll(recorderResult.Body)
424+
if err != nil {
425+
t.Fatal(err)
426+
}
427+
428+
expectedBody := ""
429+
if string(body) != expectedBody {
430+
t.Fatalf("expected body %q, got %q", expectedBody, string(body))
431+
}
407432
}
408433

409434
func TestDeniedCountryWithRedirect(t *testing.T) {

0 commit comments

Comments
 (0)