@@ -2,6 +2,7 @@ package config
2
2
3
3
import (
4
4
"context"
5
+ "errors"
5
6
"fmt"
6
7
"io/ioutil"
7
8
"net/http"
@@ -15,9 +16,11 @@ import (
15
16
"time"
16
17
17
18
"github.com/aws/aws-sdk-go-v2/aws"
19
+ "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
18
20
"github.com/aws/aws-sdk-go-v2/internal/awstesting"
19
21
"github.com/aws/aws-sdk-go-v2/service/sso"
20
22
"github.com/aws/aws-sdk-go-v2/service/sts"
23
+ "github.com/aws/smithy-go"
21
24
"github.com/aws/smithy-go/middleware"
22
25
smithytime "github.com/aws/smithy-go/time"
23
26
)
@@ -471,3 +474,122 @@ func TestResolveCredentialsCacheOptions(t *testing.T) {
471
474
t .Errorf ("expect options to be called" )
472
475
}
473
476
}
477
+
478
+ func TestResolveCredentialsIMDSClient (t * testing.T ) {
479
+ expectEnabled := func (t * testing.T , err error ) {
480
+ if err == nil {
481
+ t .Fatalf ("expect error got none" )
482
+ }
483
+ if e , a := "expected HTTP client error" , err .Error (); ! strings .Contains (a , e ) {
484
+ t .Fatalf ("expected %v error in %v" , e , a )
485
+ }
486
+ }
487
+
488
+ expectDisabled := func (t * testing.T , err error ) {
489
+ var oe * smithy.OperationError
490
+ if ! errors .As (err , & oe ) {
491
+ t .Fatalf ("unexpected error: %v" , err )
492
+ } else {
493
+ e := errors .Unwrap (oe )
494
+ if e == nil {
495
+ t .Fatalf ("unexpected empty operation error: %v" , oe )
496
+ } else {
497
+ if ! strings .HasPrefix (e .Error (), "access disabled to EC2 IMDS" ) {
498
+ t .Fatalf ("unexpected operation error: %v" , oe )
499
+ }
500
+ }
501
+ }
502
+ }
503
+
504
+ testcases := map [string ]struct {
505
+ enabledState imds.ClientEnableState
506
+ envvar string
507
+ expectedState imds.ClientEnableState
508
+ expectedError func (* testing.T , error )
509
+ }{
510
+ "default no options" : {
511
+ expectedState : imds .ClientDefaultEnableState ,
512
+ expectedError : expectEnabled ,
513
+ },
514
+
515
+ "state enabled" : {
516
+ enabledState : imds .ClientEnabled ,
517
+ expectedState : imds .ClientEnabled ,
518
+ expectedError : expectEnabled ,
519
+ },
520
+ "state disabled" : {
521
+ enabledState : imds .ClientDisabled ,
522
+ expectedState : imds .ClientDisabled ,
523
+ expectedError : expectDisabled ,
524
+ },
525
+
526
+ "env var DISABLED true" : {
527
+ envvar : "true" ,
528
+ expectedState : imds .ClientDisabled ,
529
+ expectedError : expectDisabled ,
530
+ },
531
+ "env var DISABLED false" : {
532
+ envvar : "false" ,
533
+ expectedState : imds .ClientEnabled ,
534
+ expectedError : expectEnabled ,
535
+ },
536
+
537
+ "option state enabled overrides env var DISABLED true" : {
538
+ enabledState : imds .ClientEnabled ,
539
+ envvar : "true" ,
540
+ expectedState : imds .ClientEnabled ,
541
+ expectedError : expectEnabled ,
542
+ },
543
+ "option state disabled overrides env var DISABLED false" : {
544
+ enabledState : imds .ClientDisabled ,
545
+ envvar : "false" ,
546
+ expectedState : imds .ClientDisabled ,
547
+ expectedError : expectDisabled ,
548
+ },
549
+ }
550
+
551
+ for name , tc := range testcases {
552
+ t .Run (name , func (t * testing.T ) {
553
+ restoreEnv := awstesting .StashEnv ()
554
+ defer awstesting .PopEnv (restoreEnv )
555
+
556
+ var httpClient HTTPClient
557
+ if tc .expectedState == imds .ClientDisabled {
558
+ httpClient = stubErrorClient {err : fmt .Errorf ("expect HTTP client not to be called" )}
559
+ } else {
560
+ httpClient = stubErrorClient {err : fmt .Errorf ("expected HTTP client error" )}
561
+ }
562
+
563
+ opts := []func (* LoadOptions ) error {
564
+ WithRetryer (func () aws.Retryer { return aws.NopRetryer {} }),
565
+ WithHTTPClient (httpClient ),
566
+ }
567
+
568
+ if tc .enabledState != imds .ClientDefaultEnableState {
569
+ opts = append (opts ,
570
+ WithEC2IMDSClientEnableState (tc .enabledState ),
571
+ )
572
+ }
573
+
574
+ if tc .envvar != "" {
575
+ os .Setenv ("AWS_EC2_METADATA_DISABLED" , tc .envvar )
576
+ }
577
+
578
+ c , err := LoadDefaultConfig (context .TODO (), opts ... )
579
+ if err != nil {
580
+ t .Fatalf ("could not load config: %s" , err )
581
+ }
582
+
583
+ creds := c .Credentials
584
+
585
+ _ , err = creds .Retrieve (context .TODO ())
586
+ tc .expectedError (t , err )
587
+ })
588
+ }
589
+ }
590
+
591
+ type stubErrorClient struct {
592
+ err error
593
+ }
594
+
595
+ func (c stubErrorClient ) Do (* http.Request ) (* http.Response , error ) { return nil , c .err }
0 commit comments