Skip to content

Commit ba9cff9

Browse files
fix: return 401 for expired oauth2 tokens (#18)
This change updates the OAuth2 middleware to return a 401 Unauthorized response when a token is expired. The previous behavior of returning 400 was non-compliant. We've also added endpoints to introspect access tokens.
1 parent 2cc2c38 commit ba9cff9

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

cmd/server/main.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ func main() {
3030
flag.Parse()
3131

3232
r := mux.NewRouter()
33+
r.HandleFunc("/oauth2/token", auth.HandleOAuth2InspectToken).Methods(http.MethodGet)
3334
r.HandleFunc("/oauth2/token", auth.HandleOAuth2).Methods(http.MethodPost)
3435
r.HandleFunc("/auth", auth.HandleAuth).Methods(http.MethodPost)
3536
r.HandleFunc("/auth/customsecurity/{customSchemeType}", auth.HandleCustomAuth).Methods(http.MethodGet)

internal/auth/oauth2.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"log"
88
"net/http"
9+
"strings"
910
"sync"
1011
"sync/atomic"
1112
"time"
@@ -75,6 +76,38 @@ type OAuth2TokenResponse struct {
7576
ExpiresIn int `json:"expires_in"`
7677
}
7778

79+
func HandleOAuth2InspectToken(w http.ResponseWriter, r *http.Request) {
80+
enc := json.NewEncoder(w)
81+
enc.SetIndent("", " ")
82+
w.Header().Set("Content-Type", "application/json")
83+
84+
authz := r.Header.Get("Authorization")
85+
if authz == "" {
86+
http.Error(w, `{"error": "unauthorized"}`, http.StatusUnauthorized)
87+
return
88+
}
89+
if !strings.HasPrefix(authz, "Bearer ") {
90+
http.Error(w, `{"error": "invalid authorization"}`, http.StatusBadRequest)
91+
return
92+
}
93+
94+
token := authz[len("Bearer "):]
95+
claims, err := ParseToken(token)
96+
if err != nil {
97+
http.Error(w, `{"error": "invalid token"}`, http.StatusBadRequest)
98+
return
99+
}
100+
101+
updatedExpiry := GetTokenExpiry(claims)
102+
103+
claims["exp"] = float64(updatedExpiry.Unix())
104+
105+
if err := enc.Encode(claims); err != nil {
106+
http.Error(w, `{"error": "failed to encode response"}`, http.StatusInternalServerError)
107+
return
108+
}
109+
}
110+
78111
func HandleOAuth2(w http.ResponseWriter, r *http.Request) {
79112
enc := json.NewEncoder(w)
80113
enc.SetIndent("", " ")
@@ -165,6 +198,10 @@ func HandleOAuth2(w http.ResponseWriter, r *http.Request) {
165198

166199
now := time.Now()
167200
expires := now.Add(time.Hour)
201+
forcedExpiry := r.Header.Get("x-oauth2-expire-at")
202+
if exp, err := time.Parse(time.RFC3339, forcedExpiry); err == nil {
203+
expires = exp
204+
}
168205

169206
accessTokenID := gofakeit.UUID()
170207
accessTokenClaims := jwt.MapClaims{
@@ -248,6 +285,24 @@ func RefreshToken(refreshClaims jwt.MapClaims) {
248285
tokenDB.Store(tokenID, expiry)
249286
}
250287

288+
func GetTokenExpiry(tokenClaims jwt.MapClaims) time.Time {
289+
tokenDBLastAccess.Store(time.Now())
290+
291+
tokenID := tokenClaims["id"].(string)
292+
293+
exp, found := tokenDB.Load(tokenID)
294+
if found {
295+
return exp.(time.Time)
296+
}
297+
298+
expiryClaim, err := tokenClaims.GetExpirationTime()
299+
if err != nil {
300+
panic(err)
301+
}
302+
303+
return expiryClaim.Time
304+
}
305+
251306
func IsTokenExpired(tokenClaims jwt.MapClaims) bool {
252307
tokenDBLastAccess.Store(time.Now())
253308

internal/middleware/oauth2.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ func OAuth2(h http.Handler) http.Handler {
3131
}
3232

3333
if auth.IsTokenExpired(claims) {
34-
auth.SendOAuth2Error(w, auth.ErrCodeInvalidRequest, "token has expired")
34+
w.Header().Set("Content-Type", "application/json")
35+
http.Error(w, `{"error": "token has expired"}`, http.StatusUnauthorized)
3536
return
3637
}
3738

0 commit comments

Comments
 (0)