diff --git a/cmd/account/link_key/link_key.go b/cmd/account/link_key/link_key.go index 6be66976..88787c7b 100644 --- a/cmd/account/link_key/link_key.go +++ b/cmd/account/link_key/link_key.go @@ -388,7 +388,7 @@ func (h *handler) checkIfAlreadyLinked() (bool, error) { ownerAddr := common.HexToAddress(h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerAddress) ui.Dim("Checking existing registrations...") - linked, err := h.wrc.IsOwnerLinked(ownerAddr) + linked, err := h.wrc.IsOwnerLinked(context.Background(), ownerAddr) if err != nil { return false, fmt.Errorf("failed to check owner link status: %w", err) } diff --git a/cmd/account/unlink_key/unlink_key.go b/cmd/account/unlink_key/unlink_key.go index 3c50e373..3c36e8bf 100644 --- a/cmd/account/unlink_key/unlink_key.go +++ b/cmd/account/unlink_key/unlink_key.go @@ -346,7 +346,7 @@ func (h *handler) unlinkOwner(owner string, resp initiateUnlinkingResponse) erro func (h *handler) checkIfAlreadyLinked() (bool, error) { ownerAddr := common.HexToAddress(h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerAddress) - linked, err := h.wrc.IsOwnerLinked(ownerAddr) + linked, err := h.wrc.IsOwnerLinked(context.Background(), ownerAddr) if err != nil { return false, fmt.Errorf("failed to check owner link status: %w", err) } diff --git a/cmd/client/tx.go b/cmd/client/tx.go index 1d7715cb..30f2913f 100644 --- a/cmd/client/tx.go +++ b/cmd/client/tx.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/json" "errors" "fmt" @@ -119,7 +120,7 @@ type RawTx struct { // return txOpts, nil //} -func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts) (*types.Transaction, error), funName string, validationEvent string, args ...any) (TxOutput, error) { +func (c *TxClient) executeTransactionByTxType(ctx context.Context, txFn func(opts *bind.TransactOpts) (*types.Transaction, error), funName string, validationEvent string, args ...any) (TxOutput, error) { switch c.config.TxType { case Regular: simulateTx, err := txFn(cmdCommon.SimTransactOpts()) @@ -138,7 +139,7 @@ func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts) Value: simulateTx.Value(), Data: simulateTx.Data(), } - estimatedGas, gasErr := c.EthClient.Client.EstimateGas(c.EthClient.Context, msg) + estimatedGas, gasErr := c.EthClient.Client.EstimateGas(ctx, msg) if gasErr != nil { c.Logger.Warn().Err(gasErr).Msg("Failed to estimate gas usage") } @@ -159,7 +160,7 @@ func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts) // Calculate and print total cost for sending the transaction on-chain if gasErr == nil { - gasPriceWei, gasPriceErr := c.EthClient.Client.SuggestGasPrice(c.EthClient.Context) + gasPriceWei, gasPriceErr := c.EthClient.Client.SuggestGasPrice(ctx) if gasPriceErr != nil { c.Logger.Warn().Err(gasPriceErr).Msg("Failed to fetch gas price") } else { diff --git a/cmd/client/workflow_registry_v2_client.go b/cmd/client/workflow_registry_v2_client.go index a8dd6c5f..7747b16f 100644 --- a/cmd/client/workflow_registry_v2_client.go +++ b/cmd/client/workflow_registry_v2_client.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/hex" "errors" "fmt" @@ -24,6 +25,14 @@ type workflowRegistryV2Contract interface { IsRequestAllowlisted(opts *bind.CallOpts, owner common.Address, requestDigest [32]byte) (bool, error) } +func (wrc *WorkflowRegistryV2Client) callOpts(ctx context.Context) *bind.CallOpts { + opts := wrc.EthClient.NewCallOpts() + if ctx != nil { + opts.Context = ctx + } + return opts +} + type WorkflowRegistryV2Client struct { TxClient ContractAddress common.Address @@ -73,7 +82,7 @@ func (wrc *WorkflowRegistryV2Client) LinkOwner(validityTimestamp *big.Int, proof txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) { return contract.LinkOwner(opts, validityTimestamp, proof, signature) } - txOut, err := wrc.executeTransactionByTxType(txFn, "LinkOwner", "OwnershipLinkUpdated", validityTimestamp, proof, signature) + txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "LinkOwner", "OwnershipLinkUpdated", validityTimestamp, proof, signature) if err != nil { wrc.Logger.Error(). Str("contract", contract.Address().Hex()). @@ -105,7 +114,7 @@ func (wrc *WorkflowRegistryV2Client) UnlinkOwner(owner common.Address, validityT txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) { return contract.UnlinkOwner(opts, owner, validityTimestamp, signature) } - txOut, err := wrc.executeTransactionByTxType(txFn, "UnlinkOwner", "OwnershipLinkUpdated", owner, validityTimestamp, signature) + txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "UnlinkOwner", "OwnershipLinkUpdated", owner, validityTimestamp, signature) if err != nil { wrc.Logger.Error(). Str("contract", contract.Address().Hex()). @@ -414,7 +423,7 @@ func (wrc *WorkflowRegistryV2Client) IsAllowedSigner(signer common.Address) (boo return ok, err } -func (wrc *WorkflowRegistryV2Client) IsOwnerLinked(owner common.Address) (bool, error) { +func (wrc *WorkflowRegistryV2Client) IsOwnerLinked(ctx context.Context, owner common.Address) (bool, error) { contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client) if err != nil { wrc.Logger.Error(). @@ -424,7 +433,7 @@ func (wrc *WorkflowRegistryV2Client) IsOwnerLinked(owner common.Address) (bool, } result, err := callContractMethodV2(wrc, func() (bool, error) { - return contract.IsOwnerLinked(wrc.EthClient.NewCallOpts(), owner) + return contract.IsOwnerLinked(wrc.callOpts(ctx), owner) }) if err != nil { wrc.Logger.Error(). @@ -468,7 +477,7 @@ func (wrc *WorkflowRegistryV2Client) TypeAndVersion() (string, error) { return tv, err } -func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(params RegisterWorkflowV2Parameters) (*TxOutput, error) { +func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(ctx context.Context, params RegisterWorkflowV2Parameters) (*TxOutput, error) { contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client) if err != nil { wrc.Logger.Error(). @@ -492,7 +501,7 @@ func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(params RegisterWorkflowV2Par params.KeepAlive, ) } - txOut, err := wrc.executeTransactionByTxType(txFn, "UpsertWorkflow", "WorkflowRegistered|WorkflowUpdated", + txOut, err := wrc.executeTransactionByTxType(ctx, txFn, "UpsertWorkflow", "WorkflowRegistered|WorkflowUpdated", params.WorkflowName, params.Tag, params.WorkflowID, @@ -513,7 +522,7 @@ func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(params RegisterWorkflowV2Par return &txOut, nil } -func (wrc *WorkflowRegistryV2Client) GetWorkflow(owner common.Address, workflowName, tag string) (workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { +func (wrc *WorkflowRegistryV2Client) GetWorkflow(ctx context.Context, owner common.Address, workflowName, tag string) (workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client) if err != nil { wrc.Logger.Error().Err(err).Msg("Failed to connect for GetWorkflow") @@ -521,7 +530,7 @@ func (wrc *WorkflowRegistryV2Client) GetWorkflow(owner common.Address, workflowN } result, err := callContractMethodV2(wrc, func() (workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { - return contract.GetWorkflow(wrc.EthClient.NewCallOpts(), owner, workflowName, tag) + return contract.GetWorkflow(wrc.callOpts(ctx), owner, workflowName, tag) }) if err != nil { wrc.Logger.Error().Err(err).Msg("GetWorkflow call failed") @@ -529,7 +538,7 @@ func (wrc *WorkflowRegistryV2Client) GetWorkflow(owner common.Address, workflowN return result, err } -func (wrc *WorkflowRegistryV2Client) GetWorkflowListByOwnerAndName(owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { +func (wrc *WorkflowRegistryV2Client) GetWorkflowListByOwnerAndName(ctx context.Context, owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client) if err != nil { wrc.Logger.Error().Err(err).Msg("Failed to connect for GetWorkflowListByOwnerAndName") @@ -537,7 +546,7 @@ func (wrc *WorkflowRegistryV2Client) GetWorkflowListByOwnerAndName(owner common. } result, err := callContractMethodV2(wrc, func() ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { - return contract.GetWorkflowListByOwnerAndName(wrc.EthClient.NewCallOpts(), owner, workflowName, start, limit) + return contract.GetWorkflowListByOwnerAndName(wrc.callOpts(ctx), owner, workflowName, start, limit) }) if err != nil { wrc.Logger.Error().Err(err).Msg("GetWorkflowListByOwnerAndName call failed") @@ -619,7 +628,7 @@ func (wrc *WorkflowRegistryV2Client) DeleteWorkflow(workflowID [32]byte) (*TxOut txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) { return contract.DeleteWorkflow(opts, workflowID) } - txOut, err := wrc.executeTransactionByTxType(txFn, "DeleteWorkflow", "WorkflowDeleted", workflowID) + txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "DeleteWorkflow", "WorkflowDeleted", workflowID) if err != nil { wrc.Logger.Error(). Str("contract", contract.Address().Hex()). @@ -646,7 +655,7 @@ func (wrc *WorkflowRegistryV2Client) BatchPauseWorkflows(workflowIDs [][32]byte) workflowIDs, ) } - txOut, err := wrc.executeTransactionByTxType(txFn, "BatchPauseWorkflows", "WorkflowStatusUpdated", workflowIDs) + txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "BatchPauseWorkflows", "WorkflowStatusUpdated", workflowIDs) if err != nil { wrc.Logger.Error(). Str("contract", contract.Address().Hex()). @@ -670,7 +679,7 @@ func (wrc *WorkflowRegistryV2Client) ActivateWorkflow(workflowID [32]byte, donFa txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) { return contract.ActivateWorkflow(opts, workflowID, donFamily) } - txOut, err := wrc.executeTransactionByTxType(txFn, "ActivateWorkflow", "WorkflowActivated", workflowID, donFamily) + txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "ActivateWorkflow", "WorkflowActivated", workflowID, donFamily) if err != nil { wrc.Logger.Error(). Str("contract", contract.Address().Hex()). @@ -772,7 +781,7 @@ func (wrc *WorkflowRegistryV2Client) AllowlistRequest(requestDigest [32]byte, du txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) { return contract.AllowlistRequest(opts, requestDigest, deadline) } - txOut, err := wrc.executeTransactionByTxType(txFn, "AllowlistRequest", "RequestAllowlisted", requestDigest, duration) + txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "AllowlistRequest", "RequestAllowlisted", requestDigest, duration) if err != nil { wrc.Logger.Error(). Str("contract", wrc.ContractAddress.Hex()). diff --git a/cmd/common/compile.go b/cmd/common/compile.go index 2456ec77..d500e6e9 100644 --- a/cmd/common/compile.go +++ b/cmd/common/compile.go @@ -1,6 +1,7 @@ package common import ( + "context" "errors" "fmt" "os" @@ -33,7 +34,7 @@ type WorkflowCompileOptions struct { } // getBuildCmd returns a single step that builds the workflow and returns the WASM bytes. -func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCompileOptions) (func() ([]byte, error), error) { +func getBuildCmd(ctx context.Context, workflowRootFolder, mainFile, language string, opts WorkflowCompileOptions) (func() ([]byte, error), error) { tmpPath := filepath.Join(workflowRootFolder, ".cre_build_tmp.wasm") switch language { case constants.WorkflowLanguageTypeScript: @@ -41,7 +42,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if opts.SkipTypeChecks { args = append(args, SkipTypeChecksFlag) } - cmd := exec.Command("bun", args...) + cmd := exec.CommandContext(ctx, "bun", args...) cmd.Dir = workflowRootFolder return func() ([]byte, error) { out, err := cmd.CombinedOutput() @@ -67,7 +68,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if opts.StripSymbols { ldflags = "-buildid= -w -s" } - cmd := exec.Command( + cmd := exec.CommandContext(ctx, "go", "build", "-o", tmpPath, "-trimpath", @@ -92,7 +93,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if err != nil { return nil, err } - makeCmd := exec.Command("make", "build") + makeCmd := exec.CommandContext(ctx, "make", "build") makeCmd.Dir = makeRoot builtPath := filepath.Join(makeRoot, defaultWasmOutput) return func() ([]byte, error) { @@ -108,7 +109,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom if opts.StripSymbols { ldflags = "-buildid= -w -s" } - cmd := exec.Command( + cmd := exec.CommandContext(ctx, "go", "build", "-o", tmpPath, "-trimpath", @@ -135,7 +136,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom // opts.StripSymbols: for Go builds, true strips debug symbols (deploy); false keeps them (simulate). // opts.SkipTypeChecks: for TypeScript, passes SkipTypeChecksFlag to cre-compile. // For custom Makefile WASM builds, StripSymbols and SkipTypeChecks have no effect. -func CompileWorkflowToWasm(workflowPath string, opts WorkflowCompileOptions) ([]byte, error) { +func CompileWorkflowToWasm(ctx context.Context, workflowPath string, opts WorkflowCompileOptions) ([]byte, error) { workflowRootFolder, workflowMainFile, err := WorkflowPathRootAndMain(workflowPath) if err != nil { return nil, fmt.Errorf("workflow path: %w", err) @@ -167,7 +168,7 @@ func CompileWorkflowToWasm(workflowPath string, opts WorkflowCompileOptions) ([] return nil, fmt.Errorf("unsupported workflow language for file %s", workflowMainFile) } - buildStep, err := getBuildCmd(workflowRootFolder, workflowMainFile, language, opts) + buildStep, err := getBuildCmd(ctx, workflowRootFolder, workflowMainFile, language, opts) if err != nil { return nil, err } diff --git a/cmd/common/compile_test.go b/cmd/common/compile_test.go index 4813681a..42dfa1a4 100644 --- a/cmd/common/compile_test.go +++ b/cmd/common/compile_test.go @@ -2,6 +2,7 @@ package common import ( "bytes" + "context" "io" "os" "os/exec" @@ -47,21 +48,21 @@ func TestFindMakefileRoot(t *testing.T) { func TestCompileWorkflowToWasm_Go_Success(t *testing.T) { t.Run("basic_workflow", func(t *testing.T) { path := deployTestdataPath("basic_workflow", "main.go") - wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) }) t.Run("configless_workflow", func(t *testing.T) { path := deployTestdataPath("configless_workflow", "main.go") - wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) }) t.Run("missing_go_mod", func(t *testing.T) { path := deployTestdataPath("missing_go_mod", "main.go") - wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) }) @@ -69,7 +70,7 @@ func TestCompileWorkflowToWasm_Go_Success(t *testing.T) { func TestCompileWorkflowToWasm_Go_Malformed_Fails(t *testing.T) { path := deployTestdataPath("malformed_workflow", "main.go") - _, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + _, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.Error(t, err) assert.Contains(t, err.Error(), "failed to compile workflow") assert.Contains(t, err.Error(), "undefined: sdk.RemovedFunctionThatFailsCompilation") @@ -80,7 +81,7 @@ func TestCompileWorkflowToWasm_Wasm_Success(t *testing.T) { _ = os.Remove(wasmPath) t.Cleanup(func() { _ = os.Remove(wasmPath) }) - wasm, err := CompileWorkflowToWasm(wasmPath, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), wasmPath, WorkflowCompileOptions{StripSymbols: true}) require.NoError(t, err) assert.NotEmpty(t, wasm) @@ -96,14 +97,14 @@ func TestCompileWorkflowToWasm_Wasm_Fails(t *testing.T) { wasmPath := filepath.Join(wasmDir, "workflow.wasm") require.NoError(t, os.WriteFile(wasmPath, []byte("not really wasm"), 0600)) - _, err := CompileWorkflowToWasm(wasmPath, WorkflowCompileOptions{StripSymbols: true}) + _, err := CompileWorkflowToWasm(context.Background(), wasmPath, WorkflowCompileOptions{StripSymbols: true}) require.Error(t, err) assert.Contains(t, err.Error(), "no Makefile found") }) t.Run("make_build_fails", func(t *testing.T) { path := deployTestdataPath("wasm_make_fails", "wasm", "workflow.wasm") - _, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true}) + _, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true}) require.Error(t, err) assert.Contains(t, err.Error(), "failed to compile workflow") assert.Contains(t, err.Error(), "build output:") @@ -138,7 +139,7 @@ func TestCompileWorkflowToWasm_TS_Success(t *testing.T) { "include": ["main.ts"] } `), 0600)) - wasm, err := CompileWorkflowToWasm(mainPath, WorkflowCompileOptions{StripSymbols: true}) + wasm, err := CompileWorkflowToWasm(context.Background(), mainPath, WorkflowCompileOptions{StripSymbols: true}) if err != nil { t.Skipf("TS compile failed (published cre-sdk may lack full layout): %v", err) } diff --git a/cmd/common/fetch.go b/cmd/common/fetch.go index 5f8ee4f4..9a715d6d 100644 --- a/cmd/common/fetch.go +++ b/cmd/common/fetch.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "io" "net/http" @@ -28,12 +29,16 @@ func IsURL(s string) bool { } // FetchURL performs an HTTP GET and returns the response body bytes. -func FetchURL(url string) ([]byte, error) { - resp, err := http.Get(url) //nolint:gosec,noctx +func FetchURL(ctx context.Context, url string) ([]byte, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("HTTP GET %s: %w", url, err) } - defer resp.Body.Close() + resp, err := http.DefaultClient.Do(req) //nolint:gosec // URL validated by caller + if err != nil { + return nil, fmt.Errorf("HTTP GET %s: %w", url, err) + } + defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP GET %s returned status %d", url, resp.StatusCode) diff --git a/cmd/common/fetch_test.go b/cmd/common/fetch_test.go index 10a6ade6..0291f3dd 100644 --- a/cmd/common/fetch_test.go +++ b/cmd/common/fetch_test.go @@ -1,6 +1,7 @@ package common import ( + "context" "net/http" "net/http/httptest" "testing" @@ -42,7 +43,7 @@ func TestFetchURL(t *testing.T) { })) defer srv.Close() - data, err := FetchURL(srv.URL) + data, err := FetchURL(context.Background(), srv.URL) require.NoError(t, err) assert.Equal(t, body, data) }) @@ -53,13 +54,13 @@ func TestFetchURL(t *testing.T) { })) defer srv.Close() - _, err := FetchURL(srv.URL) + _, err := FetchURL(context.Background(), srv.URL) require.Error(t, err) assert.Contains(t, err.Error(), "returned status 404") }) t.Run("unreachable host", func(t *testing.T) { - _, err := FetchURL("http://127.0.0.1:1") + _, err := FetchURL(context.Background(), "http://127.0.0.1:1") require.Error(t, err) }) } diff --git a/cmd/secrets/common/handler.go b/cmd/secrets/common/handler.go index d0ba25f2..feef196a 100644 --- a/cmd/secrets/common/handler.go +++ b/cmd/secrets/common/handler.go @@ -737,7 +737,7 @@ func (h *Handler) EnsureOwnerLinkedOrFail() error { } ownerAddr := common.HexToAddress(h.OwnerAddress) - linked, err := h.Wrc.IsOwnerLinked(ownerAddr) + linked, err := h.Wrc.IsOwnerLinked(context.Background(), ownerAddr) if err != nil { return fmt.Errorf("failed to check owner link status: %w", err) } diff --git a/cmd/workflow/activate/registry_activate_strategy_onchain.go b/cmd/workflow/activate/registry_activate_strategy_onchain.go index ea9df4dd..0a73ba24 100644 --- a/cmd/workflow/activate/registry_activate_strategy_onchain.go +++ b/cmd/workflow/activate/registry_activate_strategy_onchain.go @@ -1,6 +1,7 @@ package activate import ( + "context" "encoding/hex" "fmt" "math/big" @@ -59,7 +60,7 @@ func (a *onchainRegistryActivateStrategy) Activate() error { ownerAddr := common.HexToAddress(workflowOwner) const pageLimit = 200 - workflows, err := a.wrc.GetWorkflowListByOwnerAndName(ownerAddr, workflowName, big.NewInt(0), big.NewInt(pageLimit)) + workflows, err := a.wrc.GetWorkflowListByOwnerAndName(context.Background(), ownerAddr, workflowName, big.NewInt(0), big.NewInt(pageLimit)) if err != nil { return fmt.Errorf("failed to get workflow list: %w", err) } diff --git a/cmd/workflow/activate/registry_activate_strategy_private.go b/cmd/workflow/activate/registry_activate_strategy_private.go index 1b1e1df9..b7c0f09c 100644 --- a/cmd/workflow/activate/registry_activate_strategy_private.go +++ b/cmd/workflow/activate/registry_activate_strategy_private.go @@ -1,6 +1,7 @@ package activate import ( + "context" "fmt" "github.com/smartcontractkit/cre-cli/internal/client/graphqlclient" @@ -32,7 +33,7 @@ func (a *privateRegistryActivateStrategy) Activate() error { ui.Dim(fmt.Sprintf("Fetching workflow to activate... Name=%s", workflowName)) - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(context.Background(), workflowName) if err != nil { return fmt.Errorf("failed to get workflow: %w", err) } diff --git a/cmd/workflow/build/build.go b/cmd/workflow/build/build.go index f92f6973..63b79edd 100644 --- a/cmd/workflow/build/build.go +++ b/cmd/workflow/build/build.go @@ -1,6 +1,7 @@ package build import ( + "context" "fmt" "os" "path/filepath" @@ -26,7 +27,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { outputPath, _ := cmd.Flags().GetString("output") skipTypeChecks, _ := cmd.Flags().GetBool(cmdcommon.SkipTypeChecksCLIFlag) - return execute(args[0], outputPath, skipTypeChecks) + return execute(cmd.Context(), args[0], outputPath, skipTypeChecks) }, } buildCmd.Flags().StringP("output", "o", "", "Output file path for the compiled WASM binary (default: /binary.wasm)") @@ -34,7 +35,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { return buildCmd } -func execute(workflowFolder, outputPath string, skipTypeChecks bool) error { +func execute(ctx context.Context, workflowFolder, outputPath string, skipTypeChecks bool) error { workflowDir, err := filepath.Abs(workflowFolder) if err != nil { return fmt.Errorf("resolve workflow folder: %w", err) @@ -60,7 +61,7 @@ func execute(workflowFolder, outputPath string, skipTypeChecks bool) error { outputPath = cmdcommon.EnsureWasmExtension(outputPath) ui.Dim("Compiling workflow...") - wasmBytes, err := cmdcommon.CompileWorkflowToWasm(resolvedPath, cmdcommon.WorkflowCompileOptions{ + wasmBytes, err := cmdcommon.CompileWorkflowToWasm(ctx, resolvedPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: skipTypeChecks, }) diff --git a/cmd/workflow/delete/registry_delete_strategy_onchain.go b/cmd/workflow/delete/registry_delete_strategy_onchain.go index 1dd7335b..00bcd53c 100644 --- a/cmd/workflow/delete/registry_delete_strategy_onchain.go +++ b/cmd/workflow/delete/registry_delete_strategy_onchain.go @@ -1,6 +1,7 @@ package delete import ( + "context" "encoding/hex" "errors" "fmt" @@ -56,7 +57,7 @@ func (a *onchainRegistryDeleteStrategy) FetchWorkflows() ([]WorkflowToDelete, er workflowName := h.inputs.WorkflowName workflowOwner := common.HexToAddress(h.inputs.WorkflowOwner) - allWorkflows, err := a.wrc.GetWorkflowListByOwnerAndName(workflowOwner, workflowName, big.NewInt(0), big.NewInt(100)) + allWorkflows, err := a.wrc.GetWorkflowListByOwnerAndName(context.Background(), workflowOwner, workflowName, big.NewInt(0), big.NewInt(100)) if err != nil { return nil, fmt.Errorf("failed to get workflow list: %w", err) } diff --git a/cmd/workflow/delete/registry_delete_strategy_private.go b/cmd/workflow/delete/registry_delete_strategy_private.go index 22f832bf..5a5a2ef4 100644 --- a/cmd/workflow/delete/registry_delete_strategy_private.go +++ b/cmd/workflow/delete/registry_delete_strategy_private.go @@ -1,6 +1,7 @@ package delete import ( + "context" "fmt" "github.com/smartcontractkit/cre-cli/internal/client/graphqlclient" @@ -32,7 +33,7 @@ func (a *privateRegistryDeleteStrategy) FetchWorkflows() ([]WorkflowToDelete, er ui.Dim(fmt.Sprintf("Fetching workflow to delete... Name=%s", workflowName)) - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(context.Background(), workflowName) if err != nil { return nil, fmt.Errorf("failed to get workflow: %w", err) } diff --git a/cmd/workflow/deploy/artifacts.go b/cmd/workflow/deploy/artifacts.go index 070b1421..f4512b13 100644 --- a/cmd/workflow/deploy/artifacts.go +++ b/cmd/workflow/deploy/artifacts.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "fmt" "github.com/smartcontractkit/cre-cli/internal/client/graphqlclient" @@ -8,7 +9,7 @@ import ( "github.com/smartcontractkit/cre-cli/internal/ui" ) -func (h *handler) uploadArtifacts() error { +func (h *handler) uploadArtifacts(ctx context.Context) error { if h.workflowArtifact == nil { return fmt.Errorf("workflowArtifact is nil") } @@ -45,7 +46,7 @@ func (h *handler) uploadArtifacts() error { if !binaryFromURL { ui.Success(fmt.Sprintf("Loaded binary from: %s", h.inputs.OutputPath)) - binaryResp, err := storageClient.UploadArtifactWithRetriesAndGetURL( + binaryResp, err := storageClient.UploadArtifactWithRetriesAndGetURL(ctx, workflowID, storageclient.ArtifactTypeBinary, binaryData, "application/octet-stream") if err != nil { return fmt.Errorf("uploading binary artifact: %w", err) @@ -58,7 +59,7 @@ func (h *handler) uploadArtifacts() error { if !configFromURL && len(configData) > 0 { ui.Success(fmt.Sprintf("Loaded config from: %s", h.inputs.ConfigPath)) var err error - configURL, err = storageClient.UploadArtifactWithRetriesAndGetURL( + configURL, err = storageClient.UploadArtifactWithRetriesAndGetURL(ctx, workflowID, storageclient.ArtifactTypeConfig, configData, "text/plain") if err != nil { return fmt.Errorf("uploading config artifact: %w", err) diff --git a/cmd/workflow/deploy/artifacts_test.go b/cmd/workflow/deploy/artifacts_test.go index 24833d9c..be0acf93 100644 --- a/cmd/workflow/deploy/artifacts_test.go +++ b/cmd/workflow/deploy/artifacts_test.go @@ -1,6 +1,7 @@ package deploy import ( + "context" //nolint:gosec "encoding/json" "errors" @@ -99,7 +100,7 @@ func TestUpload_SuccessAndErrorCases(t *testing.T) { ConfigData: []byte("configdata"), WorkflowID: "workflow-id", } - err := h.uploadArtifacts() + err := h.uploadArtifacts(context.Background()) require.NoError(t, err) require.Equal(t, "http://origin/get", h.inputs.BinaryURL) require.Equal(t, "http://origin/get", *h.inputs.ConfigURL) @@ -110,12 +111,12 @@ func TestUpload_SuccessAndErrorCases(t *testing.T) { ConfigData: nil, WorkflowID: "workflow-id", } - err = h.uploadArtifacts() + err = h.uploadArtifacts(context.Background()) require.NoError(t, err) // Error: workflowArtifact is nil h.workflowArtifact = nil - err = h.uploadArtifacts() + err = h.uploadArtifacts(context.Background()) require.ErrorContains(t, err, "workflowArtifact is nil") // Error: empty BinaryData @@ -124,7 +125,7 @@ func TestUpload_SuccessAndErrorCases(t *testing.T) { ConfigData: []byte("configdata"), WorkflowID: "workflow-id", } - err = h.uploadArtifacts() + err = h.uploadArtifacts(context.Background()) require.ErrorContains(t, err, "uploading binary artifact: content is empty for artifactType BINARY") // Error: workflowID is empty @@ -133,7 +134,7 @@ func TestUpload_SuccessAndErrorCases(t *testing.T) { ConfigData: []byte("configdata"), WorkflowID: "", } - err = h.uploadArtifacts() + err = h.uploadArtifacts(context.Background()) require.ErrorContains(t, err, "workflowID is empty") } @@ -174,7 +175,7 @@ func TestUploadArtifactToStorageService_OriginError(t *testing.T) { ConfigData: []byte("configdata"), WorkflowID: "workflow-id", } - err := h.uploadArtifacts() + err := h.uploadArtifacts(context.Background()) require.ErrorContains(t, err, "upload to origin") } @@ -240,7 +241,7 @@ func TestUploadArtifactToStorageService_AlreadyExistsError(t *testing.T) { ConfigData: []byte("configdata"), WorkflowID: "workflow-id", } - err := h.uploadArtifacts() + err := h.uploadArtifacts(context.Background()) require.NoError(t, err) require.Equal(t, "http://origin/get", h.inputs.BinaryURL) require.Equal(t, "http://origin/get", *h.inputs.ConfigURL) @@ -291,7 +292,7 @@ func TestUpload_UsesResolvedWorkflowOwnerForPresignedUrls(t *testing.T) { WorkflowID: "workflow-id", } - err := h.uploadArtifacts() + err := h.uploadArtifacts(context.Background()) require.NoError(t, err) require.NotEmpty(t, ownersUsed) for _, owner := range ownersUsed { diff --git a/cmd/workflow/deploy/auto_link.go b/cmd/workflow/deploy/auto_link.go index 2ee140bf..9d0ea6d7 100644 --- a/cmd/workflow/deploy/auto_link.go +++ b/cmd/workflow/deploy/auto_link.go @@ -22,10 +22,10 @@ const ( ) // ensureOwnerLinkedOrFail checks if the owner is linked and attempts auto-link if needed -func (h *handler) ensureOwnerLinkedOrFail(onChain *settings.OnChainRegistry) error { +func (h *handler) ensureOwnerLinkedOrFail(ctx context.Context, onChain *settings.OnChainRegistry) error { ownerAddr := common.HexToAddress(h.inputs.WorkflowOwner) - linked, err := h.wrc.IsOwnerLinked(ownerAddr) + linked, err := h.wrc.IsOwnerLinked(ctx, ownerAddr) if err != nil { return fmt.Errorf("failed to check owner link status: %w", err) } @@ -34,7 +34,7 @@ func (h *handler) ensureOwnerLinkedOrFail(onChain *settings.OnChainRegistry) err if linked { // Owner is linked on contract, now verify it's linked to the current user's account - linkedToCurrentUser, err := h.checkLinkStatusViaGraphQL(ownerAddr) + linkedToCurrentUser, err := h.checkLinkStatusViaGraphQL(ctx, ownerAddr) if err != nil { return fmt.Errorf("failed to validate key ownership: %w", err) } @@ -55,7 +55,7 @@ func (h *handler) ensureOwnerLinkedOrFail(onChain *settings.OnChainRegistry) err ui.Success(fmt.Sprintf("Auto-link successful: owner=%s", ownerAddr.Hex())) // Wait for linking process to complete - if err := h.waitForBackendLinkProcessing(ownerAddr); err != nil { + if err := h.waitForBackendLinkProcessing(ctx, ownerAddr); err != nil { return fmt.Errorf("linking process failed: %w", err) } @@ -63,17 +63,17 @@ func (h *handler) ensureOwnerLinkedOrFail(onChain *settings.OnChainRegistry) err } // autoLinkMSIGAndExit handles MSIG auto-link and exits if manual intervention is needed -func (h *handler) autoLinkMSIGAndExit(onChain *settings.OnChainRegistry) (halt bool, err error) { +func (h *handler) autoLinkMSIGAndExit(ctx context.Context, onChain *settings.OnChainRegistry) (halt bool, err error) { ownerAddr := common.HexToAddress(h.inputs.WorkflowOwner) - linked, err := h.wrc.IsOwnerLinked(ownerAddr) + linked, err := h.wrc.IsOwnerLinked(ctx, ownerAddr) if err != nil { return false, fmt.Errorf("failed to check owner link status: %w", err) } if linked { // Owner is linked on contract, now verify it's linked to the current user's account - linkedToCurrentUser, err := h.checkLinkStatusViaGraphQL(ownerAddr) + linkedToCurrentUser, err := h.checkLinkStatusViaGraphQL(ctx, ownerAddr) if err != nil { return false, fmt.Errorf("failed to validate MSIG key ownership: %w", err) } @@ -115,7 +115,7 @@ func (h *handler) tryAutoLink(onChain *settings.OnChainRegistry) error { } // checkLinkStatusViaGraphQL checks if the owner is linked and verified by querying the service -func (h *handler) checkLinkStatusViaGraphQL(ownerAddr common.Address) (bool, error) { +func (h *handler) checkLinkStatusViaGraphQL(ctx context.Context, ownerAddr common.Address) (bool, error) { const query = ` query { listWorkflowOwners(filters: { linkStatus: LINKED_ONLY }) { @@ -137,7 +137,7 @@ func (h *handler) checkLinkStatusViaGraphQL(ownerAddr common.Address) (bool, err } gql := graphqlclient.New(h.credentials, h.environmentSet, h.log) - if err := gql.Execute(context.Background(), req, &resp); err != nil { + if err := gql.Execute(ctx, req, &resp); err != nil { return false, fmt.Errorf("GraphQL query failed: %w", err) } @@ -169,7 +169,7 @@ func (h *handler) checkLinkStatusViaGraphQL(ownerAddr common.Address) (bool, err } // waitForBackendLinkProcessing polls the service until the link is processed -func (h *handler) waitForBackendLinkProcessing(ownerAddr common.Address) error { +func (h *handler) waitForBackendLinkProcessing(ctx context.Context, ownerAddr common.Address) error { const maxAttempts = 5 const retryDelay = 3 * time.Second const initialBlockWait = 36 * time.Second // Wait for 3 block confirmations (~12s per block) @@ -181,11 +181,13 @@ func (h *handler) waitForBackendLinkProcessing(ownerAddr common.Address) error { ui.Line() // Wait for 3 block confirmations before polling - time.Sleep(initialBlockWait) + if err := sleepWithContext(ctx, initialBlockWait); err != nil { + return err + } err := retry.Do( func() error { - linked, err := h.checkLinkStatusViaGraphQL(ownerAddr) + linked, err := h.checkLinkStatusViaGraphQL(ctx, ownerAddr) if err != nil { h.log.Warn().Err(err).Msg("Failed to check link status") return err // Return error to trigger retry @@ -199,6 +201,7 @@ func (h *handler) waitForBackendLinkProcessing(ownerAddr common.Address) error { retry.Delay(retryDelay), retry.DelayType(retry.FixedDelay), // Use fixed 3s delay between retries retry.LastErrorOnly(true), + retry.Context(ctx), retry.OnRetry(func(n uint, err error) { h.log.Debug().Uint("attempt", n+1).Uint("maxAttempts", maxAttempts).Err(err).Msg("Retrying link status check") ui.Dim(fmt.Sprintf(" Waiting for verification... (attempt %d/%d)", n+1, maxAttempts)) @@ -212,3 +215,14 @@ func (h *handler) waitForBackendLinkProcessing(ownerAddr common.Address) error { ui.Success(fmt.Sprintf("Linking verified: owner=%s", ownerAddr.Hex())) return nil } + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/cmd/workflow/deploy/auto_link_test.go b/cmd/workflow/deploy/auto_link_test.go index e192ccfa..f12d767b 100644 --- a/cmd/workflow/deploy/auto_link_test.go +++ b/cmd/workflow/deploy/auto_link_test.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -164,7 +165,7 @@ func TestCheckLinkStatusViaGraphQL(t *testing.T) { // Test the function ownerAddr := common.HexToAddress(tt.ownerAddress) - result, err := h.checkLinkStatusViaGraphQL(ownerAddr) + result, err := h.checkLinkStatusViaGraphQL(context.Background(), ownerAddr) if tt.expectError { assert.Error(t, err) @@ -335,7 +336,7 @@ func TestWaitForBackendLinkProcessing(t *testing.T) { // Test the function ownerAddr := common.HexToAddress(tt.ownerAddress) - err := h.waitForBackendLinkProcessing(ownerAddr) + err := h.waitForBackendLinkProcessing(context.Background(), ownerAddr) if tt.expectError { assert.Error(t, err) diff --git a/cmd/workflow/deploy/compile.go b/cmd/workflow/deploy/compile.go index ecb3c064..ffa42bdf 100644 --- a/cmd/workflow/deploy/compile.go +++ b/cmd/workflow/deploy/compile.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "fmt" "os" @@ -9,7 +10,7 @@ import ( "github.com/smartcontractkit/cre-cli/internal/ui" ) -func (h *handler) Compile() error { +func (h *handler) Compile(ctx context.Context) error { if !h.validated { return fmt.Errorf("handler h.inputs not validated") } @@ -67,7 +68,7 @@ func (h *handler) Compile() error { h.runtimeContext.Workflow.Language = cmdcommon.GetWorkflowLanguage(workflowMainFile) } - wasmFile, err = cmdcommon.CompileWorkflowToWasm(resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasmFile, err = cmdcommon.CompileWorkflowToWasm(ctx, resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: h.inputs.SkipTypeChecks, }) diff --git a/cmd/workflow/deploy/compile_test.go b/cmd/workflow/deploy/compile_test.go index 149ac19a..c1f907e1 100644 --- a/cmd/workflow/deploy/compile_test.go +++ b/cmd/workflow/deploy/compile_test.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "encoding/base64" "errors" "io" @@ -271,7 +272,7 @@ func runCompile(simulatedEnvironment *chainsim.SimulatedEnvironment, inputs Inpu return err } - return handler.Compile() + return handler.Compile(context.Background()) } // outputPathWithExtensions returns the path with .wasm.br.b64 appended as in Compile(). @@ -286,7 +287,7 @@ func outputPathWithExtensions(path string) string { // file content equals CompileWorkflowToWasm(workflowPath) + brotli + base64. func assertCompileOutputMatchesUnderlying(t *testing.T, simulatedEnvironment *chainsim.SimulatedEnvironment, inputs Inputs, ownerType string) { t.Helper() - wasm, err := cmdcommon.CompileWorkflowToWasm(inputs.WorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasm, err := cmdcommon.CompileWorkflowToWasm(context.Background(), inputs.WorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: inputs.SkipTypeChecks, }) @@ -435,7 +436,7 @@ func TestCompileWithWasmPath(t *testing.T) { handler.validated = true // Compile() with URL wasm should return nil (skips compile entirely). - err := handler.Compile() + err := handler.Compile(context.Background()) require.NoError(t, err) }) diff --git a/cmd/workflow/deploy/deploy.go b/cmd/workflow/deploy/deploy.go index ba6c37c6..bc5db8f3 100644 --- a/cmd/workflow/deploy/deploy.go +++ b/cmd/workflow/deploy/deploy.go @@ -219,18 +219,18 @@ func (h *handler) Execute(ctx context.Context) error { return err } - if err := h.prepareArtifacts(); err != nil { + if err := h.prepareArtifacts(ctx); err != nil { return err } - if err := adapter.RunPreDeployChecks(); err != nil { + if err := adapter.RunPreDeployChecks(ctx); err != nil { if errors.Is(err, errDeployHalted) { return nil } return err } - exists, existingStatus, err := adapter.CheckWorkflowExists( + exists, existingStatus, err := adapter.CheckWorkflowExists(ctx, h.inputs.WorkflowOwner, h.inputs.WorkflowName, h.inputs.WorkflowTag, @@ -248,11 +248,11 @@ func (h *handler) Execute(ctx context.Context) error { ui.Line() ui.Dim("Uploading files...") - if err := h.uploadArtifacts(); err != nil { + if err := h.uploadArtifacts(ctx); err != nil { return fmt.Errorf("failed to upload workflow: %w", err) } - err = adapter.Upsert() + err = adapter.Upsert(ctx) if err == nil { warnIfPausedWorkflowUpdate(h.existingWorkflowStatus) } @@ -262,7 +262,7 @@ func (h *handler) Execute(ctx context.Context) error { // prepareArtifacts handles compile/fetch, artifact preparation, and hashing. // Artifact upload is deferred to the deploy service so it runs after any // existing-workflow update confirmation. -func (h *handler) prepareArtifacts() error { +func (h *handler) prepareArtifacts(ctx context.Context) error { workflowcommon.DisplayWorkflowDetails( h.settings, h.runtimeContext, @@ -274,14 +274,14 @@ func (h *handler) prepareArtifacts() error { if cmdcommon.IsURL(h.inputs.WasmPath) { h.inputs.BinaryURL = h.inputs.WasmPath ui.Dim("Fetching binary from URL for workflow ID computation...") - fetched, err := cmdcommon.FetchURL(h.inputs.WasmPath) + fetched, err := cmdcommon.FetchURL(ctx, h.inputs.WasmPath) if err != nil { return fmt.Errorf("failed to fetch binary from URL: %w", err) } h.urlBinaryData = fetched ui.Success(fmt.Sprintf("Using binary URL: %s", h.inputs.WasmPath)) } else { - if err := h.Compile(); err != nil { + if err := h.Compile(ctx); err != nil { return fmt.Errorf("failed to compile workflow: %w", err) } } @@ -291,7 +291,7 @@ func (h *handler) prepareArtifacts() error { h.inputs.ConfigURL = &url h.inputs.ConfigPath = "" ui.Dim("Fetching config from URL for workflow ID computation...") - fetched, err := cmdcommon.FetchURL(url) + fetched, err := cmdcommon.FetchURL(ctx, url) if err != nil { return fmt.Errorf("failed to fetch config from URL: %w", err) } diff --git a/cmd/workflow/deploy/deploy_test.go b/cmd/workflow/deploy/deploy_test.go index 72632ce6..11dd3861 100644 --- a/cmd/workflow/deploy/deploy_test.go +++ b/cmd/workflow/deploy/deploy_test.go @@ -819,7 +819,7 @@ func (f fakeUserDonLimitClient) CheckUserDonLimit(owner common.Address, donFamil return nil } -func (f fakeUserDonLimitClient) GetWorkflowListByOwnerAndName(common.Address, string, *big.Int, *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { +func (f fakeUserDonLimitClient) GetWorkflowListByOwnerAndName(context.Context, common.Address, string, *big.Int, *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) { return f.workflowsByOwnerName, nil } @@ -879,7 +879,7 @@ func TestCheckUserDonLimitBeforeDeploy(t *testing.T) { } nameLookup := fakeUserDonLimitClient{} - err := checkUserDonLimitBeforeDeploy(client, nameLookup, owner, donFamily, workflowName, true, nil) + err := checkUserDonLimitBeforeDeploy(context.Background(), client, nameLookup, owner, donFamily, workflowName, true, nil) require.Error(t, err) assert.Contains(t, err.Error(), "workflow limit reached") }) @@ -898,7 +898,7 @@ func TestCheckUserDonLimitBeforeDeploy(t *testing.T) { }, } - err := checkUserDonLimitBeforeDeploy(client, nameLookup, owner, donFamily, workflowName, false, nil) + err := checkUserDonLimitBeforeDeploy(context.Background(), client, nameLookup, owner, donFamily, workflowName, false, nil) require.NoError(t, err) }) @@ -912,7 +912,7 @@ func TestCheckUserDonLimitBeforeDeploy(t *testing.T) { nameLookup := fakeUserDonLimitClient{} existingStatus := uint8(0) - err := checkUserDonLimitBeforeDeploy(client, nameLookup, owner, donFamily, workflowName, true, &existingStatus) + err := checkUserDonLimitBeforeDeploy(context.Background(), client, nameLookup, owner, donFamily, workflowName, true, &existingStatus) require.NoError(t, err) }) } diff --git a/cmd/workflow/deploy/limits.go b/cmd/workflow/deploy/limits.go index 4a93f003..62ec1409 100644 --- a/cmd/workflow/deploy/limits.go +++ b/cmd/workflow/deploy/limits.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "fmt" "math/big" @@ -16,7 +17,7 @@ const ( ) type workflowNameLookupClient interface { - GetWorkflowListByOwnerAndName(owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) + GetWorkflowListByOwnerAndName(ctx context.Context, owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) } type userDonLimitChecker interface { @@ -24,6 +25,7 @@ type userDonLimitChecker interface { } func checkUserDonLimitBeforeDeploy( + ctx context.Context, limitChecker userDonLimitChecker, nameLookup workflowNameLookupClient, owner common.Address, @@ -38,7 +40,7 @@ func checkUserDonLimitBeforeDeploy( pending := uint32(1) if !keepAlive { - activeSameName, err := countActiveWorkflowsByOwnerNameAndDON(nameLookup, owner, workflowName, donFamily) + activeSameName, err := countActiveWorkflowsByOwnerNameAndDON(ctx, nameLookup, owner, workflowName, donFamily) if err != nil { return fmt.Errorf("failed to check active workflows for %s on DON %s: %w", workflowName, donFamily, err) } @@ -57,6 +59,7 @@ func checkUserDonLimitBeforeDeploy( } func countActiveWorkflowsByOwnerNameAndDON( + ctx context.Context, wrc workflowNameLookupClient, owner common.Address, workflowName string, @@ -67,7 +70,7 @@ func countActiveWorkflowsByOwnerNameAndDON( limit := big.NewInt(workflowListPageSize) for { - list, err := wrc.GetWorkflowListByOwnerAndName(owner, workflowName, start, limit) + list, err := wrc.GetWorkflowListByOwnerAndName(ctx, owner, workflowName, start, limit) if err != nil { return 0, err } diff --git a/cmd/workflow/deploy/private_registry_test.go b/cmd/workflow/deploy/private_registry_test.go index 7ba1047b..5fdc58d1 100644 --- a/cmd/workflow/deploy/private_registry_test.go +++ b/cmd/workflow/deploy/private_registry_test.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "encoding/base64" "encoding/hex" "encoding/json" @@ -314,7 +315,7 @@ func TestCheckWorkflowExists_PrivateRegistry(t *testing.T) { h.environmentSet.GraphQLURL = gqlServer.URL strategy := newPrivateRegistryDeployStrategy(h) - exists, status, err := strategy.CheckWorkflowExists("", "jnowak-workflow-test-v5", "", tt.workflowID) + exists, status, err := strategy.CheckWorkflowExists(context.Background(), "", "jnowak-workflow-test-v5", "", tt.workflowID) if tt.wantErr { require.Error(t, err) if tt.errMsg != "" { diff --git a/cmd/workflow/deploy/register.go b/cmd/workflow/deploy/register.go index 6c47e237..1d884bde 100644 --- a/cmd/workflow/deploy/register.go +++ b/cmd/workflow/deploy/register.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "encoding/hex" "fmt" "time" @@ -14,7 +15,7 @@ import ( "github.com/smartcontractkit/cre-cli/internal/ui" ) -func (h *handler) upsert(onChain *settings.OnChainRegistry) error { +func (h *handler) upsert(ctx context.Context, onChain *settings.OnChainRegistry) error { if !h.validated { return fmt.Errorf("handler inputs not validated") } @@ -23,7 +24,7 @@ func (h *handler) upsert(onChain *settings.OnChainRegistry) error { if err != nil { return err } - return h.handleUpsert(params, onChain) + return h.handleUpsert(ctx, params, onChain) } func (h *handler) prepareUpsertParams() (client.RegisterWorkflowV2Parameters, error) { @@ -53,11 +54,11 @@ func (h *handler) prepareUpsertParams() (client.RegisterWorkflowV2Parameters, er }, nil } -func (h *handler) handleUpsert(params client.RegisterWorkflowV2Parameters, onChain *settings.OnChainRegistry) error { +func (h *handler) handleUpsert(ctx context.Context, params client.RegisterWorkflowV2Parameters, onChain *settings.OnChainRegistry) error { workflowName := h.inputs.WorkflowName workflowTag := h.inputs.WorkflowTag h.log.Debug().Interface("Workflow parameters", params).Msg("Registering workflow...") - txOut, err := h.wrc.UpsertWorkflow(params) + txOut, err := h.wrc.UpsertWorkflow(ctx, params) if err != nil { return fmt.Errorf("failed to register workflow: %w", err) } diff --git a/cmd/workflow/deploy/register_test.go b/cmd/workflow/deploy/register_test.go index da3b0241..9864118b 100644 --- a/cmd/workflow/deploy/register_test.go +++ b/cmd/workflow/deploy/register_test.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "path/filepath" "testing" @@ -65,7 +66,7 @@ func TestWorkflowUpsert(t *testing.T) { onChain, err := settings.AsOnChain(ctx.ResolvedRegistry, "test") require.NoError(t, err) - err = handler.upsert(onChain) + err = handler.upsert(context.Background(), onChain) require.NoError(t, err) }) } diff --git a/cmd/workflow/deploy/registry_deploy_strategy.go b/cmd/workflow/deploy/registry_deploy_strategy.go index 434b62de..4db3e806 100644 --- a/cmd/workflow/deploy/registry_deploy_strategy.go +++ b/cmd/workflow/deploy/registry_deploy_strategy.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "errors" "github.com/smartcontractkit/cre-cli/internal/settings" @@ -18,15 +19,15 @@ type registryDeployStrategy interface { // RunPreDeployChecks validates readiness and runs registry-specific // prechecks (ownership linking, duplicate detection, etc.). // Return errDeployHalted to stop the deploy without returning an error. - RunPreDeployChecks() error + RunPreDeployChecks(ctx context.Context) error // CheckWorkflowExists returns whether a same-name workflow exists for this // registry target and includes the existing workflow status for updates. - CheckWorkflowExists(workflowOwner, workflowName, workflowTag, workflowID string) (bool, *uint8, error) + CheckWorkflowExists(ctx context.Context, workflowOwner, workflowName, workflowTag, workflowID string) (bool, *uint8, error) // Upsert registers or updates the workflow in the target registry // and displays the result. - Upsert() error + Upsert(ctx context.Context) error } // newRegistryDeployStrategy returns the appropriate strategy for the given target. diff --git a/cmd/workflow/deploy/registry_deploy_strategy_onchain.go b/cmd/workflow/deploy/registry_deploy_strategy_onchain.go index ea6f1583..95c64806 100644 --- a/cmd/workflow/deploy/registry_deploy_strategy_onchain.go +++ b/cmd/workflow/deploy/registry_deploy_strategy_onchain.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "fmt" "sync" @@ -44,18 +45,31 @@ func newOnchainRegistryDeployStrategy(h *handler) (*onchainRegistryDeployStrateg return a, nil } -func (a *onchainRegistryDeployStrategy) RunPreDeployChecks() error { +func (a *onchainRegistryDeployStrategy) waitInit(ctx context.Context) error { + done := make(chan struct{}) + go func() { + a.wg.Wait() + close(done) + }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return a.initErr + } +} + +func (a *onchainRegistryDeployStrategy) RunPreDeployChecks(ctx context.Context) error { h := a.h - a.wg.Wait() - if a.initErr != nil { - return a.initErr + if err := a.waitInit(ctx); err != nil { + return err } ui.Line() ui.Dim("Verifying ownership...") if h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerType == constants.WorkflowOwnerTypeMSIG { - halt, err := h.autoLinkMSIGAndExit(a.onChain) + halt, err := h.autoLinkMSIGAndExit(ctx, a.onChain) if err != nil { return fmt.Errorf("failed to check/handle MSIG owner link status: %w", err) } @@ -63,7 +77,7 @@ func (a *onchainRegistryDeployStrategy) RunPreDeployChecks() error { return errDeployHalted } } else { - if err := h.ensureOwnerLinkedOrFail(a.onChain); err != nil { + if err := h.ensureOwnerLinkedOrFail(ctx, a.onChain); err != nil { return err } } @@ -71,8 +85,8 @@ func (a *onchainRegistryDeployStrategy) RunPreDeployChecks() error { return nil } -func (a *onchainRegistryDeployStrategy) CheckWorkflowExists(workflowOwner, workflowName, workflowTag, workflowID string) (bool, *uint8, error) { - workflow, err := a.wrc.GetWorkflow(common.HexToAddress(workflowOwner), workflowName, workflowTag) +func (a *onchainRegistryDeployStrategy) CheckWorkflowExists(ctx context.Context, workflowOwner, workflowName, workflowTag, workflowID string) (bool, *uint8, error) { + workflow, err := a.wrc.GetWorkflow(ctx, common.HexToAddress(workflowOwner), workflowName, workflowTag) if err != nil { return false, nil, err } @@ -87,10 +101,10 @@ func (a *onchainRegistryDeployStrategy) CheckWorkflowExists(workflowOwner, workf return false, nil, nil } -func (a *onchainRegistryDeployStrategy) Upsert() error { +func (a *onchainRegistryDeployStrategy) Upsert(ctx context.Context) error { h := a.h - if err := checkUserDonLimitBeforeDeploy( + if err := checkUserDonLimitBeforeDeploy(ctx, a.wrc, a.wrc, common.HexToAddress(h.inputs.WorkflowOwner), @@ -104,7 +118,7 @@ func (a *onchainRegistryDeployStrategy) Upsert() error { ui.Line() ui.Dim("Preparing deployment transaction...") - if err := h.upsert(a.onChain); err != nil { + if err := h.upsert(ctx, a.onChain); err != nil { return fmt.Errorf("failed to register workflow: %w", err) } return nil diff --git a/cmd/workflow/deploy/registry_deploy_strategy_private.go b/cmd/workflow/deploy/registry_deploy_strategy_private.go index dc14b21c..132192f6 100644 --- a/cmd/workflow/deploy/registry_deploy_strategy_private.go +++ b/cmd/workflow/deploy/registry_deploy_strategy_private.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "fmt" "strings" @@ -27,14 +28,14 @@ func (a *privateRegistryDeployStrategy) ensureClient() { } } -func (a *privateRegistryDeployStrategy) RunPreDeployChecks() error { +func (a *privateRegistryDeployStrategy) RunPreDeployChecks(context.Context) error { return nil } -func (a *privateRegistryDeployStrategy) CheckWorkflowExists(_, workflowName, _, workflowID string) (bool, *uint8, error) { +func (a *privateRegistryDeployStrategy) CheckWorkflowExists(ctx context.Context, _, workflowName, _, workflowID string) (bool, *uint8, error) { a.ensureClient() - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(ctx, workflowName) if err == nil { if workflow.WorkflowID == workflowID { return false, nil, fmt.Errorf("workflow with id %s already exists", workflowID) @@ -48,7 +49,7 @@ func (a *privateRegistryDeployStrategy) CheckWorkflowExists(_, workflowName, _, return false, nil, err } -func (a *privateRegistryDeployStrategy) Upsert() error { +func (a *privateRegistryDeployStrategy) Upsert(ctx context.Context) error { a.ensureClient() h := a.h @@ -57,7 +58,7 @@ func (a *privateRegistryDeployStrategy) Upsert() error { ui.Line() ui.Dim(fmt.Sprintf("Registering workflow in private registry (workflowID: %s)...", input.WorkflowID)) - result, err := a.prc.UpsertWorkflowInRegistry(input) + result, err := a.prc.UpsertWorkflowInRegistry(ctx, input) if err != nil { return fmt.Errorf("failed to register workflow in private registry: %w", err) } diff --git a/cmd/workflow/hash/hash.go b/cmd/workflow/hash/hash.go index 49533870..7efdb434 100644 --- a/cmd/workflow/hash/hash.go +++ b/cmd/workflow/hash/hash.go @@ -1,6 +1,7 @@ package hash import ( + "context" "fmt" "os" "strings" @@ -62,7 +63,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { DerivedOwner: runtimeContext.DerivedWorkflowOwner, } - return Execute(inputs) + return Execute(cmd.Context(), inputs) }, } @@ -81,8 +82,8 @@ func New(runtimeContext *runtime.Context) *cobra.Command { return hashCmd } -func Execute(inputs Inputs) error { - rawBinary, err := loadBinary(inputs.WasmPath, inputs.WorkflowPath, inputs.SkipTypeChecks) +func Execute(ctx context.Context, inputs Inputs) error { + rawBinary, err := loadBinary(ctx, inputs.WasmPath, inputs.WorkflowPath, inputs.SkipTypeChecks) if err != nil { return err } @@ -92,7 +93,7 @@ func Execute(inputs Inputs) error { return fmt.Errorf("failed to compress binary: %w", err) } - config, err := loadConfig(inputs.ConfigPath) + config, err := loadConfig(ctx, inputs.ConfigPath) if err != nil { return err } @@ -190,11 +191,11 @@ func isPrivateRegistryID(deploymentRegistry string) bool { return strings.EqualFold(deploymentRegistry, "private") } -func loadBinary(wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) ([]byte, error) { +func loadBinary(ctx context.Context, wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) ([]byte, error) { if wasmFlag != "" { if cmdcommon.IsURL(wasmFlag) { ui.Dim("Fetching WASM binary from URL...") - data, err := cmdcommon.FetchURL(wasmFlag) + data, err := cmdcommon.FetchURL(ctx, wasmFlag) if err != nil { return nil, fmt.Errorf("failed to fetch WASM from URL: %w", err) } @@ -221,7 +222,7 @@ func loadBinary(wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) spinner := ui.NewSpinner() spinner.Start("Compiling workflow...") - wasmBytes, err := cmdcommon.CompileWorkflowToWasm(resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasmBytes, err := cmdcommon.CompileWorkflowToWasm(ctx, resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: true, SkipTypeChecks: skipTypeChecks, }) @@ -235,13 +236,13 @@ func loadBinary(wasmFlag, workflowPathFromSettings string, skipTypeChecks bool) return wasmBytes, nil } -func loadConfig(configPath string) ([]byte, error) { +func loadConfig(ctx context.Context, configPath string) ([]byte, error) { if configPath == "" { return nil, nil } if cmdcommon.IsURL(configPath) { ui.Dim("Fetching config from URL...") - data, err := cmdcommon.FetchURL(configPath) + data, err := cmdcommon.FetchURL(ctx, configPath) if err != nil { return nil, fmt.Errorf("failed to fetch config from URL: %w", err) } diff --git a/cmd/workflow/hash/hash_test.go b/cmd/workflow/hash/hash_test.go index 0ed08a82..14809402 100644 --- a/cmd/workflow/hash/hash_test.go +++ b/cmd/workflow/hash/hash_test.go @@ -1,6 +1,7 @@ package hash import ( + "context" "crypto/sha256" "encoding/hex" "io" @@ -80,7 +81,7 @@ func TestExecute_WithForUser(t *testing.T) { WorkflowName: "test-workflow", } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -94,7 +95,7 @@ func TestExecute_WithoutForUser_UsesPrivateKey(t *testing.T) { PrivateKey: testPrivateKey, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -107,7 +108,7 @@ func TestExecute_WithoutForUser_NoKey_Errors(t *testing.T) { WorkflowName: "test-workflow", } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.Error(t, err) assert.Contains(t, err.Error(), "--public_key") } @@ -173,7 +174,7 @@ func TestExecute_HashesAreDeterministic(t *testing.T) { "workflow ID should start with version byte 00") // Running Execute should succeed (hashes are printed via ui, verified above) - err = Execute(inputs) + err = Execute(context.Background(), inputs) require.NoError(t, err) } @@ -187,7 +188,7 @@ func TestExecute_EmptyConfig(t *testing.T) { WorkflowName: "test-workflow", } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -201,7 +202,7 @@ func TestExecute_OffChainRequiresPublicKey(t *testing.T) { RegistryType: settings.RegistryTypeOffChain, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.Error(t, err) assert.Contains(t, err.Error(), "--public_key") } @@ -218,7 +219,7 @@ func TestExecute_OffChainUsesPublicKey(t *testing.T) { DerivedOwner: testDerivedOwner, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } @@ -233,7 +234,7 @@ func TestExecute_OffChainUsesDerivedOwner(t *testing.T) { DerivedOwner: testDerivedOwner, } - err := Execute(inputs) + err := Execute(context.Background(), inputs) require.NoError(t, err) } diff --git a/cmd/workflow/pause/registry_pause_strategy_onchain.go b/cmd/workflow/pause/registry_pause_strategy_onchain.go index f038d6aa..5818230e 100644 --- a/cmd/workflow/pause/registry_pause_strategy_onchain.go +++ b/cmd/workflow/pause/registry_pause_strategy_onchain.go @@ -1,6 +1,7 @@ package pause import ( + "context" "encoding/hex" "fmt" "math/big" @@ -164,7 +165,7 @@ func (a *onchainRegistryPauseStrategy) Pause() error { func fetchAllWorkflows( wrc interface { - GetWorkflowListByOwnerAndName(owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) + GetWorkflowListByOwnerAndName(ctx context.Context, owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) }, owner common.Address, name string, @@ -177,7 +178,7 @@ func fetchAllWorkflows( ) for { - list, err := wrc.GetWorkflowListByOwnerAndName(owner, name, start, limit) + list, err := wrc.GetWorkflowListByOwnerAndName(context.Background(), owner, name, start, limit) if err != nil { return nil, err } diff --git a/cmd/workflow/pause/registry_pause_strategy_private.go b/cmd/workflow/pause/registry_pause_strategy_private.go index f0f05534..2324c7be 100644 --- a/cmd/workflow/pause/registry_pause_strategy_private.go +++ b/cmd/workflow/pause/registry_pause_strategy_private.go @@ -1,6 +1,7 @@ package pause import ( + "context" "fmt" "github.com/smartcontractkit/cre-cli/internal/client/graphqlclient" @@ -32,7 +33,7 @@ func (a *privateRegistryPauseStrategy) Pause() error { ui.Dim(fmt.Sprintf("Fetching workflow to pause... Name=%s", workflowName)) - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(context.Background(), workflowName) if err != nil { return fmt.Errorf("failed to get workflow: %w", err) } diff --git a/cmd/workflow/simulate/simulate.go b/cmd/workflow/simulate/simulate.go index 9f374991..7b0d340c 100644 --- a/cmd/workflow/simulate/simulate.go +++ b/cmd/workflow/simulate/simulate.go @@ -90,7 +90,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { if err != nil { return err } - return handler.Execute(inputs) + return handler.Execute(cmd.Context(), inputs) }, } @@ -252,14 +252,14 @@ func (h *handler) ValidateInputs(inputs Inputs) error { return nil } -func (h *handler) Execute(inputs Inputs) error { +func (h *handler) Execute(ctx context.Context, inputs Inputs) error { var wasmFileBinary []byte var err error if inputs.WasmPath != "" { if cmdcommon.IsURL(inputs.WasmPath) { ui.Dim("Fetching WASM binary from URL...") - wasmFileBinary, err = cmdcommon.FetchURL(inputs.WasmPath) + wasmFileBinary, err = cmdcommon.FetchURL(ctx, inputs.WasmPath) if err != nil { return fmt.Errorf("failed to fetch WASM from URL: %w", err) } @@ -298,7 +298,7 @@ func (h *handler) Execute(inputs Inputs) error { spinner := ui.NewSpinner() spinner.Start("Compiling workflow...") - wasmFileBinary, err = cmdcommon.CompileWorkflowToWasm(resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ + wasmFileBinary, err = cmdcommon.CompileWorkflowToWasm(ctx, resolvedWorkflowPath, cmdcommon.WorkflowCompileOptions{ StripSymbols: false, SkipTypeChecks: inputs.SkipTypeChecks, }) @@ -343,7 +343,7 @@ func (h *handler) Execute(inputs Inputs) error { var config []byte if cmdcommon.IsURL(inputs.ConfigPath) { ui.Dim("Fetching config from URL...") - config, err = cmdcommon.FetchURL(inputs.ConfigPath) + config, err = cmdcommon.FetchURL(ctx, inputs.ConfigPath) if err != nil { return fmt.Errorf("failed to fetch config from URL: %w", err) } diff --git a/cmd/workflow/simulate/simulate_test.go b/cmd/workflow/simulate/simulate_test.go index 1d24423a..24d7b057 100644 --- a/cmd/workflow/simulate/simulate_test.go +++ b/cmd/workflow/simulate/simulate_test.go @@ -1,6 +1,7 @@ package simulate import ( + "context" "encoding/base64" "fmt" "io" @@ -89,7 +90,7 @@ func TestBlankWorkflowSimulation(t *testing.T) { require.NoError(t, err) // Execute the simulation. We expect this to compile the workflow and run the simulator successfully. - err = handler.Execute(inputs) + err = handler.Execute(context.Background(), inputs) require.NoError(t, err, "Execute should not return an error") } diff --git a/internal/client/privateregistryclient/privateregistryclient.go b/internal/client/privateregistryclient/privateregistryclient.go index 5e8f084b..05882585 100644 --- a/internal/client/privateregistryclient/privateregistryclient.go +++ b/internal/client/privateregistryclient/privateregistryclient.go @@ -29,8 +29,11 @@ func (c *Client) SetServiceTimeout(timeout time.Duration) { c.serviceTimeout = timeout } -func (c *Client) CreateServiceContextWithTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by callers +func (c *Client) CreateServiceContextWithTimeout(parent context.Context) (context.Context, context.CancelFunc) { + if parent == nil { + parent = context.Background() + } + return context.WithTimeout(parent, c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by callers } type OffchainWorkflow struct { @@ -131,7 +134,7 @@ type GetOffchainWorkflowByNameResponse struct { Workflow OffchainWorkflow `json:"workflow"` } -func (c *Client) GetWorkflowByName(workflowName string) (OffchainWorkflow, error) { +func (c *Client) GetWorkflowByName(ctx context.Context, workflowName string) (OffchainWorkflow, error) { if workflowName == "" { return OffchainWorkflow{}, fmt.Errorf("workflowName is required") } @@ -165,10 +168,10 @@ query GetOffchainWorkflowByName($request: GetOffchainWorkflowByNameRequest!) { GetOffchainWorkflowByName GetOffchainWorkflowByNameResponse `json:"getOffchainWorkflowByName"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + serviceCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() - if err := c.graphql.Execute(ctx, req, &container); err != nil { + if err := c.graphql.Execute(serviceCtx, req, &container); err != nil { return OffchainWorkflow{}, fmt.Errorf("get workflow by name in registry: %w", err) } @@ -178,7 +181,7 @@ query GetOffchainWorkflowByName($request: GetOffchainWorkflowByNameRequest!) { return container.GetOffchainWorkflowByName.Workflow, nil } -func (c *Client) UpsertWorkflowInRegistry(workflow OffchainWorkflowInput) (OffchainWorkflow, error) { +func (c *Client) UpsertWorkflowInRegistry(ctx context.Context, workflow OffchainWorkflowInput) (OffchainWorkflow, error) { if err := validateUpsertWorkflowInput(workflow); err != nil { return OffchainWorkflow{}, err } @@ -209,10 +212,10 @@ mutation UpsertOffchainWorkflow($request: UpsertOffchainWorkflowRequest!) { UpsertOffchainWorkflow UpsertOffchainWorkflowResponse `json:"upsertOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + serviceCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() - if err := c.graphql.Execute(ctx, req, &container); err != nil { + if err := c.graphql.Execute(serviceCtx, req, &container); err != nil { return OffchainWorkflow{}, fmt.Errorf("upsert workflow in registry: %w", err) } @@ -253,7 +256,7 @@ mutation PauseOffchainWorkflow($request: PauseOffchainWorkflowRequest!) { PauseOffchainWorkflow PauseOffchainWorkflowResponse `json:"pauseOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + ctx, cancel := c.CreateServiceContextWithTimeout(context.Background()) defer cancel() if err := c.graphql.Execute(ctx, req, &container); err != nil { @@ -297,7 +300,7 @@ mutation ActivateOffchainWorkflow($request: ActivateOffchainWorkflowRequest!) { ActivateOffchainWorkflow ActivateOffchainWorkflowResponse `json:"activateOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + ctx, cancel := c.CreateServiceContextWithTimeout(context.Background()) defer cancel() if err := c.graphql.Execute(ctx, req, &container); err != nil { @@ -329,7 +332,7 @@ mutation DeleteOffchainWorkflow($request: DeleteOffchainWorkflowRequest!) { DeleteOffchainWorkflow DeleteOffchainWorkflowResponse `json:"deleteOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + ctx, cancel := c.CreateServiceContextWithTimeout(context.Background()) defer cancel() if err := c.graphql.Execute(ctx, req, &container); err != nil { diff --git a/internal/client/privateregistryclient/privateregistryclient_test.go b/internal/client/privateregistryclient/privateregistryclient_test.go index 578fbd06..59fce799 100644 --- a/internal/client/privateregistryclient/privateregistryclient_test.go +++ b/internal/client/privateregistryclient/privateregistryclient_test.go @@ -1,6 +1,7 @@ package privateregistryclient import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -147,7 +148,7 @@ func TestUpsertWorkflowInRegistry(t *testing.T) { configURL := "s3://config" tag := "v1" attributes := "{\"region\":\"us-east-1\"}" - result, err := client.UpsertWorkflowInRegistry(OffchainWorkflowInput{ + result, err := client.UpsertWorkflowInRegistry(context.Background(), OffchainWorkflowInput{ WorkflowID: "wf-123", Status: WorkflowStatusActive, WorkflowName: "registry-workflow", @@ -187,7 +188,7 @@ func TestUpsertWorkflowInRegistry_GQLError(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - _, err := client.UpsertWorkflowInRegistry(OffchainWorkflowInput{ + _, err := client.UpsertWorkflowInRegistry(context.Background(), OffchainWorkflowInput{ WorkflowID: "wf-123", Status: WorkflowStatusActive, WorkflowName: "registry-workflow", @@ -236,7 +237,7 @@ func TestGetWorkflowByName(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - result, err := client.GetWorkflowByName("registry-workflow") + result, err := client.GetWorkflowByName(context.Background(), "registry-workflow") require.NoError(t, err) assert.Contains(t, capturedQuery, "query GetOffchainWorkflowByName") @@ -262,7 +263,7 @@ func TestGetWorkflowByName_GQLError(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - _, err := client.GetWorkflowByName("registry-workflow") + _, err := client.GetWorkflowByName(context.Background(), "registry-workflow") require.Error(t, err) assert.Contains(t, err.Error(), "get workflow by name in registry") assert.Contains(t, err.Error(), "cre api error: workflow not found") @@ -273,7 +274,7 @@ func TestGetWorkflowByName_EmptyName(t *testing.T) { logger := testutil.NewTestLogger() client := New(nil, logger) - _, err := client.GetWorkflowByName("") + _, err := client.GetWorkflowByName(context.Background(), "") require.EqualError(t, err, "workflowName is required") } @@ -414,7 +415,7 @@ func TestCreateServiceContextWithTimeout(t *testing.T) { client := New(nil, &logger) client.SetServiceTimeout(150 * time.Millisecond) - ctx, cancel := client.CreateServiceContextWithTimeout() + ctx, cancel := client.CreateServiceContextWithTimeout(context.Background()) defer cancel() deadline, ok := ctx.Deadline() diff --git a/internal/client/storageclient/storageclient.go b/internal/client/storageclient/storageclient.go index 5eb9a829..8a6c4b90 100644 --- a/internal/client/storageclient/storageclient.go +++ b/internal/client/storageclient/storageclient.go @@ -69,15 +69,21 @@ func (c *Client) SetHTTPTimeout(timeout time.Duration) { c.httpTimeout = timeout } -func (c *Client) CreateServiceContextWithTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers +func (c *Client) CreateServiceContextWithTimeout(parent context.Context) (context.Context, context.CancelFunc) { + if parent == nil { + parent = context.Background() + } + return context.WithTimeout(parent, c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers } -func (c *Client) CreateHttpContextWithTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), c.httpTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers +func (c *Client) CreateHttpContextWithTimeout(parent context.Context) (context.Context, context.CancelFunc) { + if parent == nil { + parent = context.Background() + } + return context.WithTimeout(parent, c.httpTimeout) //nolint:gosec // G118 -- cancel is deferred by all callers } -func (c *Client) GeneratePostUrlForArtifact(workflowId string, artifactType ArtifactType, content []byte) (GeneratePresignedPostUrlForArtifactResponse, error) { +func (c *Client) GeneratePostUrlForArtifact(ctx context.Context, workflowId string, artifactType ArtifactType, content []byte) (GeneratePresignedPostUrlForArtifactResponse, error) { const mutation = ` mutation GeneratePresignedPostUrlForArtifact($artifact: GeneratePresignedPostUrlRequest!) { generatePresignedPostUrlForArtifact(artifact: $artifact) { @@ -102,11 +108,11 @@ mutation GeneratePresignedPostUrlForArtifact($artifact: GeneratePresignedPostUrl GeneratePresignedPostUrlForArtifact GeneratePresignedPostUrlForArtifactResponse `json:"generatePresignedPostUrlForArtifact"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + serviceCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() if err := c.graphql. - Execute(ctx, req, &container); err != nil { + Execute(serviceCtx, req, &container); err != nil { return GeneratePresignedPostUrlForArtifactResponse{}, err } @@ -116,7 +122,7 @@ mutation GeneratePresignedPostUrlForArtifact($artifact: GeneratePresignedPostUrl return container.GeneratePresignedPostUrlForArtifact, nil } -func (c *Client) GenerateUnsignedGetUrlForArtifact(workflowId string, artifactType ArtifactType) (GenerateUnsignedGetUrlForArtifactResponse, error) { +func (c *Client) GenerateUnsignedGetUrlForArtifact(ctx context.Context, workflowId string, artifactType ArtifactType) (GenerateUnsignedGetUrlForArtifactResponse, error) { const mutation = ` mutation GenerateUnsignedGetUrlForArtifact($artifact: GenerateUnsignedGetUrlRequest!) { generateUnsignedGetUrlForArtifact(artifact: $artifact) { @@ -134,11 +140,11 @@ mutation GenerateUnsignedGetUrlForArtifact($artifact: GenerateUnsignedGetUrlRequ GenerateUnsignedGetUrlForArtifact GenerateUnsignedGetUrlForArtifactResponse `json:"generateUnsignedGetUrlForArtifact"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + serviceCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() if err := c.graphql. - Execute(ctx, req, &container); err != nil { + Execute(serviceCtx, req, &container); err != nil { return GenerateUnsignedGetUrlForArtifactResponse{}, err } @@ -154,7 +160,7 @@ func calculateContentHash(content []byte) string { return contentHash } -func (c *Client) UploadToOrigin(g GeneratePresignedPostUrlForArtifactResponse, content []byte, contentType string) error { +func (c *Client) UploadToOrigin(ctx context.Context, g GeneratePresignedPostUrlForArtifactResponse, content []byte, contentType string) error { c.log.Debug().Str("URL", g.PresignedPostURL).Msg("Uploading content to origin") var b bytes.Buffer @@ -197,10 +203,10 @@ func (c *Client) UploadToOrigin(g GeneratePresignedPostUrlForArtifactResponse, c return err } - ctx, cancel := c.CreateHttpContextWithTimeout() + httpCtx, cancel := c.CreateHttpContextWithTimeout(ctx) defer cancel() - httpReq, err := http.NewRequestWithContext(ctx, "POST", g.PresignedPostURL, &b) + httpReq, err := http.NewRequestWithContext(httpCtx, "POST", g.PresignedPostURL, &b) if err != nil { c.log.Error().Err(err).Msg("Failed to create HTTP request") return err @@ -231,6 +237,7 @@ func (c *Client) UploadToOrigin(g GeneratePresignedPostUrlForArtifactResponse, c } func (c *Client) UploadArtifactWithRetriesAndGetURL( + ctx context.Context, workflowID string, artifactType ArtifactType, content []byte, @@ -251,7 +258,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( err := retry.Do( func() error { var err error - g, err = c.GeneratePostUrlForArtifact(workflowID, artifactType, content) + g, err = c.GeneratePostUrlForArtifact(ctx, workflowID, artifactType, content) if err != nil { if strings.Contains(err.Error(), "already exists") { shouldUpload = false @@ -264,6 +271,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( }, retry.Attempts(3), retry.LastErrorOnly(true), + retry.Context(ctx), ) if err != nil { c.log.Error().Err(err).Msg("Failed to generate presigned post URL for artifact") @@ -276,10 +284,11 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( if shouldUpload { err = retry.Do( func() error { - return c.UploadToOrigin(g, content, contentType) + return c.UploadToOrigin(ctx, g, content, contentType) }, retry.Attempts(3), retry.LastErrorOnly(true), + retry.Context(ctx), ) if err != nil { c.log.Error().Err(err).Msg("Failed to upload content to origin") @@ -290,7 +299,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( var g2 GenerateUnsignedGetUrlForArtifactResponse err = retry.Do( func() error { - g2, err = c.GenerateUnsignedGetUrlForArtifact(workflowID, artifactType) + g2, err = c.GenerateUnsignedGetUrlForArtifact(ctx, workflowID, artifactType) if err != nil { return fmt.Errorf("generate unsigned get url: %w", err) } @@ -298,6 +307,7 @@ func (c *Client) UploadArtifactWithRetriesAndGetURL( }, retry.Attempts(3), retry.LastErrorOnly(true), + retry.Context(ctx), ) if err != nil { c.log.Error().Err(err).Msg("Failed to generate unsigned get URL for artifact") diff --git a/internal/testutil/workflow/test_workflow.go b/internal/testutil/workflow/test_workflow.go index 5c43c1cd..7451fd06 100644 --- a/internal/testutil/workflow/test_workflow.go +++ b/internal/testutil/workflow/test_workflow.go @@ -1,6 +1,7 @@ package workflowtest import ( + "context" "testing" "github.com/ethereum/go-ethereum/common" @@ -26,6 +27,6 @@ func RegisterWorkflow(t *testing.T, wrc *client.WorkflowRegistryV2Client, workfl DonFamily: "1", } - _, err = wrc.UpsertWorkflow(params) + _, err = wrc.UpsertWorkflow(context.Background(), params) require.NoError(t, err, "Failed to register workflow") }