diff --git a/go.mod b/go.mod index 027ddc8f0..39c1b710e 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ toolchain go1.24.2 require ( filippo.io/edwards25519 v1.1.0 github.com/BurntSushi/toml v1.3.2 - github.com/Masterminds/semver v1.5.0 github.com/go-sql-driver/mysql v1.7.1 github.com/goccy/go-json v0.10.2 github.com/google/uuid v1.3.0 diff --git a/go.sum b/go.sum index 78f0c521d..1975f684f 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,6 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= -github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= -github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/mysql/util.go b/mysql/util.go index f8e5813fb..2e426262f 100644 --- a/mysql/util.go +++ b/mysql/util.go @@ -2,6 +2,7 @@ package mysql import ( "bytes" + "cmp" "compress/zlib" "crypto/rand" "crypto/rsa" @@ -13,11 +14,11 @@ import ( "io" mrand "math/rand" "runtime" + "strconv" "strings" "time" "filippo.io/edwards25519" - "github.com/Masterminds/semver" "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" ) @@ -452,21 +453,46 @@ func ErrorEqual(err1, err2 error) bool { return e1.Error() == e2.Error() } +func compareSubVersion(typ, a, b, aFull, bFull string) (int, error) { + if a == "" || b == "" { + return 0, nil + } + + var aNum, bNum int + var err error + + if aNum, err = strconv.Atoi(a); err != nil { + return 0, fmt.Errorf("cannot parse %s version %s of %s", typ, a, aFull) + } + if bNum, err = strconv.Atoi(b); err != nil { + return 0, fmt.Errorf("cannot parse %s version %s of %s", typ, b, bFull) + } + + return cmp.Compare(aNum, bNum), nil +} + +// Compares version triplet strings, ignoring anything past `-` in version. +// A version string like 8.0 will compare as if third triplet were a wildcard. +// A version string like 8 will compare as if second & third triplets were wildcards. func CompareServerVersions(a, b string) (int, error) { - var ( - aVer, bVer *semver.Version - err error - ) + aNumbers, _, _ := strings.Cut(a, "-") + bNumbers, _, _ := strings.Cut(b, "-") - if aVer, err = semver.NewVersion(a); err != nil { - return 0, fmt.Errorf("cannot parse %q as semver: %w", a, err) + aMajor, aRest, _ := strings.Cut(aNumbers, ".") + bMajor, bRest, _ := strings.Cut(bNumbers, ".") + + if majorCompare, err := compareSubVersion("major", aMajor, bMajor, a, b); err != nil || majorCompare != 0 { + return majorCompare, err } - if bVer, err = semver.NewVersion(b); err != nil { - return 0, fmt.Errorf("cannot parse %q as semver: %w", b, err) + aMinor, aPatch, _ := strings.Cut(aRest, ".") + bMinor, bPatch, _ := strings.Cut(bRest, ".") + + if minorCompare, err := compareSubVersion("minor", aMinor, bMinor, a, b); err != nil || minorCompare != 0 { + return minorCompare, err } - return aVer.Compare(bVer), nil + return compareSubVersion("patch", aPatch, bPatch, a, b) } var encodeRef = map[byte]byte{ diff --git a/mysql/util_test.go b/mysql/util_test.go index 175e907e4..c068c8991 100644 --- a/mysql/util_test.go +++ b/mysql/util_test.go @@ -15,13 +15,29 @@ func TestCompareServerVersions(t *testing.T) { }{ {A: "1.2.3", B: "1.2.3", Expect: 0}, {A: "5.6-999", B: "8.0", Expect: -1}, + {A: "5.6.3-999", B: "5.6", Expect: 0}, + {A: "5.6.3-999", B: "5.5-tag", Expect: 1}, {A: "8.0.32-0ubuntu0.20.04.2", B: "8.0.28", Expect: 1}, + {A: "a.b.c", B: "8.0", Expect: 2}, } for _, test := range tests { got, err := CompareServerVersions(test.A, test.B) - require.NoError(t, err) - require.Equal(t, test.Expect, got) + if test.Expect == 2 { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, test.Expect, got) + } + + // test logic is commutative + got, err = CompareServerVersions(test.B, test.A) + if test.Expect == 2 { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, -test.Expect, got) + } } }