Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/account/link_key/link_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (h *handler) checkIfAlreadyLinked() (bool, error) {
ownerAddr := common.HexToAddress(h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerAddress)
ui.Dim("Checking existing registrations...")

linked, err := h.wrc.IsOwnerLinked(ownerAddr)
linked, err := h.wrc.IsOwnerLinked(context.Background(), ownerAddr)
if err != nil {
return false, fmt.Errorf("failed to check owner link status: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/account/unlink_key/unlink_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func (h *handler) unlinkOwner(owner string, resp initiateUnlinkingResponse) erro
func (h *handler) checkIfAlreadyLinked() (bool, error) {
ownerAddr := common.HexToAddress(h.settings.Workflow.UserWorkflowSettings.WorkflowOwnerAddress)

linked, err := h.wrc.IsOwnerLinked(ownerAddr)
linked, err := h.wrc.IsOwnerLinked(context.Background(), ownerAddr)
if err != nil {
return false, fmt.Errorf("failed to check owner link status: %w", err)
}
Expand Down
7 changes: 4 additions & 3 deletions cmd/client/tx.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -119,7 +120,7 @@ type RawTx struct {
// return txOpts, nil
//}

func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts) (*types.Transaction, error), funName string, validationEvent string, args ...any) (TxOutput, error) {
func (c *TxClient) executeTransactionByTxType(ctx context.Context, txFn func(opts *bind.TransactOpts) (*types.Transaction, error), funName string, validationEvent string, args ...any) (TxOutput, error) {
switch c.config.TxType {
case Regular:
simulateTx, err := txFn(cmdCommon.SimTransactOpts())
Expand All @@ -138,7 +139,7 @@ func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts)
Value: simulateTx.Value(),
Data: simulateTx.Data(),
}
estimatedGas, gasErr := c.EthClient.Client.EstimateGas(c.EthClient.Context, msg)
estimatedGas, gasErr := c.EthClient.Client.EstimateGas(ctx, msg)
if gasErr != nil {
c.Logger.Warn().Err(gasErr).Msg("Failed to estimate gas usage")
}
Expand All @@ -159,7 +160,7 @@ func (c *TxClient) executeTransactionByTxType(txFn func(opts *bind.TransactOpts)

// Calculate and print total cost for sending the transaction on-chain
if gasErr == nil {
gasPriceWei, gasPriceErr := c.EthClient.Client.SuggestGasPrice(c.EthClient.Context)
gasPriceWei, gasPriceErr := c.EthClient.Client.SuggestGasPrice(ctx)
if gasPriceErr != nil {
Comment on lines 123 to 164
c.Logger.Warn().Err(gasPriceErr).Msg("Failed to fetch gas price")
} else {
Expand Down
37 changes: 23 additions & 14 deletions cmd/client/workflow_registry_v2_client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"encoding/hex"
"errors"
"fmt"
Expand All @@ -24,6 +25,14 @@ type workflowRegistryV2Contract interface {
IsRequestAllowlisted(opts *bind.CallOpts, owner common.Address, requestDigest [32]byte) (bool, error)
}

func (wrc *WorkflowRegistryV2Client) callOpts(ctx context.Context) *bind.CallOpts {
opts := wrc.EthClient.NewCallOpts()
if ctx != nil {
opts.Context = ctx
}
return opts
}

type WorkflowRegistryV2Client struct {
TxClient
ContractAddress common.Address
Expand Down Expand Up @@ -73,7 +82,7 @@ func (wrc *WorkflowRegistryV2Client) LinkOwner(validityTimestamp *big.Int, proof
txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) {
return contract.LinkOwner(opts, validityTimestamp, proof, signature)
}
txOut, err := wrc.executeTransactionByTxType(txFn, "LinkOwner", "OwnershipLinkUpdated", validityTimestamp, proof, signature)
txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "LinkOwner", "OwnershipLinkUpdated", validityTimestamp, proof, signature)
if err != nil {
wrc.Logger.Error().
Str("contract", contract.Address().Hex()).
Expand Down Expand Up @@ -105,7 +114,7 @@ func (wrc *WorkflowRegistryV2Client) UnlinkOwner(owner common.Address, validityT
txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) {
return contract.UnlinkOwner(opts, owner, validityTimestamp, signature)
}
txOut, err := wrc.executeTransactionByTxType(txFn, "UnlinkOwner", "OwnershipLinkUpdated", owner, validityTimestamp, signature)
txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "UnlinkOwner", "OwnershipLinkUpdated", owner, validityTimestamp, signature)
if err != nil {
wrc.Logger.Error().
Str("contract", contract.Address().Hex()).
Expand Down Expand Up @@ -414,7 +423,7 @@ func (wrc *WorkflowRegistryV2Client) IsAllowedSigner(signer common.Address) (boo
return ok, err
}

func (wrc *WorkflowRegistryV2Client) IsOwnerLinked(owner common.Address) (bool, error) {
func (wrc *WorkflowRegistryV2Client) IsOwnerLinked(ctx context.Context, owner common.Address) (bool, error) {
contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client)
if err != nil {
wrc.Logger.Error().
Expand All @@ -424,7 +433,7 @@ func (wrc *WorkflowRegistryV2Client) IsOwnerLinked(owner common.Address) (bool,
}

result, err := callContractMethodV2(wrc, func() (bool, error) {
return contract.IsOwnerLinked(wrc.EthClient.NewCallOpts(), owner)
return contract.IsOwnerLinked(wrc.callOpts(ctx), owner)
})
if err != nil {
wrc.Logger.Error().
Expand Down Expand Up @@ -468,7 +477,7 @@ func (wrc *WorkflowRegistryV2Client) TypeAndVersion() (string, error) {
return tv, err
}

func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(params RegisterWorkflowV2Parameters) (*TxOutput, error) {
func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(ctx context.Context, params RegisterWorkflowV2Parameters) (*TxOutput, error) {
contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client)
if err != nil {
wrc.Logger.Error().
Expand All @@ -492,7 +501,7 @@ func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(params RegisterWorkflowV2Par
params.KeepAlive,
)
}
txOut, err := wrc.executeTransactionByTxType(txFn, "UpsertWorkflow", "WorkflowRegistered|WorkflowUpdated",
txOut, err := wrc.executeTransactionByTxType(ctx, txFn, "UpsertWorkflow", "WorkflowRegistered|WorkflowUpdated",
params.WorkflowName,
params.Tag,
params.WorkflowID,
Expand All @@ -513,31 +522,31 @@ func (wrc *WorkflowRegistryV2Client) UpsertWorkflow(params RegisterWorkflowV2Par
return &txOut, nil
}

func (wrc *WorkflowRegistryV2Client) GetWorkflow(owner common.Address, workflowName, tag string) (workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) {
func (wrc *WorkflowRegistryV2Client) GetWorkflow(ctx context.Context, owner common.Address, workflowName, tag string) (workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) {
contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client)
if err != nil {
wrc.Logger.Error().Err(err).Msg("Failed to connect for GetWorkflow")
return workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView{}, err
}

result, err := callContractMethodV2(wrc, func() (workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) {
return contract.GetWorkflow(wrc.EthClient.NewCallOpts(), owner, workflowName, tag)
return contract.GetWorkflow(wrc.callOpts(ctx), owner, workflowName, tag)
})
if err != nil {
wrc.Logger.Error().Err(err).Msg("GetWorkflow call failed")
}
return result, err
}

func (wrc *WorkflowRegistryV2Client) GetWorkflowListByOwnerAndName(owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) {
func (wrc *WorkflowRegistryV2Client) GetWorkflowListByOwnerAndName(ctx context.Context, owner common.Address, workflowName string, start, limit *big.Int) ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) {
contract, err := workflow_registry_v2_wrapper.NewWorkflowRegistry(wrc.ContractAddress, wrc.EthClient.Client)
if err != nil {
wrc.Logger.Error().Err(err).Msg("Failed to connect for GetWorkflowListByOwnerAndName")
return nil, err
}

result, err := callContractMethodV2(wrc, func() ([]workflow_registry_v2_wrapper.WorkflowRegistryWorkflowMetadataView, error) {
return contract.GetWorkflowListByOwnerAndName(wrc.EthClient.NewCallOpts(), owner, workflowName, start, limit)
return contract.GetWorkflowListByOwnerAndName(wrc.callOpts(ctx), owner, workflowName, start, limit)
})
if err != nil {
wrc.Logger.Error().Err(err).Msg("GetWorkflowListByOwnerAndName call failed")
Expand Down Expand Up @@ -619,7 +628,7 @@ func (wrc *WorkflowRegistryV2Client) DeleteWorkflow(workflowID [32]byte) (*TxOut
txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) {
return contract.DeleteWorkflow(opts, workflowID)
}
txOut, err := wrc.executeTransactionByTxType(txFn, "DeleteWorkflow", "WorkflowDeleted", workflowID)
txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "DeleteWorkflow", "WorkflowDeleted", workflowID)
if err != nil {
wrc.Logger.Error().
Str("contract", contract.Address().Hex()).
Expand All @@ -646,7 +655,7 @@ func (wrc *WorkflowRegistryV2Client) BatchPauseWorkflows(workflowIDs [][32]byte)
workflowIDs,
)
}
txOut, err := wrc.executeTransactionByTxType(txFn, "BatchPauseWorkflows", "WorkflowStatusUpdated", workflowIDs)
txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "BatchPauseWorkflows", "WorkflowStatusUpdated", workflowIDs)
if err != nil {
wrc.Logger.Error().
Str("contract", contract.Address().Hex()).
Expand All @@ -670,7 +679,7 @@ func (wrc *WorkflowRegistryV2Client) ActivateWorkflow(workflowID [32]byte, donFa
txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) {
return contract.ActivateWorkflow(opts, workflowID, donFamily)
}
txOut, err := wrc.executeTransactionByTxType(txFn, "ActivateWorkflow", "WorkflowActivated", workflowID, donFamily)
txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "ActivateWorkflow", "WorkflowActivated", workflowID, donFamily)
if err != nil {
wrc.Logger.Error().
Str("contract", contract.Address().Hex()).
Expand Down Expand Up @@ -772,7 +781,7 @@ func (wrc *WorkflowRegistryV2Client) AllowlistRequest(requestDigest [32]byte, du
txFn := func(opts *bind.TransactOpts) (*types.Transaction, error) {
return contract.AllowlistRequest(opts, requestDigest, deadline)
}
txOut, err := wrc.executeTransactionByTxType(txFn, "AllowlistRequest", "RequestAllowlisted", requestDigest, duration)
txOut, err := wrc.executeTransactionByTxType(context.Background(), txFn, "AllowlistRequest", "RequestAllowlisted", requestDigest, duration)
if err != nil {
wrc.Logger.Error().
Str("contract", wrc.ContractAddress.Hex()).
Expand Down
15 changes: 8 additions & 7 deletions cmd/common/compile.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -33,15 +34,15 @@ type WorkflowCompileOptions struct {
}

// getBuildCmd returns a single step that builds the workflow and returns the WASM bytes.
func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCompileOptions) (func() ([]byte, error), error) {
func getBuildCmd(ctx context.Context, workflowRootFolder, mainFile, language string, opts WorkflowCompileOptions) (func() ([]byte, error), error) {
tmpPath := filepath.Join(workflowRootFolder, ".cre_build_tmp.wasm")
switch language {
case constants.WorkflowLanguageTypeScript:
args := []string{"cre-compile", mainFile, tmpPath}
if opts.SkipTypeChecks {
args = append(args, SkipTypeChecksFlag)
}
cmd := exec.Command("bun", args...)
cmd := exec.CommandContext(ctx, "bun", args...)
cmd.Dir = workflowRootFolder
return func() ([]byte, error) {
out, err := cmd.CombinedOutput()
Expand All @@ -67,7 +68,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom
if opts.StripSymbols {
ldflags = "-buildid= -w -s"
}
cmd := exec.Command(
cmd := exec.CommandContext(ctx,
"go", "build",
"-o", tmpPath,
"-trimpath",
Expand All @@ -92,7 +93,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom
if err != nil {
return nil, err
}
makeCmd := exec.Command("make", "build")
makeCmd := exec.CommandContext(ctx, "make", "build")
makeCmd.Dir = makeRoot
builtPath := filepath.Join(makeRoot, defaultWasmOutput)
return func() ([]byte, error) {
Expand All @@ -108,7 +109,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom
if opts.StripSymbols {
ldflags = "-buildid= -w -s"
}
cmd := exec.Command(
cmd := exec.CommandContext(ctx,
"go", "build",
"-o", tmpPath,
"-trimpath",
Expand All @@ -135,7 +136,7 @@ func getBuildCmd(workflowRootFolder, mainFile, language string, opts WorkflowCom
// opts.StripSymbols: for Go builds, true strips debug symbols (deploy); false keeps them (simulate).
// opts.SkipTypeChecks: for TypeScript, passes SkipTypeChecksFlag to cre-compile.
// For custom Makefile WASM builds, StripSymbols and SkipTypeChecks have no effect.
func CompileWorkflowToWasm(workflowPath string, opts WorkflowCompileOptions) ([]byte, error) {
func CompileWorkflowToWasm(ctx context.Context, workflowPath string, opts WorkflowCompileOptions) ([]byte, error) {
workflowRootFolder, workflowMainFile, err := WorkflowPathRootAndMain(workflowPath)
if err != nil {
return nil, fmt.Errorf("workflow path: %w", err)
Expand Down Expand Up @@ -167,7 +168,7 @@ func CompileWorkflowToWasm(workflowPath string, opts WorkflowCompileOptions) ([]
return nil, fmt.Errorf("unsupported workflow language for file %s", workflowMainFile)
}

buildStep, err := getBuildCmd(workflowRootFolder, workflowMainFile, language, opts)
buildStep, err := getBuildCmd(ctx, workflowRootFolder, workflowMainFile, language, opts)
if err != nil {
Comment on lines 139 to 172
return nil, err
}
Expand Down
17 changes: 9 additions & 8 deletions cmd/common/compile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package common

import (
"bytes"
"context"
"io"
"os"
"os/exec"
Expand Down Expand Up @@ -47,29 +48,29 @@ func TestFindMakefileRoot(t *testing.T) {
func TestCompileWorkflowToWasm_Go_Success(t *testing.T) {
t.Run("basic_workflow", func(t *testing.T) {
path := deployTestdataPath("basic_workflow", "main.go")
wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true})
wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true})
require.NoError(t, err)
assert.NotEmpty(t, wasm)
})

t.Run("configless_workflow", func(t *testing.T) {
path := deployTestdataPath("configless_workflow", "main.go")
wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true})
wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true})
require.NoError(t, err)
assert.NotEmpty(t, wasm)
})

t.Run("missing_go_mod", func(t *testing.T) {
path := deployTestdataPath("missing_go_mod", "main.go")
wasm, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true})
wasm, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true})
require.NoError(t, err)
assert.NotEmpty(t, wasm)
})
}

func TestCompileWorkflowToWasm_Go_Malformed_Fails(t *testing.T) {
path := deployTestdataPath("malformed_workflow", "main.go")
_, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true})
_, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to compile workflow")
assert.Contains(t, err.Error(), "undefined: sdk.RemovedFunctionThatFailsCompilation")
Expand All @@ -80,7 +81,7 @@ func TestCompileWorkflowToWasm_Wasm_Success(t *testing.T) {
_ = os.Remove(wasmPath)
t.Cleanup(func() { _ = os.Remove(wasmPath) })

wasm, err := CompileWorkflowToWasm(wasmPath, WorkflowCompileOptions{StripSymbols: true})
wasm, err := CompileWorkflowToWasm(context.Background(), wasmPath, WorkflowCompileOptions{StripSymbols: true})
require.NoError(t, err)
assert.NotEmpty(t, wasm)

Expand All @@ -96,14 +97,14 @@ func TestCompileWorkflowToWasm_Wasm_Fails(t *testing.T) {
wasmPath := filepath.Join(wasmDir, "workflow.wasm")
require.NoError(t, os.WriteFile(wasmPath, []byte("not really wasm"), 0600))

_, err := CompileWorkflowToWasm(wasmPath, WorkflowCompileOptions{StripSymbols: true})
_, err := CompileWorkflowToWasm(context.Background(), wasmPath, WorkflowCompileOptions{StripSymbols: true})
require.Error(t, err)
assert.Contains(t, err.Error(), "no Makefile found")
})

t.Run("make_build_fails", func(t *testing.T) {
path := deployTestdataPath("wasm_make_fails", "wasm", "workflow.wasm")
_, err := CompileWorkflowToWasm(path, WorkflowCompileOptions{StripSymbols: true})
_, err := CompileWorkflowToWasm(context.Background(), path, WorkflowCompileOptions{StripSymbols: true})
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to compile workflow")
assert.Contains(t, err.Error(), "build output:")
Expand Down Expand Up @@ -138,7 +139,7 @@ func TestCompileWorkflowToWasm_TS_Success(t *testing.T) {
"include": ["main.ts"]
}
`), 0600))
wasm, err := CompileWorkflowToWasm(mainPath, WorkflowCompileOptions{StripSymbols: true})
wasm, err := CompileWorkflowToWasm(context.Background(), mainPath, WorkflowCompileOptions{StripSymbols: true})
if err != nil {
t.Skipf("TS compile failed (published cre-sdk may lack full layout): %v", err)
}
Expand Down
11 changes: 8 additions & 3 deletions cmd/common/fetch.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -28,12 +29,16 @@ func IsURL(s string) bool {
}

// FetchURL performs an HTTP GET and returns the response body bytes.
func FetchURL(url string) ([]byte, error) {
resp, err := http.Get(url) //nolint:gosec,noctx
func FetchURL(ctx context.Context, url string) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("HTTP GET %s: %w", url, err)
}
defer resp.Body.Close()
resp, err := http.DefaultClient.Do(req) //nolint:gosec // URL validated by caller
Comment on lines 31 to +37
if err != nil {
return nil, fmt.Errorf("HTTP GET %s: %w", url, err)
}
defer func() { _ = resp.Body.Close() }()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP GET %s returned status %d", url, resp.StatusCode)
Expand Down
Loading
Loading