diff --git a/internal/state/resolver/resolver.go b/internal/state/resolver/resolver.go index bd1d733ff3..a862c17c49 100644 --- a/internal/state/resolver/resolver.go +++ b/internal/state/resolver/resolver.go @@ -59,13 +59,38 @@ func (e *ServiceResolverImpl) Resolve(ctx context.Context, svc *v1.Service, port return nil, fmt.Errorf("no endpoints found for Service %s", client.ObjectKeyFromObject(svc)) } - return resolveEndpoints(svc, port, endpointSliceList) + return resolveEndpoints(svc, port, endpointSliceList, initEndpointSetWithCalculatedSize) +} + +type initEndpointSetFunc func([]discoveryV1.EndpointSlice) map[Endpoint]struct{} + +func initEndpointSetWithCalculatedSize(endpointSlices []discoveryV1.EndpointSlice) map[Endpoint]struct{} { + // performance optimization to reduce the cost of growing the map. See the benchamarks for performance comparison. + return make(map[Endpoint]struct{}, calculateReadyEndpoints(endpointSlices)) +} + +func calculateReadyEndpoints(endpointSlices []discoveryV1.EndpointSlice) int { + total := 0 + + for _, eps := range endpointSlices { + for _, endpoint := range eps.Endpoints { + + if !endpointReady(endpoint) { + continue + } + + total += len(endpoint.Addresses) + } + } + + return total } func resolveEndpoints( svc *v1.Service, port int32, endpointSliceList discoveryV1.EndpointSliceList, + initEndpointsSet initEndpointSetFunc, ) ([]Endpoint, error) { svcPort, err := getServicePort(svc, port) if err != nil { @@ -81,7 +106,7 @@ func resolveEndpoints( // Endpoints may be duplicated across multiple EndpointSlices. // Using a set to prevent returning duplicate endpoints. - endpointSet := make(map[Endpoint]struct{}) + endpointSet := initEndpointsSet(filteredSlices) for _, eps := range filteredSlices { for _, endpoint := range eps.Endpoints { diff --git a/internal/state/resolver/resolver_test.go b/internal/state/resolver/resolver_test.go index 3c0c3e1a91..d6844a0a5e 100644 --- a/internal/state/resolver/resolver_test.go +++ b/internal/state/resolver/resolver_test.go @@ -1,9 +1,11 @@ package resolver import ( + "fmt" "testing" "github.com/google/go-cmp/cmp" + . "github.com/onsi/gomega" v1 "k8s.io/api/core/v1" discoveryV1 "k8s.io/api/discovery/v1" "k8s.io/apimachinery/pkg/util/intstr" @@ -470,3 +472,134 @@ func TestFindPort(t *testing.T) { } } } + +func TestCalculateReadyEndpoints(t *testing.T) { + g := NewGomegaWithT(t) + + slices := []discoveryV1.EndpointSlice{ + { + Endpoints: []discoveryV1.Endpoint{ + { + Addresses: []string{"1.0.0.1"}, + Conditions: discoveryV1.EndpointConditions{ + Ready: helpers.GetBoolPointer(true), + }, + }, + { + Addresses: []string{"1.1.0.1", "1.1.0.2", "1.1.0.3, 1.1.0.4, 1.1.0.5"}, + Conditions: discoveryV1.EndpointConditions{ + // nil conditions should be treated as not ready + }, + }, + }, + }, + { + Endpoints: []discoveryV1.Endpoint{ + { + Addresses: []string{"2.0.0.1", "2.0.0.2", "2.0.0.3"}, + Conditions: discoveryV1.EndpointConditions{ + Ready: helpers.GetBoolPointer(true), + }, + }, + }, + }, + } + + result := calculateReadyEndpoints(slices) + + g.Expect(result).To(Equal(4)) +} + +func generateEndpointSliceList(n int) discoveryV1.EndpointSliceList { + const maxEndpointsPerSlice = 100 // use the Kubernetes default max for endpoints in a slice. + + slicesCount := (n + maxEndpointsPerSlice - 1) / maxEndpointsPerSlice + + result := discoveryV1.EndpointSliceList{ + Items: make([]discoveryV1.EndpointSlice, 0, slicesCount), + } + + ready := true + + for i := 0; n > 0; i++ { + c := maxEndpointsPerSlice + if n < maxEndpointsPerSlice { + c = n + } + n -= maxEndpointsPerSlice + + slice := discoveryV1.EndpointSlice{ + Endpoints: make([]discoveryV1.Endpoint, c), + AddressType: discoveryV1.AddressTypeIPv4, + Ports: []discoveryV1.EndpointPort{ + { + Port: nil, // will match any port in the service + }, + }, + } + + for j := 0; j < c; j++ { + slice.Endpoints[j] = discoveryV1.Endpoint{ + Addresses: []string{fmt.Sprintf("10.0.%d.%d", i, j)}, + Conditions: discoveryV1.EndpointConditions{ + Ready: &ready, + }, + } + } + + result.Items = append(result.Items, slice) + } + + return result +} + +func BenchmarkResolve(b *testing.B) { + counts := []int{ + 1, + 2, + 5, + 10, + 25, + 50, + 100, + 500, + 1000, + } + + svc := &v1.Service{ + Spec: v1.ServiceSpec{ + Ports: []v1.ServicePort{ + { + Port: 80, + }, + }, + }, + } + + initEndpointSet := func([]discoveryV1.EndpointSlice) map[Endpoint]struct{} { + return make(map[Endpoint]struct{}) + } + + for _, count := range counts { + list := generateEndpointSliceList(count) + + b.Run(fmt.Sprintf("%d endpoints", count), func(b *testing.B) { + bench(b, svc, list, initEndpointSet, count) + }) + b.Run(fmt.Sprintf("%d endpoints with optimization", count), func(b *testing.B) { + bench(b, svc, list, initEndpointSetWithCalculatedSize, count) + }) + } +} + +func bench(b *testing.B, svc *v1.Service, list discoveryV1.EndpointSliceList, initSet initEndpointSetFunc, n int) { + for i := 0; i < b.N; i++ { + res, err := resolveEndpoints(svc, 80, list, initSet) + if len(res) != n { + b.Fatalf("expected %d endpoints, got %d", n, len(res)) + } + if err != nil { + b.Fatal(err) + } + } +}