Skip to content

Commit 0697847

Browse files
committed
fix: don't include generic types in variants
1 parent 803c965 commit 0697847

File tree

6 files changed

+53
-1
lines changed

6 files changed

+53
-1
lines changed

bin/hermit.hcl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
env = {
2+
"PATH": "${HERMIT_ENV}/scripts:${PATH}",
3+
}

cmd/go-check-sumtype/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"flag"
45
"log"
56
"os"
67
"strings"
@@ -11,7 +12,8 @@ import (
1112

1213
func main() {
1314
log.SetFlags(0)
14-
if len(os.Args) < 2 {
15+
flag.Parse()
16+
if len(flag.Args()) < 1 {
1517
log.Fatalf("Usage: sumtype <packages>\n")
1618
}
1719
args := os.Args[1:]

decl.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ func findSumTypeDecls(pkgs []*packages.Package) ([]sumTypeDecl, error) {
5757
}
5858
pos = pkg.Fset.Position(tspec.Pos())
5959
decl := sumTypeDecl{Package: pkg, TypeName: tspec.Name.Name, Pos: pos}
60+
debugf("found sum type decl: %s.%s", decl.Package.PkgPath, decl.TypeName)
6061
decls = append(decls, decl)
6162
break
6263
}

def.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
package gochecksumtype
22

33
import (
4+
"flag"
45
"fmt"
56
"go/token"
67
"go/types"
8+
"log"
79
)
810

11+
var debug = flag.Bool("debug", false, "enable debug logging")
12+
13+
func debugf(format string, args ...interface{}) {
14+
if *debug {
15+
log.Printf(format, args...)
16+
}
17+
}
18+
919
// Error as returned by Run()
1020
type Error interface {
1121
error
@@ -107,6 +117,7 @@ func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) {
107117
Decl: decl,
108118
Ty: iface,
109119
}
120+
debugf("searching for variants of %s.%s\n", pkg.Path(), decl.TypeName)
110121
for _, name := range pkg.Scope().Names() {
111122
obj, ok := pkg.Scope().Lookup(name).(*types.TypeName)
112123
if !ok {
@@ -116,7 +127,12 @@ func newSumTypeDef(pkg *types.Package, decl sumTypeDecl) (*sumTypeDef, error) {
116127
if types.Identical(ty.Underlying(), iface) {
117128
continue
118129
}
130+
// Skip generic types.
131+
if named, ok := ty.(*types.Named); ok && named.TypeParams() != nil {
132+
continue
133+
}
119134
if types.Implements(ty, iface) || types.Implements(types.NewPointer(ty), iface) {
135+
debugf(" found variant: %s.%s\n", pkg.Path(), obj.Name())
120136
def.Variants = append(def.Variants, obj)
121137
}
122138
}

scripts/go-check-sumtype

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/bash
2+
set -euo pipefail
3+
basedir="$(dirname "$0")/.."
4+
name="$(basename "$0")"
5+
dest="${basedir}/build/devel"
6+
mkdir -p "$dest"
7+
(cd "${basedir}" && ./bin/go build -ldflags="-s -w -buildid=" -o "$dest/${name}" "./cmd/${name}") && exec "$dest/${name}" "$@"

testdata/sum.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package testdata
2+
3+
//sumtype:decl
4+
type Sum interface{ sum() }
5+
6+
type A struct{}
7+
8+
func (A) sum() {}
9+
10+
type B struct{}
11+
12+
func (B) sum() {}
13+
14+
type C[T any] struct{}
15+
16+
func (C[T]) sum() {}
17+
18+
func SumSwitch(x Sum) {
19+
switch x.(type) {
20+
case A:
21+
case B:
22+
}
23+
}

0 commit comments

Comments
 (0)