Skip to content

Commit 3b55616

Browse files
Merge pull request #20 from patterninc/feature/snowflake-s3-stage-v1
Feature/snowflake s3 stage
2 parents ebd0fbf + 313b423 commit 3b55616

32 files changed

Lines changed: 3323 additions & 1543 deletions

.github/workflows/ci-cd-ds-platform-utils.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,14 @@ on:
55
push:
66
branches:
77
- main
8+
paths-ignore:
9+
- "README.md"
10+
- "docs/**"
811
pull_request:
912
types: [opened, synchronize]
13+
paths-ignore:
14+
- "README.md"
15+
- "docs/**"
1016

1117
jobs:
1218
check-version:
@@ -110,7 +116,7 @@ jobs:
110116
uv pip install --group dev
111117
COVERAGE_DIR="$(python -c 'import ds_platform_utils; print(ds_platform_utils.__path__[0])')"
112118
poe clean
113-
poe test --cov="$COVERAGE_DIR" --no-cov
119+
poe test --cov="$COVERAGE_DIR" --no-cov -n auto
114120
115121
tag-version:
116122
needs: [check-version, code-quality-checks, build-wheel, execute-tests]

README.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1-
## ds-platform-utils
1+
# ds-platform-utils
2+
3+
## Metaflow API Docs
4+
5+
- [BatchInferencePipeline](docs/metaflow/batch_inference_pipeline.md)
6+
- [make_pydantic_parser_fn](docs/metaflow/make_pydantic_parser_fn.md)
7+
- [publish](docs/metaflow/publish.md)
8+
- [publish_pandas](docs/metaflow/publish_pandas.md)
9+
- [query_pandas_from_snowflake](docs/metaflow/query_pandas_from_snowflake.md)
10+
- [restore_step_state](docs/metaflow/restore_step_state.md)
211

3-
Utility library to support Pattern's [data-science-projects](https://github.com/patterninc/data-science-projects/).
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# `BatchInferencePipeline`
2+
3+
Source: `ds_platform_utils.metaflow.batch_inference_pipeline.BatchInferencePipeline`
4+
5+
Utility class to orchestrate batch inference with Snowflake + S3 in Metaflow steps.
6+
7+
## Main methods
8+
9+
- `query_and_batch(...)`: export source data to S3 and create worker batches.
10+
- `process_batch(...)`: run download → inference → upload for one worker.
11+
- `publish_results(...)`: copy prediction outputs from S3 to Snowflake.
12+
- `run(...)`: convenience method to execute full flow sequentially.
13+
14+
## Detailed example (Metaflow foreach)
15+
16+
This example shows the intended 3-step pattern in a Metaflow `FlowSpec`:
17+
18+
1. `query_and_batch()` in `start`
19+
2. `process_batch()` in `foreach`
20+
3. `publish_results()` in `join`
21+
22+
```python
23+
from metaflow import FlowSpec, step
24+
import pandas as pd
25+
26+
from ds_platform_utils.metaflow import BatchInferencePipeline
27+
28+
29+
def predict_fn(df: pd.DataFrame) -> pd.DataFrame:
30+
# Example model logic
31+
out = pd.DataFrame()
32+
out["id"] = df["id"]
33+
out["score"] = (df["feature_1"].fillna(0) * 0.7 + df["feature_2"].fillna(0) * 0.3).round(6)
34+
out["label"] = (out["score"] >= 0.5).astype(int)
35+
return out
36+
37+
38+
class BatchPredictFlow(FlowSpec):
39+
40+
@step
41+
def start(self):
42+
self.next(self.query_and_batch)
43+
44+
@step
45+
def query_and_batch(self):
46+
self.pipeline = BatchInferencePipeline()
47+
48+
# Query can be inline SQL or a file path.
49+
# {schema} is provided by ds_platform_utils (DEV/PROD selection).
50+
self.worker_ids = self.pipeline.query_and_batch(
51+
input_query="""
52+
SELECT
53+
id,
54+
feature_1,
55+
feature_2
56+
FROM {{schema}}.model_features
57+
WHERE ds = '2026-02-26'
58+
""",
59+
parallel_workers=8,
60+
warehouse="MED",
61+
use_utc=True,
62+
)
63+
64+
self.next(self.process_batch, foreach="worker_ids")
65+
66+
@step
67+
def process_batch(self):
68+
# In a foreach step, self.input contains one worker_id.
69+
self.pipeline.process_batch(
70+
worker_id=self.input,
71+
predict_fn=predict_fn,
72+
batch_size_in_mb=256,
73+
timeout_per_batch=300,
74+
)
75+
self.next(self.publish_results)
76+
77+
@step
78+
def publish_results(self, inputs):
79+
# Reuse one pipeline object from foreach branches.
80+
self.pipeline = inputs[0].pipeline
81+
82+
self.pipeline.publish_results(
83+
output_table_name="MODEL_PREDICTIONS_DAILY",
84+
output_table_definition=[
85+
("id", "NUMBER"),
86+
("score", "FLOAT"),
87+
("label", "NUMBER"),
88+
],
89+
auto_create_table=True,
90+
overwrite=True,
91+
warehouse="MED",
92+
use_utc=True,
93+
)
94+
self.next(self.end)
95+
96+
@step
97+
def end(self):
98+
print("Batch inference complete")
99+
```
100+
101+
## Detailed example (single-step convenience)
102+
103+
Use `run()` when you do not need Metaflow foreach parallelization:
104+
105+
```python
106+
from ds_platform_utils.metaflow import BatchInferencePipeline
107+
import pandas as pd
108+
109+
110+
@step
111+
def batch_inference_step(self):
112+
def predict_fn(df: pd.DataFrame) -> pd.DataFrame:
113+
return pd.DataFrame(
114+
{
115+
"id": df["id"],
116+
"score": (df["feature_1"] * 0.9).fillna(0),
117+
}
118+
)
119+
120+
pipeline = BatchInferencePipeline()
121+
pipeline.run(
122+
input_query="""
123+
SELECT id, feature_1
124+
FROM {{schema}}.model_features
125+
WHERE ds = '2026-02-26'
126+
""",
127+
output_table_name="MODEL_PREDICTIONS_DAILY",
128+
predict_fn=predict_fn,
129+
output_table_definition=[("id", "NUMBER"), ("score", "FLOAT")],
130+
warehouse="XL",
131+
)
132+
133+
self.next(self.end)
134+
```
135+
136+
## Parameters
137+
138+
### `query_and_batch(...)`
139+
140+
| Parameter | Type | Required | Description |
141+
| ------------------ | ------------- | -------: | ----------------------------------------------------------------------------------------------------------------------- |
142+
| `input_query` | `str \| Path` | Yes | SQL query string or SQL file path used to fetch source rows. `{schema}` placeholder is resolved by `ds_platform_utils`. |
143+
| `ctx` | `dict` | No | Optional substitution map for templated SQL; merged with the internal `{"schema": ...}` mapping before query execution. |
144+
| `warehouse` | `str` | No | Snowflake warehouse used to execute the source query/export. |
145+
| `use_utc` | `bool` | No | If `True`, uses UTC timestamps/paths for partitioning and run metadata. |
146+
| `parallel_workers` | `int` | No | Number of worker partitions to create for downstream processing. |
147+
148+
**Returns:** `list[int]` of `worker_id` values for Metaflow `foreach`.
149+
150+
---
151+
152+
### `process_batch(...)`
153+
154+
| Parameter | Type | Required | Description |
155+
| ------------------- | ---------------------------------------- | -------: | -------------------------------------------------------------------------------------------------------- |
156+
| `worker_id` | `int` | Yes | Worker partition identifier generated by `query_and_batch()`. |
157+
| `predict_fn` | `Callable[[pd.DataFrame], pd.DataFrame]` | Yes | Inference function applied to each input chunk. Must return a DataFrame matching expected output schema. |
158+
| `batch_size_in_mb` | `int` | No | Target chunk size for reading/processing batch files. |
159+
| `timeout_per_batch` | `int` | No | Processing time for each batch in seconds. (Used for Queuing operations) |
160+
161+
**Returns:** `None`
162+
163+
**Recommended**: Tune `batch_size_in_mb` for Outerbounds Small tasks (3 CPU, 15 GB memory), which are about 6x more cost-effective than Medium tasks.
164+
165+
## Limitations
166+
167+
- The pipeline uses Snowflake ↔ S3 stage copy operations, so some column data types may be inferred differently than expected.
168+
- For predictable output types, provide an explicit `output_table_definition` in `publish_results(...)` / `run(...)` and cast columns in `predict_fn` as needed.
169+
170+
### `publish_results(...)`
171+
172+
| Parameter | Type | Required | Description |
173+
| ------------------------- | ------------------------------- | -------: | ----------------------------------------------------------------- |
174+
| `output_table_name` | `str` | Yes | Destination Snowflake table for predictions. |
175+
| `output_table_definition` | `list[tuple[str, str]] \| None` | No | Optional output schema as `(column_name, snowflake_type)` tuples. |
176+
| `auto_create_table` | `bool` | No | If `True`, creates destination table when missing. |
177+
| `overwrite` | `bool` | No | If `True`, replaces existing table data before loading results. |
178+
| `warehouse` | `str` | No | Snowflake warehouse used for load/publish operations. |
179+
| `use_utc` | `bool` | No | If `True`, uses UTC for load metadata/time handling. |
180+
181+
**Returns:** `None`
182+
183+
---
184+
185+
### `run(...)` (convenience method)
186+
187+
Runs `query_and_batch()``process_batch()``publish_results()` in a single sequential call.
188+
189+
| Parameter | Type | Required | Description |
190+
| ------------------------- | ---------------------------------------- | -------: | ----------------------------------------------------------------------------------------------------------------------- |
191+
| `input_query` | `str \| Path` | Yes | SQL query string or SQL file path used to fetch source rows. `{schema}` placeholder is resolved by `ds_platform_utils`. |
192+
| `output_table_name` | `str` | Yes | Destination Snowflake table for predictions. |
193+
| `predict_fn` | `Callable[[pd.DataFrame], pd.DataFrame]` | Yes | Inference function applied to each input chunk. Must return a DataFrame matching expected output schema. |
194+
| `ctx` | `dict` | No | Optional substitution map for templated SQL; merged with the internal `{"schema": ...}` mapping before query execution. |
195+
| `output_table_definition` | `list[tuple[str, str]] \| None` | No | Optional output schema as `(column_name, snowflake_type)` tuples. |
196+
| `batch_size_in_mb` | `int` | No | Target chunk size for reading/processing batch files. |
197+
| `timeout_per_batch` | `int` | No | Processing time for each batch in seconds. (Used for Queuing operations) |
198+
| `auto_create_table` | `bool` | No | If `True`, creates destination table when missing. |
199+
| `overwrite` | `bool` | No | If `True`, replaces existing table data before loading results. |
200+
| `warehouse` | `str` | No | Snowflake warehouse used for load/publish operations. |
201+
| `use_utc` | `bool` | No | If `True`, uses UTC for load metadata/time handling. |
202+
203+
**Returns:** `None`
204+
205+
**Recommended**: Tune `batch_size_in_mb` for Outerbounds Small tasks (3 CPU, 15 GB memory), which are about 6x more cost-effective than Medium tasks.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# `make_pydantic_parser_fn`
2+
3+
Source: `ds_platform_utils.metaflow.validate_config.make_pydantic_parser_fn`
4+
5+
Creates a Metaflow `Config(..., parser=...)` parser backed by a Pydantic model.
6+
7+
## Signature
8+
9+
```python
10+
make_pydantic_parser_fn(
11+
pydantic_model: type[BaseModel],
12+
) -> Callable[[str], dict]
13+
```
14+
15+
## What it does
16+
17+
- Parses config content as JSON, TOML, or YAML.
18+
- Validates and normalizes with Pydantic.
19+
- Returns a dict with applied defaults from the model.
20+
21+
## Parameters
22+
23+
| Parameter | Type | Required | Description |
24+
| ---------------- | ----------------- | -------: | ------------------------------------------------------------------- |
25+
| `pydantic_model` | `type[BaseModel]` | Yes | Pydantic model class used to validate and normalize config content. |
26+
27+
**Returns:** `Callable[[str], dict]` parser function for Metaflow `Config(..., parser=...)`.
28+
29+
## Typical usage
30+
31+
```python
32+
config: MyConfig = Config(
33+
name="config",
34+
default="./configs/default.yaml",
35+
parser=make_pydantic_parser_fn(MyConfig),
36+
)
37+
```

docs/metaflow/publish.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# `publish`
2+
3+
Source: `ds_platform_utils.metaflow.write_audit_publish.publish`
4+
5+
Publishes data to a Snowflake table using the write-audit-publish (WAP) pattern.
6+
7+
## Signature
8+
9+
```python
10+
publish(
11+
table_name: str,
12+
query: str | Path,
13+
audits: list[str | Path] | None = None,
14+
ctx: dict[str, Any] | None = None,
15+
warehouse: Literal["XS", "MED", "XL"] = None,
16+
use_utc: bool = True,
17+
) -> None
18+
```
19+
20+
## What it does
21+
22+
- Reads SQL from a string or `.sql` path.
23+
- Runs write/audit/publish operations through Snowflake.
24+
- Adds operation details and table links to the Metaflow card when available.
25+
26+
## Parameters
27+
28+
| Parameter | Type | Required | Description |
29+
| ------------ | ------------------------------------ | -------: | ------------------------------------------------------------------------------------------------------------- |
30+
| `table_name` | `str` | Yes | Destination Snowflake table name for the publish operation. |
31+
| `query` | `str \| Path` | Yes | SQL query text or path to SQL file that produces the table data. |
32+
| `audits` | `list[str \| Path] \| None` | No | Optional SQL audits (strings or file paths) executed as validation checks. |
33+
| `ctx` | `dict[str, Any] \| None` | No | Optional template substitution context for SQL operations. |
34+
| `warehouse` | `Literal["XS", "MED", "XL"] \| None` | No | Snowflake warehouse override for this operation. Supports `XS`/`MED`/`XL` shortcuts or a full warehouse name. |
35+
| `use_utc` | `bool` | No | If `True`, uses UTC timezone for Snowflake session. |
36+
37+
**Returns:** `None`
38+
39+
## Typical usage
40+
41+
```python
42+
from ds_platform_utils.metaflow import publish
43+
44+
publish(
45+
table_name="MY_TABLE",
46+
query="SELECT * FROM PATTERN_DB.{{schema}}.SOURCE",
47+
audits=["SELECT COUNT(*) > 0 FROM PATTERN_DB.{{schema}}.{{table_name}}"],
48+
)
49+
```

docs/metaflow/publish_pandas.md

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# `publish_pandas`
2+
3+
Source: `ds_platform_utils.metaflow.pandas.publish_pandas`
4+
5+
Writes a pandas DataFrame to Snowflake.
6+
7+
## Signature
8+
9+
```python
10+
publish_pandas(
11+
table_name: str,
12+
df: pd.DataFrame,
13+
add_created_date: bool = False,
14+
chunk_size: int | None = None,
15+
compression: Literal["snappy", "gzip"] = "snappy",
16+
warehouse: Literal["XS", "MED", "XL"] = None,
17+
parallel: int = 4,
18+
quote_identifiers: bool = False,
19+
auto_create_table: bool = False,
20+
overwrite: bool = False,
21+
use_logical_type: bool = True,
22+
use_utc: bool = True,
23+
use_s3_stage: bool = False,
24+
table_definition: list[tuple[str, str]] | None = None,
25+
) -> None
26+
```
27+
28+
## What it does
29+
30+
- Validates DataFrame input.
31+
- Writes directly via `write_pandas` or via S3 stage flow for large data.
32+
- Adds a Snowflake table URL to Metaflow card output.
33+
34+
## Parameters
35+
36+
| Parameter | Type | Required | Description |
37+
| ------------------- | ------------------------------- | -------: | ------------------------------------------------------------------------------------------------------------- |
38+
| `table_name` | `str` | Yes | Destination Snowflake table name. |
39+
| `df` | `pd.DataFrame` | Yes | DataFrame to publish. |
40+
| `add_created_date` | `bool` | No | If `True`, adds a `created_date` UTC timestamp column before publish. |
41+
| `chunk_size` | `int \| None` | No | Number of rows per uploaded chunk. If not provided, calculate based on DataFrame size. |
42+
| `compression` | `Literal["snappy", "gzip"]` | No | Compression codec used for staged parquet files. |
43+
| `warehouse` | `str \| None` | No | Snowflake warehouse override for this operation. Supports `XS`/`MED`/`XL` shortcuts or a full warehouse name. |
44+
| `parallel` | `int` | No | Number of upload threads used by `write_pandas` path. |
45+
| `quote_identifiers` | `bool` | No | If `False`, passes identifiers unquoted so Snowflake applies uppercase coercion. |
46+
| `auto_create_table` | `bool` | No | If `True`, creates destination table when missing. |
47+
| `overwrite` | `bool` | No | If `True`, replaces existing table contents. |
48+
| `use_logical_type` | `bool` | No | Controls parquet logical type handling when loading data. |
49+
| `use_utc` | `bool` | No | If `True`, uses UTC timezone for Snowflake session. |
50+
| `use_s3_stage` | `bool` | No | If `True`, publishes via S3 stage flow; otherwise uses direct `write_pandas`. |
51+
| `table_definition` | `list[tuple[str, str]] \| None` | No | Optional Snowflake table schema; used by S3 stage flow when table creation is needed. |
52+
53+
**Returns:** `None`
54+
55+
## Limitations
56+
57+
- When `use_s3_stage=True`, some column data types may not map exactly as expected between pandas/parquet and Snowflake.
58+
- If needed, provide an explicit `table_definition` and/or cast columns before publishing to avoid data type mismatches.

0 commit comments

Comments
 (0)