diff --git a/pkg/cluster/instance.go b/pkg/cluster/instance.go index f3c23b446..16e448db6 100644 --- a/pkg/cluster/instance.go +++ b/pkg/cluster/instance.go @@ -129,16 +129,25 @@ func (i *Instance) PodName() string { // WhitelistCIDR returns the CIDR range to whitelist for GR based on the Pod's IP. func (i *Instance) WhitelistCIDR() (string, error) { - switch i.IP.To4()[0] { - case 10: - return "10.0.0.0/8", nil - case 172: - return "172.16.0.0/12", nil - case 192: - return "192.168.0.0/16", nil - default: - return "", errors.Errorf("pod IP %q is not a private IPv4 address", i.IP.String()) + var privateRanges []*net.IPNet + + for _, addrRange := range []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "100.64.0.0/10", // IPv4 shared address space (RFC 6598), improperly used by kops + } { + _, block, _ := net.ParseCIDR(addrRange) + privateRanges = append(privateRanges, block) } + + for _, block := range privateRanges { + if block.Contains(i.IP) { + return block.String(), nil + } + } + + return "", errors.Errorf("pod IP %q is not a private IPv4 address", i.IP.String()) } // statefulPodRegex is a regular expression that extracts the parent StatefulSet diff --git a/pkg/cluster/instance_test.go b/pkg/cluster/instance_test.go index 342d95812..a50d25181 100644 --- a/pkg/cluster/instance_test.go +++ b/pkg/cluster/instance_test.go @@ -15,6 +15,7 @@ package cluster import ( + "net" "testing" "github.com/stretchr/testify/assert" @@ -75,3 +76,29 @@ func TestGetPodName(t *testing.T) { }) } } + +func TestWhitelistCIDR(t *testing.T) { + testCases := []struct { + ip string + expected string + }{ + {ip: "192.168.0.1", expected: "192.168.0.0/16"}, + {ip: "192.167.0.1", expected: ""}, + {ip: "10.1.1.1", expected: "10.0.0.0/8"}, + {ip: "172.15.0.1", expected: ""}, + {ip: "172.16.0.1", expected: "172.16.0.0/12"}, + {ip: "172.17.0.1", expected: "172.16.0.0/12"}, + {ip: "100.64.0.1", expected: "100.64.0.0/10"}, + {ip: "100.63.0.1", expected: ""}, + {ip: "1.2.3.4", expected: ""}, + } + + for _, tt := range testCases { + i := Instance{IP: net.ParseIP(tt.ip)} + + cidr, _ := i.WhitelistCIDR() + if cidr != tt.expected { + t.Errorf("ip: %v, cidr: %v, expected: %v", tt.ip, cidr, tt.expected) + } + } +}