Skip to content

Commit cbc7e73

Browse files
committed
Add error fields to RetrieveError
1 parent 62b4eed commit cbc7e73

File tree

3 files changed

+83
-12
lines changed

3 files changed

+83
-12
lines changed

internal/token.go

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,18 @@ type Token struct {
5555
}
5656

5757
// tokenJSON is the struct representing the HTTP response from OAuth2
58-
// providers returning a token in JSON form.
58+
// providers returning a token or error in JSON form.
59+
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
5960
type tokenJSON struct {
6061
AccessToken string `json:"access_token"`
6162
TokenType string `json:"token_type"`
6263
RefreshToken string `json:"refresh_token"`
6364
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
65+
// error fields
66+
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
67+
Error string `json:"error"`
68+
ErrorDescription string `json:"error_description"`
69+
ErrorUri string `json:"error_uri"`
6470
}
6571

6672
func (e *tokenJSON) expiry() (t time.Time) {
@@ -236,21 +242,29 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
236242
if err != nil {
237243
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
238244
}
239-
if code := r.StatusCode; code < 200 || code > 299 {
240-
return nil, &RetrieveError{
241-
Response: r,
242-
Body: body,
243-
}
245+
246+
failureStatus := r.StatusCode < 200 || r.StatusCode > 299
247+
retrieveError := &RetrieveError{
248+
Response: r,
249+
Body: body,
250+
// attempt to populate error detail below
244251
}
245252

246253
var token *Token
247254
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
248255
switch content {
249256
case "application/x-www-form-urlencoded", "text/plain":
257+
// some endpoints such as GitHub return a query string https://docs.github.com/en/developers/apps/building-oauth-apps/authorizing-oauth-apps#response-1
250258
vals, err := url.ParseQuery(string(body))
251259
if err != nil {
252-
return nil, err
260+
if failureStatus {
261+
return nil, retrieveError
262+
}
263+
return nil, fmt.Errorf("oauth2: cannot parse response: %v", err)
253264
}
265+
retrieveError.ErrorCode = vals.Get("error")
266+
retrieveError.ErrorDescription = vals.Get("error_description")
267+
retrieveError.ErrorUri = vals.Get("error_uri")
254268
token = &Token{
255269
AccessToken: vals.Get("access_token"),
256270
TokenType: vals.Get("token_type"),
@@ -263,10 +277,17 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
263277
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
264278
}
265279
default:
280+
// spec says to return JSON https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
266281
var tj tokenJSON
267282
if err = json.Unmarshal(body, &tj); err != nil {
268-
return nil, err
283+
if failureStatus {
284+
return nil, retrieveError
285+
}
286+
return nil, fmt.Errorf("oauth2: cannot parse json: %v", err)
269287
}
288+
retrieveError.ErrorCode = tj.Error
289+
retrieveError.ErrorDescription = tj.ErrorDescription
290+
retrieveError.ErrorUri = tj.ErrorUri
270291
token = &Token{
271292
AccessToken: tj.AccessToken,
272293
TokenType: tj.TokenType,
@@ -276,15 +297,25 @@ func doTokenRoundTrip(ctx context.Context, req *http.Request) (*Token, error) {
276297
}
277298
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
278299
}
300+
// according to spec, servers should respond status 400 in error case
301+
// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
302+
// but some unorthodox servers respond 200 in error case
303+
if failureStatus || retrieveError.ErrorCode != "" {
304+
return nil, retrieveError
305+
}
279306
if token.AccessToken == "" {
280307
return nil, errors.New("oauth2: server response missing access_token")
281308
}
282309
return token, nil
283310
}
284311

312+
// mirrors oauth2.RetrieveError
285313
type RetrieveError struct {
286-
Response *http.Response
287-
Body []byte
314+
Response *http.Response
315+
Body []byte
316+
ErrorCode string
317+
ErrorDescription string
318+
ErrorUri string
288319
}
289320

290321
func (r *RetrieveError) Error() string {

oauth2_test.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ func TestTokenRetrieveError(t *testing.T) {
484484
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
485485
}
486486
w.Header().Set("Content-type", "application/json")
487+
// "The authorization server responds with an HTTP 400 (Bad Request)" https://www.rfc-editor.org/rfc/rfc6749#section-5.2
487488
w.WriteHeader(http.StatusBadRequest)
488489
w.Write([]byte(`{"error": "invalid_grant"}`))
489490
}))
@@ -493,7 +494,7 @@ func TestTokenRetrieveError(t *testing.T) {
493494
if err == nil {
494495
t.Fatalf("got no error, expected one")
495496
}
496-
_, ok := err.(*RetrieveError)
497+
re, ok := err.(*RetrieveError)
497498
if !ok {
498499
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
499500
}
@@ -502,6 +503,39 @@ func TestTokenRetrieveError(t *testing.T) {
502503
if errStr := err.Error(); errStr != expected {
503504
t.Fatalf("got %#v, expected %#v", errStr, expected)
504505
}
506+
expected = "invalid_grant"
507+
if re.ErrorCode != expected {
508+
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
509+
}
510+
}
511+
512+
// TestTokenRetrieveError200 tests handling of unorthodox server that returns 200 in error case
513+
func TestTokenRetrieveError200(t *testing.T) {
514+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
515+
if r.URL.String() != "/token" {
516+
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
517+
}
518+
w.Header().Set("Content-type", "application/json")
519+
w.Write([]byte(`{"error": "invalid_grant"}`))
520+
}))
521+
defer ts.Close()
522+
conf := newConf(ts.URL)
523+
_, err := conf.Exchange(context.Background(), "exchange-code")
524+
if err == nil {
525+
t.Fatalf("got no error, expected one")
526+
}
527+
re, ok := err.(*RetrieveError)
528+
if !ok {
529+
t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
530+
}
531+
expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "200 OK", `{"error": "invalid_grant"}`)
532+
if errStr := err.Error(); errStr != expected {
533+
t.Fatalf("got %#v, expected %#v", errStr, expected)
534+
}
535+
expected = "invalid_grant"
536+
if re.ErrorCode != expected {
537+
t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
538+
}
505539
}
506540

507541
func TestRefreshToken_RefreshTokenReplacement(t *testing.T) {

token.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,18 @@ func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error)
165165
}
166166

167167
// RetrieveError is the error returned when the token endpoint returns a
168-
// non-2XX HTTP status code.
168+
// non-2XX HTTP status code or populates rfc6749 error parameter.
169169
type RetrieveError struct {
170170
Response *http.Response
171171
// Body is the body that was consumed by reading Response.Body.
172172
// It may be truncated.
173173
Body []byte
174+
// rfc6749 error parameter https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
175+
ErrorCode string
176+
// rfc6749 error_description parameter https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
177+
ErrorDescription string
178+
// rfc6749 error_uri parameter https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
179+
ErrorUri string
174180
}
175181

176182
func (r *RetrieveError) Error() string {

0 commit comments

Comments
 (0)