diff --git a/client/nginx.go b/client/nginx.go index 87503eca..d5fff60f 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -35,6 +35,7 @@ var ( defaultBackup = false defaultDown = false defaultWeight = 1 + defaultTimeout = 10 * time.Second ) // ErrUnsupportedVer means that client's API version is not supported by NGINX plus API. @@ -46,6 +47,7 @@ type NginxClient struct { apiEndpoint string apiVersion int checkAPI bool + ctxTimeout time.Duration } type Option func(*NginxClient) @@ -546,6 +548,13 @@ func WithCheckAPI() Option { } } +// WithTimeout sets the timeout per request for the client. +func WithTimeout(duration time.Duration) Option { + return func(o *NginxClient) { + o.ctxTimeout = duration + } +} + // NewNginxClient creates a new NginxClient. func NewNginxClient(apiEndpoint string, opts ...Option) (*NginxClient, error) { c := &NginxClient{ @@ -553,6 +562,7 @@ func NewNginxClient(apiEndpoint string, opts ...Option) (*NginxClient, error) { apiEndpoint: apiEndpoint, apiVersion: APIVersion, checkAPI: false, + ctxTimeout: defaultTimeout, } for _, opt := range opts { @@ -567,8 +577,12 @@ func NewNginxClient(apiEndpoint string, opts ...Option) (*NginxClient, error) { return nil, fmt.Errorf("API version %v is not supported by the client", c.apiVersion) } + if c.ctxTimeout <= 0 { + return nil, fmt.Errorf("timeout has to be greater than 0 %v", c.ctxTimeout) + } + if c.checkAPI { - versions, err := getAPIVersions(c.httpClient, apiEndpoint) + versions, err := c.getAPIVersions(c.httpClient, apiEndpoint) if err != nil { return nil, fmt.Errorf("error accessing the API: %w", err) } @@ -596,8 +610,8 @@ func versionSupported(n int) bool { return false } -func getAPIVersions(httpClient *http.Client, endpoint string) (*versions, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) +func (client *NginxClient) getAPIVersions(httpClient *http.Client, endpoint string) (*versions, error) { + ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) @@ -852,7 +866,7 @@ func (client *NginxClient) getIDOfHTTPServer(upstream string, name string) (int, } func (client *NginxClient) get(path string, data interface{}) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) @@ -886,7 +900,7 @@ func (client *NginxClient) get(path string, data interface{}) error { } func (client *NginxClient) post(path string, input interface{}) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) @@ -918,7 +932,7 @@ func (client *NginxClient) post(path string, input interface{}) error { } func (client *NginxClient) delete(path string, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) @@ -943,7 +957,7 @@ func (client *NginxClient) delete(path string, expectedStatusCode int) error { } func (client *NginxClient) patch(path string, input interface{}, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) diff --git a/client/nginx_test.go b/client/nginx_test.go index 1f533972..2cb8a7e8 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -6,6 +6,7 @@ import ( "reflect" "strings" "testing" + "time" ) func TestDetermineUpdates(t *testing.T) { @@ -578,6 +579,27 @@ func TestClientWithAPIVersion(t *testing.T) { } } +func TestClientWithTimeout(t *testing.T) { + t.Parallel() + // Test creating a new client with a supported API version on the client + client, err := NewNginxClient("http://api-url", WithTimeout(1*time.Second)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if client == nil { + t.Fatalf("client is nil") + } + + // Test creating a new client with an invalid duration + client, err = NewNginxClient("http://api-url", WithTimeout(-1*time.Second)) + if err == nil { + t.Fatalf("expected error, but got nil") + } + if client != nil { + t.Fatalf("expected client to be nil, but got %v", client) + } +} + func TestClientWithHTTPClient(t *testing.T) { t.Parallel() // Test creating a new client passing a custom HTTP client