Skip to content

Commit 021b6aa

Browse files
committed
feat(plugin): Use gRPC interface for codegen plugin communication
1 parent 9cd9139 commit 021b6aa

File tree

8 files changed

+219
-39
lines changed

8 files changed

+219
-39
lines changed

internal/cmd/generate.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"sync"
1515

1616
"golang.org/x/sync/errgroup"
17+
"google.golang.org/grpc"
1718
"google.golang.org/grpc/status"
1819

1920
"github.com/sqlc-dev/sqlc/internal/codegen/golang"
@@ -383,7 +384,7 @@ func parse(ctx context.Context, name, dir string, sql config.SQL, combo config.C
383384
func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, result *compiler.Result) (string, *plugin.CodeGenResponse, error) {
384385
defer trace.StartRegion(ctx, "codegen").End()
385386
req := codeGenRequest(result, combo)
386-
var handler ext.Handler
387+
var handler grpc.ClientConnInterface
387388
var out string
388389
switch {
389390
case sql.Plugin != nil:
@@ -453,6 +454,7 @@ func codegen(ctx context.Context, combo config.CombinedSettings, sql outPair, re
453454
default:
454455
return "", nil, fmt.Errorf("missing language backend")
455456
}
456-
resp, err := handler.Generate(ctx, req)
457+
client := plugin.NewCodeGeneratorClient(handler)
458+
resp, err := client.Generate(ctx, req)
457459
return out, resp, err
458460
}

internal/ext/handler.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,19 @@ package ext
22

33
import (
44
"context"
5+
"fmt"
6+
7+
"google.golang.org/grpc"
8+
"google.golang.org/grpc/codes"
9+
"google.golang.org/grpc/status"
510

611
"github.com/sqlc-dev/sqlc/internal/plugin"
712
)
813

914
type Handler interface {
1015
Generate(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)
16+
Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error
17+
NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error)
1118
}
1219

1320
type wrapper struct {
@@ -18,6 +25,27 @@ func (w *wrapper) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*pl
1825
return w.fn(ctx, req)
1926
}
2027

28+
func (w *wrapper) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
29+
req, ok := args.(*plugin.CodeGenRequest)
30+
if !ok {
31+
return fmt.Errorf("args isn't a CodeGenRequest")
32+
}
33+
resp, ok := reply.(*plugin.CodeGenResponse)
34+
if !ok {
35+
return fmt.Errorf("reply isn't a CodeGenResponse")
36+
}
37+
res, err := w.Generate(ctx, req)
38+
if err != nil {
39+
return err
40+
}
41+
resp.Files = res.Files
42+
return nil
43+
}
44+
45+
func (w *wrapper) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
46+
return nil, status.Error(codes.Unimplemented, "")
47+
}
48+
2149
func HandleFunc(fn func(context.Context, *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error)) Handler {
2250
return &wrapper{fn}
2351
}

internal/ext/process/gen.go

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import (
88
"os"
99
"os/exec"
1010

11+
"google.golang.org/grpc"
12+
"google.golang.org/grpc/codes"
13+
"google.golang.org/grpc/status"
1114
"google.golang.org/protobuf/proto"
1215

1316
"github.com/sqlc-dev/sqlc/internal/plugin"
@@ -18,20 +21,24 @@ type Runner struct {
1821
Env []string
1922
}
2023

21-
// TODO: Update the gen func signature to take a ctx
22-
func (r Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
24+
func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
25+
req, ok := args.(*plugin.CodeGenRequest)
26+
if !ok {
27+
return fmt.Errorf("args isn't a CodeGenRequest")
28+
}
29+
2330
stdin, err := proto.Marshal(req)
2431
if err != nil {
25-
return nil, fmt.Errorf("failed to encode codegen request: %s", err)
32+
return fmt.Errorf("failed to encode codegen request: %w", err)
2633
}
2734

2835
// Check if the output plugin exists
2936
path, err := exec.LookPath(r.Cmd)
3037
if err != nil {
31-
return nil, fmt.Errorf("process: %s not found", r.Cmd)
38+
return fmt.Errorf("process: %s not found", r.Cmd)
3239
}
3340

34-
cmd := exec.CommandContext(ctx, path)
41+
cmd := exec.CommandContext(ctx, path, method)
3542
cmd.Stdin = bytes.NewReader(stdin)
3643
cmd.Env = []string{
3744
fmt.Sprintf("SQLC_VERSION=%s", req.SqlcVersion),
@@ -50,13 +57,21 @@ func (r Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plug
5057
if errors.As(err, &exit) {
5158
stderr = string(exit.Stderr)
5259
}
53-
return nil, fmt.Errorf("process: error running command %s", stderr)
60+
return fmt.Errorf("process: error running command %s", stderr)
61+
}
62+
63+
resp, ok := reply.(*plugin.CodeGenResponse)
64+
if !ok {
65+
return fmt.Errorf("reply isn't a CodeGenResponse")
5466
}
5567

56-
var resp plugin.CodeGenResponse
57-
if err := proto.Unmarshal(out, &resp); err != nil {
58-
return nil, fmt.Errorf("process: failed to read codegen resp: %s", err)
68+
if err := proto.Unmarshal(out, resp); err != nil {
69+
return fmt.Errorf("process: failed to read codegen resp: %w", err)
5970
}
6071

61-
return &resp, nil
72+
return nil
73+
}
74+
75+
func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
76+
return nil, status.Error(codes.Unimplemented, "")
6277
}

internal/ext/wasm/nowasm.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@ package wasm
44

55
import (
66
"context"
7-
"fmt"
7+
8+
"google.golang.org/grpc"
9+
"google.golang.org/grpc/codes"
10+
"google.golang.org/grpc/status"
811

912
"github.com/sqlc-dev/sqlc/internal/plugin"
1013
)
1114

12-
func (r *Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
13-
return nil, fmt.Errorf("sqlc built without wasmtime support")
15+
func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
16+
return status.Error(codes.FailedPrecondition, "sqlc built without wasmtime support")
17+
}
18+
19+
func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
20+
return nil, status.Error(codes.Unimplemented, "")
1421
}

internal/ext/wasm/wasm.go

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ import (
1919

2020
wasmtime "github.com/bytecodealliance/wasmtime-go/v14"
2121
"golang.org/x/sync/singleflight"
22+
"google.golang.org/grpc"
23+
"google.golang.org/grpc/codes"
24+
"google.golang.org/grpc/status"
2225

2326
"github.com/sqlc-dev/sqlc/internal/cache"
2427
"github.com/sqlc-dev/sqlc/internal/info"
@@ -206,29 +209,34 @@ func removePGCatalog(req *plugin.CodeGenRequest) {
206209
req.Catalog.Schemas = filtered
207210
}
208211

209-
func (r *Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
212+
func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error {
213+
req, ok := args.(*plugin.CodeGenRequest)
214+
if !ok {
215+
return status.Error(codes.InvalidArgument, "args isn't a CodeGenRequest")
216+
}
217+
210218
// Remove the pg_catalog schema. Its sheer size causes unknown issues with wasm plugins
211219
removePGCatalog(req)
212220

213221
stdinBlob, err := req.MarshalVT()
214222
if err != nil {
215-
return nil, err
223+
return fmt.Errorf("failed to encode codegen request: %w", err)
216224
}
217225

218226
engine := wasmtime.NewEngine()
219227
module, err := r.loadModule(ctx, engine)
220228
if err != nil {
221-
return nil, fmt.Errorf("loadModule: %w", err)
229+
return fmt.Errorf("loadModule: %w", err)
222230
}
223231

224232
linker := wasmtime.NewLinker(engine)
225233
if err := linker.DefineWasi(); err != nil {
226-
return nil, err
234+
return err
227235
}
228236

229237
dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out")
230238
if err != nil {
231-
return nil, fmt.Errorf("temp dir: %w", err)
239+
return fmt.Errorf("temp dir: %w", err)
232240
}
233241

234242
defer os.RemoveAll(dir)
@@ -237,11 +245,12 @@ func (r *Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plu
237245
stdoutPath := filepath.Join(dir, "stdout")
238246

239247
if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil {
240-
return nil, fmt.Errorf("write file: %w", err)
248+
return fmt.Errorf("write file: %w", err)
241249
}
242250

243251
// Configure WASI imports to write stdout into a file.
244252
wasiConfig := wasmtime.NewWasiConfig()
253+
wasiConfig.SetArgv([]string{method})
245254
wasiConfig.SetStdinFile(stdinPath)
246255
wasiConfig.SetStdoutFile(stdoutPath)
247256
wasiConfig.SetStderrFile(stderrPath)
@@ -261,31 +270,43 @@ func (r *Runner) Generate(ctx context.Context, req *plugin.CodeGenRequest) (*plu
261270
err = linker.DefineModule(store, "", module)
262271
linkRegion.End()
263272
if err != nil {
264-
return nil, fmt.Errorf("define wasi: %w", err)
273+
return fmt.Errorf("define wasi: %w", err)
265274
}
266275

267276
// Run the function
268277
fn, err := linker.GetDefault(store, "")
269278
if err != nil {
270-
return nil, fmt.Errorf("wasi: get default: %w", err)
279+
return fmt.Errorf("wasi: get default: %w", err)
271280
}
272281

273282
callRegion := trace.StartRegion(ctx, "call _start")
274283
_, err = fn.Call(store)
275284
callRegion.End()
276285

277286
if cerr := checkError(err, stderrPath); cerr != nil {
278-
return nil, cerr
287+
return cerr
279288
}
280289

281290
// Print WASM stdout
282291
stdoutBlob, err := os.ReadFile(stdoutPath)
283292
if err != nil {
284-
return nil, fmt.Errorf("read file: %w", err)
293+
return fmt.Errorf("read file: %w", err)
285294
}
286295

287-
var resp plugin.CodeGenResponse
288-
return &resp, resp.UnmarshalVT(stdoutBlob)
296+
resp, ok := reply.(*plugin.CodeGenResponse)
297+
if !ok {
298+
return fmt.Errorf("reply isn't a CodeGenResponse")
299+
}
300+
301+
if err := resp.UnmarshalVT(stdoutBlob); err != nil {
302+
return err
303+
}
304+
305+
return nil
306+
}
307+
308+
func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
309+
return nil, status.Error(codes.Unimplemented, "")
289310
}
290311

291312
func checkError(err error, stderrPath string) error {

internal/plugin/codegen.pb.go

Lines changed: 19 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)