Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 35 additions & 10 deletions udf/worker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ WorkerDispatcher -- manages workers, creates sessions
|
v
WorkerSession -- one UDF execution
| 1. session.init(InitMessage(payload, inputSchema, outputSchema))
| 1. session.init(Init proto: udf payload + data format + schemas)
| 2. val results = session.process(inputBatches)
| 3. session.close()
```
Expand All @@ -34,12 +34,13 @@ provisioning service or daemon).
```
udf/worker/
├── proto/
│ worker_spec.proto -- UDFWorkerSpecification protobuf (+ generated Java classes)
│ worker_spec.proto -- UDFWorkerSpecification protobuf
│ udf_protocol.proto -- UDF execution protocol (Init, UdfPayload, ...)
│ common.proto -- shared enums (UDFWorkerDataFormat, etc.)
└── core/ -- abstract interfaces
WorkerDispatcher.scala -- creates sessions, manages worker lifecycle
WorkerSession.scala -- per-UDF init/process/cancel/close + InitMessage
WorkerSession.scala -- per-UDF init/process/cancel/close
WorkerConnection.scala -- transport channel abstraction
WorkerSecurityScope.scala -- security boundary for worker pooling
Expand All @@ -55,6 +56,23 @@ worker creation where Spark spawns local OS processes. Future packages
(e.g., `core/indirect/`) can implement alternative creation modes such as
obtaining workers from a provisioning service or daemon.

## Wire protocol

Each UDF execution uses a single bidirectional `Execute` gRPC stream:

```
Engine -> Worker: Init -> PayloadChunk* -> (DataRequest)* -> (Finish | Cancel)
Worker -> Engine: InitResponse -> (DataResponse)* -> (ExecutionError)? -> (FinishResponse | CancelResponse)
```

`DataRequest` and `DataResponse` are independent streams: the worker may emit
`DataResponse` messages at any point after `InitResponse`, including before the
first `DataRequest` arrives. Generator-style UDFs may have zero `DataRequest`
messages, with the engine sending `Finish` directly after `Init`.
`PayloadChunk.last = true` is the canonical end-of-chunking signal.
See `udf/worker/proto/src/main/protobuf/udf_protocol.proto` for the complete
ordering invariants, gRPC error contract, and cancel-vs-finish race contract.

### Direct worker creation

`DirectWorkerDispatcher` spawns worker processes locally. On the first
Expand All @@ -76,10 +94,12 @@ Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed.

```scala
import org.apache.spark.udf.worker.{
DirectWorker, ProcessCallable, UDFProtoCommunicationPattern,
UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification,
UnixDomainSocket, WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment}
DirectWorker, Init, ProcessCallable, UdfPayload,
UDFProtoCommunicationPattern, UDFWorkerDataFormat, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerCapabilities,
WorkerConnectionSpec, WorkerEnvironment}
import org.apache.spark.udf.worker.core._
import com.google.protobuf.ByteString

// 1. Define a worker spec (direct creation mode).
val spec = UDFWorkerSpecification.newBuilder()
Expand Down Expand Up @@ -112,10 +132,15 @@ val dispatcher: WorkerDispatcher = ...
val session = dispatcher.createSession(securityScope = None)
try {
// 4. Initialize with the serialized function and schemas.
session.init(InitMessage(
functionPayload = serializedFunction,
inputSchema = arrowInputSchema,
outputSchema = arrowOutputSchema))
session.init(Init.newBuilder()
.setUdf(UdfPayload.newBuilder()
.setPayload(ByteString.copyFrom(serializedFunction))
.setFormat(payloadFormat) // worker-recognised tag
.build())
.setDataFormat(UDFWorkerDataFormat.ARROW)
.setInputSchema(ByteString.copyFrom(arrowInputSchema))
.setOutputSchema(ByteString.copyFrom(arrowOutputSchema))
.build())

// 5. Process data -- Iterator in, Iterator out.
val results: Iterator[Array[Byte]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,15 @@ import org.apache.spark.udf.worker.UDFWorkerSpecification
* as security scope). It owns the underlying worker processes and connections,
* handling pooling, reuse, and lifecycle behind the scenes. Spark interacts with
* workers exclusively through the [[WorkerSession]]s returned by [[createSession]].
*
* '''Worker invalidation:''' if a session's Execute stream terminates with a gRPC
* transport error the worker that backed it MUST NOT be returned to any reuse pool.
* A transport error leaves the worker in an unknown state; only workers that
* complete sessions cleanly (via [[org.apache.spark.udf.worker.FinishResponse]] or
* [[org.apache.spark.udf.worker.CancelResponse]]) are eligible for reuse.
* Implementations are responsible for tracking this condition -- typically
* [[WorkerSession.doProcess]] flags the worker as invalid before [[WorkerSession.doClose]]
* releases it, so the dispatcher can distinguish a clean release from a failed one.
*/
@Experimental
trait WorkerDispatcher extends AutoCloseable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,7 @@ package org.apache.spark.udf.worker.core
import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.annotation.Experimental

/**
* :: Experimental ::
* Carries all information needed to initialize a UDF execution on a worker.
*
* This message is passed to [[WorkerSession#init]] and contains the function
* definition, schemas, and any additional configuration.
*
* Placeholder: will be replaced by a generated proto message once the
* UDF wire protocol lands. Do not rely on case-class equality --
* `Array[Byte]` fields compare by reference.
*
* @param functionPayload serialized function (e.g., pickled Python, JVM bytes)
* @param inputSchema serialized input schema (e.g., Arrow schema bytes)
* @param outputSchema serialized output schema (e.g., Arrow schema bytes)
* @param properties additional key-value configuration. Can carry
* protocol-specific or engine-specific metadata that
* does not yet have a dedicated field.
*/
@Experimental
case class InitMessage(
functionPayload: Array[Byte],
inputSchema: Array[Byte],
outputSchema: Array[Byte],
properties: Map[String, String] = Map.empty)
import org.apache.spark.udf.worker.Init

/**
* :: Experimental ::
Expand All @@ -62,7 +38,10 @@ case class InitMessage(
* {{{
* val session = dispatcher.createSession(securityScope = None)
* try {
* session.init(InitMessage(functionPayload, inputSchema, outputSchema))
* session.init(Init.newBuilder()
* .setUdf(UdfPayload.newBuilder().setPayload(callable).setFormat(fmt).build())
* .setDataFormat(UDFWorkerDataFormat.ARROW)
* .build())
* val results = session.process(inputBatches)
* results.foreach(handleBatch)
* } finally {
Expand All @@ -74,7 +53,19 @@ case class InitMessage(
* - [[init]] must be called exactly once before [[process]].
* - [[process]] must be called at most once per session.
* - [[close]] must always be called (use try-finally).
* - [[cancel]] may be called at any time to abort execution.
* - [[cancel]] may be called at any time, including before [[init]]
* or after [[process]]/[[close]] has returned. Implementations
* treat such calls as a no-op so that callers driven by a task
* interruption listener (which has no view into the session state)
* do not need to coordinate with the thread driving [[process]].
*
* [[cancel]] may be called even after all input data has been
* submitted (i.e. after [[Finish]] has been sent on the transport).
* In that case implementations MUST send [[Cancel]] on the transport
* if [[FinishResponse]] has not yet been received, and MUST be
* prepared to receive either [[FinishResponse]] or [[CancelResponse]].
* See the [[org.apache.spark.udf.worker.Finish]] proto message for
* the full contract.
*
* The lifecycle is enforced here: [[init]] and [[process]] are `final`
* and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards.
Expand All @@ -93,10 +84,12 @@ abstract class WorkerSession extends AutoCloseable {
*
* Throws `IllegalStateException` if called more than once.
*
* @param message the initialization parameters including the serialized
* function, input/output schemas, and configuration.
* @param message the [[Init]] proto carrying the UDF body, the wire
* data format, optional input/output schemas, and any
* engine-side session context the worker needs to start
* processing.
*/
final def init(message: InitMessage): Unit = {
final def init(message: Init): Unit = {
if (!initialized.compareAndSet(false, true)) {
throw new IllegalStateException("init has already been called on this session")
}
Expand Down Expand Up @@ -127,22 +120,81 @@ abstract class WorkerSession extends AutoCloseable {
doProcess(input)
}

/** Subclass hook for [[init]]. Called once, after the guard. */
protected def doInit(message: InitMessage): Unit
/**
* Subclass hook for [[init]]. Called once, after the guard.
* Implementations MUST NOT open the Execute gRPC stream before
* this call: [[cancel]] before [[init]] is contractually a no-op
* at the transport level, which only holds if no stream has been
* opened yet.
*/
protected def doInit(message: Init): Unit

/** Subclass hook for [[process]]. Called at most once, after the guard. */
/**
* Subclass hook for [[process]]. Called at most once, after the guard.
*
* If the Execute stream terminates with a gRPC transport error (i.e.
* the connection broke rather than the worker sending a protocol
* response), the implementation MUST:
* - throw an appropriate exception so the caller observes a failure
* rather than a silent empty result; and
* - ensure the underlying worker is not returned to any reuse pool,
* since a transport error leaves the worker in an unknown state.
* Implementations signal this to the [[WorkerDispatcher]] via
* whatever mechanism the dispatcher provides (e.g. flagging the
* worker as invalid before calling [[doClose]]).
*/
protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]]

/**
* Requests cancellation of the current UDF execution.
*
* '''Thread-safety:''' implementations must allow [[cancel]] to be called
* from a thread different from the one driving [[process]] (typically a
* task interruption thread). It may be invoked at any point after
* [[init]] and should be a no-op if execution has already finished.
* task interruption thread).
*
* '''Lifecycle:''' [[cancel]] is idempotent and safe at any point in
* the session's life:
* - before [[init]] -- a no-op; the session may still be closed
* normally via [[close]].
* - between [[init]] and [[process]] -- signals that the session
* should be terminated; the caller should not invoke [[process]]
* and should call [[close]] to release resources.
* - during [[process]] (data flowing or awaiting [[FinishResponse]])
* -- sends [[Cancel]] and waits for [[CancelResponse]] or
* [[FinishResponse]] (whichever arrives first).
* - after [[FinishResponse]] or [[CancelResponse]] has been received
* -- a no-op; the stream is already terminated.
*
* Implementations are responsible for the lifecycle-aware behavior
* described above (no-op outside the active window; cancellation
* thrown from a subsequent [[process]] when applicable) so that
* callers (e.g. task interruption listeners) do not need to
* coordinate with the thread driving [[process]].
*/
def cancel(): Unit

/** Closes this session and releases resources. */
override def close(): Unit
/**
* Closes this session and releases resources. Idempotent; safe to
* call from a `finally` block regardless of whether [[init]],
* [[process]], or [[cancel]] have been invoked.
*
* If [[init]] was called but [[process]] was not (e.g. an exception
* was thrown between the two), [[close]] sends `Cancel` on the
* Execute stream before releasing resources, so the worker can clean
* up deterministically rather than observing a gRPC transport error.
* Subclasses implement [[doClose]] for resource teardown; the base
* class handles the cancel-before-close guarantee automatically.
*/
final override def close(): Unit = {
if (initialized.get() && !processed.get()) {
cancel()
}
doClose()
}

/** Subclass hook for [[close]]. The base class guarantees that
* [[cancel]] has already been called if [[init]] was invoked but
* [[process]] was not.
*/
protected def doClose(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ abstract class DirectWorkerSession(
/** The connection to the worker for this session. */
def connection: WorkerConnection = workerProcess.connection

override def close(): Unit = {
override protected def doClose(): Unit = {
if (released.compareAndSet(false, true)) {
workerProcess.releaseSession()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.udf.worker.{
DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
DirectWorker, Init, LocalTcpConnection, ProcessCallable, UDFWorkerProperties,
UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec,
WorkerEnvironment}
import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher,
Expand All @@ -51,14 +51,14 @@ class SocketFileConnection(socketPath: String)
* TODO: [[cancel]] is a no-op here. Once a concrete [[DirectWorkerSession]]
* with real data-plane wiring lands, add tests exercising cancel() in
* particular: cancel from a different thread than process(), cancel
* after process() has returned, and cancel before init (should be a
* no-op). Tracking the thread-safety contract in the docstring on
* after process() has returned, and cancel before init (should be a no-op).
* See the thread-safety contract in the docstring on
* [[org.apache.spark.udf.worker.core.WorkerSession.cancel]].
*/
class StubWorkerSession(
workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) {

override protected def doInit(message: InitMessage): Unit = {}
override protected def doInit(message: Init): Unit = {}

override protected def doProcess(
input: Iterator[Array[Byte]]): Iterator[Array[Byte]] =
Expand Down
Loading