Skip to content
This repository was archived by the owner on Mar 19, 2026. It is now read-only.

Commit 713612c

Browse files
committed
feat(workflow): enhance node definition and execution with data handling and validation
1 parent 669beb1 commit 713612c

3 files changed

Lines changed: 502 additions & 161 deletions

File tree

packages/workflow/src/index.ts

Lines changed: 162 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ import {
77
NodeId,
88
NodeIO,
99
NodeResult,
10+
WorkflowCallbacks,
1011
WorkflowNode,
11-
WorkflowOptions
12+
WorkflowOptions,
13+
WorkflowResult
1214
} from './types.ts'
1315

1416
export * from './types.ts'
@@ -17,12 +19,17 @@ export function createNodeFactory(): NodeFactory {
1719
const nodes = new Map<string, NodeDefinition>()
1820

1921
return {
20-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
21-
registerNode: (type: string, definition: NodeDefinition<any, any>) => {
22+
registerNode: (
23+
type: string,
24+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
25+
definition: NodeDefinition<any, any, any>
26+
) => {
2227
nodes.set(type, definition)
2328
},
24-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
25-
getNode: (type: string) => nodes.get(type) as NodeDefinition<any, any>
29+
30+
getNode: (type: string) =>
31+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
32+
nodes.get(type) as NodeDefinition<any, any, any>
2633
}
2734
}
2835

@@ -56,28 +63,116 @@ function detectCircularDependencies(nodes: WorkflowNode[]): boolean {
5663
return nodes.some((node) => hasCycle(node.id))
5764
}
5865

66+
function validateNodeInput(
67+
node: WorkflowNode,
68+
definition: NodeDefinition,
69+
input: NodeIO,
70+
context: NodeContext
71+
): boolean {
72+
if (!definition.inputSchema) return true
73+
try {
74+
definition.inputSchema.parse(input)
75+
return true
76+
} catch {
77+
return false
78+
}
79+
}
80+
81+
function validateNodeOutput(
82+
definition: NodeDefinition,
83+
output: NodeIO,
84+
context: NodeContext
85+
): boolean {
86+
if (!definition.outputSchema) return true
87+
try {
88+
for (const [key, schema] of Object.entries(definition.outputSchema)) {
89+
if (key in output) {
90+
schema.parse(output[key])
91+
}
92+
}
93+
return true
94+
} catch {
95+
return false
96+
}
97+
}
98+
99+
function checkAndMarkSkippedNodes(
100+
nodes: WorkflowNode[],
101+
completed: Set<NodeId>,
102+
failed: Set<NodeId>,
103+
skipped: Set<NodeId>,
104+
results: Record<NodeId, NodeResult>,
105+
callbacks?: WorkflowCallbacks
106+
): void {
107+
nodes.forEach((node) => {
108+
if (
109+
completed.has(node.id) ||
110+
failed.has(node.id) ||
111+
skipped.has(node.id)
112+
) {
113+
return
114+
}
115+
116+
const cannotBeExecuted = node.dependencies.some((dep) => {
117+
const depId = resolveDependency(dep)
118+
if (failed.has(depId)) return true
119+
120+
if (completed.has(depId) && typeof dep === 'object' && dep.portId) {
121+
const depOutput = results[depId]?.output || {}
122+
return depOutput[dep.portId] === undefined
123+
}
124+
125+
return false
126+
})
127+
128+
if (cannotBeExecuted) {
129+
skipped.add(node.id)
130+
results[node.id] = {
131+
state: 'skipped' as const,
132+
output: {}
133+
}
134+
callbacks?.onNodeSkipped?.(node.id, node.type)
135+
}
136+
})
137+
}
138+
59139
function getExecutableNodes(
60140
nodes: WorkflowNode[],
61141
completed: Set<NodeId>,
62-
failed: Set<NodeId>
142+
failed: Set<NodeId>,
143+
skipped: Set<NodeId>,
144+
results: Record<NodeId, NodeResult>
63145
): WorkflowNode[] {
64146
return nodes.filter((node) => {
65-
if (completed.has(node.id) || failed.has(node.id)) return false
147+
if (
148+
completed.has(node.id) ||
149+
failed.has(node.id) ||
150+
skipped.has(node.id)
151+
)
152+
return false
153+
66154
return node.dependencies.every((dep) => {
67155
const depId = resolveDependency(dep)
68-
return completed.has(depId) && !failed.has(depId)
156+
const depResult = completed.has(depId)
157+
158+
if (typeof dep === 'object' && dep.portId) {
159+
const depOutput = results[depId]?.output || {}
160+
return depResult && depOutput[dep.portId] !== undefined
161+
}
162+
163+
return depResult && !failed.has(depId)
69164
})
70165
})
71166
}
72167

73-
function validateNodeInput(
168+
function validateNodeData(
74169
node: WorkflowNode,
75170
definition: NodeDefinition,
76-
input: NodeIO
171+
context: NodeContext
77172
): boolean {
78-
if (!definition.inputSchema) return true
173+
if (!definition.dataSchema || !node.data) return true
79174
try {
80-
definition.inputSchema.parse(input)
175+
definition.dataSchema.parse(node.data)
81176
return true
82177
} catch {
83178
return false
@@ -89,25 +184,34 @@ export async function executeWorkflow(
89184
factory: NodeFactory,
90185
initialContext: NodeContext = { variables: {}, metadata: {} },
91186
options: WorkflowOptions = {}
92-
): Promise<Record<NodeId, NodeResult>> {
187+
): Promise<WorkflowResult> {
93188
const { maxRetries = 3, maxParallel = 4, callbacks = {} } = options
189+
const results: Record<NodeId, NodeResult> = {}
94190

95191
if (detectCircularDependencies(nodes)) {
96192
throw new Error('Circular dependencies detected in workflow')
97193
}
98194

99195
const completed = new Set<NodeId>()
100196
const failed = new Set<NodeId>()
101-
const results: Record<NodeId, NodeResult> = {}
197+
const skipped = new Set<NodeId>()
102198
const context: NodeContext = { ...initialContext }
199+
let lastExecutedNode: WorkflowNode | undefined
103200

104-
while (completed.size + failed.size < nodes.length) {
105-
const executableNodes = getExecutableNodes(nodes, completed, failed)
201+
while (completed.size + failed.size + skipped.size < nodes.length) {
202+
const executableNodes = getExecutableNodes(
203+
nodes,
204+
completed,
205+
failed,
206+
skipped,
207+
results
208+
)
106209
if (executableNodes.length === 0) break
107210

108211
const executions = executableNodes
109212
.slice(0, maxParallel)
110213
.map(async (node) => {
214+
lastExecutedNode = node
111215
const definition = factory.getNode(node.type)
112216
if (!definition) {
113217
failed.add(node.id)
@@ -120,6 +224,15 @@ export async function executeWorkflow(
120224

121225
callbacks.onNodeStart?.(node.id, node.type)
122226

227+
if (!validateNodeData(node, definition, context)) {
228+
const error = new Error(`Invalid data for node ${node.id}`)
229+
timing.endTime = Date.now()
230+
timing.duration = timing.endTime - timing.startTime
231+
callbacks.onNodeError?.(node.id, node.type, error, timing)
232+
failed.add(node.id)
233+
return
234+
}
235+
123236
const input =
124237
node.dependencies.length === 0
125238
? { ...context.variables }
@@ -138,7 +251,8 @@ export async function executeWorkflow(
138251
return { ...acc, ...depOutput }
139252
}, {} as NodeIO)
140253

141-
if (!validateNodeInput(node, definition, input)) {
254+
if (!validateNodeInput(node, definition, input, context)) {
255+
console.log(node, input)
142256
const error = new Error(`Invalid input for node ${node.id}`)
143257
timing.endTime = Date.now()
144258
timing.duration = timing.endTime - timing.startTime
@@ -150,12 +264,22 @@ export async function executeWorkflow(
150264
let retries = 0
151265
while (retries <= maxRetries) {
152266
try {
153-
const output = await definition.run(input, context)
267+
const output = await definition.run(
268+
input,
269+
context,
270+
node.data
271+
)
154272
timing.endTime = Date.now()
155273
timing.duration = timing.endTime - timing.startTime
156274

275+
if (!validateNodeOutput(definition, output, context)) {
276+
throw new Error(
277+
`Invalid output for node ${node.id}`
278+
)
279+
}
280+
157281
const result: NodeResult = {
158-
state: 'completed',
282+
state: 'completed' as const,
159283
output
160284
}
161285

@@ -167,6 +291,16 @@ export async function executeWorkflow(
167291
result,
168292
timing
169293
)
294+
295+
// Check for nodes that should be skipped after this node completes
296+
checkAndMarkSkippedNodes(
297+
nodes,
298+
completed,
299+
failed,
300+
skipped,
301+
results,
302+
callbacks
303+
)
170304
break
171305
} catch (error) {
172306
retries++
@@ -175,7 +309,7 @@ export async function executeWorkflow(
175309
timing.duration = timing.endTime - timing.startTime
176310

177311
const result: NodeResult = {
178-
state: 'failed',
312+
state: 'failed' as const,
179313
output: {},
180314
error: error as Error
181315
}
@@ -195,6 +329,13 @@ export async function executeWorkflow(
195329
await Promise.all(executions)
196330
}
197331

332+
const lastNodeResult: NodeResult = lastExecutedNode
333+
? results[lastExecutedNode.id]
334+
: {
335+
state: 'skipped' as const,
336+
output: {}
337+
}
338+
198339
callbacks.onWorkflowComplete?.(results)
199-
return results
340+
return [lastNodeResult, results]
200341
}

packages/workflow/src/types.ts

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@ export type NodeResult = {
2222
error?: Error
2323
}
2424

25+
export type WorkflowResult = [NodeResult, Record<NodeId, NodeResult>]
26+
2527
export type NodeDefinition<
2628
TInput extends NodeIO = NodeIO,
27-
TOutput extends NodeIO = NodeIO
29+
TOutput extends NodeIO = NodeIO,
30+
TData = unknown
2831
> = {
29-
run: (input: TInput, context: NodeContext) => Promise<TOutput>
32+
run: (input: TInput, context: NodeContext, data?: TData) => Promise<TOutput>
3033
inputSchema?: z.ZodType<TInput>
31-
outputSchema?: z.ZodType<TOutput>
34+
outputSchema?: Record<string, z.ZodType<unknown>>
35+
dataSchema?: z.ZodType<TData>
3236
}
3337

3438
export type NodeDependency =
@@ -44,6 +48,7 @@ export type WorkflowNode = {
4448
type: string
4549
dependencies: NodeDependency[]
4650
config?: Record<string, unknown>
51+
data?: unknown
4752
}
4853

4954
export type NodeExecutionTiming = {
@@ -82,15 +87,17 @@ export type NodeFactory = {
8287
// eslint-disable-next-line @typescript-eslint/no-explicit-any
8388
registerNode<
8489
TInput extends NodeIO = NodeIO<unknown>,
85-
TOutput extends NodeIO = NodeIO<unknown>
90+
TOutput extends NodeIO = NodeIO<unknown>,
91+
TData = unknown
8692
>(
8793
type: string,
88-
definition: NodeDefinition<TInput, TOutput>
94+
definition: NodeDefinition<TInput, TOutput, TData>
8995
): void
9096
getNode<
9197
TInput extends NodeIO = NodeIO<unknown>,
92-
TOutput extends NodeIO = NodeIO<unknown>
98+
TOutput extends NodeIO = NodeIO<unknown>,
99+
TData = unknown
93100
>(
94101
type: string
95-
): NodeDefinition<TInput, TOutput> | undefined
102+
): NodeDefinition<TInput, TOutput, TData> | undefined
96103
}

0 commit comments

Comments
 (0)