diff --git a/arduino/cli/commander.go b/arduino/cli/commander.go index 6c03ed72..cfc15a0a 100644 --- a/arduino/cli/commander.go +++ b/arduino/cli/commander.go @@ -44,9 +44,11 @@ type commander struct { func NewCommander() (arduino.Commander, error) { // Discard arduino-cli log info messages logrus.SetLevel(logrus.ErrorLevel) - // Initialize arduino-cli configuration + + // Initialize arduino-cli configuration. configuration.Settings = configuration.Init(configuration.FindConfigFileInArgsOrWorkingDirectory(os.Args)) - // Create arduino-cli instance, needed to execute arduino-cli commands + + // Create and init an arduino-cli instance, needed to execute arduino-cli commands. inst, err := instance.Create() if err != nil { err = fmt.Errorf("creating arduino-cli instance: %w", err) @@ -61,34 +63,64 @@ func NewCommander() (arduino.Commander, error) { return cmd, nil } +func mergeErrors(err error, errs []error) error { + merr := errors.New("merged errors: ") + empty := true + + if err != nil { + merr = fmt.Errorf("%w%v; ", merr, err) + empty = false + } + + if len(errs) > 0 { + empty = false + for _, e := range errs { + merr = fmt.Errorf("%w%v; ", merr, e) + } + } + + if !empty { + return merr + } + return nil +} + // BoardList executes the 'arduino-cli board list' command // and returns its result. -func (c *commander) BoardList() ([]*rpc.DetectedPort, error) { +func (c *commander) BoardList(ctx context.Context) ([]*rpc.DetectedPort, error) { req := &rpc.BoardListRequest{ Instance: c.Instance, Timeout: time.Second.Milliseconds(), } - ports, errs, err := board.List(req) - if err != nil { - err = fmt.Errorf("%s: %w", "detecting boards", err) - return nil, err + // There is no obvious way to cancel the execution of this command. + // So, we execute it in a goroutine and leave it running alone if ctx gets cancelled. + type resp struct { + err error + ports []*rpc.DetectedPort } + quit := make(chan resp, 1) + go func() { + ports, errs, err := board.List(req) + quit <- resp{err: mergeErrors(err, errs), ports: ports} + close(quit) + }() - if len(errs) > 0 { - err = errors.New("starting discovery procedure: received errors: ") - for _, e := range errs { - err = fmt.Errorf("%w%v; ", err, e) + // Wait for the command to complete or the context to be terminated. + select { + case <-ctx.Done(): + return nil, errors.New("board list command cancelled") + case r := <-quit: + if r.err != nil { + return nil, fmt.Errorf("executing board list command: %w", r.err) } - return nil, err + return r.ports, nil } - - return ports, nil } // UploadBin executes the 'arduino-cli upload -i' command // and returns its result. -func (c *commander) UploadBin(fqbn, bin, address, protocol string) error { +func (c *commander) UploadBin(ctx context.Context, fqbn, bin, address, protocol string) error { req := &rpc.UploadRequest{ Instance: c.Instance, Fqbn: fqbn, @@ -97,11 +129,25 @@ func (c *commander) UploadBin(fqbn, bin, address, protocol string) error { Port: &rpc.Port{Address: address, Protocol: protocol}, Verbose: false, } - l := logrus.StandardLogger().WithField("source", "arduino-cli").Writer() - if _, err := upload.Upload(context.Background(), req, l, l); err != nil { - err = fmt.Errorf("%s: %w", "uploading binary", err) - return err + + // There is no obvious way to cancel the execution of this command. + // So, we execute it in a goroutine and leave it running if ctx gets cancelled. + quit := make(chan error, 1) + go func() { + _, err := upload.Upload(ctx, req, l, l) + quit <- err + close(quit) + }() + + // Wait for the upload to complete or the context to be terminated. + select { + case <-ctx.Done(): + return errors.New("upload cancelled") + case err := <-quit: + if err != nil { + return fmt.Errorf("uploading binary: %w", err) + } + return nil } - return nil } diff --git a/arduino/commander.go b/arduino/commander.go index 1b5ddcd4..a4e7f164 100644 --- a/arduino/commander.go +++ b/arduino/commander.go @@ -18,13 +18,15 @@ package arduino import ( + "context" + rpc "github.com/arduino/arduino-cli/rpc/cc/arduino/cli/commands/v1" ) // Commander of arduino package allows to call // the arduino-cli commands in a programmatic way. type Commander interface { - BoardList() ([]*rpc.DetectedPort, error) - UploadBin(fqbn, bin, address, protocol string) error + BoardList(ctx context.Context) ([]*rpc.DetectedPort, error) + UploadBin(ctx context.Context, fqbn, bin, address, protocol string) error //Compile() error } diff --git a/arduino/grpc/board.go b/arduino/grpc/board.go index 9dfb3f2e..f81228af 100644 --- a/arduino/grpc/board.go +++ b/arduino/grpc/board.go @@ -30,7 +30,7 @@ type boardHandler struct { // BoardList executes the 'arduino-cli board list' command // and returns its result. -func (b boardHandler) BoardList() ([]*rpc.DetectedPort, error) { +func (b boardHandler) BoardList(ctx context.Context) ([]*rpc.DetectedPort, error) { boardListResp, err := b.serviceClient.BoardList(context.Background(), &rpc.BoardListRequest{Instance: b.instance}) diff --git a/arduino/grpc/compile.go b/arduino/grpc/compile.go index 9946b2a3..07b95c72 100644 --- a/arduino/grpc/compile.go +++ b/arduino/grpc/compile.go @@ -38,7 +38,7 @@ func (c compileHandler) Compile() error { // Upload executes the 'arduino-cli upload -i' command // and returns its result. -func (c compileHandler) UploadBin(fqbn, bin, address, protocol string) error { +func (c compileHandler) UploadBin(ctx context.Context, fqbn, bin, address, protocol string) error { stream, err := c.serviceClient.Upload(context.Background(), &rpc.UploadRequest{ Instance: c.instance, diff --git a/cli/dashboard/create.go b/cli/dashboard/create.go index 33078035..752a498b 100644 --- a/cli/dashboard/create.go +++ b/cli/dashboard/create.go @@ -18,6 +18,7 @@ package dashboard import ( + "context" "fmt" "os" "strings" @@ -76,7 +77,7 @@ func runCreateCommand(flags *createFlags) error { params.Name = &flags.name } - dashboard, err := dashboard.Create(params, cred) + dashboard, err := dashboard.Create(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/dashboard/delete.go b/cli/dashboard/delete.go index 215c56db..710da5a6 100644 --- a/cli/dashboard/delete.go +++ b/cli/dashboard/delete.go @@ -18,6 +18,7 @@ package dashboard import ( + "context" "fmt" "os" @@ -60,7 +61,7 @@ func runDeleteCommand(flags *deleteFlags) error { } params := &dashboard.DeleteParams{ID: flags.id} - err = dashboard.Delete(params, cred) + err = dashboard.Delete(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/dashboard/extract.go b/cli/dashboard/extract.go index e2511005..2ce03cb6 100644 --- a/cli/dashboard/extract.go +++ b/cli/dashboard/extract.go @@ -18,6 +18,7 @@ package dashboard import ( + "context" "fmt" "os" @@ -64,7 +65,7 @@ func runExtractCommand(flags *extractFlags) error { ID: flags.id, } - template, err := dashboard.Extract(params, cred) + template, err := dashboard.Extract(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/dashboard/list.go b/cli/dashboard/list.go index 9fe2e657..fe029c3b 100644 --- a/cli/dashboard/list.go +++ b/cli/dashboard/list.go @@ -18,6 +18,7 @@ package dashboard import ( + "context" "fmt" "math" "os" @@ -66,7 +67,7 @@ func runListCommand(flags *listFlags) error { return fmt.Errorf("retrieving credentials: %w", err) } - dash, err := dashboard.List(cred) + dash, err := dashboard.List(context.TODO(), cred) if err != nil { return err } diff --git a/cli/device/create.go b/cli/device/create.go index 6a725728..0c5aabc3 100644 --- a/cli/device/create.go +++ b/cli/device/create.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "os" @@ -27,6 +28,7 @@ import ( "github.com/arduino/arduino-cloud-cli/config" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "go.bug.st/cleanup" ) type createFlags struct { @@ -73,7 +75,10 @@ func runCreateCommand(flags *createFlags) error { params.FQBN = &flags.fqbn } - dev, err := device.Create(params, cred) + ctx, cancel := cleanup.InterruptableContext(context.Background()) + defer cancel() + + dev, err := device.Create(ctx, params, cred) if err != nil { return err } diff --git a/cli/device/creategeneric.go b/cli/device/creategeneric.go index aeb0d343..21fabf6a 100644 --- a/cli/device/creategeneric.go +++ b/cli/device/creategeneric.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "os" @@ -27,6 +28,7 @@ import ( "github.com/arduino/arduino-cloud-cli/config" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "go.bug.st/cleanup" ) type createGenericFlags struct { @@ -66,7 +68,10 @@ func runCreateGenericCommand(flags *createGenericFlags) error { FQBN: flags.fqbn, } - dev, err := device.CreateGeneric(params, cred) + ctx, cancel := cleanup.InterruptableContext(context.Background()) + defer cancel() + + dev, err := device.CreateGeneric(ctx, params, cred) if err != nil { return err } diff --git a/cli/device/createlora.go b/cli/device/createlora.go index b88962dd..e182ca26 100644 --- a/cli/device/createlora.go +++ b/cli/device/createlora.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "os" @@ -27,6 +28,7 @@ import ( "github.com/arduino/arduino-cloud-cli/config" "github.com/sirupsen/logrus" "github.com/spf13/cobra" + "go.bug.st/cleanup" ) type createLoraFlags struct { @@ -80,7 +82,10 @@ func runCreateLoraCommand(flags *createLoraFlags) error { params.FQBN = &flags.fqbn } - dev, err := device.CreateLora(params, cred) + ctx, cancel := cleanup.InterruptableContext(context.Background()) + defer cancel() + + dev, err := device.CreateLora(ctx, params, cred) if err != nil { return err } diff --git a/cli/device/delete.go b/cli/device/delete.go index 204a8282..23af41bb 100644 --- a/cli/device/delete.go +++ b/cli/device/delete.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "os" @@ -72,7 +73,7 @@ func runDeleteCommand(flags *deleteFlags) error { params.ID = &flags.id } - err = device.Delete(params, cred) + err = device.Delete(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/device/list.go b/cli/device/list.go index f7730494..a76a7412 100644 --- a/cli/device/list.go +++ b/cli/device/list.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "os" "strings" @@ -67,7 +68,7 @@ func runListCommand(flags *listFlags) error { } params := &device.ListParams{Tags: flags.tags} - devs, err := device.List(params, cred) + devs, err := device.List(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/device/listfqbn.go b/cli/device/listfqbn.go index d0fe5939..6f9864e2 100644 --- a/cli/device/listfqbn.go +++ b/cli/device/listfqbn.go @@ -18,6 +18,7 @@ package device import ( + "context" "os" "github.com/arduino/arduino-cli/cli/errorcodes" @@ -46,7 +47,7 @@ func initListFQBNCommand() *cobra.Command { func runListFQBNCommand() error { logrus.Info("Listing supported FQBN") - fqbn, err := device.ListFQBN() + fqbn, err := device.ListFQBN(context.TODO()) if err != nil { return err } diff --git a/cli/device/listfrequency.go b/cli/device/listfrequency.go index b1cdc5b8..71ac4452 100644 --- a/cli/device/listfrequency.go +++ b/cli/device/listfrequency.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "os" @@ -53,7 +54,7 @@ func runListFrequencyPlansCommand() error { return fmt.Errorf("retrieving credentials: %w", err) } - freqs, err := device.ListFrequencyPlans(cred) + freqs, err := device.ListFrequencyPlans(context.TODO(), cred) if err != nil { return err } diff --git a/cli/device/tag/create.go b/cli/device/tag/create.go index 38a51b06..22b62a30 100644 --- a/cli/device/tag/create.go +++ b/cli/device/tag/create.go @@ -18,6 +18,7 @@ package tag import ( + "context" "fmt" "os" @@ -73,7 +74,7 @@ func runCreateTagsCommand(flags *createTagsFlags) error { return fmt.Errorf("retrieving credentials: %w", err) } - if err = tag.CreateTags(params, cred); err != nil { + if err = tag.CreateTags(context.TODO(), params, cred); err != nil { return err } diff --git a/cli/device/tag/delete.go b/cli/device/tag/delete.go index 6aee8a67..0779f2d1 100644 --- a/cli/device/tag/delete.go +++ b/cli/device/tag/delete.go @@ -18,6 +18,7 @@ package tag import ( + "context" "fmt" "os" @@ -68,7 +69,7 @@ func runDeleteTagsCommand(flags *deleteTagsFlags) error { Resource: tag.Device, } - err = tag.DeleteTags(params, cred) + err = tag.DeleteTags(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/ota/massupload.go b/cli/ota/massupload.go index 88c1bd62..10675780 100644 --- a/cli/ota/massupload.go +++ b/cli/ota/massupload.go @@ -18,6 +18,7 @@ package ota import ( + "context" "fmt" "os" "sort" @@ -84,7 +85,7 @@ func runMassUploadCommand(flags *massUploadFlags) error { return fmt.Errorf("retrieving credentials: %w", err) } - resp, err := ota.MassUpload(params, cred) + resp, err := ota.MassUpload(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/ota/upload.go b/cli/ota/upload.go index 0c7619a2..0721595a 100644 --- a/cli/ota/upload.go +++ b/cli/ota/upload.go @@ -18,6 +18,7 @@ package ota import ( + "context" "fmt" "os" @@ -69,7 +70,7 @@ func runUploadCommand(flags *uploadFlags) error { File: flags.file, Deferred: flags.deferred, } - err = ota.Upload(params, cred) + err = ota.Upload(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/thing/bind.go b/cli/thing/bind.go index 347ea967..2c2bc8c3 100644 --- a/cli/thing/bind.go +++ b/cli/thing/bind.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "os" @@ -66,7 +67,7 @@ func runBindCommand(flags *bindFlags) error { ID: flags.id, DeviceID: flags.deviceID, } - if err = thing.Bind(params, cred); err != nil { + if err = thing.Bind(context.TODO(), params, cred); err != nil { return err } diff --git a/cli/thing/clone.go b/cli/thing/clone.go index 4034c628..55408d1d 100644 --- a/cli/thing/clone.go +++ b/cli/thing/clone.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "os" "strings" @@ -68,7 +69,7 @@ func runCloneCommand(flags *cloneFlags) error { CloneID: flags.cloneID, } - thing, err := thing.Clone(params, cred) + thing, err := thing.Clone(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/thing/create.go b/cli/thing/create.go index 5b7da735..06665627 100644 --- a/cli/thing/create.go +++ b/cli/thing/create.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "os" "strings" @@ -75,7 +76,7 @@ func runCreateCommand(flags *createFlags) error { params.Name = &flags.name } - thing, err := thing.Create(params, cred) + thing, err := thing.Create(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/thing/delete.go b/cli/thing/delete.go index 4750591a..9a788295 100644 --- a/cli/thing/delete.go +++ b/cli/thing/delete.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "os" @@ -72,7 +73,7 @@ func runDeleteCommand(flags *deleteFlags) error { params.ID = &flags.id } - err = thing.Delete(params, cred) + err = thing.Delete(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/thing/extract.go b/cli/thing/extract.go index c5125a87..31568351 100644 --- a/cli/thing/extract.go +++ b/cli/thing/extract.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "os" @@ -64,7 +65,7 @@ func runExtractCommand(flags *extractFlags) error { ID: flags.id, } - template, err := thing.Extract(params, cred) + template, err := thing.Extract(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/thing/list.go b/cli/thing/list.go index ba759708..fe9ac0e1 100644 --- a/cli/thing/list.go +++ b/cli/thing/list.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "os" "strings" @@ -83,7 +84,7 @@ func runListCommand(flags *listFlags) error { params.DeviceID = &flags.deviceID } - things, err := thing.List(params, cred) + things, err := thing.List(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/thing/tag/create.go b/cli/thing/tag/create.go index 31d87a2f..cd3063b5 100644 --- a/cli/thing/tag/create.go +++ b/cli/thing/tag/create.go @@ -18,6 +18,7 @@ package tag import ( + "context" "fmt" "os" @@ -73,7 +74,7 @@ func runCreateTagsCommand(flags *createTagsFlags) error { return fmt.Errorf("retrieving credentials: %w", err) } - err = tag.CreateTags(params, cred) + err = tag.CreateTags(context.TODO(), params, cred) if err != nil { return err } diff --git a/cli/thing/tag/delete.go b/cli/thing/tag/delete.go index 98a73956..dc86de18 100644 --- a/cli/thing/tag/delete.go +++ b/cli/thing/tag/delete.go @@ -18,6 +18,7 @@ package tag import ( + "context" "fmt" "os" @@ -68,7 +69,7 @@ func runDeleteTagsCommand(flags *deleteTagsFlags) error { Resource: tag.Thing, } - err = tag.DeleteTags(params, cred) + err = tag.DeleteTags(context.TODO(), params, cred) if err != nil { return err } diff --git a/command/dashboard/create.go b/command/dashboard/create.go index 8c2d8c0f..0aab64cb 100644 --- a/command/dashboard/create.go +++ b/command/dashboard/create.go @@ -18,6 +18,7 @@ package dashboard import ( + "context" "errors" "github.com/arduino/arduino-cloud-cli/config" @@ -33,13 +34,13 @@ type CreateParams struct { } // Create allows to create a new dashboard. -func Create(params *CreateParams, cred *config.Credentials) (*DashboardInfo, error) { +func Create(ctx context.Context, params *CreateParams, cred *config.Credentials) (*DashboardInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - dashboard, err := template.LoadDashboard(params.Template, params.Override, iotClient) + dashboard, err := template.LoadDashboard(ctx, params.Template, params.Override, iotClient) if err != nil { return nil, err } @@ -53,7 +54,7 @@ func Create(params *CreateParams, cred *config.Credentials) (*DashboardInfo, err return nil, errors.New("dashboard name not specified") } - newDashboard, err := iotClient.DashboardCreate(dashboard) + newDashboard, err := iotClient.DashboardCreate(ctx, dashboard) if err != nil { return nil, err } diff --git a/command/dashboard/delete.go b/command/dashboard/delete.go index e3446525..e5c9d21e 100644 --- a/command/dashboard/delete.go +++ b/command/dashboard/delete.go @@ -18,6 +18,8 @@ package dashboard import ( + "context" + "github.com/arduino/arduino-cloud-cli/config" "github.com/arduino/arduino-cloud-cli/internal/iot" ) @@ -30,11 +32,11 @@ type DeleteParams struct { // Delete command is used to delete a dashboard // from Arduino IoT Cloud. -func Delete(params *DeleteParams, cred *config.Credentials) error { +func Delete(ctx context.Context, params *DeleteParams, cred *config.Credentials) error { iotClient, err := iot.NewClient(cred) if err != nil { return err } - return iotClient.DashboardDelete(params.ID) + return iotClient.DashboardDelete(ctx, params.ID) } diff --git a/command/dashboard/extract.go b/command/dashboard/extract.go index 5b702fe1..e09281f4 100644 --- a/command/dashboard/extract.go +++ b/command/dashboard/extract.go @@ -18,6 +18,7 @@ package dashboard import ( + "context" "fmt" "github.com/arduino/arduino-cloud-cli/config" @@ -33,13 +34,13 @@ type ExtractParams struct { // Extract command is used to extract a dashboard template // from a dashboard on Arduino IoT Cloud. -func Extract(params *ExtractParams, cred *config.Credentials) (map[string]interface{}, error) { +func Extract(ctx context.Context, params *ExtractParams, cred *config.Credentials) (map[string]interface{}, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - dashboard, err := iotClient.DashboardShow(params.ID) + dashboard, err := iotClient.DashboardShow(ctx, params.ID) if err != nil { err = fmt.Errorf("%s: %w", "cannot extract dashboard: ", err) return nil, err diff --git a/command/dashboard/list.go b/command/dashboard/list.go index 7ae44dfc..5a2f46a2 100644 --- a/command/dashboard/list.go +++ b/command/dashboard/list.go @@ -18,19 +18,22 @@ package dashboard import ( + "context" + "github.com/arduino/arduino-cloud-cli/config" + "github.com/arduino/arduino-cloud-cli/internal/iot" ) // List command is used to list // the dashboards of Arduino IoT Cloud. -func List(cred *config.Credentials) ([]DashboardInfo, error) { +func List(ctx context.Context, cred *config.Credentials) ([]DashboardInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - foundDashboards, err := iotClient.DashboardList() + foundDashboards, err := iotClient.DashboardList(ctx) if err != nil { return nil, err } diff --git a/command/device/create.go b/command/device/create.go index 422bf185..d7741983 100644 --- a/command/device/create.go +++ b/command/device/create.go @@ -18,6 +18,7 @@ package device import ( + "context" "errors" "fmt" @@ -37,13 +38,13 @@ type CreateParams struct { // Create command is used to provision a new arduino device // and to add it to Arduino IoT Cloud. -func Create(params *CreateParams, cred *config.Credentials) (*DeviceInfo, error) { +func Create(ctx context.Context, params *CreateParams, cred *config.Credentials) (*DeviceInfo, error) { comm, err := cli.NewCommander() if err != nil { return nil, err } - ports, err := comm.BoardList() + ports, err := comm.BoardList(ctx) if err != nil { return nil, err } @@ -69,7 +70,7 @@ func Create(params *CreateParams, cred *config.Credentials) (*DeviceInfo, error) } logrus.Info("Creating a new device on the cloud") - dev, err := iotClient.DeviceCreate(board.fqbn, params.Name, board.serial, board.dType) + dev, err := iotClient.DeviceCreate(ctx, board.fqbn, params.Name, board.serial, board.dType) if err != nil { return nil, err } @@ -80,8 +81,9 @@ func Create(params *CreateParams, cred *config.Credentials) (*DeviceInfo, error) board: board, id: dev.Id, } - if err = prov.run(); err != nil { - if errDel := iotClient.DeviceDelete(dev.Id); errDel != nil { + if err = prov.run(ctx); err != nil { + // Don't use the passed context for the cleanup because it could be cancelled. + if errDel := iotClient.DeviceDelete(context.Background(), dev.Id); errDel != nil { return nil, fmt.Errorf( "device was NOT successfully provisioned but " + "now we can't delete it from the cloud - please check " + diff --git a/command/device/creategeneric.go b/command/device/creategeneric.go index 9dd0a5c9..9d72c315 100644 --- a/command/device/creategeneric.go +++ b/command/device/creategeneric.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "github.com/arduino/arduino-cloud-cli/config" @@ -43,20 +44,21 @@ type DeviceGenericInfo struct { } // CreateGeneric command is used to add a new generic device to Arduino IoT Cloud. -func CreateGeneric(params *CreateGenericParams, cred *config.Credentials) (*DeviceGenericInfo, error) { +func CreateGeneric(ctx context.Context, params *CreateGenericParams, cred *config.Credentials) (*DeviceGenericInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - dev, err := iotClient.DeviceCreate(params.FQBN, params.Name, "", genericDType) + dev, err := iotClient.DeviceCreate(ctx, params.FQBN, params.Name, "", genericDType) if err != nil { return nil, err } - pass, err := iotClient.DevicePassSet(dev.Id) + pass, err := iotClient.DevicePassSet(ctx, dev.Id) if err != nil { - if errDel := iotClient.DeviceDelete(dev.Id); errDel != nil { + // Don't use the passed context for the cleanup because it could be cancelled. + if errDel := iotClient.DeviceDelete(context.Background(), dev.Id); errDel != nil { return nil, fmt.Errorf( "device was successfully created on IoT-API but " + "now we can't set its secret key nor delete it - please check " + diff --git a/command/device/createlora.go b/command/device/createlora.go index afa77f54..821d5018 100644 --- a/command/device/createlora.go +++ b/command/device/createlora.go @@ -18,6 +18,7 @@ package device import ( + "context" "errors" "fmt" "time" @@ -62,13 +63,13 @@ type CreateLoraParams struct { // CreateLora command is used to provision a new LoRa arduino device // and to add it to Arduino IoT Cloud. -func CreateLora(params *CreateLoraParams, cred *config.Credentials) (*DeviceLoraInfo, error) { +func CreateLora(ctx context.Context, params *CreateLoraParams, cred *config.Credentials) (*DeviceLoraInfo, error) { comm, err := cli.NewCommander() if err != nil { return nil, err } - ports, err := comm.BoardList() + ports, err := comm.BoardList(ctx) if err != nil { return nil, err } @@ -88,21 +89,21 @@ func CreateLora(params *CreateLoraParams, cred *config.Credentials) (*DeviceLora ) } - bin, err := downloadProvisioningFile(board.fqbn) + bin, err := downloadProvisioningFile(ctx, board.fqbn) if err != nil { return nil, err } logrus.Infof("%s", "Uploading deveui sketch on the LoRa board") errMsg := "Error while uploading the LoRa provisioning binary" - err = retry(deveuiUploadAttempts, deveuiUploadWait*time.Millisecond, errMsg, func() error { - return comm.UploadBin(board.fqbn, bin, board.address, board.protocol) + err = retry(ctx, deveuiUploadAttempts, deveuiUploadWait*time.Millisecond, errMsg, func() error { + return comm.UploadBin(ctx, board.fqbn, bin, board.address, board.protocol) }) if err != nil { return nil, fmt.Errorf("failed to upload LoRa provisioning binary: %w", err) } - eui, err := extractEUI(board.address) + eui, err := extractEUI(ctx, board.address) if err != nil { return nil, err } @@ -113,14 +114,15 @@ func CreateLora(params *CreateLoraParams, cred *config.Credentials) (*DeviceLora } logrus.Info("Creating a new device on the cloud") - dev, err := iotClient.DeviceLoraCreate(params.Name, board.serial, board.dType, eui, params.FrequencyPlan) + dev, err := iotClient.DeviceLoraCreate(ctx, params.Name, board.serial, board.dType, eui, params.FrequencyPlan) if err != nil { return nil, err } - devInfo, err := getDeviceLoraInfo(iotClient, dev) + devInfo, err := getDeviceLoraInfo(ctx, iotClient, dev) if err != nil { - errDel := iotClient.DeviceDelete(dev.DeviceId) + // Don't use the passed context for the cleanup because it could be cancelled. + errDel := iotClient.DeviceDelete(context.Background(), dev.DeviceId) if errDel != nil { // Oh no return nil, fmt.Errorf( "device was successfully provisioned and configured on IoT-API but " + @@ -135,12 +137,12 @@ func CreateLora(params *CreateLoraParams, cred *config.Credentials) (*DeviceLora } // extractEUI extracts the EUI from the provisioned lora board. -func extractEUI(port string) (string, error) { +func extractEUI(ctx context.Context, port string) (string, error) { var ser serial.Port logrus.Infof("%s\n", "Connecting to the board through serial port") errMsg := "Error while connecting to the board" - err := retry(serialEUIAttempts, serialEUIWait*time.Millisecond, errMsg, func() error { + err := retry(ctx, serialEUIAttempts, serialEUIWait*time.Millisecond, errMsg, func() error { var err error ser, err = serial.Open(port, &serial.Mode{BaudRate: serialEUIBaudrate}) return err @@ -167,8 +169,8 @@ func extractEUI(port string) (string, error) { return eui, nil } -func getDeviceLoraInfo(iotClient *iot.Client, loraDev *iotclient.ArduinoLoradevicev1) (*DeviceLoraInfo, error) { - dev, err := iotClient.DeviceShow(loraDev.DeviceId) +func getDeviceLoraInfo(ctx context.Context, iotClient *iot.Client, loraDev *iotclient.ArduinoLoradevicev1) (*DeviceLoraInfo, error) { + dev, err := iotClient.DeviceShow(ctx, loraDev.DeviceId) if err != nil { return nil, fmt.Errorf("cannot retrieve device from the cloud: %w", err) } diff --git a/command/device/delete.go b/command/device/delete.go index 7b9de638..1c222dc0 100644 --- a/command/device/delete.go +++ b/command/device/delete.go @@ -18,6 +18,7 @@ package device import ( + "context" "errors" "github.com/arduino/arduino-cloud-cli/config" @@ -36,7 +37,7 @@ type DeleteParams struct { // Delete command is used to delete a device // from Arduino IoT Cloud. -func Delete(params *DeleteParams, cred *config.Credentials) error { +func Delete(ctx context.Context, params *DeleteParams, cred *config.Credentials) error { if params.ID == nil && params.Tags == nil { return errors.New("provide either ID or Tags") } else if params.ID != nil && params.Tags != nil { @@ -53,7 +54,7 @@ func Delete(params *DeleteParams, cred *config.Credentials) error { deviceIDs = append(deviceIDs, *params.ID) } if params.Tags != nil { - dev, err := iotClient.DeviceList(params.Tags) + dev, err := iotClient.DeviceList(ctx, params.Tags) if err != nil { return err } @@ -63,7 +64,7 @@ func Delete(params *DeleteParams, cred *config.Credentials) error { } for _, id := range deviceIDs { - err = iotClient.DeviceDelete(id) + err = iotClient.DeviceDelete(ctx, id) if err != nil { return err } diff --git a/command/device/list.go b/command/device/list.go index aeee20ba..d1f6af0f 100644 --- a/command/device/list.go +++ b/command/device/list.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "github.com/arduino/arduino-cloud-cli/config" @@ -32,13 +33,13 @@ type ListParams struct { // List command is used to list // the devices of Arduino IoT Cloud. -func List(params *ListParams, cred *config.Credentials) ([]DeviceInfo, error) { +func List(ctx context.Context, params *ListParams, cred *config.Credentials) ([]DeviceInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - foundDevices, err := iotClient.DeviceList(params.Tags) + foundDevices, err := iotClient.DeviceList(ctx, params.Tags) if err != nil { return nil, err } diff --git a/command/device/listfqbn.go b/command/device/listfqbn.go index 3fd1e8de..4779bb46 100644 --- a/command/device/listfqbn.go +++ b/command/device/listfqbn.go @@ -18,6 +18,7 @@ package device import ( + "context" "encoding/json" "fmt" "io/ioutil" @@ -39,11 +40,17 @@ type FQBNInfo struct { } // ListFQBN command returns a list of the supported FQBN. -func ListFQBN() ([]FQBNInfo, error) { +func ListFQBN(ctx context.Context) ([]FQBNInfo, error) { + url := "https://builder.arduino.cc/v3/boards/" + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("cannot retrieve boards: %w", err) + } + h := &http.Client{Timeout: time.Second * 5} - resp, err := h.Get("https://builder.arduino.cc/v3/boards/") + resp, err := h.Do(req) if err != nil { - return nil, fmt.Errorf("cannot retrieve boards from builder.arduino.cc: %w", err) + return nil, fmt.Errorf("cannot retrieve boards: %w", err) } defer resp.Body.Close() diff --git a/command/device/listfrequency.go b/command/device/listfrequency.go index 6cd14e4a..8f97b956 100644 --- a/command/device/listfrequency.go +++ b/command/device/listfrequency.go @@ -18,6 +18,7 @@ package device import ( + "context" "fmt" "github.com/arduino/arduino-cloud-cli/config" @@ -33,13 +34,13 @@ type FrequencyPlanInfo struct { // ListFrequencyPlans command is used to list // the supported LoRa frequency plans. -func ListFrequencyPlans(cred *config.Credentials) ([]FrequencyPlanInfo, error) { +func ListFrequencyPlans(ctx context.Context, cred *config.Credentials) ([]FrequencyPlanInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - foundFreqs, err := iotClient.LoraFrequencyPlansList() + foundFreqs, err := iotClient.LoraFrequencyPlansList(ctx) if err != nil { return nil, err } diff --git a/command/device/provision.go b/command/device/provision.go index 38408666..58781a33 100644 --- a/command/device/provision.go +++ b/command/device/provision.go @@ -18,6 +18,7 @@ package device import ( + "context" "encoding/hex" "fmt" "path/filepath" @@ -34,8 +35,8 @@ import ( // downloadProvisioningFile downloads and returns the absolute path // of the provisioning binary corresponding to the passed fqbn. -func downloadProvisioningFile(fqbn string) (string, error) { - index, err := binary.LoadIndex() +func downloadProvisioningFile(ctx context.Context, fqbn string) (string, error) { + index, err := binary.LoadIndex(ctx) if err != nil { return "", err } @@ -43,7 +44,7 @@ func downloadProvisioningFile(fqbn string) (string, error) { if bin == nil { return "", fmt.Errorf("provisioning binary for board %s not found", fqbn) } - bytes, err := binary.Download(bin) + bytes, err := binary.Download(ctx, bin) if err != nil { return "", fmt.Errorf("downloading provisioning binary: %w", err) } @@ -64,7 +65,7 @@ func downloadProvisioningFile(fqbn string) (string, error) { } type certificateCreator interface { - CertificateCreate(id, csr string) (*iotclient.ArduinoCompressedv2, error) + CertificateCreate(ctx context.Context, id, csr string) (*iotclient.ArduinoCompressedv2, error) } // provision is responsible for running the provisioning @@ -78,44 +79,49 @@ type provision struct { } // run provisioning procedure for boards with crypto-chip. -func (p provision) run() error { - bin, err := downloadProvisioningFile(p.board.fqbn) +func (p provision) run(ctx context.Context) error { + bin, err := downloadProvisioningFile(ctx, p.board.fqbn) if err != nil { return err } - logrus.Infof("%s\n", "Uploading provisioning sketch on the board") - time.Sleep(500 * time.Millisecond) // Try to upload the provisioning sketch + logrus.Infof("%s\n", "Uploading provisioning sketch on the board") + if err = sleepCtx(ctx, 500*time.Millisecond); err != nil { + return err + } errMsg := "Error while uploading the provisioning sketch" - err = retry(5, time.Millisecond*1000, errMsg, func() error { + err = retry(ctx, 5, time.Millisecond*1000, errMsg, func() error { //serialutils.Reset(dev.port, true, nil) - return p.UploadBin(p.board.fqbn, bin, p.board.address, p.board.protocol) + return p.UploadBin(ctx, p.board.fqbn, bin, p.board.address, p.board.protocol) }) if err != nil { return err } - logrus.Infof("%s\n", "Connecting to the board through serial port") // Try to connect to board through the serial port - time.Sleep(1500 * time.Millisecond) + logrus.Infof("%s\n", "Connecting to the board through serial port") + if err = sleepCtx(ctx, 1500*time.Millisecond); err != nil { + return err + } p.ser = serial.NewSerial() errMsg = "Error while connecting to the board" - err = retry(5, time.Millisecond*1000, errMsg, func() error { + err = retry(ctx, 5, time.Millisecond*1000, errMsg, func() error { return p.ser.Connect(p.board.address) }) if err != nil { return err } defer p.ser.Close() + logrus.Infof("%s\n\n", "Connected to the board") // Wait some time before using the serial port - time.Sleep(2000 * time.Millisecond) - logrus.Infof("%s\n\n", "Connected to the board") + if err = sleepCtx(ctx, 2000*time.Millisecond); err != nil { + return err + } // Send configuration commands to the board - err = p.configBoard() - if err != nil { + if err = p.configBoard(ctx); err != nil { return err } @@ -123,109 +129,106 @@ func (p provision) run() error { return nil } -func (p provision) configBoard() error { +func (p provision) configBoard(ctx context.Context) error { logrus.Info("Receiving the certificate") - csr, err := p.ser.SendReceive(serial.CSR, []byte(p.id)) + csr, err := p.ser.SendReceive(ctx, serial.CSR, []byte(p.id)) if err != nil { return err } - cert, err := p.cert.CertificateCreate(p.id, string(csr)) + cert, err := p.cert.CertificateCreate(ctx, p.id, string(csr)) if err != nil { return err } logrus.Info("Requesting begin storage") - err = p.ser.Send(serial.BeginStorage, nil) - if err != nil { + if err = p.ser.Send(ctx, serial.BeginStorage, nil); err != nil { return err } s := strconv.Itoa(cert.NotBefore.Year()) logrus.Info("Sending year: ", s) - err = p.ser.Send(serial.SetYear, []byte(s)) - if err != nil { + if err = p.ser.Send(ctx, serial.SetYear, []byte(s)); err != nil { return err } s = fmt.Sprintf("%02d", int(cert.NotBefore.Month())) logrus.Info("Sending month: ", s) - err = p.ser.Send(serial.SetMonth, []byte(s)) - if err != nil { + if err = p.ser.Send(ctx, serial.SetMonth, []byte(s)); err != nil { return err } s = fmt.Sprintf("%02d", cert.NotBefore.Day()) logrus.Info("Sending day: ", s) - err = p.ser.Send(serial.SetDay, []byte(s)) - if err != nil { + if err = p.ser.Send(ctx, serial.SetDay, []byte(s)); err != nil { return err } s = fmt.Sprintf("%02d", cert.NotBefore.Hour()) logrus.Info("Sending hour: ", s) - err = p.ser.Send(serial.SetHour, []byte(s)) - if err != nil { + if err = p.ser.Send(ctx, serial.SetHour, []byte(s)); err != nil { return err } s = strconv.Itoa(31) logrus.Info("Sending validity: ", s) - err = p.ser.Send(serial.SetValidity, []byte(s)) - if err != nil { + if err = p.ser.Send(ctx, serial.SetValidity, []byte(s)); err != nil { return err } logrus.Info("Sending certificate serial") b, err := hex.DecodeString(cert.Serial) if err != nil { - err = fmt.Errorf("%s: %w", "decoding certificate serial", err) - return err + return fmt.Errorf("decoding certificate serial: %w", err) } - err = p.ser.Send(serial.SetCertSerial, b) - if err != nil { + if err = p.ser.Send(ctx, serial.SetCertSerial, b); err != nil { return err } logrus.Info("Sending certificate authority key") b, err = hex.DecodeString(cert.AuthorityKeyIdentifier) if err != nil { - err = fmt.Errorf("%s: %w", "decoding certificate authority key id", err) - return err + return fmt.Errorf("decoding certificate authority key id: %w", err) } - err = p.ser.Send(serial.SetAuthKey, b) - if err != nil { + if err = p.ser.Send(ctx, serial.SetAuthKey, b); err != nil { return err } logrus.Info("Sending certificate signature") b, err = hex.DecodeString(cert.SignatureAsn1X + cert.SignatureAsn1Y) if err != nil { - err = fmt.Errorf("%s: %w", "decoding certificate signature", err) + err = fmt.Errorf("decoding certificate signature: %w", err) return err } - err = p.ser.Send(serial.SetSignature, b) - if err != nil { + if err = p.ser.Send(ctx, serial.SetSignature, b); err != nil { + return err + } + + if err := sleepCtx(ctx, 1*time.Second); err != nil { return err } - time.Sleep(time.Second) logrus.Info("Requesting end storage") - err = p.ser.Send(serial.EndStorage, nil) - if err != nil { + if err = p.ser.Send(ctx, serial.EndStorage, nil); err != nil { + return err + } + + if err := sleepCtx(ctx, 2*time.Second); err != nil { return err } - time.Sleep(2 * time.Second) logrus.Info("Requesting certificate reconstruction") - err = p.ser.Send(serial.ReconstructCert, nil) - if err != nil { + if err = p.ser.Send(ctx, serial.ReconstructCert, nil); err != nil { return err } return nil } -func retry(tries int, sleep time.Duration, errMsg string, fun func() error) error { +func retry(ctx context.Context, tries int, sleep time.Duration, errMsg string, fun func() error) error { + if err := ctx.Err(); err != nil { + return err + } + var err error for n := 0; n < tries; n++ { err = fun() @@ -233,7 +236,18 @@ func retry(tries int, sleep time.Duration, errMsg string, fun func() error) erro break } logrus.Warningf("%s: %s: %s", errMsg, err.Error(), "\nTrying again...") - time.Sleep(sleep) + if err := sleepCtx(ctx, sleep); err != nil { + return err + } } return err } + +func sleepCtx(ctx context.Context, tm time.Duration) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(tm): + return nil + } +} diff --git a/command/ota/massupload.go b/command/ota/massupload.go index 5f316f20..3cef9909 100644 --- a/command/ota/massupload.go +++ b/command/ota/massupload.go @@ -18,6 +18,7 @@ package ota import ( + "context" "errors" "fmt" "io/ioutil" @@ -51,7 +52,7 @@ type Result struct { // MassUpload command is used to mass upload a firmware OTA, // on devices of Arduino IoT Cloud. -func MassUpload(params *MassUploadParams, cred *config.Credentials) ([]Result, error) { +func MassUpload(ctx context.Context, params *MassUploadParams, cred *config.Credentials) ([]Result, error) { if params.DeviceIDs == nil && params.Tags == nil { return nil, errors.New("provide either DeviceIDs or Tags") } else if params.DeviceIDs != nil && params.Tags != nil { @@ -77,12 +78,12 @@ func MassUpload(params *MassUploadParams, cred *config.Credentials) ([]Result, e } // Prepare the list of device-ids to update - d, err := idsGivenTags(iotClient, params.Tags) + d, err := idsGivenTags(ctx, iotClient, params.Tags) if err != nil { return nil, err } d = append(params.DeviceIDs, d...) - valid, invalid, err := validateDevices(iotClient, d, params.FQBN) + valid, invalid, err := validateDevices(ctx, iotClient, d, params.FQBN) if err != nil { return nil, fmt.Errorf("failed to validate devices: %w", err) } @@ -95,20 +96,20 @@ func MassUpload(params *MassUploadParams, cred *config.Credentials) ([]Result, e expiration = otaDeferredExpirationMins } - res := run(iotClient, valid, otaFile, expiration) + res := run(ctx, iotClient, valid, otaFile, expiration) res = append(res, invalid...) return res, nil } type deviceLister interface { - DeviceList(tags map[string]string) ([]iotclient.ArduinoDevicev2, error) + DeviceList(ctx context.Context, tags map[string]string) ([]iotclient.ArduinoDevicev2, error) } -func idsGivenTags(lister deviceLister, tags map[string]string) ([]string, error) { +func idsGivenTags(ctx context.Context, lister deviceLister, tags map[string]string) ([]string, error) { if tags == nil { return nil, nil } - devs, err := lister.DeviceList(tags) + devs, err := lister.DeviceList(ctx, tags) if err != nil { return nil, fmt.Errorf("%s: %w", "cannot retrieve devices from cloud", err) } @@ -119,8 +120,8 @@ func idsGivenTags(lister deviceLister, tags map[string]string) ([]string, error) return devices, nil } -func validateDevices(lister deviceLister, ids []string, fqbn string) (valid []string, invalid []Result, err error) { - devs, err := lister.DeviceList(nil) +func validateDevices(ctx context.Context, lister deviceLister, ids []string, fqbn string) (valid []string, invalid []Result, err error) { + devs, err := lister.DeviceList(ctx, nil) if err != nil { return nil, nil, fmt.Errorf("%s: %w", "cannot retrieve devices from cloud", err) } @@ -151,10 +152,10 @@ func validateDevices(lister deviceLister, ids []string, fqbn string) (valid []st } type otaUploader interface { - DeviceOTA(id string, file *os.File, expireMins int) error + DeviceOTA(ctx context.Context, id string, file *os.File, expireMins int) error } -func run(uploader otaUploader, ids []string, otaFile string, expiration int) []Result { +func run(ctx context.Context, uploader otaUploader, ids []string, otaFile string, expiration int) []Result { type job struct { id string file *os.File @@ -179,7 +180,7 @@ func run(uploader otaUploader, ids []string, otaFile string, expiration int) []R for i := 0; i < numConcurrentUploads; i++ { go func() { for job := range jobs { - err := uploader.DeviceOTA(job.id, job.file, expiration) + err := uploader.DeviceOTA(ctx, job.id, job.file, expiration) resCh <- Result{ID: job.id, Err: err} } }() diff --git a/command/ota/massupload_test.go b/command/ota/massupload_test.go index fc83187e..f3cec3c3 100644 --- a/command/ota/massupload_test.go +++ b/command/ota/massupload_test.go @@ -1,6 +1,7 @@ package ota import ( + "context" "errors" "os" "strings" @@ -12,11 +13,11 @@ import ( const testFilename = "testdata/empty.bin" type deviceUploaderTest struct { - deviceOTA func(id string, file *os.File, expireMins int) error + deviceOTA func(ctx context.Context, id string, file *os.File, expireMins int) error } -func (d *deviceUploaderTest) DeviceOTA(id string, file *os.File, expireMins int) error { - return d.deviceOTA(id, file, expireMins) +func (d *deviceUploaderTest) DeviceOTA(ctx context.Context, id string, file *os.File, expireMins int) error { + return d.deviceOTA(ctx, id, file, expireMins) } func TestRun(t *testing.T) { @@ -30,7 +31,7 @@ func TestRun(t *testing.T) { okID3 = okPrefix + "-dac4-4a6a-80a4-698062fe2af5" ) mockClient := &deviceUploaderTest{ - deviceOTA: func(id string, file *os.File, expireMins int) error { + deviceOTA: func(ctx context.Context, id string, file *os.File, expireMins int) error { if strings.Split(id, "-")[0] == failPrefix { return errors.New("err") } @@ -39,7 +40,7 @@ func TestRun(t *testing.T) { } devs := []string{okID1, failID1, okID2, failID2, okID3} - res := run(mockClient, devs, testFilename, 0) + res := run(context.TODO(), mockClient, devs, testFilename, 0) if len(res) != len(devs) { t.Errorf("expected %d results, got %d", len(devs), len(res)) } @@ -59,7 +60,7 @@ type deviceListerTest struct { list []iotclient.ArduinoDevicev2 } -func (d *deviceListerTest) DeviceList(tags map[string]string) ([]iotclient.ArduinoDevicev2, error) { +func (d *deviceListerTest) DeviceList(ctx context.Context, tags map[string]string) ([]iotclient.ArduinoDevicev2, error) { return d.list, nil } @@ -88,7 +89,7 @@ func TestValidateDevices(t *testing.T) { idCorrect2, idNotValid, } - v, i, err := validateDevices(&mockDeviceList, ids, correctFQBN) + v, i, err := validateDevices(context.TODO(), &mockDeviceList, ids, correctFQBN) if err != nil { t.Errorf("unexpected error: %s", err.Error()) } diff --git a/command/ota/upload.go b/command/ota/upload.go index 7595ad3a..a98b6e44 100644 --- a/command/ota/upload.go +++ b/command/ota/upload.go @@ -18,6 +18,7 @@ package ota import ( + "context" "fmt" "io/ioutil" "os" @@ -44,13 +45,13 @@ type UploadParams struct { // Upload command is used to upload a firmware OTA, // on a device of Arduino IoT Cloud. -func Upload(params *UploadParams, cred *config.Credentials) error { +func Upload(ctx context.Context, params *UploadParams, cred *config.Credentials) error { iotClient, err := iot.NewClient(cred) if err != nil { return err } - dev, err := iotClient.DeviceShow(params.DeviceID) + dev, err := iotClient.DeviceShow(ctx, params.DeviceID) if err != nil { return err } @@ -78,7 +79,7 @@ func Upload(params *UploadParams, cred *config.Credentials) error { expiration = otaDeferredExpirationMins } - err = iotClient.DeviceOTA(params.DeviceID, file, expiration) + err = iotClient.DeviceOTA(ctx, params.DeviceID, file, expiration) if err != nil { return err } diff --git a/command/tag/create.go b/command/tag/create.go index fc6b9025..7bd8ff1a 100644 --- a/command/tag/create.go +++ b/command/tag/create.go @@ -18,6 +18,7 @@ package tag import ( + "context" "errors" "github.com/arduino/arduino-cloud-cli/config" @@ -34,7 +35,7 @@ type CreateTagsParams struct { // CreateTags allows to create or overwrite tags // on a resource of Arduino IoT Cloud. -func CreateTags(params *CreateTagsParams, cred *config.Credentials) error { +func CreateTags(ctx context.Context, params *CreateTagsParams, cred *config.Credentials) error { iotClient, err := iot.NewClient(cred) if err != nil { return err @@ -42,9 +43,9 @@ func CreateTags(params *CreateTagsParams, cred *config.Credentials) error { switch params.Resource { case Thing: - err = iotClient.ThingTagsCreate(params.ID, params.Tags) + err = iotClient.ThingTagsCreate(ctx, params.ID, params.Tags) case Device: - err = iotClient.DeviceTagsCreate(params.ID, params.Tags) + err = iotClient.DeviceTagsCreate(ctx, params.ID, params.Tags) default: err = errors.New("passed Resource parameter is not valid") } diff --git a/command/tag/delete.go b/command/tag/delete.go index 779db943..ee02854c 100644 --- a/command/tag/delete.go +++ b/command/tag/delete.go @@ -18,6 +18,7 @@ package tag import ( + "context" "errors" "github.com/arduino/arduino-cloud-cli/config" @@ -34,7 +35,7 @@ type DeleteTagsParams struct { // DeleteTags command is used to delete tags of a device // from Arduino IoT Cloud. -func DeleteTags(params *DeleteTagsParams, cred *config.Credentials) error { +func DeleteTags(ctx context.Context, params *DeleteTagsParams, cred *config.Credentials) error { iotClient, err := iot.NewClient(cred) if err != nil { return err @@ -42,9 +43,9 @@ func DeleteTags(params *DeleteTagsParams, cred *config.Credentials) error { switch params.Resource { case Thing: - err = iotClient.ThingTagsDelete(params.ID, params.Keys) + err = iotClient.ThingTagsDelete(ctx, params.ID, params.Keys) case Device: - err = iotClient.DeviceTagsDelete(params.ID, params.Keys) + err = iotClient.DeviceTagsDelete(ctx, params.ID, params.Keys) default: err = errors.New("passed Resource parameter is not valid") } diff --git a/command/thing/bind.go b/command/thing/bind.go index 343c6564..c21f0961 100644 --- a/command/thing/bind.go +++ b/command/thing/bind.go @@ -18,7 +18,10 @@ package thing import ( + "context" + "github.com/arduino/arduino-cloud-cli/config" + "github.com/arduino/arduino-cloud-cli/internal/iot" iotclient "github.com/arduino/iot-client-go" ) @@ -32,7 +35,7 @@ type BindParams struct { // Bind command is used to bind a thing to a device // on Arduino IoT Cloud. -func Bind(params *BindParams, cred *config.Credentials) error { +func Bind(ctx context.Context, params *BindParams, cred *config.Credentials) error { iotClient, err := iot.NewClient(cred) if err != nil { return err @@ -42,7 +45,7 @@ func Bind(params *BindParams, cred *config.Credentials) error { DeviceId: params.DeviceID, } - err = iotClient.ThingUpdate(params.ID, thing, true) + err = iotClient.ThingUpdate(ctx, params.ID, thing, true) if err != nil { return err } diff --git a/command/thing/clone.go b/command/thing/clone.go index c960b3b3..b22e6ae9 100644 --- a/command/thing/clone.go +++ b/command/thing/clone.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "github.com/arduino/arduino-cloud-cli/config" @@ -32,20 +33,20 @@ type CloneParams struct { } // Clone allows to create a new thing from an already existing one. -func Clone(params *CloneParams, cred *config.Credentials) (*ThingInfo, error) { +func Clone(ctx context.Context, params *CloneParams, cred *config.Credentials) (*ThingInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - thing, err := retrieve(iotClient, params.CloneID) + thing, err := retrieve(ctx, iotClient, params.CloneID) if err != nil { return nil, err } thing.Name = params.Name force := true - newThing, err := iotClient.ThingCreate(thing, force) + newThing, err := iotClient.ThingCreate(ctx, thing, force) if err != nil { return nil, err } @@ -58,11 +59,11 @@ func Clone(params *CloneParams, cred *config.Credentials) (*ThingInfo, error) { } type thingFetcher interface { - ThingShow(id string) (*iotclient.ArduinoThing, error) + ThingShow(ctx context.Context, id string) (*iotclient.ArduinoThing, error) } -func retrieve(fetcher thingFetcher, thingID string) (*iotclient.ThingCreate, error) { - clone, err := fetcher.ThingShow(thingID) +func retrieve(ctx context.Context, fetcher thingFetcher, thingID string) (*iotclient.ThingCreate, error) { + clone, err := fetcher.ThingShow(ctx, thingID) if err != nil { return nil, fmt.Errorf("%s: %w", "retrieving the thing to be cloned", err) } diff --git a/command/thing/create.go b/command/thing/create.go index 88588618..a0e447be 100644 --- a/command/thing/create.go +++ b/command/thing/create.go @@ -18,6 +18,7 @@ package thing import ( + "context" "errors" "fmt" @@ -33,7 +34,7 @@ type CreateParams struct { } // Create allows to create a new thing. -func Create(params *CreateParams, cred *config.Credentials) (*ThingInfo, error) { +func Create(ctx context.Context, params *CreateParams, cred *config.Credentials) (*ThingInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err @@ -54,7 +55,7 @@ func Create(params *CreateParams, cred *config.Credentials) (*ThingInfo, error) } force := true - newThing, err := iotClient.ThingCreate(thing, force) + newThing, err := iotClient.ThingCreate(ctx, thing, force) if err != nil { return nil, err } diff --git a/command/thing/delete.go b/command/thing/delete.go index f71a3e12..91af09a6 100644 --- a/command/thing/delete.go +++ b/command/thing/delete.go @@ -18,6 +18,7 @@ package thing import ( + "context" "errors" "github.com/arduino/arduino-cloud-cli/config" @@ -36,7 +37,7 @@ type DeleteParams struct { // Delete command is used to delete a thing // from Arduino IoT Cloud. -func Delete(params *DeleteParams, cred *config.Credentials) error { +func Delete(ctx context.Context, params *DeleteParams, cred *config.Credentials) error { if params.ID == nil && params.Tags == nil { return errors.New("provide either ID or Tags") } else if params.ID != nil && params.Tags != nil { @@ -53,7 +54,7 @@ func Delete(params *DeleteParams, cred *config.Credentials) error { thingIDs = append(thingIDs, *params.ID) } if params.Tags != nil { - th, err := iotClient.ThingList(nil, nil, false, params.Tags) + th, err := iotClient.ThingList(ctx, nil, nil, false, params.Tags) if err != nil { return err } @@ -63,7 +64,7 @@ func Delete(params *DeleteParams, cred *config.Credentials) error { } for _, id := range thingIDs { - err = iotClient.ThingDelete(id) + err = iotClient.ThingDelete(ctx, id) if err != nil { return err } diff --git a/command/thing/extract.go b/command/thing/extract.go index 5541d196..dc4b84d8 100644 --- a/command/thing/extract.go +++ b/command/thing/extract.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "github.com/arduino/arduino-cloud-cli/config" @@ -33,13 +34,13 @@ type ExtractParams struct { // Extract command is used to extract a thing template // from a thing on Arduino IoT Cloud. -func Extract(params *ExtractParams, cred *config.Credentials) (map[string]interface{}, error) { +func Extract(ctx context.Context, params *ExtractParams, cred *config.Credentials) (map[string]interface{}, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - thing, err := iotClient.ThingShow(params.ID) + thing, err := iotClient.ThingShow(ctx, params.ID) if err != nil { err = fmt.Errorf("%s: %w", "cannot extract thing: ", err) return nil, err diff --git a/command/thing/list.go b/command/thing/list.go index 1d88f6d6..b6c3ae43 100644 --- a/command/thing/list.go +++ b/command/thing/list.go @@ -18,6 +18,7 @@ package thing import ( + "context" "fmt" "github.com/arduino/arduino-cloud-cli/config" @@ -35,13 +36,13 @@ type ListParams struct { // List command is used to list // the things of Arduino IoT Cloud. -func List(params *ListParams, cred *config.Credentials) ([]ThingInfo, error) { +func List(ctx context.Context, params *ListParams, cred *config.Credentials) ([]ThingInfo, error) { iotClient, err := iot.NewClient(cred) if err != nil { return nil, err } - foundThings, err := iotClient.ThingList(params.IDs, params.DeviceID, params.Variables, params.Tags) + foundThings, err := iotClient.ThingList(ctx, params.IDs, params.DeviceID, params.Variables, params.Tags) if err != nil { return nil, err } diff --git a/go.mod b/go.mod index cf461529..e6cc7a8c 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/spf13/viper v1.10.1 github.com/stretchr/testify v1.8.0 github.com/xanzy/ssh-agent v0.3.1 // indirect + go.bug.st/cleanup v1.0.0 go.bug.st/serial v1.3.3 go.bug.st/serial.v1 v0.0.0-20191202182710-24a6610f0541 // indirect golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 diff --git a/internal/binary/download.go b/internal/binary/download.go index a00f4ef3..8dc0cfa9 100644 --- a/internal/binary/download.go +++ b/internal/binary/download.go @@ -19,6 +19,7 @@ package binary import ( "bytes" + "context" "crypto" "encoding/hex" "errors" @@ -32,8 +33,8 @@ import ( ) // Download a binary file contained in the binary index. -func Download(bin *IndexBin) ([]byte, error) { - b, err := download(bin.URL) +func Download(ctx context.Context, bin *IndexBin) ([]byte, error) { + b, err := download(ctx, bin.URL) if err != nil { return nil, fmt.Errorf("cannot download binary at %s: %w", bin.URL, err) } @@ -54,11 +55,11 @@ func Download(bin *IndexBin) ([]byte, error) { return b, nil } -func download(url string) ([]byte, error) { +func download(ctx context.Context, url string) ([]byte, error) { cl := http.Client{ Timeout: time.Second * 3, } - req, err := http.NewRequest(http.MethodGet, url, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { err = fmt.Errorf("%s: %w", "request url", err) return nil, err diff --git a/internal/binary/index.go b/internal/binary/index.go index 0a4445eb..f3c2b543 100644 --- a/internal/binary/index.go +++ b/internal/binary/index.go @@ -19,6 +19,7 @@ package binary import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -56,8 +57,8 @@ type IndexBin struct { // LoadIndex downloads and verifies the index of binaries // contained in 'cloud-downloads'. -func LoadIndex() (*Index, error) { - indexGZ, err := download(IndexGZURL) +func LoadIndex(ctx context.Context) (*Index, error) { + indexGZ, err := download(ctx, IndexGZURL) if err != nil { return nil, fmt.Errorf("cannot download index: %w", err) } @@ -71,7 +72,7 @@ func LoadIndex() (*Index, error) { return nil, fmt.Errorf("cannot read downloaded index: %w", err) } - sig, err := download(IndexSigURL) + sig, err := download(ctx, IndexSigURL) if err != nil { return nil, fmt.Errorf("cannot download index signature: %w", err) } diff --git a/internal/iot/client.go b/internal/iot/client.go index 7145000b..b5beab33 100644 --- a/internal/iot/client.go +++ b/internal/iot/client.go @@ -25,12 +25,13 @@ import ( "github.com/antihax/optional" "github.com/arduino/arduino-cloud-cli/config" iotclient "github.com/arduino/iot-client-go" + "golang.org/x/oauth2" ) // Client can perform actions on Arduino IoT Cloud. type Client struct { - ctx context.Context - api *iotclient.APIClient + api *iotclient.APIClient + token oauth2.TokenSource } // NewClient returns a new client implementing the Client interface. @@ -47,14 +48,19 @@ func NewClient(cred *config.Credentials) (*Client, error) { // DeviceCreate allows to create a new device on Arduino IoT Cloud. // It returns the newly created device, and an error. -func (cl *Client) DeviceCreate(fqbn, name, serial, dType string) (*iotclient.ArduinoDevicev2, error) { +func (cl *Client) DeviceCreate(ctx context.Context, fqbn, name, serial, dType string) (*iotclient.ArduinoDevicev2, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + payload := iotclient.CreateDevicesV2Payload{ Fqbn: fqbn, Name: name, Serial: serial, Type: dType, } - dev, _, err := cl.api.DevicesV2Api.DevicesV2Create(cl.ctx, payload, nil) + dev, _, err := cl.api.DevicesV2Api.DevicesV2Create(ctx, payload, nil) if err != nil { err = fmt.Errorf("creating device, %w", errorDetail(err)) return nil, err @@ -64,7 +70,12 @@ func (cl *Client) DeviceCreate(fqbn, name, serial, dType string) (*iotclient.Ard // DeviceLoraCreate allows to create a new LoRa device on Arduino IoT Cloud. // It returns the LoRa information about the newly created device, and an error. -func (cl *Client) DeviceLoraCreate(name, serial, devType, eui, freq string) (*iotclient.ArduinoLoradevicev1, error) { +func (cl *Client) DeviceLoraCreate(ctx context.Context, name, serial, devType, eui, freq string) (*iotclient.ArduinoLoradevicev1, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + payload := iotclient.CreateLoraDevicesV1Payload{ App: "defaultApp", Eui: eui, @@ -74,7 +85,7 @@ func (cl *Client) DeviceLoraCreate(name, serial, devType, eui, freq string) (*io Type: devType, UserId: "me", } - dev, _, err := cl.api.LoraDevicesV1Api.LoraDevicesV1Create(cl.ctx, payload) + dev, _, err := cl.api.LoraDevicesV1Api.LoraDevicesV1Create(ctx, payload) if err != nil { err = fmt.Errorf("creating lora device: %w", errorDetail(err)) return nil, err @@ -84,17 +95,23 @@ func (cl *Client) DeviceLoraCreate(name, serial, devType, eui, freq string) (*io // DevicePassSet sets the device password to the one suggested by Arduino IoT Cloud. // Returns the set password. -func (cl *Client) DevicePassSet(id string) (*iotclient.ArduinoDevicev2Pass, error) { +func (cl *Client) DevicePassSet(ctx context.Context, id string) (*iotclient.ArduinoDevicev2Pass, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + // Fetch suggested password opts := &iotclient.DevicesV2PassGetOpts{SuggestedPassword: optional.NewBool(true)} - pass, _, err := cl.api.DevicesV2PassApi.DevicesV2PassGet(cl.ctx, id, opts) + pass, _, err := cl.api.DevicesV2PassApi.DevicesV2PassGet(ctx, id, opts) if err != nil { err = fmt.Errorf("fetching device suggested password: %w", errorDetail(err)) return nil, err } + // Set password to the suggested one p := iotclient.Devicev2Pass{Password: pass.SuggestedPassword} - pass, _, err = cl.api.DevicesV2PassApi.DevicesV2PassSet(cl.ctx, id, p) + pass, _, err = cl.api.DevicesV2PassApi.DevicesV2PassSet(ctx, id, p) if err != nil { err = fmt.Errorf("setting device password: %w", errorDetail(err)) return nil, err @@ -104,8 +121,13 @@ func (cl *Client) DevicePassSet(id string) (*iotclient.ArduinoDevicev2Pass, erro // DeviceDelete deletes the device corresponding to the passed ID // from Arduino IoT Cloud. -func (cl *Client) DeviceDelete(id string) error { - _, err := cl.api.DevicesV2Api.DevicesV2Delete(cl.ctx, id, nil) +func (cl *Client) DeviceDelete(ctx context.Context, id string) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + + _, err = cl.api.DevicesV2Api.DevicesV2Delete(ctx, id, nil) if err != nil { err = fmt.Errorf("deleting device: %w", errorDetail(err)) return err @@ -115,7 +137,12 @@ func (cl *Client) DeviceDelete(id string) error { // DeviceList retrieves and returns a list of all Arduino IoT Cloud devices // belonging to the user performing the request. -func (cl *Client) DeviceList(tags map[string]string) ([]iotclient.ArduinoDevicev2, error) { +func (cl *Client) DeviceList(ctx context.Context, tags map[string]string) ([]iotclient.ArduinoDevicev2, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + opts := &iotclient.DevicesV2ListOpts{} if tags != nil { t := make([]string, 0, len(tags)) @@ -126,7 +153,7 @@ func (cl *Client) DeviceList(tags map[string]string) ([]iotclient.ArduinoDevicev opts.Tags = optional.NewInterface(t) } - devices, _, err := cl.api.DevicesV2Api.DevicesV2List(cl.ctx, opts) + devices, _, err := cl.api.DevicesV2Api.DevicesV2List(ctx, opts) if err != nil { err = fmt.Errorf("listing devices: %w", errorDetail(err)) return nil, err @@ -136,8 +163,13 @@ func (cl *Client) DeviceList(tags map[string]string) ([]iotclient.ArduinoDevicev // DeviceShow allows to retrieve a specific device, given its id, // from Arduino IoT Cloud. -func (cl *Client) DeviceShow(id string) (*iotclient.ArduinoDevicev2, error) { - dev, _, err := cl.api.DevicesV2Api.DevicesV2Show(cl.ctx, id, nil) +func (cl *Client) DeviceShow(ctx context.Context, id string) (*iotclient.ArduinoDevicev2, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + + dev, _, err := cl.api.DevicesV2Api.DevicesV2Show(ctx, id, nil) if err != nil { err = fmt.Errorf("retrieving device, %w", errorDetail(err)) return nil, err @@ -147,12 +179,17 @@ func (cl *Client) DeviceShow(id string) (*iotclient.ArduinoDevicev2, error) { // DeviceOTA performs an OTA upload request to Arduino IoT Cloud, passing // the ID of the device to be updated and the actual file containing the OTA firmware. -func (cl *Client) DeviceOTA(id string, file *os.File, expireMins int) error { +func (cl *Client) DeviceOTA(ctx context.Context, id string, file *os.File, expireMins int) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + opt := &iotclient.DevicesV2OtaUploadOpts{ ExpireInMins: optional.NewInt32(int32(expireMins)), Async: optional.NewBool(true), } - _, err := cl.api.DevicesV2OtaApi.DevicesV2OtaUpload(cl.ctx, id, file, opt) + _, err = cl.api.DevicesV2OtaApi.DevicesV2OtaUpload(ctx, id, file, opt) if err != nil { err = fmt.Errorf("uploading device ota: %w", errorDetail(err)) return err @@ -161,10 +198,15 @@ func (cl *Client) DeviceOTA(id string, file *os.File, expireMins int) error { } // DeviceTagsCreate allows to create or overwrite tags on a device of Arduino IoT Cloud. -func (cl *Client) DeviceTagsCreate(id string, tags map[string]string) error { +func (cl *Client) DeviceTagsCreate(ctx context.Context, id string, tags map[string]string) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + for key, val := range tags { t := iotclient.Tag{Key: key, Value: val} - _, err := cl.api.DevicesV2TagsApi.DevicesV2TagsUpsert(cl.ctx, id, t) + _, err := cl.api.DevicesV2TagsApi.DevicesV2TagsUpsert(ctx, id, t) if err != nil { err = fmt.Errorf("cannot create tag %s: %w", key, errorDetail(err)) return err @@ -175,9 +217,14 @@ func (cl *Client) DeviceTagsCreate(id string, tags map[string]string) error { // DeviceTagsDelete deletes the tags of a device of Arduino IoT Cloud, // given the device id and the keys of the tags. -func (cl *Client) DeviceTagsDelete(id string, keys []string) error { +func (cl *Client) DeviceTagsDelete(ctx context.Context, id string, keys []string) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + for _, key := range keys { - _, err := cl.api.DevicesV2TagsApi.DevicesV2TagsDelete(cl.ctx, id, key) + _, err := cl.api.DevicesV2TagsApi.DevicesV2TagsDelete(ctx, id, key) if err != nil { err = fmt.Errorf("cannot delete tag %s: %w", key, errorDetail(err)) return err @@ -188,8 +235,13 @@ func (cl *Client) DeviceTagsDelete(id string, keys []string) error { // LoraFrequencyPlansList retrieves and returns the list of all supported // LoRa frequency plans. -func (cl *Client) LoraFrequencyPlansList() ([]iotclient.ArduinoLorafreqplanv1, error) { - freqs, _, err := cl.api.LoraFreqPlanV1Api.LoraFreqPlanV1List(cl.ctx) +func (cl *Client) LoraFrequencyPlansList(ctx context.Context) ([]iotclient.ArduinoLorafreqplanv1, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + + freqs, _, err := cl.api.LoraFreqPlanV1Api.LoraFreqPlanV1List(ctx) if err != nil { err = fmt.Errorf("listing lora frequency plans: %w", errorDetail(err)) return nil, err @@ -199,14 +251,19 @@ func (cl *Client) LoraFrequencyPlansList() ([]iotclient.ArduinoLorafreqplanv1, e // CertificateCreate allows to upload a certificate on Arduino IoT Cloud. // It returns the certificate parameters populated by the cloud. -func (cl *Client) CertificateCreate(id, csr string) (*iotclient.ArduinoCompressedv2, error) { +func (cl *Client) CertificateCreate(ctx context.Context, id, csr string) (*iotclient.ArduinoCompressedv2, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + cert := iotclient.CreateDevicesV2CertsPayload{ Ca: "Arduino", Csr: csr, Enabled: true, } - newCert, _, err := cl.api.DevicesV2CertsApi.DevicesV2CertsCreate(cl.ctx, id, cert) + newCert, _, err := cl.api.DevicesV2CertsApi.DevicesV2CertsCreate(ctx, id, cert) if err != nil { err = fmt.Errorf("creating certificate, %w", errorDetail(err)) return nil, err @@ -216,9 +273,14 @@ func (cl *Client) CertificateCreate(id, csr string) (*iotclient.ArduinoCompresse } // ThingCreate adds a new thing on Arduino IoT Cloud. -func (cl *Client) ThingCreate(thing *iotclient.ThingCreate, force bool) (*iotclient.ArduinoThing, error) { +func (cl *Client) ThingCreate(ctx context.Context, thing *iotclient.ThingCreate, force bool) (*iotclient.ArduinoThing, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + opt := &iotclient.ThingsV2CreateOpts{Force: optional.NewBool(force)} - newThing, _, err := cl.api.ThingsV2Api.ThingsV2Create(cl.ctx, *thing, opt) + newThing, _, err := cl.api.ThingsV2Api.ThingsV2Create(ctx, *thing, opt) if err != nil { return nil, fmt.Errorf("%s: %w", "adding new thing", errorDetail(err)) } @@ -226,9 +288,14 @@ func (cl *Client) ThingCreate(thing *iotclient.ThingCreate, force bool) (*iotcli } // ThingUpdate updates a thing on Arduino IoT Cloud. -func (cl *Client) ThingUpdate(id string, thing *iotclient.ThingUpdate, force bool) error { +func (cl *Client) ThingUpdate(ctx context.Context, id string, thing *iotclient.ThingUpdate, force bool) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + opt := &iotclient.ThingsV2UpdateOpts{Force: optional.NewBool(force)} - _, _, err := cl.api.ThingsV2Api.ThingsV2Update(cl.ctx, id, *thing, opt) + _, _, err = cl.api.ThingsV2Api.ThingsV2Update(ctx, id, *thing, opt) if err != nil { return fmt.Errorf("%s: %v", "updating thing", errorDetail(err)) } @@ -236,8 +303,13 @@ func (cl *Client) ThingUpdate(id string, thing *iotclient.ThingUpdate, force boo } // ThingDelete deletes a thing from Arduino IoT Cloud. -func (cl *Client) ThingDelete(id string) error { - _, err := cl.api.ThingsV2Api.ThingsV2Delete(cl.ctx, id, nil) +func (cl *Client) ThingDelete(ctx context.Context, id string) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + + _, err = cl.api.ThingsV2Api.ThingsV2Delete(ctx, id, nil) if err != nil { err = fmt.Errorf("deleting thing: %w", errorDetail(err)) return err @@ -247,8 +319,13 @@ func (cl *Client) ThingDelete(id string) error { // ThingShow allows to retrieve a specific thing, given its id, // from Arduino IoT Cloud. -func (cl *Client) ThingShow(id string) (*iotclient.ArduinoThing, error) { - thing, _, err := cl.api.ThingsV2Api.ThingsV2Show(cl.ctx, id, nil) +func (cl *Client) ThingShow(ctx context.Context, id string) (*iotclient.ArduinoThing, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + + thing, _, err := cl.api.ThingsV2Api.ThingsV2Show(ctx, id, nil) if err != nil { err = fmt.Errorf("retrieving thing, %w", errorDetail(err)) return nil, err @@ -257,7 +334,12 @@ func (cl *Client) ThingShow(id string) (*iotclient.ArduinoThing, error) { } // ThingList returns a list of things on Arduino IoT Cloud. -func (cl *Client) ThingList(ids []string, device *string, props bool, tags map[string]string) ([]iotclient.ArduinoThing, error) { +func (cl *Client) ThingList(ctx context.Context, ids []string, device *string, props bool, tags map[string]string) ([]iotclient.ArduinoThing, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + opts := &iotclient.ThingsV2ListOpts{} opts.ShowProperties = optional.NewBool(props) @@ -278,7 +360,7 @@ func (cl *Client) ThingList(ids []string, device *string, props bool, tags map[s opts.Tags = optional.NewInterface(t) } - things, _, err := cl.api.ThingsV2Api.ThingsV2List(cl.ctx, opts) + things, _, err := cl.api.ThingsV2Api.ThingsV2List(ctx, opts) if err != nil { err = fmt.Errorf("retrieving things, %w", errorDetail(err)) return nil, err @@ -287,10 +369,15 @@ func (cl *Client) ThingList(ids []string, device *string, props bool, tags map[s } // ThingTagsCreate allows to create or overwrite tags on a thing of Arduino IoT Cloud. -func (cl *Client) ThingTagsCreate(id string, tags map[string]string) error { +func (cl *Client) ThingTagsCreate(ctx context.Context, id string, tags map[string]string) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + for key, val := range tags { t := iotclient.Tag{Key: key, Value: val} - _, err := cl.api.ThingsV2TagsApi.ThingsV2TagsUpsert(cl.ctx, id, t) + _, err := cl.api.ThingsV2TagsApi.ThingsV2TagsUpsert(ctx, id, t) if err != nil { err = fmt.Errorf("cannot create tag %s: %w", key, errorDetail(err)) return err @@ -301,9 +388,14 @@ func (cl *Client) ThingTagsCreate(id string, tags map[string]string) error { // ThingTagsDelete deletes the tags of a thing of Arduino IoT Cloud, // given the thing id and the keys of the tags. -func (cl *Client) ThingTagsDelete(id string, keys []string) error { +func (cl *Client) ThingTagsDelete(ctx context.Context, id string, keys []string) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + for _, key := range keys { - _, err := cl.api.ThingsV2TagsApi.ThingsV2TagsDelete(cl.ctx, id, key) + _, err := cl.api.ThingsV2TagsApi.ThingsV2TagsDelete(ctx, id, key) if err != nil { err = fmt.Errorf("cannot delete tag %s: %w", key, errorDetail(err)) return err @@ -313,8 +405,13 @@ func (cl *Client) ThingTagsDelete(id string, keys []string) error { } // DashboardCreate adds a new dashboard on Arduino IoT Cloud. -func (cl *Client) DashboardCreate(dashboard *iotclient.Dashboardv2) (*iotclient.ArduinoDashboardv2, error) { - newDashboard, _, err := cl.api.DashboardsV2Api.DashboardsV2Create(cl.ctx, *dashboard, nil) +func (cl *Client) DashboardCreate(ctx context.Context, dashboard *iotclient.Dashboardv2) (*iotclient.ArduinoDashboardv2, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + + newDashboard, _, err := cl.api.DashboardsV2Api.DashboardsV2Create(ctx, *dashboard, nil) if err != nil { return nil, fmt.Errorf("%s: %w", "adding new dashboard", errorDetail(err)) } @@ -323,8 +420,13 @@ func (cl *Client) DashboardCreate(dashboard *iotclient.Dashboardv2) (*iotclient. // DashboardShow allows to retrieve a specific dashboard, given its id, // from Arduino IoT Cloud. -func (cl *Client) DashboardShow(id string) (*iotclient.ArduinoDashboardv2, error) { - dashboard, _, err := cl.api.DashboardsV2Api.DashboardsV2Show(cl.ctx, id, nil) +func (cl *Client) DashboardShow(ctx context.Context, id string) (*iotclient.ArduinoDashboardv2, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + + dashboard, _, err := cl.api.DashboardsV2Api.DashboardsV2Show(ctx, id, nil) if err != nil { err = fmt.Errorf("retrieving dashboard, %w", errorDetail(err)) return nil, err @@ -333,8 +435,13 @@ func (cl *Client) DashboardShow(id string) (*iotclient.ArduinoDashboardv2, error } // DashboardList returns a list of dashboards on Arduino IoT Cloud. -func (cl *Client) DashboardList() ([]iotclient.ArduinoDashboardv2, error) { - dashboards, _, err := cl.api.DashboardsV2Api.DashboardsV2List(cl.ctx, nil) +func (cl *Client) DashboardList(ctx context.Context) ([]iotclient.ArduinoDashboardv2, error) { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return nil, err + } + + dashboards, _, err := cl.api.DashboardsV2Api.DashboardsV2List(ctx, nil) if err != nil { err = fmt.Errorf("listing dashboards: %w", errorDetail(err)) return nil, err @@ -343,8 +450,13 @@ func (cl *Client) DashboardList() ([]iotclient.ArduinoDashboardv2, error) { } // DashboardDelete deletes a dashboard from Arduino IoT Cloud. -func (cl *Client) DashboardDelete(id string) error { - _, err := cl.api.DashboardsV2Api.DashboardsV2Delete(cl.ctx, id, nil) +func (cl *Client) DashboardDelete(ctx context.Context, id string) error { + ctx, err := ctxWithToken(ctx, cl.token) + if err != nil { + return err + } + + _, err = cl.api.DashboardsV2Api.DashboardsV2Delete(ctx, id, nil) if err != nil { err = fmt.Errorf("deleting dashboard: %w", errorDetail(err)) return err @@ -358,15 +470,8 @@ func (cl *Client) setup(client, secret, organization string) error { baseURL = url } - // Get the access token in exchange of client_id and client_secret - tok, err := token(client, secret, baseURL) - if err != nil { - err = fmt.Errorf("cannot retrieve token given client and secret: %w", err) - return err - } - - // We use the token to create a context that will be passed to any API call - cl.ctx = context.WithValue(context.Background(), iotclient.ContextAccessToken, tok.AccessToken) + // Configure a token source given the user's credentials. + cl.token = token(client, secret, baseURL) config := iotclient.NewConfiguration() if organization != "" { diff --git a/internal/iot/token.go b/internal/iot/token.go index dd369c04..2f4b1ead 100644 --- a/internal/iot/token.go +++ b/internal/iot/token.go @@ -19,19 +19,21 @@ package iot import ( "context" + "errors" "fmt" "net/url" "strings" + iotclient "github.com/arduino/iot-client-go" "golang.org/x/oauth2" cc "golang.org/x/oauth2/clientcredentials" ) -func token(client, secret, baseURL string) (*oauth2.Token, error) { - // We need to pass the additional "audience" var to request an access token +func token(client, secret, baseURL string) oauth2.TokenSource { + // We need to pass the additional "audience" var to request an access token. additionalValues := url.Values{} additionalValues.Add("audience", "https://api2.arduino.cc/iot") - // Set up OAuth2 configuration + // Set up OAuth2 configuration. config := cc.Config{ ClientID: client, ClientSecret: secret, @@ -39,10 +41,19 @@ func token(client, secret, baseURL string) (*oauth2.Token, error) { EndpointParams: additionalValues, } - // Get the access token in exchange of client_id and client_secret - t, err := config.Token(context.Background()) - if err != nil && strings.Contains(err.Error(), "401") { - return nil, fmt.Errorf("wrong credentials") + // Retrieve a token source that allows to retrieve tokens + // with an automatic refresh mechanism. + return config.TokenSource(context.Background()) +} + +func ctxWithToken(ctx context.Context, src oauth2.TokenSource) (context.Context, error) { + // Retrieve a valid token from the src. + tok, err := src.Token() + if err != nil { + if strings.Contains(err.Error(), "401") { + return nil, errors.New("wrong credentials") + } + return nil, fmt.Errorf("cannot retrieve a valid token: %w", err) } - return t, err + return context.WithValue(ctx, iotclient.ContextAccessToken, tok.AccessToken), nil } diff --git a/internal/serial/serial.go b/internal/serial/serial.go index 5d3b7dbb..abcb4951 100644 --- a/internal/serial/serial.go +++ b/internal/serial/serial.go @@ -19,6 +19,7 @@ package serial import ( "bytes" + "context" "encoding/binary" "errors" "fmt" @@ -60,7 +61,11 @@ func (s *Serial) Connect(address string) error { } // Send allows to send a provisioning command to a connected arduino device. -func (s *Serial) Send(cmd Command, payload []byte) error { +func (s *Serial) Send(ctx context.Context, cmd Command, payload []byte) error { + if err := ctx.Err(); err != nil { + return err + } + payload = append([]byte{byte(cmd)}, payload...) msg := encode(Cmd, payload) @@ -76,12 +81,11 @@ func (s *Serial) Send(cmd Command, payload []byte) error { // SendReceive allows to send a provisioning command to a connected arduino device. // Then, it waits for a response from the device and, if any, returns it. // If no response is received after 2 seconds, an error is returned. -func (s *Serial) SendReceive(cmd Command, payload []byte) ([]byte, error) { - err := s.Send(cmd, payload) - if err != nil { +func (s *Serial) SendReceive(ctx context.Context, cmd Command, payload []byte) ([]byte, error) { + if err := s.Send(ctx, cmd, payload); err != nil { return nil, err } - return s.receive() + return s.receive(ctx) } // Close should be used when the Serial connection isn't used anymore. @@ -96,7 +100,7 @@ func (s *Serial) Close() error { // TODO: consider refactoring using a more explicit procedure: // start := s.Read(buff, MsgStartLength) // payloadLen := s.Read(buff, payloadFieldLen) -func (s *Serial) receive() ([]byte, error) { +func (s *Serial) receive(ctx context.Context) ([]byte, error) { buff := make([]byte, 1000) var resp []byte @@ -105,6 +109,10 @@ func (s *Serial) receive() ([]byte, error) { // Wait to receive the entire packet that is long as the preamble (from msgStart to payload length field) // plus the actual payload length plus the length of the ending sequence. for received < (payloadLenField+payloadLenFieldLen)+payloadLen+len(msgEnd) { + if err := ctx.Err(); err != nil { + return nil, err + } + n, err := s.port.Read(buff) if err != nil { err = fmt.Errorf("%s: %w", "receiving from serial", err) diff --git a/internal/serial/serial_test.go b/internal/serial/serial_test.go index b5838a10..1c049393 100644 --- a/internal/serial/serial_test.go +++ b/internal/serial/serial_test.go @@ -19,6 +19,7 @@ package serial import ( "bytes" + "context" "testing" "github.com/arduino/arduino-cloud-cli/internal/serial/mocks" @@ -45,7 +46,7 @@ func TestSendReceive(t *testing.T) { mockPort.On("Write", mock.AnythingOfType("[]uint8")).Return(0, nil) mockPort.On("Read", mock.AnythingOfType("[]uint8")).Return(mockRead, nil) - res, err := mockSerial.SendReceive(BeginStorage, []byte{1, 2}) + res, err := mockSerial.SendReceive(context.TODO(), BeginStorage, []byte{1, 2}) if err != nil { t.Error(err) } @@ -64,7 +65,7 @@ func TestSend(t *testing.T) { cmd := SetDay want := []byte{msgStart[0], msgStart[1], 1, 0, 5, 10, 1, 2, 143, 124, msgEnd[0], msgEnd[1]} - err := mockSerial.Send(cmd, payload) + err := mockSerial.Send(context.TODO(), cmd, payload) if err != nil { t.Error(err) } diff --git a/internal/template/dashboard.go b/internal/template/dashboard.go index 82cf6124..aeace762 100644 --- a/internal/template/dashboard.go +++ b/internal/template/dashboard.go @@ -18,6 +18,7 @@ package template import ( + "context" "encoding/json" "fmt" @@ -63,13 +64,13 @@ func (v *variableTemplate) MarshalJSON() ([]byte, error) { // ThingFetcher wraps the method to fetch a thing given its id. type ThingFetcher interface { - ThingShow(id string) (*iotclient.ArduinoThing, error) + ThingShow(ctx context.Context, id string) (*iotclient.ArduinoThing, error) } // getVariableID returns the id of a variable, given its name and its thing id. // If the variable is not found, an error is returned. -func getVariableID(thingID string, variableName string, fetcher ThingFetcher) (string, error) { - thing, err := fetcher.ThingShow(thingID) +func getVariableID(ctx context.Context, thingID string, variableName string, fetcher ThingFetcher) (string, error) { + thing, err := fetcher.ThingShow(ctx, thingID) if err != nil { return "", fmt.Errorf("getting variables of thing %s: %w", thingID, err) } diff --git a/internal/template/load.go b/internal/template/load.go index 9a5c21b7..eca76933 100644 --- a/internal/template/load.go +++ b/internal/template/load.go @@ -18,6 +18,7 @@ package template import ( + "context" "encoding/json" "errors" "fmt" @@ -87,7 +88,7 @@ func LoadThing(file string) (*iotclient.ThingCreate, error) { // LoadDashboard loads a dashboard from a dashboard template file. // It applies the thing overrides specified by the override parameter. // It requires a ThingFetcher to retrieve the actual variable ids. -func LoadDashboard(file string, override map[string]string, thinger ThingFetcher) (*iotclient.Dashboardv2, error) { +func LoadDashboard(ctx context.Context, file string, override map[string]string, thinger ThingFetcher) (*iotclient.Dashboardv2, error) { template := dashboardTemplate{} err := loadTemplate(file, &template) if err != nil { @@ -113,7 +114,7 @@ func LoadDashboard(file string, override map[string]string, thinger ThingFetcher if id, ok := override[variable.ThingID]; ok { variable.ThingID = id } - variable.VariableID, err = getVariableID(variable.ThingID, variable.VariableName, thinger) + variable.VariableID, err = getVariableID(ctx, variable.ThingID, variable.VariableName, thinger) if err != nil { return nil, err } diff --git a/internal/template/load_test.go b/internal/template/load_test.go index 18a5f71e..1d7cecc6 100644 --- a/internal/template/load_test.go +++ b/internal/template/load_test.go @@ -18,6 +18,7 @@ package template import ( + "context" "testing" iotclient "github.com/arduino/iot-client-go" @@ -145,7 +146,7 @@ func TestLoadTemplate(t *testing.T) { type thingShowTest struct{} -func (t *thingShowTest) ThingShow(thingID string) (*iotclient.ArduinoThing, error) { +func (t *thingShowTest) ThingShow(ctx context.Context, thingID string) (*iotclient.ArduinoThing, error) { if thingID == thingOverriddenID { return &iotclient.ArduinoThing{ Properties: []iotclient.ArduinoProperty{ @@ -215,7 +216,7 @@ func TestLoadDashboard(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := LoadDashboard(tt.file, tt.override, mockThingShow) + got, err := LoadDashboard(context.TODO(), tt.file, tt.override, mockThingShow) if err != nil { t.Errorf("%v", err) }