-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstreaming.py
More file actions
140 lines (102 loc) · 4.64 KB
/
streaming.py
File metadata and controls
140 lines (102 loc) · 4.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Streaming examples: producer streams and exchange (bidirectional) streams.
Producer streams generate a sequence of Arrow RecordBatches from the server.
Exchange streams allow the client to send data and receive transformed results
in a lockstep request/response pattern.
Run::
python examples/streaming.py
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Protocol, cast
import pyarrow as pa
import pyarrow.compute as pc
from vgi_rpc import (
AnnotatedBatch,
CallContext,
ExchangeState,
OutputCollector,
ProducerState,
Stream,
StreamState,
serve_pipe,
)
# ---------------------------------------------------------------------------
# Producer stream: server generates batches, client iterates
# ---------------------------------------------------------------------------
@dataclass
class CounterState(ProducerState):
"""State for the counter producer stream.
Extends ``ProducerState`` so only ``produce(out, ctx)`` needs to be
implemented — no phantom ``input`` parameter to ignore.
Call ``out.finish()`` to signal the end of the stream.
"""
limit: int
current: int = 0
def produce(self, out: OutputCollector, ctx: CallContext) -> None:
"""Emit one batch per call, finish when done."""
if self.current >= self.limit:
out.finish()
return
out.emit_pydict({"n": [self.current], "n_squared": [self.current**2]})
self.current += 1
# ---------------------------------------------------------------------------
# Exchange stream: client sends data, server transforms and returns it
# ---------------------------------------------------------------------------
@dataclass
class ScaleState(ExchangeState):
"""State for the scale exchange stream.
Extends ``ExchangeState`` so only ``exchange(input, out, ctx)`` needs
to be implemented. Exchange streams must emit exactly one output batch
per call and must not call ``out.finish()``.
"""
factor: float
def exchange(self, input: AnnotatedBatch, out: OutputCollector, ctx: CallContext) -> None:
"""Multiply each value by the configured factor."""
scaled = cast("pa.Array[Any]", pc.multiply(input.batch.column("value"), self.factor))
out.emit_arrays([scaled])
# ---------------------------------------------------------------------------
# Service definition
# ---------------------------------------------------------------------------
_COUNTER_SCHEMA = pa.schema([pa.field("n", pa.int64()), pa.field("n_squared", pa.int64())])
_SCALE_SCHEMA = pa.schema([pa.field("value", pa.float64())])
class MathService(Protocol):
"""Service demonstrating producer and exchange streams."""
def count(self, limit: int) -> Stream[StreamState]:
"""Produce *limit* batches of (n, n_squared)."""
...
def scale(self, factor: float) -> Stream[StreamState]:
"""Multiply incoming values by *factor*."""
...
class MathServiceImpl:
"""Concrete implementation of MathService."""
def count(self, limit: int) -> Stream[CounterState]:
"""Produce *limit* batches of (n, n_squared)."""
return Stream(output_schema=_COUNTER_SCHEMA, state=CounterState(limit=limit))
def scale(self, factor: float) -> Stream[ScaleState]:
"""Multiply incoming values by *factor*."""
return Stream(
output_schema=_SCALE_SCHEMA,
state=ScaleState(factor=factor),
input_schema=_SCALE_SCHEMA, # Setting input_schema makes this an exchange stream
)
# ---------------------------------------------------------------------------
# Client usage
# ---------------------------------------------------------------------------
def main() -> None:
"""Run the streaming examples."""
with serve_pipe(MathService, MathServiceImpl()) as svc:
# --- Producer stream: iterate over server-generated batches ----------
print("=== Producer stream (count to 5) ===")
for batch in svc.count(limit=5):
rows = batch.batch.to_pylist()
for row in rows:
print(f" n={row['n']} n^2={row['n_squared']}")
# --- Exchange stream: send data, receive transformed results ---------
print("\n=== Exchange stream (scale by 10) ===")
with svc.scale(factor=10.0) as stream:
for values in [[1.0, 2.0, 3.0], [100.0, 200.0]]:
input_batch = AnnotatedBatch(pa.RecordBatch.from_pydict({"value": values}))
result = stream.exchange(input_batch)
print(f" input={values} output={result.batch.column('value').to_pylist()}")
if __name__ == "__main__":
main()