diff --git a/cmd/workflow/activate/activate.go b/cmd/workflow/activate/activate.go index d250fd61..7e63cf1a 100644 --- a/cmd/workflow/activate/activate.go +++ b/cmd/workflow/activate/activate.go @@ -1,6 +1,7 @@ package activate import ( + "context" "fmt" "github.com/rs/zerolog" @@ -47,7 +48,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { if err := handler.ValidateInputs(); err != nil { return err } - return handler.Execute() + return handler.Execute(cmd.Context()) }, } @@ -66,6 +67,7 @@ type handler struct { runtimeContext *runtime.Context validated bool + execCtx context.Context } func newHandler(ctx *runtime.Context) *handler { @@ -110,7 +112,9 @@ func (h *handler) ValidateInputs() error { return nil } -func (h *handler) Execute() error { +func (h *handler) Execute(ctx context.Context) error { + h.execCtx = ctx + if !h.validated { return fmt.Errorf("handler inputs not validated") } diff --git a/cmd/workflow/activate/activate_test.go b/cmd/workflow/activate/activate_test.go index f94522aa..0e9e753a 100644 --- a/cmd/workflow/activate/activate_test.go +++ b/cmd/workflow/activate/activate_test.go @@ -1,6 +1,7 @@ package activate import ( + "context" "errors" "testing" @@ -35,7 +36,7 @@ func TestNonInteractive_WithoutYes_ReturnsError(t *testing.T) { } handler.validated = true - err := handler.Execute() + err := handler.Execute(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "missing required flags for --non-interactive mode") } @@ -62,7 +63,7 @@ func TestNonInteractive_WithYes_PassesGuard(t *testing.T) { } handler.validated = true - err := handler.Execute() + err := handler.Execute(context.Background()) // Guard passes; error comes from WRC (no matching workflow), not the guard require.Error(t, err) require.NotContains(t, err.Error(), "missing required flags for --non-interactive mode") diff --git a/cmd/workflow/activate/registry_activate_strategy_private.go b/cmd/workflow/activate/registry_activate_strategy_private.go index 1b1e1df9..b73bb4bd 100644 --- a/cmd/workflow/activate/registry_activate_strategy_private.go +++ b/cmd/workflow/activate/registry_activate_strategy_private.go @@ -32,7 +32,7 @@ func (a *privateRegistryActivateStrategy) Activate() error { ui.Dim(fmt.Sprintf("Fetching workflow to activate... Name=%s", workflowName)) - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(a.h.execCtx, workflowName) if err != nil { return fmt.Errorf("failed to get workflow: %w", err) } @@ -45,7 +45,7 @@ func (a *privateRegistryActivateStrategy) Activate() error { ui.Dim(fmt.Sprintf("Processing activation for workflow ID %s...", workflow.WorkflowID)) - result, err := a.prc.ActivateWorkflowInRegistry(workflow.WorkflowID) + result, err := a.prc.ActivateWorkflowInRegistry(a.h.execCtx, workflow.WorkflowID) if err != nil { return fmt.Errorf("failed to activate workflow in private registry: %w", err) } diff --git a/cmd/workflow/delete/delete.go b/cmd/workflow/delete/delete.go index 60db53c3..f15fb95c 100644 --- a/cmd/workflow/delete/delete.go +++ b/cmd/workflow/delete/delete.go @@ -1,6 +1,7 @@ package delete import ( + "context" "fmt" "io" @@ -44,7 +45,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { if err != nil { return err } - return handler.Execute() + return handler.Execute(cmd.Context()) }, } @@ -66,6 +67,7 @@ type handler struct { runtimeContext *runtime.Context validated bool + execCtx context.Context } func newHandler(ctx *runtime.Context, stdin io.Reader) *handler { @@ -128,11 +130,13 @@ func (h *handler) ValidateInputs() error { return nil } -func (h *handler) Execute() error { +func (h *handler) Execute(ctx context.Context) error { if !h.validated { return fmt.Errorf("handler inputs not validated") } + h.execCtx = ctx + adapter, err := newRegistryDeleteStrategy(h.runtimeContext.ResolvedRegistry, h) if err != nil { return err diff --git a/cmd/workflow/delete/registry_delete_strategy_private.go b/cmd/workflow/delete/registry_delete_strategy_private.go index 22f832bf..b5d52864 100644 --- a/cmd/workflow/delete/registry_delete_strategy_private.go +++ b/cmd/workflow/delete/registry_delete_strategy_private.go @@ -32,7 +32,7 @@ func (a *privateRegistryDeleteStrategy) FetchWorkflows() ([]WorkflowToDelete, er ui.Dim(fmt.Sprintf("Fetching workflow to delete... Name=%s", workflowName)) - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(a.h.execCtx, workflowName) if err != nil { return nil, fmt.Errorf("failed to get workflow: %w", err) } @@ -55,7 +55,7 @@ func (a *privateRegistryDeleteStrategy) DeleteWorkflows(workflows []WorkflowToDe for _, wf := range workflows { workflowID := wf.RawID.(string) - deletedID, err := a.prc.DeleteWorkflowInRegistry(workflowID) + deletedID, err := a.prc.DeleteWorkflowInRegistry(a.h.execCtx, workflowID) if err != nil { h.log.Error(). Err(err). diff --git a/cmd/workflow/deploy/deploy.go b/cmd/workflow/deploy/deploy.go index 1ce740a9..34f68f9b 100644 --- a/cmd/workflow/deploy/deploy.go +++ b/cmd/workflow/deploy/deploy.go @@ -73,6 +73,8 @@ type handler struct { // existingWorkflowStatus stores the status of an existing workflow when updating. // nil means this is a new workflow, otherwise it contains the current status (0=active, 1=paused). existingWorkflowStatus *uint8 + + execCtx context.Context } var defaultOutputPath = "./binary.wasm.br.b64" @@ -209,6 +211,8 @@ func (h *handler) Execute(ctx context.Context) error { return fmt.Errorf("handler inputs not validated") } + h.execCtx = ctx + deployAccess, err := h.credentials.GetDeploymentAccessStatus() if err != nil { return fmt.Errorf("failed to check deployment access: %w", err) diff --git a/cmd/workflow/deploy/private_registry_test.go b/cmd/workflow/deploy/private_registry_test.go index 714b2918..db0ed61b 100644 --- a/cmd/workflow/deploy/private_registry_test.go +++ b/cmd/workflow/deploy/private_registry_test.go @@ -1,6 +1,7 @@ package deploy import ( + "context" "encoding/base64" "encoding/hex" "encoding/json" @@ -312,6 +313,7 @@ func TestCheckWorkflowExists_PrivateRegistry(t *testing.T) { defer gqlServer.Close() h.environmentSet.GraphQLURL = gqlServer.URL + h.execCtx = context.Background() strategy := newPrivateRegistryDeployStrategy(h) exists, status, err := strategy.CheckWorkflowExists("", "jnowak-workflow-test-v5", "", tt.workflowID) diff --git a/cmd/workflow/deploy/registry_deploy_strategy_private.go b/cmd/workflow/deploy/registry_deploy_strategy_private.go index dee39315..9909800e 100644 --- a/cmd/workflow/deploy/registry_deploy_strategy_private.go +++ b/cmd/workflow/deploy/registry_deploy_strategy_private.go @@ -34,7 +34,7 @@ func (a *privateRegistryDeployStrategy) RunPreDeployChecks() error { func (a *privateRegistryDeployStrategy) CheckWorkflowExists(_, workflowName, _, workflowID string) (bool, *uint8, error) { a.ensureClient() - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(a.h.execCtx, workflowName) if err == nil { if workflow.WorkflowID == workflowID { return true, offchainStatusToUint8(workflow.Status), fmt.Errorf("workflow with id %s is already registered and unchanged; re-deployment skipped: %w", workflowID, errWorkflowUnchanged) @@ -57,7 +57,7 @@ func (a *privateRegistryDeployStrategy) Upsert() error { ui.Line() ui.Dim(fmt.Sprintf("Registering workflow in private registry (workflowID: %s)...", input.WorkflowID)) - result, err := a.prc.UpsertWorkflowInRegistry(input) + result, err := a.prc.UpsertWorkflowInRegistry(a.h.execCtx, input) if err != nil { return fmt.Errorf("failed to register workflow in private registry: %w", err) } diff --git a/cmd/workflow/pause/pause.go b/cmd/workflow/pause/pause.go index 19eaadf2..fe3382ea 100644 --- a/cmd/workflow/pause/pause.go +++ b/cmd/workflow/pause/pause.go @@ -1,6 +1,7 @@ package pause import ( + "context" "fmt" "github.com/rs/zerolog" @@ -46,7 +47,7 @@ func New(runtimeContext *runtime.Context) *cobra.Command { if err := handler.ValidateInputs(); err != nil { return err } - return handler.Execute() + return handler.Execute(cmd.Context()) }, } @@ -64,6 +65,7 @@ type handler struct { runtimeContext *runtime.Context validated bool + execCtx context.Context } func newHandler(ctx *runtime.Context) *handler { @@ -107,7 +109,9 @@ func (h *handler) ValidateInputs() error { return nil } -func (h *handler) Execute() error { +func (h *handler) Execute(ctx context.Context) error { + h.execCtx = ctx + if !h.validated { return fmt.Errorf("handler inputs not validated") } diff --git a/cmd/workflow/pause/pause_test.go b/cmd/workflow/pause/pause_test.go index 3af6e2f6..04ffa6cf 100644 --- a/cmd/workflow/pause/pause_test.go +++ b/cmd/workflow/pause/pause_test.go @@ -1,6 +1,7 @@ package pause import ( + "context" "errors" "testing" @@ -34,7 +35,7 @@ func TestNonInteractive_WithoutYes_ReturnsError(t *testing.T) { } h.validated = true - err := h.Execute() + err := h.Execute(context.Background()) require.Error(t, err) require.Contains(t, err.Error(), "missing required flags for --non-interactive mode") } @@ -60,7 +61,7 @@ func TestNonInteractive_WithYes_PassesGuard(t *testing.T) { } h.validated = true - err := h.Execute() + err := h.Execute(context.Background()) // Guard passes; error comes from WRC (no matching workflow), not the guard require.Error(t, err) require.NotContains(t, err.Error(), "missing required flags for --non-interactive mode") diff --git a/cmd/workflow/pause/registry_pause_strategy_private.go b/cmd/workflow/pause/registry_pause_strategy_private.go index f0f05534..3c583e07 100644 --- a/cmd/workflow/pause/registry_pause_strategy_private.go +++ b/cmd/workflow/pause/registry_pause_strategy_private.go @@ -32,7 +32,7 @@ func (a *privateRegistryPauseStrategy) Pause() error { ui.Dim(fmt.Sprintf("Fetching workflow to pause... Name=%s", workflowName)) - workflow, err := a.prc.GetWorkflowByName(workflowName) + workflow, err := a.prc.GetWorkflowByName(a.h.execCtx, workflowName) if err != nil { return fmt.Errorf("failed to get workflow: %w", err) } @@ -45,7 +45,7 @@ func (a *privateRegistryPauseStrategy) Pause() error { ui.Dim(fmt.Sprintf("Processing pause for workflow ID %s...", workflow.WorkflowID)) - result, err := a.prc.PauseWorkflowInRegistry(workflow.WorkflowID) + result, err := a.prc.PauseWorkflowInRegistry(a.h.execCtx, workflow.WorkflowID) if err != nil { return fmt.Errorf("failed to pause workflow in private registry: %w", err) } diff --git a/internal/client/privateregistryclient/privateregistryclient.go b/internal/client/privateregistryclient/privateregistryclient.go index 5e8f084b..e91b6f3b 100644 --- a/internal/client/privateregistryclient/privateregistryclient.go +++ b/internal/client/privateregistryclient/privateregistryclient.go @@ -29,8 +29,8 @@ func (c *Client) SetServiceTimeout(timeout time.Duration) { c.serviceTimeout = timeout } -func (c *Client) CreateServiceContextWithTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by callers +func (c *Client) CreateServiceContextWithTimeout(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, c.serviceTimeout) //nolint:gosec // G118 -- cancel is deferred by callers } type OffchainWorkflow struct { @@ -131,7 +131,7 @@ type GetOffchainWorkflowByNameResponse struct { Workflow OffchainWorkflow `json:"workflow"` } -func (c *Client) GetWorkflowByName(workflowName string) (OffchainWorkflow, error) { +func (c *Client) GetWorkflowByName(ctx context.Context, workflowName string) (OffchainWorkflow, error) { if workflowName == "" { return OffchainWorkflow{}, fmt.Errorf("workflowName is required") } @@ -165,10 +165,10 @@ query GetOffchainWorkflowByName($request: GetOffchainWorkflowByNameRequest!) { GetOffchainWorkflowByName GetOffchainWorkflowByNameResponse `json:"getOffchainWorkflowByName"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + callCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() - if err := c.graphql.Execute(ctx, req, &container); err != nil { + if err := c.graphql.Execute(callCtx, req, &container); err != nil { return OffchainWorkflow{}, fmt.Errorf("get workflow by name in registry: %w", err) } @@ -178,7 +178,7 @@ query GetOffchainWorkflowByName($request: GetOffchainWorkflowByNameRequest!) { return container.GetOffchainWorkflowByName.Workflow, nil } -func (c *Client) UpsertWorkflowInRegistry(workflow OffchainWorkflowInput) (OffchainWorkflow, error) { +func (c *Client) UpsertWorkflowInRegistry(ctx context.Context, workflow OffchainWorkflowInput) (OffchainWorkflow, error) { if err := validateUpsertWorkflowInput(workflow); err != nil { return OffchainWorkflow{}, err } @@ -209,10 +209,10 @@ mutation UpsertOffchainWorkflow($request: UpsertOffchainWorkflowRequest!) { UpsertOffchainWorkflow UpsertOffchainWorkflowResponse `json:"upsertOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + callCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() - if err := c.graphql.Execute(ctx, req, &container); err != nil { + if err := c.graphql.Execute(callCtx, req, &container); err != nil { return OffchainWorkflow{}, fmt.Errorf("upsert workflow in registry: %w", err) } @@ -222,7 +222,7 @@ mutation UpsertOffchainWorkflow($request: UpsertOffchainWorkflowRequest!) { return container.UpsertOffchainWorkflow.Workflow, nil } -func (c *Client) PauseWorkflowInRegistry(workflowID string) (OffchainWorkflow, error) { +func (c *Client) PauseWorkflowInRegistry(ctx context.Context, workflowID string) (OffchainWorkflow, error) { if workflowID == "" { return OffchainWorkflow{}, fmt.Errorf("workflowId is required") } @@ -253,10 +253,10 @@ mutation PauseOffchainWorkflow($request: PauseOffchainWorkflowRequest!) { PauseOffchainWorkflow PauseOffchainWorkflowResponse `json:"pauseOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + callCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() - if err := c.graphql.Execute(ctx, req, &container); err != nil { + if err := c.graphql.Execute(callCtx, req, &container); err != nil { return OffchainWorkflow{}, fmt.Errorf("pause workflow in registry: %w", err) } @@ -266,7 +266,7 @@ mutation PauseOffchainWorkflow($request: PauseOffchainWorkflowRequest!) { return container.PauseOffchainWorkflow.Workflow, nil } -func (c *Client) ActivateWorkflowInRegistry(workflowID string) (OffchainWorkflow, error) { +func (c *Client) ActivateWorkflowInRegistry(ctx context.Context, workflowID string) (OffchainWorkflow, error) { if workflowID == "" { return OffchainWorkflow{}, fmt.Errorf("workflowId is required") } @@ -297,10 +297,10 @@ mutation ActivateOffchainWorkflow($request: ActivateOffchainWorkflowRequest!) { ActivateOffchainWorkflow ActivateOffchainWorkflowResponse `json:"activateOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + callCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() - if err := c.graphql.Execute(ctx, req, &container); err != nil { + if err := c.graphql.Execute(callCtx, req, &container); err != nil { return OffchainWorkflow{}, fmt.Errorf("activate workflow in registry: %w", err) } @@ -310,7 +310,7 @@ mutation ActivateOffchainWorkflow($request: ActivateOffchainWorkflowRequest!) { return container.ActivateOffchainWorkflow.Workflow, nil } -func (c *Client) DeleteWorkflowInRegistry(workflowID string) (string, error) { +func (c *Client) DeleteWorkflowInRegistry(ctx context.Context, workflowID string) (string, error) { if workflowID == "" { return "", fmt.Errorf("workflowId is required") } @@ -329,10 +329,10 @@ mutation DeleteOffchainWorkflow($request: DeleteOffchainWorkflowRequest!) { DeleteOffchainWorkflow DeleteOffchainWorkflowResponse `json:"deleteOffchainWorkflow"` } - ctx, cancel := c.CreateServiceContextWithTimeout() + callCtx, cancel := c.CreateServiceContextWithTimeout(ctx) defer cancel() - if err := c.graphql.Execute(ctx, req, &container); err != nil { + if err := c.graphql.Execute(callCtx, req, &container); err != nil { return "", fmt.Errorf("delete workflow in registry: %w", err) } diff --git a/internal/client/privateregistryclient/privateregistryclient_test.go b/internal/client/privateregistryclient/privateregistryclient_test.go index 578fbd06..8dc4c5db 100644 --- a/internal/client/privateregistryclient/privateregistryclient_test.go +++ b/internal/client/privateregistryclient/privateregistryclient_test.go @@ -1,6 +1,7 @@ package privateregistryclient import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -147,7 +148,7 @@ func TestUpsertWorkflowInRegistry(t *testing.T) { configURL := "s3://config" tag := "v1" attributes := "{\"region\":\"us-east-1\"}" - result, err := client.UpsertWorkflowInRegistry(OffchainWorkflowInput{ + result, err := client.UpsertWorkflowInRegistry(context.Background(), OffchainWorkflowInput{ WorkflowID: "wf-123", Status: WorkflowStatusActive, WorkflowName: "registry-workflow", @@ -187,7 +188,7 @@ func TestUpsertWorkflowInRegistry_GQLError(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - _, err := client.UpsertWorkflowInRegistry(OffchainWorkflowInput{ + _, err := client.UpsertWorkflowInRegistry(context.Background(), OffchainWorkflowInput{ WorkflowID: "wf-123", Status: WorkflowStatusActive, WorkflowName: "registry-workflow", @@ -236,7 +237,7 @@ func TestGetWorkflowByName(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - result, err := client.GetWorkflowByName("registry-workflow") + result, err := client.GetWorkflowByName(context.Background(), "registry-workflow") require.NoError(t, err) assert.Contains(t, capturedQuery, "query GetOffchainWorkflowByName") @@ -262,7 +263,7 @@ func TestGetWorkflowByName_GQLError(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - _, err := client.GetWorkflowByName("registry-workflow") + _, err := client.GetWorkflowByName(context.Background(), "registry-workflow") require.Error(t, err) assert.Contains(t, err.Error(), "get workflow by name in registry") assert.Contains(t, err.Error(), "cre api error: workflow not found") @@ -273,7 +274,7 @@ func TestGetWorkflowByName_EmptyName(t *testing.T) { logger := testutil.NewTestLogger() client := New(nil, logger) - _, err := client.GetWorkflowByName("") + _, err := client.GetWorkflowByName(context.Background(), "") require.EqualError(t, err, "workflowName is required") } @@ -306,7 +307,7 @@ func TestPauseWorkflowInRegistry(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - result, err := client.PauseWorkflowInRegistry("wf-123") + result, err := client.PauseWorkflowInRegistry(context.Background(), "wf-123") require.NoError(t, err) assert.Equal(t, "wf-123", result.WorkflowID) assert.Equal(t, WorkflowStatusPaused, result.Status) @@ -344,7 +345,7 @@ func TestActivateWorkflowInRegistry(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - result, err := client.ActivateWorkflowInRegistry("wf-123") + result, err := client.ActivateWorkflowInRegistry(context.Background(), "wf-123") require.NoError(t, err) assert.Equal(t, WorkflowStatusActive, result.Status) @@ -374,7 +375,7 @@ func TestDeleteWorkflowInRegistry(t *testing.T) { defer srv.Close() client := newTestPrivateRegistryClient(t, srv.URL) - deletedWorkflowID, err := client.DeleteWorkflowInRegistry("wf-123") + deletedWorkflowID, err := client.DeleteWorkflowInRegistry(context.Background(), "wf-123") require.NoError(t, err) assert.Equal(t, "wf-123", deletedWorkflowID) @@ -387,13 +388,13 @@ func TestWorkflowMutations_RequireWorkflowID(t *testing.T) { logger := testutil.NewTestLogger() client := New(nil, logger) - _, pauseErr := client.PauseWorkflowInRegistry("") + _, pauseErr := client.PauseWorkflowInRegistry(context.Background(), "") require.EqualError(t, pauseErr, "workflowId is required") - _, activateErr := client.ActivateWorkflowInRegistry("") + _, activateErr := client.ActivateWorkflowInRegistry(context.Background(), "") require.EqualError(t, activateErr, "workflowId is required") - _, deleteErr := client.DeleteWorkflowInRegistry("") + _, deleteErr := client.DeleteWorkflowInRegistry(context.Background(), "") require.EqualError(t, deleteErr, "workflowId is required") } @@ -414,10 +415,11 @@ func TestCreateServiceContextWithTimeout(t *testing.T) { client := New(nil, &logger) client.SetServiceTimeout(150 * time.Millisecond) - ctx, cancel := client.CreateServiceContextWithTimeout() + parent := context.Background() + callCtx, cancel := client.CreateServiceContextWithTimeout(parent) defer cancel() - deadline, ok := ctx.Deadline() + deadline, ok := callCtx.Deadline() require.True(t, ok) assert.WithinDuration(t, time.Now().Add(150*time.Millisecond), deadline, 100*time.Millisecond) }