diff --git a/client/nginx.go b/client/nginx.go index 75e87c9e..7a76d41c 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -11,7 +11,10 @@ import ( "reflect" "slices" "strings" + "sync" "time" + + "golang.org/x/sync/errgroup" ) const ( @@ -123,6 +126,40 @@ func (internalError *internalError) Wrap(err string) *internalError { return internalError } +// this is an internal representation of the Stats object including endpoint and streamEndpoint lists. +type extendedStats struct { + endpoints []string + streamEndpoints []string + Stats +} + +func defaultStats() *extendedStats { + return &extendedStats{ + endpoints: []string{}, + streamEndpoints: []string{}, + Stats: Stats{ + Upstreams: map[string]Upstream{}, + ServerZones: map[string]ServerZone{}, + StreamServerZones: map[string]StreamServerZone{}, + StreamUpstreams: map[string]StreamUpstream{}, + Slabs: map[string]Slab{}, + Caches: map[string]HTTPCache{}, + HTTPLimitConnections: map[string]LimitConnection{}, + StreamLimitConnections: map[string]LimitConnection{}, + HTTPLimitRequests: map[string]HTTPLimitRequest{}, + Resolvers: map[string]Resolver{}, + LocationZones: map[string]LocationZone{}, + StreamZoneSync: nil, + Workers: []*Workers{}, + NginxInfo: NginxInfo{}, + SSL: SSL{}, + Connections: Connections{}, + HTTPRequests: HTTPRequests{}, + Processes: Processes{}, + }, + } +} + // Stats represents NGINX Plus stats fetched from the NGINX Plus API. // https://nginx.org/en/docs/http/ngx_http_api_module.html type Stats struct { @@ -890,9 +927,13 @@ func (client *NginxClient) getIDOfHTTPServer(upstream string, name string) (int, } func (client *NginxClient) get(path string, data interface{}) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() + return client.getWithContext(timeoutCtx, path, data) +} + +func (client *NginxClient) getWithContext(ctx context.Context, path string, data interface{}) error { url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) @@ -924,9 +965,13 @@ func (client *NginxClient) get(path string, data interface{}) error { } func (client *NginxClient) post(path string, input interface{}) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() + return client.postWithContext(timeoutCtx, path, input) +} + +func (client *NginxClient) postWithContext(ctx context.Context, path string, input interface{}) error { url := fmt.Sprintf("%v/%v/%v", client.apiEndpoint, client.apiVersion, path) jsonInput, err := json.Marshal(input) @@ -956,9 +1001,13 @@ func (client *NginxClient) post(path string, input interface{}) error { } func (client *NginxClient) delete(path string, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() + return client.deleteWithContext(timeoutCtx, path, expectedStatusCode) +} + +func (client *NginxClient) deleteWithContext(ctx context.Context, path string, expectedStatusCode int) error { path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, path, nil) @@ -981,9 +1030,13 @@ func (client *NginxClient) delete(path string, expectedStatusCode int) error { } func (client *NginxClient) patch(path string, input interface{}, expectedStatusCode int) error { - ctx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) + timeoutCtx, cancel := context.WithTimeout(context.Background(), client.ctxTimeout) defer cancel() + return client.patchWithContext(timeoutCtx, path, input, expectedStatusCode) +} + +func (client *NginxClient) patchWithContext(ctx context.Context, path string, input interface{}, expectedStatusCode int) error { path = fmt.Sprintf("%v/%v/%v/", client.apiEndpoint, client.apiVersion, path) jsonInput, err := json.Marshal(input) @@ -1199,149 +1252,321 @@ func determineStreamUpdates(updatedServers []StreamUpstreamServer, nginxServers return } -// GetStats gets process, slab, connection, request, ssl, zone, stream zone, upstream and stream upstream related stats from the NGINX Plus API. -func (client *NginxClient) GetStats() (*Stats, error) { - endpoints, err := client.GetAvailableEndpoints() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } +// GetStatsWithContext gets process, slab, connection, request, ssl, zone, stream zone, upstream and stream upstream related stats from the NGINX Plus API. +func (client *NginxClient) GetStatsWithContext(ctx context.Context) (*Stats, error) { + initialGroup, initialCtx := errgroup.WithContext(ctx) + var mu sync.Mutex + stats := defaultStats() + // Collecting initial stats + initialGroup.Go(func() error { + endpoints, err := client.GetAvailableEndpointsWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get available Endpoints: %w", err) + } - info, err := client.GetNginxInfo() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.endpoints = endpoints + mu.Unlock() + return nil + }) - caches, err := client.GetCaches() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + initialGroup.Go(func() error { + nginxInfo, err := client.GetNginxInfoWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get NGINX info: %w", err) + } - processes, err := client.GetProcesses() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.NginxInfo = *nginxInfo + mu.Unlock() - slabs, err := client.GetSlabs() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + return nil + }) - cons, err := client.GetConnections() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + initialGroup.Go(func() error { + caches, err := client.GetCachesWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Caches: %w", err) + } - requests, err := client.GetHTTPRequests() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.Caches = *caches + mu.Unlock() - ssl, err := client.GetSSL() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + return nil + }) - zones, err := client.GetServerZones() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + initialGroup.Go(func() error { + processes, err := client.GetProcessesWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Process information: %w", err) + } - upstreams, err := client.GetUpstreams() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.Processes = *processes + mu.Unlock() - locationZones, err := client.GetLocationZones() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + return nil + }) - resolvers, err := client.GetResolvers() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + initialGroup.Go(func() error { + slabs, err := client.GetSlabsWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Slabs: %w", err) + } - limitReqs, err := client.GetHTTPLimitReqs() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.Slabs = *slabs + mu.Unlock() - limitConnsHTTP, err := client.GetHTTPConnectionsLimit() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + return nil + }) - workers, err := client.GetWorkers() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + initialGroup.Go(func() error { + httpRequests, err := client.GetHTTPRequestsWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get HTTP Requests: %w", err) + } + + mu.Lock() + stats.HTTPRequests = *httpRequests + mu.Unlock() - streamZones := &StreamServerZones{} - streamUpstreams := &StreamUpstreams{} - limitConnsStream := &StreamLimitConnections{} - var streamZoneSync *StreamZoneSync + return nil + }) - if slices.Contains(endpoints, "stream") { - streamEndpoints, err := client.GetAvailableStreamEndpoints() + initialGroup.Go(func() error { + ssl, err := client.GetSSLWithContext(initialCtx) if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) + return fmt.Errorf("failed to get SSL: %w", err) } - if slices.Contains(streamEndpoints, "server_zones") { - streamZones, err = client.GetStreamServerZones() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.SSL = *ssl + mu.Unlock() + + return nil + }) + + initialGroup.Go(func() error { + serverZones, err := client.GetServerZonesWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Server Zones: %w", err) } - if slices.Contains(streamEndpoints, "upstreams") { - streamUpstreams, err = client.GetStreamUpstreams() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.ServerZones = *serverZones + mu.Unlock() + + return nil + }) + + initialGroup.Go(func() error { + upstreams, err := client.GetUpstreamsWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Upstreams: %w", err) } - if slices.Contains(streamEndpoints, "limit_conns") { - limitConnsStream, err = client.GetStreamConnectionsLimit() - if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) - } + mu.Lock() + stats.Upstreams = *upstreams + mu.Unlock() + + return nil + }) + + initialGroup.Go(func() error { + locationZones, err := client.GetLocationZonesWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Location Zones: %w", err) + } + + mu.Lock() + stats.LocationZones = *locationZones + mu.Unlock() + + return nil + }) + + initialGroup.Go(func() error { + resolvers, err := client.GetResolversWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Resolvers: %w", err) + } + + mu.Lock() + stats.Resolvers = *resolvers + mu.Unlock() + + return nil + }) + + initialGroup.Go(func() error { + httpLimitRequests, err := client.GetHTTPLimitReqsWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get HTTPLimitRequests: %w", err) + } + + mu.Lock() + stats.HTTPLimitRequests = *httpLimitRequests + mu.Unlock() + + return nil + }) + + initialGroup.Go(func() error { + httpLimitConnections, err := client.GetHTTPConnectionsLimitWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get HTTPLimitConnections: %w", err) } - if slices.Contains(streamEndpoints, "zone_sync") { - streamZoneSync, err = client.GetStreamZoneSync() + mu.Lock() + stats.HTTPLimitConnections = *httpLimitConnections + mu.Unlock() + + return nil + }) + + initialGroup.Go(func() error { + workers, err := client.GetWorkersWithContext(initialCtx) + if err != nil { + return fmt.Errorf("failed to get Workers: %w", err) + } + + mu.Lock() + stats.Workers = workers + mu.Unlock() + + return nil + }) + + if err := initialGroup.Wait(); err != nil { + return nil, fmt.Errorf("error returned from contacting Plus API: %w", err) + } + + // Process stream endpoints if they exist + if slices.Contains(stats.endpoints, "stream") { + availableStreamGroup, asgCtx := errgroup.WithContext(ctx) + + availableStreamGroup.Go(func() error { + streamEndpoints, err := client.GetAvailableStreamEndpointsWithContext(asgCtx) if err != nil { - return nil, fmt.Errorf("failed to get stats: %w", err) + return fmt.Errorf("failed to get available Stream Endpoints: %w", err) } + + mu.Lock() + stats.streamEndpoints = streamEndpoints + mu.Unlock() + + return nil + }) + + if err := availableStreamGroup.Wait(); err != nil { + return nil, fmt.Errorf("no useful metrics found in stream stats: %w", err) + } + + streamGroup, sgCtx := errgroup.WithContext(ctx) + + if slices.Contains(stats.streamEndpoints, "server_zones") { + streamGroup.Go(func() error { + streamServerZones, err := client.GetStreamServerZonesWithContext(sgCtx) + if err != nil { + return fmt.Errorf("failed to get streamServerZones: %w", err) + } + + mu.Lock() + stats.StreamServerZones = *streamServerZones + mu.Unlock() + + return nil + }) + } + + if slices.Contains(stats.streamEndpoints, "upstreams") { + streamGroup.Go(func() error { + streamUpstreams, err := client.GetStreamUpstreamsWithContext(sgCtx) + if err != nil { + return fmt.Errorf("failed to get StreamUpstreams: %w", err) + } + + mu.Lock() + stats.StreamUpstreams = *streamUpstreams + mu.Unlock() + + return nil + }) + } + + if slices.Contains(stats.streamEndpoints, "limit_conns") { + streamGroup.Go(func() error { + streamConnectionsLimit, err := client.GetStreamConnectionsLimitWithContext(sgCtx) + if err != nil { + return fmt.Errorf("failed to get StreamLimitConnections: %w", err) + } + + mu.Lock() + stats.StreamLimitConnections = *streamConnectionsLimit + mu.Unlock() + + return nil + }) + + streamGroup.Go(func() error { + streamZoneSync, err := client.GetStreamZoneSyncWithContext(sgCtx) + if err != nil { + return fmt.Errorf("failed to get StreamZoneSync: %w", err) + } + + mu.Lock() + stats.StreamZoneSync = streamZoneSync + mu.Unlock() + + return nil + }) + } + + if err := streamGroup.Wait(); err != nil { + return nil, fmt.Errorf("no useful metrics found in stream stats: %w", err) } } - return &Stats{ - NginxInfo: *info, - Caches: *caches, - Processes: *processes, - Slabs: *slabs, - Connections: *cons, - HTTPRequests: *requests, - SSL: *ssl, - ServerZones: *zones, - StreamServerZones: *streamZones, - Upstreams: *upstreams, - StreamUpstreams: *streamUpstreams, - StreamZoneSync: streamZoneSync, - LocationZones: *locationZones, - Resolvers: *resolvers, - HTTPLimitRequests: *limitReqs, - HTTPLimitConnections: *limitConnsHTTP, - StreamLimitConnections: *limitConnsStream, - Workers: workers, - }, nil + // Report connection metrics separately so it does not influence the results + connectionsGroup, cgCtx := errgroup.WithContext(ctx) + + connectionsGroup.Go(func() error { + // replace this call with a context specific call + connections, err := client.GetConnectionsWithContext(cgCtx) + if err != nil { + return fmt.Errorf("failed to get connections: %w", err) + } + + mu.Lock() + stats.Connections = *connections + mu.Unlock() + + return nil + }) + + if err := connectionsGroup.Wait(); err != nil { + return nil, fmt.Errorf("connections metrics not found: %w", err) + } + + return &stats.Stats, nil +} + +// GetStats gets process, slab, connection, request, ssl, zone, stream zone, upstream and stream upstream related stats from the NGINX Plus API. +func (client *NginxClient) GetStats() (*Stats, error) { + return client.GetStatsWithContext(context.Background()) } // GetAvailableEndpoints returns available endpoints in the API. func (client *NginxClient) GetAvailableEndpoints() ([]string, error) { + return client.GetAvailableEndpointsWithContext(context.Background()) +} + +// GetAvailableEndpointsWithContext returns available endpoints in the API. +func (client *NginxClient) GetAvailableEndpointsWithContext(ctx context.Context) ([]string, error) { var endpoints []string - err := client.get("", &endpoints) + err := client.getWithContext(ctx, "", &endpoints) if err != nil { return nil, fmt.Errorf("failed to get endpoints: %w", err) } @@ -1350,8 +1575,13 @@ func (client *NginxClient) GetAvailableEndpoints() ([]string, error) { // GetAvailableStreamEndpoints returns available stream endpoints in the API. func (client *NginxClient) GetAvailableStreamEndpoints() ([]string, error) { + return client.GetAvailableStreamEndpointsWithContext(context.Background()) +} + +// GetAvailableStreamEndpointsWithContext returns available stream endpoints in the API with a context. +func (client *NginxClient) GetAvailableStreamEndpointsWithContext(ctx context.Context) ([]string, error) { var endpoints []string - err := client.get("stream", &endpoints) + err := client.getWithContext(ctx, "stream", &endpoints) if err != nil { return nil, fmt.Errorf("failed to get endpoints: %w", err) } @@ -1360,8 +1590,13 @@ func (client *NginxClient) GetAvailableStreamEndpoints() ([]string, error) { // GetNginxInfo returns Nginx stats. func (client *NginxClient) GetNginxInfo() (*NginxInfo, error) { + return client.GetNginxInfoWithContext(context.Background()) +} + +// GetNginxInfoWithContext returns Nginx stats with a context. +func (client *NginxClient) GetNginxInfoWithContext(ctx context.Context) (*NginxInfo, error) { var info NginxInfo - err := client.get("nginx", &info) + err := client.getWithContext(ctx, "nginx", &info) if err != nil { return nil, fmt.Errorf("failed to get info: %w", err) } @@ -1370,8 +1605,13 @@ func (client *NginxClient) GetNginxInfo() (*NginxInfo, error) { // GetCaches returns Cache stats. func (client *NginxClient) GetCaches() (*Caches, error) { + return client.GetCachesWithContext(context.Background()) +} + +// GetCachesWithContext returns Cache stats with a context. +func (client *NginxClient) GetCachesWithContext(ctx context.Context) (*Caches, error) { var caches Caches - err := client.get("http/caches", &caches) + err := client.getWithContext(ctx, "http/caches", &caches) if err != nil { return nil, fmt.Errorf("failed to get caches: %w", err) } @@ -1380,8 +1620,13 @@ func (client *NginxClient) GetCaches() (*Caches, error) { // GetSlabs returns Slabs stats. func (client *NginxClient) GetSlabs() (*Slabs, error) { + return client.GetSlabsWithContext(context.Background()) +} + +// GetSlabsWithContext returns Slabs stats with a context. +func (client *NginxClient) GetSlabsWithContext(ctx context.Context) (*Slabs, error) { var slabs Slabs - err := client.get("slabs", &slabs) + err := client.getWithContext(ctx, "slabs", &slabs) if err != nil { return nil, fmt.Errorf("failed to get slabs: %w", err) } @@ -1390,8 +1635,13 @@ func (client *NginxClient) GetSlabs() (*Slabs, error) { // GetConnections returns Connections stats. func (client *NginxClient) GetConnections() (*Connections, error) { + return client.GetConnectionsWithContext(context.Background()) +} + +// GetConnectionsWithContext returns Connections stats with a context. +func (client *NginxClient) GetConnectionsWithContext(ctx context.Context) (*Connections, error) { var cons Connections - err := client.get("connections", &cons) + err := client.getWithContext(ctx, "connections", &cons) if err != nil { return nil, fmt.Errorf("failed to get connections: %w", err) } @@ -1400,8 +1650,13 @@ func (client *NginxClient) GetConnections() (*Connections, error) { // GetHTTPRequests returns http/requests stats. func (client *NginxClient) GetHTTPRequests() (*HTTPRequests, error) { + return client.GetHTTPRequestsWithContext(context.Background()) +} + +// GetHTTPRequestsWithContext returns http/requests stats with a context. +func (client *NginxClient) GetHTTPRequestsWithContext(ctx context.Context) (*HTTPRequests, error) { var requests HTTPRequests - err := client.get("http/requests", &requests) + err := client.getWithContext(ctx, "http/requests", &requests) if err != nil { return nil, fmt.Errorf("failed to get http requests: %w", err) } @@ -1410,8 +1665,13 @@ func (client *NginxClient) GetHTTPRequests() (*HTTPRequests, error) { // GetSSL returns SSL stats. func (client *NginxClient) GetSSL() (*SSL, error) { + return client.GetSSLWithContext(context.Background()) +} + +// GetSSLWithContext returns SSL stats with a context. +func (client *NginxClient) GetSSLWithContext(ctx context.Context) (*SSL, error) { var ssl SSL - err := client.get("ssl", &ssl) + err := client.getWithContext(ctx, "ssl", &ssl) if err != nil { return nil, fmt.Errorf("failed to get ssl: %w", err) } @@ -1420,8 +1680,13 @@ func (client *NginxClient) GetSSL() (*SSL, error) { // GetServerZones returns http/server_zones stats. func (client *NginxClient) GetServerZones() (*ServerZones, error) { + return client.GetServerZonesWithContext(context.Background()) +} + +// GetServerZonesWithContext returns http/server_zones stats with a context. +func (client *NginxClient) GetServerZonesWithContext(ctx context.Context) (*ServerZones, error) { var zones ServerZones - err := client.get("http/server_zones", &zones) + err := client.getWithContext(ctx, "http/server_zones", &zones) if err != nil { return nil, fmt.Errorf("failed to get server zones: %w", err) } @@ -1430,8 +1695,13 @@ func (client *NginxClient) GetServerZones() (*ServerZones, error) { // GetStreamServerZones returns stream/server_zones stats. func (client *NginxClient) GetStreamServerZones() (*StreamServerZones, error) { + return client.GetStreamServerZonesWithContext(context.Background()) +} + +// GetStreamServerZonesWithContext returns stream/server_zones stats with a context. +func (client *NginxClient) GetStreamServerZonesWithContext(ctx context.Context) (*StreamServerZones, error) { var zones StreamServerZones - err := client.get("stream/server_zones", &zones) + err := client.getWithContext(ctx, "stream/server_zones", &zones) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1446,8 +1716,13 @@ func (client *NginxClient) GetStreamServerZones() (*StreamServerZones, error) { // GetUpstreams returns http/upstreams stats. func (client *NginxClient) GetUpstreams() (*Upstreams, error) { + return client.GetUpstreamsWithContext(context.Background()) +} + +// GetUpstreamsWithContext returns http/upstreams stats with a context. +func (client *NginxClient) GetUpstreamsWithContext(ctx context.Context) (*Upstreams, error) { var upstreams Upstreams - err := client.get("http/upstreams", &upstreams) + err := client.getWithContext(ctx, "http/upstreams", &upstreams) if err != nil { return nil, fmt.Errorf("failed to get upstreams: %w", err) } @@ -1456,8 +1731,13 @@ func (client *NginxClient) GetUpstreams() (*Upstreams, error) { // GetStreamUpstreams returns stream/upstreams stats. func (client *NginxClient) GetStreamUpstreams() (*StreamUpstreams, error) { + return client.GetStreamUpstreamsWithContext(context.Background()) +} + +// GetStreamUpstreamsWithContext returns stream/upstreams stats with a context. +func (client *NginxClient) GetStreamUpstreamsWithContext(ctx context.Context) (*StreamUpstreams, error) { var upstreams StreamUpstreams - err := client.get("stream/upstreams", &upstreams) + err := client.getWithContext(ctx, "stream/upstreams", &upstreams) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1472,8 +1752,13 @@ func (client *NginxClient) GetStreamUpstreams() (*StreamUpstreams, error) { // GetStreamZoneSync returns stream/zone_sync stats. func (client *NginxClient) GetStreamZoneSync() (*StreamZoneSync, error) { + return client.GetStreamZoneSyncWithContext(context.Background()) +} + +// GetStreamZoneSyncWithContext returns stream/zone_sync stats with a context. +func (client *NginxClient) GetStreamZoneSyncWithContext(ctx context.Context) (*StreamZoneSync, error) { var streamZoneSync StreamZoneSync - err := client.get("stream/zone_sync", &streamZoneSync) + err := client.getWithContext(ctx, "stream/zone_sync", &streamZoneSync) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1489,11 +1774,16 @@ func (client *NginxClient) GetStreamZoneSync() (*StreamZoneSync, error) { // GetLocationZones returns http/location_zones stats. func (client *NginxClient) GetLocationZones() (*LocationZones, error) { + return client.GetLocationZonesWithContext(context.Background()) +} + +// GetLocationZonesWithContext returns http/location_zones stats with a context. +func (client *NginxClient) GetLocationZonesWithContext(ctx context.Context) (*LocationZones, error) { var locationZones LocationZones if client.apiVersion < 5 { return &locationZones, nil } - err := client.get("http/location_zones", &locationZones) + err := client.getWithContext(ctx, "http/location_zones", &locationZones) if err != nil { return nil, fmt.Errorf("failed to get location zones: %w", err) } @@ -1503,11 +1793,16 @@ func (client *NginxClient) GetLocationZones() (*LocationZones, error) { // GetResolvers returns Resolvers stats. func (client *NginxClient) GetResolvers() (*Resolvers, error) { + return client.GetResolversWithContext(context.Background()) +} + +// GetResolversWithContext returns Resolvers stats with a context. +func (client *NginxClient) GetResolversWithContext(ctx context.Context) (*Resolvers, error) { var resolvers Resolvers if client.apiVersion < 5 { return &resolvers, nil } - err := client.get("resolvers", &resolvers) + err := client.getWithContext(ctx, "resolvers", &resolvers) if err != nil { return nil, fmt.Errorf("failed to get resolvers: %w", err) } @@ -1517,8 +1812,13 @@ func (client *NginxClient) GetResolvers() (*Resolvers, error) { // GetProcesses returns Processes stats. func (client *NginxClient) GetProcesses() (*Processes, error) { + return client.GetProcessesWithContext(context.Background()) +} + +// GetProcessesWithContext returns Processes stats with a context. +func (client *NginxClient) GetProcessesWithContext(ctx context.Context) (*Processes, error) { var processes Processes - err := client.get("processes", &processes) + err := client.getWithContext(ctx, "processes", &processes) if err != nil { return nil, fmt.Errorf("failed to get processes: %w", err) } @@ -1748,11 +2048,16 @@ func addPortToServer(server string) string { // GetHTTPLimitReqs returns http/limit_reqs stats. func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) { + return client.GetHTTPLimitReqsWithContext(context.Background()) +} + +// GetHTTPLimitReqsWithContext returns http/limit_reqs stats with a context. +func (client *NginxClient) GetHTTPLimitReqsWithContext(ctx context.Context) (*HTTPLimitRequests, error) { var limitReqs HTTPLimitRequests if client.apiVersion < 6 { return &limitReqs, nil } - err := client.get("http/limit_reqs", &limitReqs) + err := client.getWithContext(ctx, "http/limit_reqs", &limitReqs) if err != nil { return nil, fmt.Errorf("failed to get http limit requests: %w", err) } @@ -1761,11 +2066,16 @@ func (client *NginxClient) GetHTTPLimitReqs() (*HTTPLimitRequests, error) { // GetHTTPConnectionsLimit returns http/limit_conns stats. func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, error) { + return client.GetHTTPConnectionsLimitWithContext(context.Background()) +} + +// GetHTTPConnectionsLimitWithContext returns http/limit_conns stats with a context. +func (client *NginxClient) GetHTTPConnectionsLimitWithContext(ctx context.Context) (*HTTPLimitConnections, error) { var limitConns HTTPLimitConnections if client.apiVersion < 6 { return &limitConns, nil } - err := client.get("http/limit_conns", &limitConns) + err := client.getWithContext(ctx, "http/limit_conns", &limitConns) if err != nil { return nil, fmt.Errorf("failed to get http connections limit: %w", err) } @@ -1774,11 +2084,16 @@ func (client *NginxClient) GetHTTPConnectionsLimit() (*HTTPLimitConnections, err // GetStreamConnectionsLimit returns stream/limit_conns stats. func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections, error) { + return client.GetStreamConnectionsLimitWithContext(context.Background()) +} + +// GetStreamConnectionsLimitWithContext returns stream/limit_conns stats with a context. +func (client *NginxClient) GetStreamConnectionsLimitWithContext(ctx context.Context) (*StreamLimitConnections, error) { var limitConns StreamLimitConnections if client.apiVersion < 6 { return &limitConns, nil } - err := client.get("stream/limit_conns", &limitConns) + err := client.getWithContext(ctx, "stream/limit_conns", &limitConns) if err != nil { var ie *internalError if errors.As(err, &ie) { @@ -1793,11 +2108,16 @@ func (client *NginxClient) GetStreamConnectionsLimit() (*StreamLimitConnections, // GetWorkers returns workers stats. func (client *NginxClient) GetWorkers() ([]*Workers, error) { + return client.GetWorkersWithContext(context.Background()) +} + +// GetWorkersWithContext returns workers stats with a context. +func (client *NginxClient) GetWorkersWithContext(ctx context.Context) ([]*Workers, error) { var workers []*Workers if client.apiVersion < 9 { return workers, nil } - err := client.get("workers", &workers) + err := client.getWithContext(ctx, "workers", &workers) if err != nil { return nil, fmt.Errorf("failed to get workers: %w", err) } diff --git a/client/nginx_test.go b/client/nginx_test.go index 312cf73c..0ad650c9 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -1,10 +1,12 @@ package client import ( + "context" "net/http" "net/http/httptest" "reflect" "strings" + "sync" "testing" "time" ) @@ -622,23 +624,44 @@ func TestClientWithHTTPClient(t *testing.T) { } func TestGetStats_NoStreamEndpoint(t *testing.T) { + tests := []struct { + ctx context.Context + name string + }{ + { + ctx: nil, + name: "no context test", + }, + { + ctx: context.Background(), + name: "with context test", + }, + } + var err error + var client *NginxClient + var writeLock sync.Mutex + t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeLock.Lock() + defer writeLock.Unlock() + switch { case r.RequestURI == "/": - _, err := w.Write([]byte(`[4, 5, 6, 7, 8, 9]`)) + + _, err = w.Write([]byte(`[4, 5, 6, 7, 8, 9]`)) if err != nil { t.Fatalf("unexpected error: %v", err) } case r.RequestURI == "/7/": - _, err := w.Write([]byte(`["nginx","processes","connections","slabs","http","resolvers","ssl"]`)) + _, err = w.Write([]byte(`["nginx","processes","connections","slabs","http","resolvers","ssl"]`)) if err != nil { t.Fatalf("unexpected error: %v", err) } case strings.HasPrefix(r.RequestURI, "/7/stream"): t.Fatal("Stream endpoint should not be called since it does not exist.") default: - _, err := w.Write([]byte(`{}`)) + _, err = w.Write([]byte(`{}`)) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -647,7 +670,7 @@ func TestGetStats_NoStreamEndpoint(t *testing.T) { defer ts.Close() // Test creating a new client with a supported API version on the server - client, err := NewNginxClient(ts.URL, WithAPIVersion(7), WithCheckAPI()) + client, err = NewNginxClient(ts.URL, WithAPIVersion(7), WithCheckAPI()) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -655,9 +678,19 @@ func TestGetStats_NoStreamEndpoint(t *testing.T) { t.Fatalf("client is nil") } - stats, err := client.GetStats() - if err != nil { - t.Fatalf("unexpected error: %v", err) + var stats *Stats + for _, test := range tests { + if test.ctx == nil { + stats, err = client.GetStats() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } else { + stats, err = client.GetStatsWithContext(test.ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } } if !reflect.DeepEqual(stats.StreamServerZones, StreamServerZones{}) { @@ -675,6 +708,20 @@ func TestGetStats_NoStreamEndpoint(t *testing.T) { } func TestGetStats_SSL(t *testing.T) { + tests := []struct { + ctx context.Context + name string + }{ + { + ctx: nil, + name: "no context test", + }, + { + ctx: context.Background(), + name: "with context test", + }, + } + t.Parallel() ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -708,6 +755,11 @@ func TestGetStats_SSL(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + case strings.HasPrefix(r.RequestURI, "/8/stream"): + _, err := w.Write([]byte(`[""]`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } default: _, err := w.Write([]byte(`{}`)) if err != nil { @@ -726,30 +778,41 @@ func TestGetStats_SSL(t *testing.T) { t.Fatalf("client is nil") } - stats, err := client.GetStats() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + var stats *Stats - testStats := SSL{ - Handshakes: 79572, - HandshakesFailed: 21025, - SessionReuses: 15762, - NoCommonProtocol: 4, - NoCommonCipher: 2, - HandshakeTimeout: 0, - PeerRejectedCert: 0, - VerifyFailures: VerifyFailures{ - NoCert: 0, - ExpiredCert: 2, - RevokedCert: 1, - HostnameMismatch: 2, - Other: 1, - }, - } + for _, test := range tests { + if test.ctx == nil { + stats, err = client.GetStats() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } else { + stats, err = client.GetStatsWithContext(test.ctx) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + } + + testStats := SSL{ + Handshakes: 79572, + HandshakesFailed: 21025, + SessionReuses: 15762, + NoCommonProtocol: 4, + NoCommonCipher: 2, + HandshakeTimeout: 0, + PeerRejectedCert: 0, + VerifyFailures: VerifyFailures{ + NoCert: 0, + ExpiredCert: 2, + RevokedCert: 1, + HostnameMismatch: 2, + Other: 1, + }, + } - if !reflect.DeepEqual(stats.SSL, testStats) { - t.Fatalf("SSL stats: expected %v, actual %v", testStats, stats.SSL) + if !reflect.DeepEqual(stats.SSL, testStats) { + t.Fatalf("SSL stats: expected %v, actual %v", testStats, stats.SSL) + } } } diff --git a/go.mod b/go.mod index 5b6000bc..918fd882 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/nginxinc/nginx-plus-go-client go 1.22.6 + +require golang.org/x/sync v0.8.0 diff --git a/go.sum b/go.sum index e69de29b..e584c1bd 100644 --- a/go.sum +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=