-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpipeline.py
More file actions
149 lines (125 loc) · 5.49 KB
/
pipeline.py
File metadata and controls
149 lines (125 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# pipeline.py
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
import itertools
import multiprocessing as mp
from typing import Any
from typing import TypeVar
from typing import overload
from laygo.helpers import PipelineContext
from laygo.helpers import is_context_aware
from laygo.transformers.transformer import Transformer
T = TypeVar("T")
PipelineFunction = Callable[[T], Any]
class Pipeline[T]:
"""
Manages a data source and applies transformers to it.
Always uses a multiprocessing-safe shared context.
"""
def __init__(self, *data: Iterable[T]):
if len(data) == 0:
raise ValueError("At least one data source must be provided to Pipeline.")
self.data_source: Iterable[T] = itertools.chain.from_iterable(data) if len(data) > 1 else data[0]
self.processed_data: Iterator = iter(self.data_source)
# Always create a shared context with multiprocessing manager
self._manager = mp.Manager()
self.ctx = self._manager.dict()
# Add a shared lock to the context for safe concurrent updates
self.ctx["lock"] = self._manager.Lock()
# Store reference to original context for final synchronization
self._original_context_ref: PipelineContext | None = None
def __del__(self):
"""Clean up the multiprocessing manager when the pipeline is destroyed."""
try:
self._sync_context_back()
self._manager.shutdown()
except Exception:
pass # Ignore errors during cleanup
def context(self, ctx: PipelineContext) -> "Pipeline[T]":
"""
Updates the pipeline context and stores a reference to the original context.
When the pipeline finishes processing, the original context will be updated
with the final pipeline context data.
"""
# Store reference to the original context
self._original_context_ref = ctx
# Copy the context data to the pipeline's shared context
self.ctx.update(ctx)
return self
def _sync_context_back(self) -> None:
"""
Synchronize the final pipeline context back to the original context reference.
This is called after processing is complete.
"""
if self._original_context_ref is not None:
# Copy the final context state back to the original context reference
final_context_state = dict(self.ctx)
final_context_state.pop("lock", None) # Remove non-serializable lock
self._original_context_ref.clear()
self._original_context_ref.update(final_context_state)
def transform[U](self, t: Callable[[Transformer[T, T]], Transformer[T, U]]) -> "Pipeline[U]":
"""
Shorthand method to apply a transformation using a lambda function.
Creates a Transformer under the hood and applies it to the pipeline.
Args:
t: A callable that takes a transformer and returns a transformed transformer
Returns:
A new Pipeline with the transformed data
"""
# Create a new transformer and apply the transformation function
transformer = t(Transformer[T, T]())
return self.apply(transformer)
@overload
def apply[U](self, transformer: Transformer[T, U]) -> "Pipeline[U]": ...
@overload
def apply[U](self, transformer: Callable[[Iterable[T]], Iterator[U]]) -> "Pipeline[U]": ...
@overload
def apply[U](self, transformer: Callable[[Iterable[T], PipelineContext], Iterator[U]]) -> "Pipeline[U]": ...
def apply[U](
self,
transformer: Transformer[T, U]
| Callable[[Iterable[T]], Iterator[U]]
| Callable[[Iterable[T], PipelineContext], Iterator[U]],
) -> "Pipeline[U]":
"""
Applies a transformer to the current data source. The pipeline's
managed context is passed down.
"""
match transformer:
case Transformer():
# The transformer is called with self.ctx, which is the
# shared mp.Manager.dict proxy when inside a 'with' block.
self.processed_data = transformer(self.processed_data, self.ctx) # type: ignore
case _ if callable(transformer):
if is_context_aware(transformer):
processed_transformer = transformer
else:
processed_transformer = lambda data, ctx: transformer(data) # type: ignore # noqa: E731
self.processed_data = processed_transformer(self.processed_data, self.ctx) # type: ignore
case _:
raise TypeError("Transformer must be a Transformer instance or a callable function")
return self # type: ignore
# ... The rest of the Pipeline class (transform, __iter__, to_list, etc.) remains unchanged ...
def __iter__(self) -> Iterator[T]:
"""Allows the pipeline to be iterated over."""
yield from self.processed_data
def to_list(self) -> list[T]:
"""Executes the pipeline and returns the results as a list."""
return list(self.processed_data)
def each(self, function: PipelineFunction[T]) -> None:
"""Applies a function to each element (terminal operation)."""
# Context needs to be accessed from the function if it's context-aware,
# but the pipeline itself doesn't own a context. This is a design choice.
# For simplicity, we assume the function is not context-aware here
# or that context is handled within the Transformers.
for item in self.processed_data:
function(item)
def first(self, n: int = 1) -> list[T]:
"""Gets the first n elements of the pipeline (terminal operation)."""
assert n >= 1, "n must be at least 1"
return list(itertools.islice(self.processed_data, n))
def consume(self) -> None:
"""Consumes the pipeline without returning results."""
for _ in self.processed_data:
pass