Skip to content

feat(plugin): Use gRPC interface for codegen plugin communication #2930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion buf.gen.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ plugins:
- plugin: buf.build/protocolbuffers/go:v1.30.0
out: internal
opt: paths=source_relative
- plugin: buf.build/community/planetscale-vtprotobuf:v0.4.0
- plugin: buf.build/grpc/go:v1.3.0
out: internal
opt: paths=source_relative
7 changes: 4 additions & 3 deletions cmd/sqlc-gen-json/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/sqlc-dev/sqlc/internal/codegen/json"
"github.com/sqlc-dev/sqlc/internal/plugin"
"google.golang.org/protobuf/proto"
)

func main() {
Expand All @@ -19,19 +20,19 @@ func main() {
}

func run() error {
var req plugin.CodeGenRequest
var req plugin.GenerateRequest
reqBlob, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
if err := req.UnmarshalVT(reqBlob); err != nil {
if err := proto.Unmarshal(reqBlob, &req); err != nil {
return err
}
resp, err := json.Generate(context.Background(), &req)
if err != nil {
return err
}
respBlob, err := resp.MarshalVT()
respBlob, err := proto.Marshal(resp)
if err != nil {
return err
}
Expand Down
8 changes: 5 additions & 3 deletions internal/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"

"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/status"

"github.com/sqlc-dev/sqlc/internal/codegen/golang"
Expand Down Expand Up @@ -380,10 +381,10 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
return c.Result(), false
}

func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.CodeGenResponse, error) {
func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.GenerateResponse, error) {
defer trace.StartRegion(ctx, "codegen").End()
req := codeGenRequest(result, combo)
var handler ext.Handler
var handler grpc.ClientConnInterface
var out string
switch {
case sql.Plugin != nil:
Expand Down Expand Up @@ -453,6 +454,7 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
default:
return "", nil, fmt.Errorf("missing language backend")
}
resp, err := handler.Generate(ctx, req)
client := plugin.NewCodegenServiceClient(handler)
resp, err := client.Generate(ctx, req)
return out, resp, err
}
4 changes: 2 additions & 2 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ func pluginQueryParam(p compiler.Parameter) *plugin.Parameter {
}
}

func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.CodeGenRequest {
return &plugin.CodeGenRequest{
func codeGenRequest(r *compiler.Result, settings config.CombinedSettings) *plugin.GenerateRequest {
return &plugin.GenerateRequest{
Settings: pluginSettings(r, settings),
Catalog: pluginCatalog(r.Catalog),
Queries: pluginQueries(r),
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/vet.go
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ func (c *checker) checkSQL(ctx context.Context, s config.SQL) error {
return nil
}

func vetConfig(req *plugin.CodeGenRequest) *vet.Config {
func vetConfig(req *plugin.GenerateRequest) *vet.Config {
return &vet.Config{
Version: req.Settings.Version,
Engine: req.Settings.Engine,
Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/golang/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (t *tmplCtx) codegenQueryRetval(q Query) (string, error) {
}
}

func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
options, err := opts.Parse(req)
if err != nil {
return nil, err
Expand All @@ -127,7 +127,7 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR
return generate(req, options, enums, structs, queries)
}

func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.CodeGenResponse, error) {
func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, structs []Struct, queries []Query) (*plugin.GenerateResponse, error) {
i := &importer{
Options: options,
Queries: queries,
Expand Down Expand Up @@ -282,7 +282,7 @@ func generate(req *plugin.CodeGenRequest, options *opts.Options, enums []Enum, s
return nil, err
}
}
resp := plugin.CodeGenResponse{}
resp := plugin.GenerateResponse{}

for filename, code := range output {
resp.Files = append(resp.Files, &plugin.File{
Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/golang/go_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) {
func addExtraGoStructTags(tags map[string]string, req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) {
for _, override := range options.Overrides {
oride := override.ShimOverride
if oride.GoType.StructTags == nil {
Expand All @@ -33,7 +33,7 @@ func addExtraGoStructTags(tags map[string]string, req *plugin.CodeGenRequest, op
}
}

func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func goType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
// Check if the column's type has been overridden
for _, override := range options.Overrides {
oride := override.ShimOverride
Expand Down Expand Up @@ -63,7 +63,7 @@ func goType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Colum
return typ
}

func goInnerType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func goInnerType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray

Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/mysql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func mysqlType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func mysqlType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray
unsigned := col.Unsigned
Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/golang/opts/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type GlobalOptions struct {
Rename map[string]string `json:"rename,omitempty" yaml:"rename"`
}

func Parse(req *plugin.CodeGenRequest) (*Options, error) {
func Parse(req *plugin.GenerateRequest) (*Options, error) {
options, err := parseOpts(req)
if err != nil {
return nil, err
Expand All @@ -68,7 +68,7 @@ func Parse(req *plugin.CodeGenRequest) (*Options, error) {
return options, nil
}

func parseOpts(req *plugin.CodeGenRequest) (*Options, error) {
func parseOpts(req *plugin.GenerateRequest) (*Options, error) {
var options Options
if len(req.PluginOptions) == 0 {
return &options, nil
Expand All @@ -91,7 +91,7 @@ func parseOpts(req *plugin.CodeGenRequest) (*Options, error) {
return &options, nil
}

func parseGlobalOpts(req *plugin.CodeGenRequest) (*GlobalOptions, error) {
func parseGlobalOpts(req *plugin.GenerateRequest) (*GlobalOptions, error) {
var options GlobalOptions
if len(req.GlobalOptions) == 0 {
return &options, nil
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/opts/override.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (o *Override) Matches(n *plugin.Identifier, defaultSchema string) bool {
return true
}

func (o *Override) parse(req *plugin.CodeGenRequest) (err error) {
func (o *Override) parse(req *plugin.GenerateRequest) (err error) {
// validate deprecated postgres_type field
if o.Deprecated_PostgresType != "" {
fmt.Fprintf(os.Stderr, "WARNING: \"postgres_type\" is deprecated. Instead, use \"db_type\" to specify a type override.\n")
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/opts/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ type ShimOverride struct {
GoType *ShimGoType
}

func shimOverride(req *plugin.CodeGenRequest, o *Override) *ShimOverride {
func shimOverride(req *plugin.GenerateRequest, o *Override) *ShimOverride {
var column string
var table plugin.Identifier

Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/postgresql_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) {
}
}

func postgresType(req *plugin.CodeGenRequest, options *opts.Options, col *plugin.Column) string {
func postgresType(req *plugin.GenerateRequest, options *opts.Options, col *plugin.Column) string {
columnType := sdk.DataType(col.Type)
notNull := col.NotNull || col.IsArray
driver := parseDriver(options.SqlPackage)
Expand Down
8 changes: 4 additions & 4 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func buildEnums(req *plugin.CodeGenRequest, options *opts.Options) []Enum {
func buildEnums(req *plugin.GenerateRequest, options *opts.Options) []Enum {
var enums []Enum
for _, schema := range req.Catalog.Schemas {
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
Expand Down Expand Up @@ -59,7 +59,7 @@ func buildEnums(req *plugin.CodeGenRequest, options *opts.Options) []Enum {
return enums
}

func buildStructs(req *plugin.CodeGenRequest, options *opts.Options) []Struct {
func buildStructs(req *plugin.GenerateRequest, options *opts.Options) []Struct {
var structs []Struct
for _, schema := range req.Catalog.Schemas {
if schema.Name == "pg_catalog" || schema.Name == "information_schema" {
Expand Down Expand Up @@ -182,7 +182,7 @@ func argName(name string) string {
return out
}

func buildQueries(req *plugin.CodeGenRequest, options *opts.Options, structs []Struct) ([]Query, error) {
func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs []Struct) ([]Query, error) {
qs := make([]Query, 0, len(req.Queries))
for _, query := range req.Queries {
if query.Name == "" {
Expand Down Expand Up @@ -332,7 +332,7 @@ func putOutColumns(query *plugin.Query) bool {
// JSON tags: count, count_2, count_2
//
// This is unlikely to happen, so don't fix it yet
func columnsToStruct(req *plugin.CodeGenRequest, options *opts.Options, name string, columns []goColumn, useID bool) (*Struct, error) {
func columnsToStruct(req *plugin.GenerateRequest, options *opts.Options, name string, columns []goColumn, useID bool) (*Struct, error) {
gs := Struct{
Name: name,
}
Expand Down
2 changes: 1 addition & 1 deletion internal/codegen/golang/sqlite_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func sqliteType(req *plugin.CodeGenRequest, col *plugin.Column) string {
func sqliteType(req *plugin.GenerateRequest, col *plugin.Column) string {
dt := strings.ToLower(sdk.DataType(col.Type))
notNull := col.NotNull || col.IsArray

Expand Down
6 changes: 3 additions & 3 deletions internal/codegen/json/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"github.com/sqlc-dev/sqlc/internal/plugin"
)

func parseOptions(req *plugin.CodeGenRequest) (*opts, error) {
func parseOptions(req *plugin.GenerateRequest) (*opts, error) {
if len(req.PluginOptions) == 0 {
return new(opts), nil
}
Expand All @@ -25,7 +25,7 @@ func parseOptions(req *plugin.CodeGenRequest) (*opts, error) {
return options, nil
}

func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
func Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
options, err := parseOptions(req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -57,7 +57,7 @@ func Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenR
if err != nil {
return nil, err
}
return &plugin.CodeGenResponse{
return &plugin.GenerateResponse{
Files: []*plugin.File{
{
Name: filename,
Expand Down
37 changes: 33 additions & 4 deletions internal/ext/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,51 @@ package ext

import (
"context"
"fmt"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/sqlc-dev/sqlc/internal/plugin"
)

type Handler interface {
Generate(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)
Generate(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)

Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error
NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)
}

type wrapper struct {
fn func(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)
fn func(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)
}

func (w *wrapper) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
func (w *wrapper) Generate(ctx context.Context, req *plugin.GenerateRequest) (*plugin.GenerateResponse, error) {
return w.fn(ctx, req)
}

func HandleFunc(fn func(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)) Handler {
func (w *wrapper) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
req, ok := args.(*plugin.GenerateRequest)
if !ok {
return fmt.Errorf("args isn't a GenerateRequest")
}
resp, ok := reply.(*plugin.GenerateResponse)
if !ok {
return fmt.Errorf("reply isn't a GenerateResponse")
}
res, err := w.Generate(ctx, req)
if err != nil {
return err
}
resp.Files = res.Files
return nil
}

func (w *wrapper) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, status.Error(codes.Unimplemented, "")
}

func HandleFunc(fn func(context.Context, *plugin.GenerateRequest) (*plugin.GenerateResponse, error)) Handler {
return &wrapper{fn}
}
Loading