Skip to content

Commit 07f19ac

Browse files
committed
Add XForwardedForReverseProxy option
which basically tells GeoBlock to only allow/deny a request based on the first IP address in the X-ForwardedFor HTTP header. This is useful for servers behind e.g. a Cloudflare proxy
1 parent b7a8c91 commit 07f19ac

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

geoblock.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type Config struct {
4141
APITimeoutMs int `yaml:"apiTimeoutMs"`
4242
IgnoreAPITimeout bool `yaml:"ignoreApiTimeout"`
4343
IPGeolocationHTTPHeaderField string `yaml:"ipGeolocationHttpHeaderField"`
44+
XForwardedForReverseProxy bool `yaml:"xForwardedForReverseProxy"`
4445
CacheSize int `yaml:"cacheSize"`
4546
ForceMonthlyUpdate bool `yaml:"forceMonthlyUpdate"`
4647
AllowUnknownCountries bool `yaml:"allowUnknownCountries"`
@@ -75,6 +76,7 @@ type GeoBlock struct {
7576
apiTimeoutMs int
7677
ignoreAPITimeout bool
7778
iPGeolocationHTTPHeaderField string
79+
xForwardedForReverseProxy bool
7880
forceMonthlyUpdate bool
7981
allowUnknownCountries bool
8082
unknownCountryCode string
@@ -157,6 +159,7 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h
157159
apiTimeoutMs: config.APITimeoutMs,
158160
ignoreAPITimeout: config.IgnoreAPITimeout,
159161
iPGeolocationHTTPHeaderField: config.IPGeolocationHTTPHeaderField,
162+
xForwardedForReverseProxy: config.XForwardedForReverseProxy,
160163
forceMonthlyUpdate: config.ForceMonthlyUpdate,
161164
allowUnknownCountries: config.AllowUnknownCountries,
162165
unknownCountryCode: config.UnknownCountryAPIResponse,
@@ -182,6 +185,11 @@ func (a *GeoBlock) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
182185
return
183186
}
184187

188+
// only keep the first IP address, which should be the client (if the proxy behaves itself), to check if allowed or denied
189+
if a.xForwardedForReverseProxy {
190+
requestIPAddresses = requestIPAddresses[:1]
191+
}
192+
185193
for _, requestIPAddress := range requestIPAddresses {
186194
if !a.allowDenyIPAddress(requestIPAddress, req) {
187195
rw.WriteHeader(a.httpStatusCodeDeniedRequest)

geoblock_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,62 @@ func TestMultipleIpAddressesReverse(t *testing.T) {
209209
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
210210
}
211211

212+
func TestMultipleIpAddressesProxy(t *testing.T) {
213+
cfg := createTesterConfig()
214+
215+
cfg.Countries = append(cfg.Countries, "CA")
216+
cfg.XForwardedForReverseProxy = true
217+
218+
ctx := context.Background()
219+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
220+
221+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
222+
if err != nil {
223+
t.Fatal(err)
224+
}
225+
226+
recorder := httptest.NewRecorder()
227+
228+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
229+
if err != nil {
230+
t.Fatal(err)
231+
}
232+
233+
req.Header.Add(xForwardedFor, strings.Join([]string{caExampleIP, chExampleIP}, ","))
234+
235+
handler.ServeHTTP(recorder, req)
236+
237+
assertStatusCode(t, recorder.Result(), http.StatusOK)
238+
}
239+
240+
func TestMultipleIpAddressesProxyReverse(t *testing.T) {
241+
cfg := createTesterConfig()
242+
243+
cfg.Countries = append(cfg.Countries, "CA")
244+
cfg.XForwardedForReverseProxy = true
245+
246+
ctx := context.Background()
247+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
248+
249+
handler, err := geoblock.New(ctx, next, cfg, "GeoBlock")
250+
if err != nil {
251+
t.Fatal(err)
252+
}
253+
254+
recorder := httptest.NewRecorder()
255+
256+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
257+
if err != nil {
258+
t.Fatal(err)
259+
}
260+
261+
req.Header.Add(xForwardedFor, strings.Join([]string{chExampleIP, caExampleIP}, ","))
262+
263+
handler.ServeHTTP(recorder, req)
264+
265+
assertStatusCode(t, recorder.Result(), http.StatusForbidden)
266+
}
267+
212268
func TestAllowedUnknownCountry(t *testing.T) {
213269
cfg := createTesterConfig()
214270

0 commit comments

Comments
 (0)