From 4a49fd53ef8f62096ce7772c8bf89ad085f9fe6e Mon Sep 17 00:00:00 2001 From: Dylan WAY Date: Wed, 15 Jan 2025 11:51:46 -0700 Subject: [PATCH 1/2] Add Content-Type header to NGINX client PATCH requests --- client/nginx.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/nginx.go b/client/nginx.go index 4317ad95..dd712fcc 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -1046,6 +1046,7 @@ func (client *NginxClient) patch(ctx context.Context, path string, input interfa if err != nil { return fmt.Errorf("failed to create a patch request: %w", err) } + req.Header.Set("Content-Type", "application/json") resp, err := client.httpClient.Do(req) if err != nil { From 30363d4ff3f4459b2eb74a2c11c1ddfee9ff8a89 Mon Sep 17 00:00:00 2001 From: Dylan WAY Date: Wed, 22 Jan 2025 11:14:40 -0700 Subject: [PATCH 2/2] Update to reduce unnecessary API calls and sanitize input Update UpdateHTTPServers and UpdateStreamServers: - No longer make extra GET requests for each PUT and DELETE request. - Removes identical duplicate servers. - Returns errors for duplicate servers with different parameters. --- client/nginx.go | 206 ++++++++++++---- client/nginx_test.go | 550 +++++++++++++++++++++++++++++++++---------- tests/client_test.go | 70 +++++- 3 files changed, 642 insertions(+), 184 deletions(-) diff --git a/client/nginx.go b/client/nginx.go index dd712fcc..4a44d98c 100644 --- a/client/nginx.go +++ b/client/nginx.go @@ -48,6 +48,7 @@ var ( ErrServerExists = errors.New("server already exists") ErrNotSupported = errors.New("not supported") ErrInvalidTimeout = errors.New("invalid timeout") + ErrParameterMismatch = errors.New("encountered duplicate server with different parameters") ErrPlusVersionNotFound = errors.New("plus version not found in the input string") ) @@ -775,9 +776,13 @@ func (client *NginxClient) AddHTTPServer(ctx context.Context, upstream string, s if id != -1 { return fmt.Errorf("failed to add %v server to %v upstream: %w", server.Server, upstream, ErrServerExists) } + err = client.addHTTPServer(ctx, upstream, server) + return err +} +func (client *NginxClient) addHTTPServer(ctx context.Context, upstream string, server UpstreamServer) error { path := fmt.Sprintf("http/upstreams/%v/servers/", upstream) - err = client.post(ctx, path, &server) + err := client.post(ctx, path, &server) if err != nil { return fmt.Errorf("failed to add %v server to %v upstream: %w", server.Server, upstream, err) } @@ -794,9 +799,13 @@ func (client *NginxClient) DeleteHTTPServer(ctx context.Context, upstream string if id == -1 { return fmt.Errorf("failed to remove %v server from %v upstream: %w", server, upstream, ErrServerNotFound) } + err = client.deleteHTTPServer(ctx, upstream, server, id) + return err +} - path := fmt.Sprintf("http/upstreams/%v/servers/%v", upstream, id) - err = client.delete(ctx, path, http.StatusOK) +func (client *NginxClient) deleteHTTPServer(ctx context.Context, upstream, server string, serverID int) error { + path := fmt.Sprintf("http/upstreams/%v/servers/%v", upstream, serverID) + err := client.delete(ctx, path, http.StatusOK) if err != nil { return fmt.Errorf("failed to remove %v server from %v upstream: %w", server, upstream, err) } @@ -809,6 +818,8 @@ func (client *NginxClient) DeleteHTTPServer(ctx context.Context, upstream string // Servers that aren't in the slice, but exist in NGINX, will be removed from NGINX. // Servers that are in the slice and exist in NGINX, but have different parameters, will be updated. // The client will attempt to update all servers, returning all the errors that occurred. +// If there are duplicate servers with equivalent parameters, the duplicates will be ignored. +// If there are duplicate servers with different parameters, those server entries will be ignored and an error returned. func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream string, servers []UpstreamServer) (added []UpstreamServer, deleted []UpstreamServer, updated []UpstreamServer, err error) { serversInNginx, err := client.GetHTTPServers(ctx, upstream) if err != nil { @@ -822,10 +833,12 @@ func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream strin formattedServers = append(formattedServers, server) } + formattedServers, err = deduplicateServers(upstream, formattedServers) + toAdd, toDelete, toUpdate := determineUpdates(formattedServers, serversInNginx) for _, server := range toAdd { - addErr := client.AddHTTPServer(ctx, upstream, server) + addErr := client.addHTTPServer(ctx, upstream, server) if addErr != nil { err = errors.Join(err, addErr) continue @@ -834,7 +847,7 @@ func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream strin } for _, server := range toDelete { - deleteErr := client.DeleteHTTPServer(ctx, upstream, server.Server) + deleteErr := client.deleteHTTPServer(ctx, upstream, server.Server, server.ID) if deleteErr != nil { err = errors.Join(err, deleteErr) continue @@ -858,46 +871,82 @@ func (client *NginxClient) UpdateHTTPServers(ctx context.Context, upstream strin return added, deleted, updated, err } -// haveSameParameters checks if a given server has the same parameters as a server already present in NGINX. Order matters. -func haveSameParameters(newServer UpstreamServer, serverNGX UpstreamServer) bool { - newServer.ID = serverNGX.ID +func deduplicateServers(upstream string, servers []UpstreamServer) ([]UpstreamServer, error) { + type serverCheck struct { + server UpstreamServer + valid bool + } - if serverNGX.MaxConns != nil && newServer.MaxConns == nil { - newServer.MaxConns = &defaultMaxConns + serverMap := make(map[string]*serverCheck, len(servers)) + var err error + for _, server := range servers { + if prev, ok := serverMap[server.Server]; ok { + if !prev.valid { + continue + } + if !server.hasSameParametersAs(prev.server) { + prev.valid = false + err = errors.Join(err, fmt.Errorf( + "failed to update %s server to %s upstream: %w", + server.Server, upstream, ErrParameterMismatch)) + } + continue + } + serverMap[server.Server] = &serverCheck{server, true} } + retServers := make([]UpstreamServer, 0, len(serverMap)) + for _, server := range servers { + if check, ok := serverMap[server.Server]; ok && check.valid { + retServers = append(retServers, server) + delete(serverMap, server.Server) + } + } + return retServers, err +} - if serverNGX.MaxFails != nil && newServer.MaxFails == nil { - newServer.MaxFails = &defaultMaxFails +// hasSameParametersAs checks if a given server has the same parameters. +func (s UpstreamServer) hasSameParametersAs(compareServer UpstreamServer) bool { + s.ID = compareServer.ID + s.applyDefaults() + compareServer.applyDefaults() + return reflect.DeepEqual(s, compareServer) +} + +func (s *UpstreamServer) applyDefaults() { + if s.MaxConns == nil { + s.MaxConns = &defaultMaxConns } - if serverNGX.FailTimeout != "" && newServer.FailTimeout == "" { - newServer.FailTimeout = defaultFailTimeout + if s.MaxFails == nil { + s.MaxFails = &defaultMaxFails } - if serverNGX.SlowStart != "" && newServer.SlowStart == "" { - newServer.SlowStart = defaultSlowStart + if s.FailTimeout == "" { + s.FailTimeout = defaultFailTimeout } - if serverNGX.Backup != nil && newServer.Backup == nil { - newServer.Backup = &defaultBackup + if s.SlowStart == "" { + s.SlowStart = defaultSlowStart } - if serverNGX.Down != nil && newServer.Down == nil { - newServer.Down = &defaultDown + if s.Backup == nil { + s.Backup = &defaultBackup } - if serverNGX.Weight != nil && newServer.Weight == nil { - newServer.Weight = &defaultWeight + if s.Down == nil { + s.Down = &defaultDown } - return reflect.DeepEqual(newServer, serverNGX) + if s.Weight == nil { + s.Weight = &defaultWeight + } } func determineUpdates(updatedServers []UpstreamServer, nginxServers []UpstreamServer) (toAdd []UpstreamServer, toRemove []UpstreamServer, toUpdate []UpstreamServer) { for _, server := range updatedServers { updateFound := false for _, serverNGX := range nginxServers { - if server.Server == serverNGX.Server && !haveSameParameters(server, serverNGX) { + if server.Server == serverNGX.Server && !server.hasSameParametersAs(serverNGX) { server.ID = serverNGX.ID updateFound = true break @@ -1089,9 +1138,13 @@ func (client *NginxClient) AddStreamServer(ctx context.Context, upstream string, if id != -1 { return fmt.Errorf("failed to add %v stream server to %v upstream: %w", server.Server, upstream, ErrServerExists) } + err = client.addStreamServer(ctx, upstream, server) + return err +} +func (client *NginxClient) addStreamServer(ctx context.Context, upstream string, server StreamUpstreamServer) error { path := fmt.Sprintf("stream/upstreams/%v/servers/", upstream) - err = client.post(ctx, path, &server) + err := client.post(ctx, path, &server) if err != nil { return fmt.Errorf("failed to add %v stream server to %v upstream: %w", server.Server, upstream, err) } @@ -1107,9 +1160,13 @@ func (client *NginxClient) DeleteStreamServer(ctx context.Context, upstream stri if id == -1 { return fmt.Errorf("failed to remove %v stream server from %v upstream: %w", server, upstream, ErrServerNotFound) } + err = client.deleteStreamServer(ctx, upstream, server, id) + return err +} - path := fmt.Sprintf("stream/upstreams/%v/servers/%v", upstream, id) - err = client.delete(ctx, path, http.StatusOK) +func (client *NginxClient) deleteStreamServer(ctx context.Context, upstream, server string, serverID int) error { + path := fmt.Sprintf("stream/upstreams/%v/servers/%v", upstream, serverID) + err := client.delete(ctx, path, http.StatusOK) if err != nil { return fmt.Errorf("failed to remove %v stream server from %v upstream: %w", server, upstream, err) } @@ -1121,6 +1178,8 @@ func (client *NginxClient) DeleteStreamServer(ctx context.Context, upstream stri // Servers that aren't in the slice, but exist in NGINX, will be removed from NGINX. // Servers that are in the slice and exist in NGINX, but have different parameters, will be updated. // The client will attempt to update all servers, returning all the errors that occurred. +// If there are duplicate servers with equivalent parameters, the duplicates will be ignored. +// If there are duplicate servers with different parameters, those server entries will be ignored and an error returned. func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream string, servers []StreamUpstreamServer) (added []StreamUpstreamServer, deleted []StreamUpstreamServer, updated []StreamUpstreamServer, err error) { serversInNginx, err := client.GetStreamServers(ctx, upstream) if err != nil { @@ -1133,10 +1192,12 @@ func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream str formattedServers = append(formattedServers, server) } + formattedServers, err = deduplicateStreamServers(upstream, formattedServers) + toAdd, toDelete, toUpdate := determineStreamUpdates(formattedServers, serversInNginx) for _, server := range toAdd { - addErr := client.AddStreamServer(ctx, upstream, server) + addErr := client.addStreamServer(ctx, upstream, server) if addErr != nil { err = errors.Join(err, addErr) continue @@ -1145,7 +1206,7 @@ func (client *NginxClient) UpdateStreamServers(ctx context.Context, upstream str } for _, server := range toDelete { - deleteErr := client.DeleteStreamServer(ctx, upstream, server.Server) + deleteErr := client.deleteStreamServer(ctx, upstream, server.Server, server.ID) if deleteErr != nil { err = errors.Join(err, deleteErr) continue @@ -1184,45 +1245,82 @@ func (client *NginxClient) getIDOfStreamServer(ctx context.Context, upstream str return -1, nil } -// haveSameParametersForStream checks if a given server has the same parameters as a server already present in NGINX. Order matters. -func haveSameParametersForStream(newServer StreamUpstreamServer, serverNGX StreamUpstreamServer) bool { - newServer.ID = serverNGX.ID - if serverNGX.MaxConns != nil && newServer.MaxConns == nil { - newServer.MaxConns = &defaultMaxConns +func deduplicateStreamServers(upstream string, servers []StreamUpstreamServer) ([]StreamUpstreamServer, error) { + type serverCheck struct { + server StreamUpstreamServer + valid bool + } + + serverMap := make(map[string]*serverCheck, len(servers)) + var err error + for _, server := range servers { + if prev, ok := serverMap[server.Server]; ok { + if !prev.valid { + continue + } + if !server.hasSameParametersAs(prev.server) { + prev.valid = false + err = errors.Join(err, fmt.Errorf( + "failed to update stream %s server to %s upstream: %w", + server.Server, upstream, ErrParameterMismatch)) + } + continue + } + serverMap[server.Server] = &serverCheck{server, true} + } + retServers := make([]StreamUpstreamServer, 0, len(serverMap)) + for _, server := range servers { + if check, ok := serverMap[server.Server]; ok && check.valid { + retServers = append(retServers, server) + delete(serverMap, server.Server) + } } + return retServers, err +} + +// hasSameParametersAs checks if a given server has the same parameters. +func (s StreamUpstreamServer) hasSameParametersAs(compareServer StreamUpstreamServer) bool { + s.ID = compareServer.ID + s.applyDefaults() + compareServer.applyDefaults() + return reflect.DeepEqual(s, compareServer) +} - if serverNGX.MaxFails != nil && newServer.MaxFails == nil { - newServer.MaxFails = &defaultMaxFails +func (s *StreamUpstreamServer) applyDefaults() { + if s.MaxConns == nil { + s.MaxConns = &defaultMaxConns } - if serverNGX.FailTimeout != "" && newServer.FailTimeout == "" { - newServer.FailTimeout = defaultFailTimeout + if s.MaxFails == nil { + s.MaxFails = &defaultMaxFails } - if serverNGX.SlowStart != "" && newServer.SlowStart == "" { - newServer.SlowStart = defaultSlowStart + if s.FailTimeout == "" { + s.FailTimeout = defaultFailTimeout } - if serverNGX.Backup != nil && newServer.Backup == nil { - newServer.Backup = &defaultBackup + if s.SlowStart == "" { + s.SlowStart = defaultSlowStart } - if serverNGX.Down != nil && newServer.Down == nil { - newServer.Down = &defaultDown + if s.Backup == nil { + s.Backup = &defaultBackup } - if serverNGX.Weight != nil && newServer.Weight == nil { - newServer.Weight = &defaultWeight + if s.Down == nil { + s.Down = &defaultDown } - return reflect.DeepEqual(newServer, serverNGX) + if s.Weight == nil { + s.Weight = &defaultWeight + } } func determineStreamUpdates(updatedServers []StreamUpstreamServer, nginxServers []StreamUpstreamServer) (toAdd []StreamUpstreamServer, toRemove []StreamUpstreamServer, toUpdate []StreamUpstreamServer) { for _, server := range updatedServers { updateFound := false for _, serverNGX := range nginxServers { - if server.Server == serverNGX.Server && !haveSameParametersForStream(server, serverNGX) { + if server.Server == serverNGX.Server && !server.hasSameParametersAs(serverNGX) { server.ID = serverNGX.ID updateFound = true break @@ -1950,9 +2048,13 @@ func (client *NginxClient) deleteKeyValPairs(ctx context.Context, zone string, s return nil } -// UpdateHTTPServer updates the server of the upstream. +// UpdateHTTPServer updates the server of the upstream with the matching server ID. func (client *NginxClient) UpdateHTTPServer(ctx context.Context, upstream string, server UpstreamServer) error { path := fmt.Sprintf("http/upstreams/%v/servers/%v", upstream, server.ID) + // The server ID is expected in the URI, but not expected in the body. + // The NGINX API will return + // {"error":{"status":400,"text":"unknown parameter \"id\"","code":"UpstreamConfFormatError"} + // if the ID field is present. server.ID = 0 err := client.patch(ctx, path, &server, http.StatusOK) if err != nil { @@ -1962,9 +2064,13 @@ func (client *NginxClient) UpdateHTTPServer(ctx context.Context, upstream string return nil } -// UpdateStreamServer updates the stream server of the upstream. +// UpdateStreamServer updates the stream server of the upstream with the matching server ID. func (client *NginxClient) UpdateStreamServer(ctx context.Context, upstream string, server StreamUpstreamServer) error { path := fmt.Sprintf("stream/upstreams/%v/servers/%v", upstream, server.ID) + // The server ID is expected in the URI, but not expected in the body. + // The NGINX API will return + // {"error":{"status":400,"text":"unknown parameter \"id\"","code":"UpstreamConfFormatError"} + // if the ID field is present. server.ID = 0 err := client.patch(ctx, path, &server, http.StatusOK) if err != nil { diff --git a/client/nginx_test.go b/client/nginx_test.go index 920e9c72..46467477 100644 --- a/client/nginx_test.go +++ b/client/nginx_test.go @@ -477,9 +477,9 @@ func TestHaveSameParameters(t *testing.T) { for _, test := range tests { t.Run(test.msg, func(t *testing.T) { t.Parallel() - result := haveSameParameters(test.server, test.serverNGX) + result := test.server.hasSameParametersAs(test.serverNGX) if result != test.expected { - t.Errorf("haveSameParameters(%v, %v) returned %v but expected %v", test.server, test.serverNGX, result, test.expected) + t.Errorf("(%v) hasSameParametersAs (%v) returned %v but expected %v", test.server, test.serverNGX, result, test.expected) } }) } @@ -562,9 +562,9 @@ func TestHaveSameParametersForStream(t *testing.T) { for _, test := range tests { t.Run(test.msg, func(t *testing.T) { t.Parallel() - result := haveSameParametersForStream(test.server, test.serverNGX) + result := test.server.hasSameParametersAs(test.serverNGX) if result != test.expected { - t.Errorf("haveSameParametersForStream(%v, %v) returned %v but expected %v", test.server, test.serverNGX, result, test.expected) + t.Errorf("(%v) hasSameParametersAs (%v) returned %v but expected %v", test.server, test.serverNGX, result, test.expected) } }) } @@ -982,174 +982,464 @@ func TestExtractPlusVersionNegativeCase(t *testing.T) { } } -func TestClientHTTPUpdateServers(t *testing.T) { +func TestUpdateHTTPServers(t *testing.T) { t.Parallel() - responses := []response{ - // response for first serversInNginx GET servers - { - statusCode: http.StatusOK, - servers: []UpstreamServer{}, + testcases := map[string]struct { + reqServers []UpstreamServer + responses []response + expAdded, expDeleted, expUpdated int + expErr bool + }{ + "successfully add 1 server": { + reqServers: []UpstreamServer{{Server: "127.0.0.1:80"}}, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + }, + // response for addHTTPServer POST server for http server + { + statusCode: http.StatusCreated, + }, + }, + expAdded: 1, }, - // response for AddHTTPServer GET servers for http server - { - statusCode: http.StatusOK, - servers: []UpstreamServer{}, + "successfully update 1 server": { + reqServers: []UpstreamServer{{Server: "127.0.0.1:80"}}, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{ + {ID: 1, Server: "127.0.0.1:80", Route: "/test"}, + }, + }, + // response for UpdateHTTPServer PATCH server for http server + { + statusCode: http.StatusOK, + }, + }, + expUpdated: 1, }, - // response for AddHTTPServer POST server for http server - { - statusCode: http.StatusInternalServerError, - servers: []UpstreamServer{}, + "successfully delete 1 server": { + reqServers: []UpstreamServer{{Server: "127.0.0.1:80"}}, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{ + {ID: 1, Server: "127.0.0.1:80"}, + {ID: 2, Server: "127.0.0.2:80"}, + }, + }, + // response for deleteHTTPServer DELETE server for http server + { + statusCode: http.StatusOK, + }, + }, + expDeleted: 1, }, - // response for AddHTTPServer GET servers for https server - { - statusCode: http.StatusOK, - servers: []UpstreamServer{}, + "successfully add 1 server, update 1 server, delete 1 server": { + reqServers: []UpstreamServer{ + {Server: "127.0.0.1:80", Route: "/test"}, + {Server: "127.0.0.2:80"}, + }, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{ + {ID: 1, Server: "127.0.0.1:80"}, + {ID: 2, Server: "127.0.0.3:80"}, + }, + }, + // response for addHTTPServer POST server for http server + { + statusCode: http.StatusCreated, + }, + // response for deleteHTTPServer DELETE server for http server + { + statusCode: http.StatusOK, + }, + // response for UpdateHTTPServer PATCH server for http server + { + statusCode: http.StatusOK, + }, + }, + expAdded: 1, + expUpdated: 1, + expDeleted: 1, }, - // response for AddHTTPServer POST server for https server - { - statusCode: http.StatusCreated, - servers: []UpstreamServer{}, + "successfully add 1 server with ignored identical duplicate": { + reqServers: []UpstreamServer{ + {Server: "127.0.0.1:80", Route: "/test"}, + {Server: "127.0.0.1", Route: "/test"}, + {Server: "127.0.0.1:80", Route: "/test", MaxConns: &defaultMaxConns}, + {Server: "127.0.0.1:80", Route: "/test", Backup: &defaultBackup}, + {Server: "127.0.0.1", Route: "/test", SlowStart: defaultSlowStart}, + }, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for addHTTPServer POST server for http server + { + statusCode: http.StatusCreated, + }, + }, + expAdded: 1, + }, + "successfully add 1 server, receive 1 error for non-identical duplicates": { + reqServers: []UpstreamServer{ + {Server: "127.0.0.1:80", Route: "/test"}, + {Server: "127.0.0.1:80", Route: "/test"}, + {Server: "127.0.0.2:80", Route: "/test1"}, + {Server: "127.0.0.2:80", Route: "/test2"}, + {Server: "127.0.0.2:80", Route: "/test3"}, + }, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for addHTTPServer POST server for http server + { + statusCode: http.StatusCreated, + }, + }, + expAdded: 1, + expErr: true, + }, + "successfully add 1 server, receive 1 error": { + reqServers: []UpstreamServer{ + {Server: "127.0.0.1:80"}, + {Server: "127.0.0.1:443"}, + }, + responses: []response{ // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for addHTTPServer POST server for server1 + { + statusCode: http.StatusInternalServerError, + servers: []UpstreamServer{}, + }, + // response for addHTTPServer POST server for server2 + { + statusCode: http.StatusCreated, + servers: []UpstreamServer{}, + }, + }, + expAdded: 1, + expErr: true, }, } - handler := &fakeHandler{ - func(w http.ResponseWriter, _ *http.Request) { - if len(responses) == 0 { - t.Fatal("ran out of responses") - } + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + t.Parallel() - re := responses[0] - responses = responses[1:] + var requests []*http.Request + handler := &fakeHandler{ + func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r) - w.WriteHeader(re.statusCode) + if len(tc.responses) == 0 { + t.Fatal("ran out of responses") + } + if r.Method == http.MethodPost || r.Method == http.MethodPut { + contentType, ok := r.Header["Content-Type"] + if !ok { + t.Fatalf("expected request type %s to have a Content-Type header", r.Method) + } + if len(contentType) != 1 || contentType[0] != "application/json" { + t.Fatalf("expected request type %s to have a Content-Type header value of 'application/json'", r.Method) + } + } - resp, err := json.Marshal(re.servers) - if err != nil { - t.Fatal(err) - } - _, err = w.Write(resp) - if err != nil { - t.Fatal(err) - } - }, - } + re := tc.responses[0] + tc.responses = tc.responses[1:] - server := httptest.NewServer(handler) - defer server.Close() + w.WriteHeader(re.statusCode) - client, err := NewNginxClient(server.URL, WithHTTPClient(&http.Client{})) - if err != nil { - t.Fatal(err) - } + resp, err := json.Marshal(re.servers) + if err != nil { + t.Fatal(err) + } + _, err = w.Write(resp) + if err != nil { + t.Fatal(err) + } + }, + } - httpServer := UpstreamServer{Server: "127.0.0.1:80"} - httpsServer := UpstreamServer{Server: "127.0.0.1:443"} + server := httptest.NewServer(handler) + defer server.Close() - // we expect that we will get an error for the 500 error encountered when putting the http server - // but we also expect that we have the https server added - added, _, _, err := client.UpdateHTTPServers(context.TODO(), "fakeUpstream", []UpstreamServer{ - httpServer, - httpsServer, - }) - if err == nil { - t.Fatal("expected to receive an error for 500 response when adding first server") - } + client, err := NewNginxClient(server.URL, WithHTTPClient(&http.Client{})) + if err != nil { + t.Fatal(err) + } - if len(added) != 1 { - t.Fatalf("expected to get one added server, instead got %d", len(added)) - } + added, deleted, updated, err := client.UpdateHTTPServers(context.Background(), "fakeUpstream", tc.reqServers) + if tc.expErr && err == nil { + t.Fatal("expected to receive an error") + } + if !tc.expErr && err != nil { + t.Fatalf("received an unexpected error: %v", err) + } - if !reflect.DeepEqual(httpsServer, added[0]) { - t.Errorf("expected: %v got: %v", httpsServer, added[0]) + if len(added) != tc.expAdded { + t.Fatalf("expected to get %d added server(s), instead got %d", tc.expAdded, len(added)) + } + if len(deleted) != tc.expDeleted { + t.Fatalf("expected to get %d deleted server(s), instead got %d", tc.expDeleted, len(deleted)) + } + if len(updated) != tc.expUpdated { + t.Fatalf("expected to get %d updated server(s), instead got %d", tc.expUpdated, len(updated)) + } + if len(tc.responses) != 0 { + t.Fatalf("did not use all expected responses, %d unused", len(tc.responses)) + } + }) } } -func TestClientStreamUpdateServers(t *testing.T) { +func TestUpdateStreamServers(t *testing.T) { t.Parallel() - responses := []response{ - // response for first serversInNginx GET servers - { - statusCode: http.StatusOK, - servers: []UpstreamServer{}, + testcases := map[string]struct { + reqServers []StreamUpstreamServer + responses []response + expAdded, expDeleted, expUpdated int + expErr bool + }{ + "successfully add 1 server": { + reqServers: []StreamUpstreamServer{{Server: "127.0.0.1:80"}}, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + }, + // response for addStreamServer POST server for stream server + { + statusCode: http.StatusCreated, + }, + }, + expAdded: 1, }, - // response for AddStreamServer GET servers for streamServer1 - { - statusCode: http.StatusOK, - servers: []UpstreamServer{}, + "successfully update 1 server": { + reqServers: []StreamUpstreamServer{{Server: "127.0.0.1:80"}}, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []StreamUpstreamServer{ + {ID: 1, Server: "127.0.0.1:80", SlowStart: "30s"}, + }, + }, + // response for UpdateStreamServer PATCH server for stream server + { + statusCode: http.StatusOK, + }, + }, + expUpdated: 1, }, - // response for AddStreamServer POST server for streamServer1 - { - statusCode: http.StatusInternalServerError, - servers: []UpstreamServer{}, + "successfully delete 1 server": { + reqServers: []StreamUpstreamServer{{Server: "127.0.0.1:80"}}, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []StreamUpstreamServer{ + {ID: 1, Server: "127.0.0.1:80"}, + {ID: 2, Server: "127.0.0.2:80"}, + }, + }, + // response for deleteStreamServer DELETE server for stream server + { + statusCode: http.StatusOK, + }, + }, + expDeleted: 1, }, - // response for AddStreamServer GET servers for streamServer2 - { - statusCode: http.StatusOK, - servers: []UpstreamServer{}, + "successfully add 1 server, update 1 server, delete 1 server": { + reqServers: []StreamUpstreamServer{ + {Server: "127.0.0.1:80", SlowStart: "30s"}, + {Server: "127.0.0.2:80"}, + }, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []StreamUpstreamServer{ + {ID: 1, Server: "127.0.0.1:80"}, + {ID: 2, Server: "127.0.0.3:80"}, + }, + }, + // response for addStreamServer POST server for stream server + { + statusCode: http.StatusCreated, + }, + // response for deleteStreamServer DELETE server for stream server + { + statusCode: http.StatusOK, + }, + // response for UpdateStreamServer PATCH server for stream server + { + statusCode: http.StatusOK, + }, + }, + expAdded: 1, + expUpdated: 1, + expDeleted: 1, }, - // response for AddStreamServer POST server for streamServer2 - { - statusCode: http.StatusCreated, - servers: []UpstreamServer{}, + "successfully add 1 server with ignored identical duplicate": { + reqServers: []StreamUpstreamServer{ + {Server: "127.0.0.1:80", SlowStart: "30s"}, + {Server: "127.0.0.1", SlowStart: "30s"}, + {Server: "127.0.0.1:80", SlowStart: "30s", MaxConns: &defaultMaxConns}, + {Server: "127.0.0.1", SlowStart: "30s", MaxFails: &defaultMaxFails}, + {Server: "127.0.0.1", SlowStart: "30s", FailTimeout: defaultFailTimeout}, + }, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for addStreamServer POST server for stream server + { + statusCode: http.StatusCreated, + }, + }, + expAdded: 1, + }, + "successfully add 1 server, receive 1 error for non-identical duplicates": { + reqServers: []StreamUpstreamServer{ + {Server: "127.0.0.1:80", SlowStart: "30s"}, + {Server: "127.0.0.1:80", SlowStart: "30s"}, + {Server: "127.0.0.2:80", SlowStart: "10s"}, + {Server: "127.0.0.2:80", SlowStart: "20s"}, + {Server: "127.0.0.2:80", SlowStart: "30s"}, + }, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for addStreamServer POST server for stream server + { + statusCode: http.StatusCreated, + }, + }, + expAdded: 1, + expErr: true, + }, + "successfully add 1 server, receive 1 error": { + reqServers: []StreamUpstreamServer{ + {Server: "127.0.0.1:2000"}, + {Server: "127.0.0.1:3000"}, + }, + responses: []response{ + // response for first serversInNginx GET servers + { + statusCode: http.StatusOK, + servers: []UpstreamServer{}, + }, + // response for addStreamServer POST server for server1 + { + statusCode: http.StatusInternalServerError, + servers: []UpstreamServer{}, + }, + // response for addStreamServer POST server for server2 + { + statusCode: http.StatusCreated, + servers: []UpstreamServer{}, + }, + }, + expAdded: 1, + expErr: true, }, } - handler := &fakeHandler{ - func(w http.ResponseWriter, _ *http.Request) { - if len(responses) == 0 { - t.Fatal("ran out of responses") - } + for name, tc := range testcases { + t.Run(name, func(t *testing.T) { + t.Parallel() - re := responses[0] - responses = responses[1:] + var requests []*http.Request + handler := &fakeHandler{ + func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r) - w.WriteHeader(re.statusCode) + if len(tc.responses) == 0 { + t.Fatal("ran out of responses") + } + if r.Method == http.MethodPost || r.Method == http.MethodPut { + contentType, ok := r.Header["Content-Type"] + if !ok { + t.Fatalf("expected request type %s to have a Content-Type header", r.Method) + } + if len(contentType) != 1 || contentType[0] != "application/json" { + t.Fatalf("expected request type %s to have a Content-Type header value of 'application/json'", r.Method) + } + } - resp, err := json.Marshal(re.servers) - if err != nil { - t.Fatal(err) - } - _, err = w.Write(resp) - if err != nil { - t.Fatal(err) - } - }, - } + re := tc.responses[0] + tc.responses = tc.responses[1:] - server := httptest.NewServer(handler) - defer server.Close() + w.WriteHeader(re.statusCode) - client, err := NewNginxClient(server.URL, WithHTTPClient(&http.Client{})) - if err != nil { - t.Fatal(err) - } - - streamServer1 := StreamUpstreamServer{Server: "127.0.0.1:2000"} - streamServer2 := StreamUpstreamServer{Server: "127.0.0.1:3000"} + resp, err := json.Marshal(re.servers) + if err != nil { + t.Fatal(err) + } + _, err = w.Write(resp) + if err != nil { + t.Fatal(err) + } + }, + } - // we expect that we will get an error for the 500 error encountered when putting server1 - // but we also expect that we get the second server added - added, _, _, err := client.UpdateStreamServers(context.TODO(), "fakeUpstream", []StreamUpstreamServer{ - streamServer1, - streamServer2, - }) - if err == nil { - t.Fatal("expected to receive an error for 500 response when adding first server") - } + server := httptest.NewServer(handler) + defer server.Close() - if len(added) != 1 { - t.Fatalf("expected to get one added server, instead got %d", len(added)) - } + client, err := NewNginxClient(server.URL, WithHTTPClient(&http.Client{})) + if err != nil { + t.Fatal(err) + } - if !reflect.DeepEqual(streamServer2, added[0]) { - t.Errorf("expected: %v got: %v", streamServer2, added[0]) + added, deleted, updated, err := client.UpdateStreamServers(context.Background(), "fakeUpstream", tc.reqServers) + if tc.expErr && err == nil { + t.Fatal("expected to receive an error") + } + if !tc.expErr && err != nil { + t.Fatalf("received an unexpected error: %v", err) + } + if len(added) != tc.expAdded { + t.Fatalf("expected to get %d added server(s), instead got %d", tc.expAdded, len(added)) + } + if len(deleted) != tc.expDeleted { + t.Fatalf("expected to get %d deleted server(s), instead got %d", tc.expDeleted, len(deleted)) + } + if len(updated) != tc.expUpdated { + t.Fatalf("expected to get %d updated server(s), instead got %d", tc.expUpdated, len(updated)) + } + if len(tc.responses) != 0 { + t.Fatalf("did not use all expected responses, %d unused", len(tc.responses)) + } + }) } } type response struct { - servers []UpstreamServer + servers interface{} statusCode int } diff --git a/tests/client_test.go b/tests/client_test.go index 4f1fe6fa..cdc185a1 100644 --- a/tests/client_test.go +++ b/tests/client_test.go @@ -61,8 +61,39 @@ func TestStreamClient(t *testing.T) { t.Errorf("Adding a duplicated server succeeded") } - // test deleting a stream server + // test updating a stream server + streamServers, err := c.GetStreamServers(ctx, streamUpstream) + if err != nil { + t.Errorf("Error getting stream servers: %v", err) + } + if len(streamServers) != 1 { + t.Errorf("Expected 1 servers, got %v", streamServers) + } + + streamServers[0].SlowStart = "30s" + err = c.UpdateStreamServer(ctx, streamUpstream, streamServers[0]) + if err != nil { + t.Errorf("Error when updating a server: %v", err) + } + streamServers, err = c.GetStreamServers(ctx, streamUpstream) + if err != nil { + t.Errorf("Error getting stream servers: %v", err) + } + if len(streamServers) != 1 { + t.Errorf("Expected 1 servers, got %v", streamServers) + } + if streamServers[0].SlowStart != "30s" { + t.Errorf("The server wasn't successfully updated: expected a 'SlowStart' of 30s, actual was %s", streamServers[0].SlowStart) + } + + streamServers[0].ID++ + err = c.UpdateStreamServer(ctx, streamUpstream, streamServers[0]) + if err == nil { + t.Errorf("Updating a server without a matching server ID succeeded") + } + + // test deleting a stream server err = c.DeleteStreamServer(ctx, streamUpstream, streamServer.Server) if err != nil { t.Fatalf("Error when deleting a server: %v", err) @@ -73,7 +104,7 @@ func TestStreamClient(t *testing.T) { t.Errorf("Deleting a nonexisting server succeeded") } - streamServers, err := c.GetStreamServers(ctx, streamUpstream) + streamServers, err = c.GetStreamServers(ctx, streamUpstream) if err != nil { t.Errorf("Error getting stream servers: %v", err) } @@ -340,8 +371,39 @@ func TestClient(t *testing.T) { t.Errorf("Adding a duplicated server succeeded") } - // test deleting a http server + // test updating an http server + servers, err := c.GetHTTPServers(ctx, upstream) + if err != nil { + t.Errorf("Error getting servers: %v", err) + } + if len(servers) != 1 { + t.Errorf("Expected 1 servers, got %v", servers) + } + + servers[0].SlowStart = "30s" + err = c.UpdateHTTPServer(ctx, upstream, servers[0]) + if err != nil { + t.Errorf("Error when updating a server: %v", err) + } + servers, err = c.GetHTTPServers(ctx, upstream) + if err != nil { + t.Errorf("Error getting servers: %v", err) + } + if len(servers) != 1 { + t.Errorf("Expected 1 servers, got %v", servers) + } + if servers[0].SlowStart != "30s" { + t.Errorf("The server wasn't successfully updated: expected a 'SlowStart' of 30s, actual was %s", servers[0].SlowStart) + } + + servers[0].ID++ + err = c.UpdateHTTPServer(ctx, upstream, servers[0]) + if err == nil { + t.Errorf("Updating a server without a matching server ID succeeded") + } + + // test deleting a http server err = c.DeleteHTTPServer(ctx, upstream, server.Server) if err != nil { t.Fatalf("Error when deleting a server: %v", err) @@ -381,7 +443,7 @@ func TestClient(t *testing.T) { // test getting servers - servers, err := c.GetHTTPServers(ctx, upstream) + servers, err = c.GetHTTPServers(ctx, upstream) if err != nil { t.Fatalf("Error when getting servers: %v", err) }