diff --git a/README.md b/README.md index 9d4dad2..7eea4c5 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ ## Metaflow API Docs - [BatchInferencePipeline](docs/metaflow/batch_inference_pipeline.md) +- [create_ownership_registry_view](docs/metaflow/create_ownership_registry_view.md) - [make_pydantic_parser_fn](docs/metaflow/make_pydantic_parser_fn.md) - [publish](docs/metaflow/publish.md) - [publish_pandas](docs/metaflow/publish_pandas.md) diff --git a/docs/metaflow/create_ownership_registry_view.md b/docs/metaflow/create_ownership_registry_view.md new file mode 100644 index 0000000..0dd3e1a --- /dev/null +++ b/docs/metaflow/create_ownership_registry_view.md @@ -0,0 +1,46 @@ +# `create_ownership_registry_view` + +Source: `ds_platform_utils.metaflow.registry.create_ownership_registry_view` + +Creates (or replaces) the central **table-ownership registry view**, +`PATTERN_DB.DATA_SCIENCE.TABLE_OWNERSHIP_REGISTRY`. The view pivots the object tags +applied by [`publish`](publish.md) / [`publish_pandas`](publish_pandas.md) into one row +per table, exposing `owner`, `team`, `domain`, `project`, `status`, `sla` and `contact`. + +This is a one-time admin helper. + +## Signature + +```python +create_ownership_registry_view(conn: SnowflakeConnection | None = None) -> str +``` + +| Parameter | Type | Required | Description | +| --------- | ----------------------------- | -------: | ------------------------------------------------------------------------ | +| `conn` | `SnowflakeConnection \| None` | No | Open Snowflake connection. If omitted, one is created via `get_snowflake_connection()`. | + +**Returns:** the executed `CREATE OR REPLACE VIEW` SQL string. + +## Usage + +```python +from ds_platform_utils.metaflow import create_ownership_registry_view + +create_ownership_registry_view() +``` + +Then query it: + +```sql +SELECT * FROM PATTERN_DB.DATA_SCIENCE.TABLE_OWNERSHIP_REGISTRY +ORDER BY team, table_name; +``` + +## Notes + +- **No refresh needed.** A view is not materialized — it re-runs its query on every read, + so it is always live. +- **~2h lag.** The view reads `SNOWFLAKE.ACCOUNT_USAGE.TAG_REFERENCES`, which itself lags + up to ~2 hours. For the current value of a single table's tag, use + `SYSTEM$GET_TAG('PATTERN_DB.DATA_SCIENCE.TABLE_OWNER', '', 'table')` instead. +- **Adoption-based.** Only tables that have at least one ownership tag appear in the view. diff --git a/docs/metaflow/publish.md b/docs/metaflow/publish.md index 13485d9..89e766f 100644 --- a/docs/metaflow/publish.md +++ b/docs/metaflow/publish.md @@ -14,6 +14,7 @@ publish( ctx: dict[str, Any] | None = None, warehouse: Literal["XS", "MED", "XL"] = None, use_utc: bool = True, + tags: dict[str, str] | None = None, ) -> None ``` @@ -22,6 +23,8 @@ publish( - Reads SQL from a string or `.sql` path. - Runs write/audit/publish operations through Snowflake. - Adds operation details and table links to the Metaflow card when available. +- **Automatically applies ownership object tags to production tables** (see + [Ownership tags](#ownership-tags) below). ## Parameters @@ -33,6 +36,7 @@ publish( | `ctx` | `dict[str, Any] \| None` | No | Optional template substitution context for SQL operations. | | `warehouse` | `Literal["XS", "MED", "XL"] \| None` | No | Snowflake warehouse override for this operation. Supports `XS`/`MED`/`XL` shortcuts or a full warehouse name. | | `use_utc` | `bool` | No | If `True`, uses UTC timezone for Snowflake session. | +| `tags` | `dict[str, str] \| None` | No | Overrides for the ownership object tags applied to the published table. See [Ownership tags](#ownership-tags).| **Returns:** `None` @@ -47,3 +51,47 @@ publish( audits=["SELECT COUNT(*) > 0 FROM PATTERN_DB.{{schema}}.{{table_name}}"], ) ``` + +## Ownership tags + +When publishing to **production**, `publish()` automatically applies the table-ownership +object tags from the table-ownership RFC. The seven tags are: + +| Tag | Source | Always set? | +| --------------- | ------------------------------------------------------- | --------------- | +| `TABLE_OWNER` | Metaflow `current.username` | yes | +| `TABLE_TEAM` | `data-science` | yes | +| `TABLE_DOMAIN` | `ds.domain` Metaflow tag, else `unknown` | yes | +| `TABLE_PROJECT` | `ds.project` Metaflow tag, else `unknown` | yes | +| `TABLE_STATUS` | `active` (override allows `active`/`development`/`testing`/`deprecated`/`archived`/`retired`) | yes | +| `TABLE_SLA` | override only (`streaming`/`realtime`/`hourly`/`daily`/`weekly`/`monthly`/`quarterly`/`ad_hoc`/`on_demand`) | only if given | +| `TABLE_CONTACT` | override only (Slack channel or email) | only if given | + +> **`TABLE_DOMAIN` / `TABLE_PROJECT` depend on flow tags.** These are read from the +> `ds.domain` / `ds.project` Metaflow tags. If a flow runs without them, the value falls +> back to the literal string `unknown` and a warning is printed (the same warning used +> for select.dev cost tracking). Make sure your flow carries `--tag "ds.domain:..."` and +> `--tag "ds.project:..."` — these are applied automatically in CI and the standard `poe` +> run commands in the monorepo — or pass `tags={"domain": ..., "project": ...}` explicitly. + +Pass `tags=` to override any value. Keys may be `owner`/`team`/`domain`/`project`/ +`status`/`sla`/`contact` (optionally `TABLE_`-prefixed): + +```python +publish( + table_name="OUT_OF_STOCK_ADS", + query="sql/create_training_data.sql", + tags={"sla": "daily", "contact": "#ds-recsys", "status": "active"}, +) +``` + +Notes: + +- Tags are applied **only to production tables**. Non-prod (`DATA_SCIENCE_STAGE`) runs + apply no tags. +- The tag *definitions* must first be created once by a Snowflake admin (the RFC + `CREATE TAG` setup). Until then, tagging is **skipped with a warning** — the publish + still succeeds. +- Invalid `status`/`sla` values raise `ValueError` before any data is written. +- Tagged tables surface in the `TABLE_OWNERSHIP_REGISTRY` view (see + `create_ownership_registry_view`). diff --git a/docs/metaflow/publish_pandas.md b/docs/metaflow/publish_pandas.md index b00185d..e41573e 100644 --- a/docs/metaflow/publish_pandas.md +++ b/docs/metaflow/publish_pandas.md @@ -22,6 +22,7 @@ publish_pandas( use_utc: bool = True, use_s3_stage: bool = False, table_definition: list[tuple[str, str]] | None = None, + tags: dict[str, str] | None = None, ) -> None ``` @@ -30,6 +31,8 @@ publish_pandas( - Validates DataFrame input. - Writes directly via `write_pandas` or via S3 stage flow for large data. - Adds a Snowflake table URL to Metaflow card output. +- **Automatically applies ownership object tags to production tables** (see + [Ownership tags](#ownership-tags) below). ## Parameters @@ -49,9 +52,34 @@ publish_pandas( | `use_utc` | `bool` | No | If `True`, uses UTC timezone for Snowflake session. | | `use_s3_stage` | `bool` | No | If `True`, publishes via S3 stage flow; otherwise uses direct `write_pandas`. | | `table_definition` | `list[tuple[str, str]] \| None` | No | Optional Snowflake table schema; used by S3 stage flow when table creation is needed. | +| `tags` | `dict[str, str] \| None` | No | Overrides for the ownership object tags applied to the published table. See [Ownership tags](#ownership-tags).| **Returns:** `None` +## Ownership tags + +When publishing to **production**, `publish_pandas()` automatically applies the same +seven table-ownership object tags as [`publish`](publish.md#ownership-tags): +`TABLE_OWNER`, `TABLE_TEAM`, `TABLE_DOMAIN`, `TABLE_PROJECT`, `TABLE_STATUS` and +(when provided via `tags=`) `TABLE_SLA` / `TABLE_CONTACT`. + +```python +publish_pandas( + table_name="MY_TABLE", + df=df, + tags={"sla": "daily", "contact": "#ds-recsys"}, +) +``` + +- Tags are applied **only to production tables**; non-prod runs apply none. +- `TABLE_DOMAIN` / `TABLE_PROJECT` come from the `ds.domain` / `ds.project` Metaflow tags; + if a flow runs without them they fall back to the literal `unknown` and a warning is + printed. Ensure the flow carries those tags (automatic in CI / standard `poe` commands) + or pass `tags={"domain": ..., "project": ...}`. See [`publish`](publish.md#ownership-tags). +- Tag *definitions* must first be created by a Snowflake admin (RFC `CREATE TAG` setup); + until then tagging is **skipped with a warning** and the publish still succeeds. +- Invalid `status`/`sla` values raise `ValueError` before any data is written. + ## Limitations - When `use_s3_stage=True`, some column data types may not map exactly as expected between pandas/parquet and Snowflake. diff --git a/pyproject.toml b/pyproject.toml index 121b9cd..b2d50ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,12 @@ [project] name = "ds-platform-utils" -version = "0.4.2" +version = "0.5.0" description = "Utility library for Pattern Data Science." readme = "README.md" authors = [ { name = "Amit Vikram Raj", email = "amit.raj@pattern.com" }, - { name = "Eric Riddoch", email = "eric.riddoch@pattern.com" } + { name = "Eric Riddoch", email = "eric.riddoch@pattern.com" }, + { name = "Vinay Shende", email = "vinay.shende@pattern.com" } ] # requires-python = ">=3.7" dependencies = [ diff --git a/src/ds_platform_utils/_snowflake/object_tags.py b/src/ds_platform_utils/_snowflake/object_tags.py new file mode 100644 index 0000000..f575306 --- /dev/null +++ b/src/ds_platform_utils/_snowflake/object_tags.py @@ -0,0 +1,202 @@ +"""Build and apply Snowflake object tags for table ownership / governance. + +Implements the tag schema from the "Snowflake table ownership via object tags" RFC. +Tags are applied only to production tables, so both the tag *definitions* and the +*tables* live in ``PATTERN_DB.DATA_SCIENCE``. + +The tag *definitions* must be created once by a Snowflake admin (see the RFC's +``CREATE TAG`` setup). Until they exist, :func:`apply_table_tags` warns and leaves the +(already successful) table write untouched -- tagging must never break a publish. +""" + +import re +from typing import TYPE_CHECKING, Dict, Optional + +from ds_platform_utils._snowflake.run_query import _execute_sql +from ds_platform_utils.metaflow._consts import PROD_SCHEMA +from ds_platform_utils.sql_utils import get_select_dev_query_tags + +if TYPE_CHECKING: + from snowflake.connector import SnowflakeConnection + +DATABASE = "PATTERN_DB" + +# A Snowflake unquoted identifier: starts with a letter/underscore, then letters/digits/underscores. +# Identifiers (table name, schema, tag names) are interpolated directly into the SET TAG SQL, so we +# reject anything else to avoid malformed SQL or statement injection. (Tag *values* are safely +# single-quoted + escaped via _quote and are not subject to this check.) +_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + +# RFC allowed-value lists for the constrained tags. +TABLE_STATUS_ALLOWED = {"active", "development", "testing", "deprecated", "archived", "retired"} +TABLE_SLA_ALLOWED = { + "streaming", + "realtime", + "hourly", + "daily", + "weekly", + "monthly", + "quarterly", + "ad_hoc", + "on_demand", +} +DEFAULT_TABLE_STATUS = "active" + +# All seven RFC tag names. +TAG_OWNER = "TABLE_OWNER" +TAG_TEAM = "TABLE_TEAM" +TAG_DOMAIN = "TABLE_DOMAIN" +TAG_PROJECT = "TABLE_PROJECT" +TAG_STATUS = "TABLE_STATUS" +TAG_SLA = "TABLE_SLA" +TAG_CONTACT = "TABLE_CONTACT" + +# Maps accepted override keys (case-insensitive, with or without the ``TABLE_`` prefix) +# to the canonical tag name. +_OVERRIDE_ALIASES = { + "owner": TAG_OWNER, + "team": TAG_TEAM, + "domain": TAG_DOMAIN, + "project": TAG_PROJECT, + "status": TAG_STATUS, + "sla": TAG_SLA, + "contact": TAG_CONTACT, +} + + +def _normalize_overrides(tags_override: Optional[Dict[str, str]]) -> Dict[str, str]: + """Normalize caller override keys to canonical tag names. + + Accepts e.g. ``owner``, ``OWNER`` or ``TABLE_OWNER`` -> ``TABLE_OWNER``. + + :param tags_override: Raw override dict supplied by the caller. + :return: Override dict keyed by canonical tag name. + :raises ValueError: If an override key does not map to a known tag. + """ + normalized: Dict[str, str] = {} + for key, value in (tags_override or {}).items(): + canonical = _OVERRIDE_ALIASES.get(key.strip().lower().removeprefix("table_")) + if canonical is None: + raise ValueError( + f"Unknown tag override key {key!r}. Allowed keys: {sorted(_OVERRIDE_ALIASES)} " + f"(optionally prefixed with 'TABLE_')." + ) + normalized[canonical] = value + return normalized + + +def build_table_tags( + tags_override: Optional[Dict[str, str]] = None, + current_obj: Optional[object] = None, +) -> Dict[str, str]: + """Build the final ``{TAG_NAME: value}`` dict to apply to a published table. + + OWNER / TEAM / DOMAIN / PROJECT are derived from the Metaflow run context (reusing + :func:`get_select_dev_query_tags`); STATUS defaults to ``active``. Any value may be + overridden via ``tags_override``. SLA and CONTACT are only included when supplied + via ``tags_override`` (they cannot be inferred). + + :param tags_override: Optional overrides, keyed by ``owner``/``TABLE_OWNER``/etc. + :param current_obj: Optional Metaflow ``current`` stand-in (for testing). + :return: Mapping of canonical tag name to value, ready to apply. + :raises ValueError: If STATUS or SLA is not in its allowed-value list, or an + override key is unknown. + """ + overrides = _normalize_overrides(tags_override) + derived = get_select_dev_query_tags(current_obj=current_obj) + + tags: Dict[str, str] = { + TAG_OWNER: derived["user"], + TAG_TEAM: derived["team"], + TAG_DOMAIN: derived["domain"], + TAG_PROJECT: derived["workload_id"], + TAG_STATUS: DEFAULT_TABLE_STATUS, + } + # SLA / CONTACT are only set when explicitly provided. + tags.update(overrides) + + status = tags[TAG_STATUS] + if status not in TABLE_STATUS_ALLOWED: + raise ValueError(f"TABLE_STATUS must be one of {sorted(TABLE_STATUS_ALLOWED)}, got {status!r}.") + + sla = tags.get(TAG_SLA) + if sla is not None and sla not in TABLE_SLA_ALLOWED: + raise ValueError(f"TABLE_SLA must be one of {sorted(TABLE_SLA_ALLOWED)}, got {sla!r}.") + + # Drop any tags whose value is None/empty so we never emit ``= ''``. + return {name: str(value) for name, value in tags.items() if value is not None and str(value) != ""} + + +def _quote(value: str) -> str: + """Escape a tag value for a single-quoted SQL literal (double embedded quotes).""" + return value.replace("'", "''") + + +def _validate_identifier(value: str, kind: str) -> None: + """Reject anything that isn't a plain unquoted SQL identifier. + + Identifiers are interpolated unquoted into the ``SET TAG`` SQL, so a value containing + e.g. ``;`` or whitespace could produce invalid SQL or statement injection. + + :param value: Identifier to check (table name, schema, or tag name). + :param kind: Human-readable label used in the error message. + :raises ValueError: If ``value`` is not a valid unquoted identifier. + """ + if not _IDENTIFIER_RE.match(value): + raise ValueError(f"Invalid {kind} {value!r}; expected an unquoted identifier (letters/numbers/underscore).") + + +def build_set_tag_sql(table_name: str, tags: Dict[str, str], schema: str = PROD_SCHEMA) -> str: + """Build a single ``ALTER TABLE ... SET TAG`` statement. + + Tag definitions and the table both live in ``schema`` (``DATA_SCIENCE`` for prod). + + :param table_name: Table to tag (upper-cased to match Snowflake's stored identifier). + :param tags: Mapping of tag name to value (e.g. from :func:`build_table_tags`). + :param schema: Schema holding both the table and the tag definitions. + :return: The ``ALTER TABLE`` SQL string. + :raises ValueError: If ``tags`` is empty, or any identifier (table/schema/tag name) is invalid. + """ + if not tags: + raise ValueError("No tags to apply.") + table = table_name.upper() + _validate_identifier(table, "table_name") + _validate_identifier(schema, "schema") + for name in tags: + _validate_identifier(name, "tag name") + assignments = ",\n ".join(f"{DATABASE}.{schema}.{name} = '{_quote(value)}'" for name, value in tags.items()) + return f"ALTER TABLE {DATABASE}.{schema}.{table}\n SET TAG\n {assignments};" + + +def apply_table_tags( + conn: "SnowflakeConnection", + table_name: str, + tags: Dict[str, str], + schema: str = PROD_SCHEMA, +) -> None: + """Apply object tags to a published table, warning (never raising) on failure. + + A failure here most commonly means the tag definitions have not yet been created by + an admin (see the RFC ``CREATE TAG`` setup). Because the table write has already + succeeded by this point, we log a clear warning and return rather than breaking the + publish. + + :param conn: Open Snowflake connection. + :param table_name: Table to tag. + :param tags: Mapping of tag name to value. + :param schema: Schema holding both the table and the tag definitions. + """ + if not tags: + return + try: + # Built inside the try so identifier-validation errors warn-and-skip rather than break publish. + sql = build_set_tag_sql(table_name=table_name, tags=tags, schema=schema) + _execute_sql(conn, sql) + conn.commit() + print(f"Applied ownership tags to {DATABASE}.{schema}.{table_name.upper()}: {sorted(tags)}") + except Exception as exc: # noqa: BLE001 -- tagging must never break a successful publish + print( + f"Warning: failed to apply ownership tags to {DATABASE}.{schema}.{table_name.upper()} " + f"({exc}). The table was published successfully; tags were skipped. This usually means the " + f"tag definitions have not been created yet by a Snowflake admin (see the table-ownership RFC)." + ) diff --git a/src/ds_platform_utils/metaflow/__init__.py b/src/ds_platform_utils/metaflow/__init__.py index 5a88f4f..403fc2d 100644 --- a/src/ds_platform_utils/metaflow/__init__.py +++ b/src/ds_platform_utils/metaflow/__init__.py @@ -1,11 +1,13 @@ from .batch_inference_pipeline import BatchInferencePipeline from .pandas import publish_pandas, query_pandas_from_snowflake +from .registry import create_ownership_registry_view from .restore_step_state import restore_step_state from .validate_config import make_pydantic_parser_fn from .write_audit_publish import publish __all__ = [ "BatchInferencePipeline", + "create_ownership_registry_view", "make_pydantic_parser_fn", "publish", "publish_pandas", diff --git a/src/ds_platform_utils/metaflow/pandas.py b/src/ds_platform_utils/metaflow/pandas.py index 9fbcd58..0d79c81 100644 --- a/src/ds_platform_utils/metaflow/pandas.py +++ b/src/ds_platform_utils/metaflow/pandas.py @@ -44,6 +44,7 @@ def publish_pandas( # noqa: PLR0913 (too many arguments) use_utc: bool = True, use_s3_stage: bool = False, table_definition: Optional[List[Tuple[str, str]]] = None, + tags: Optional[Dict[str, str]] = None, ) -> None: """Store a pandas dataframe as a Snowflake table. @@ -89,13 +90,26 @@ def publish_pandas( # noqa: PLR0913 (too many arguments) :param table_definition: Optional list of tuples specifying the column names and types for the Snowflake table. This is only used when `use_s3_stage` is True, and is required in that case. The list should be in the format: `[(col_name1, col_type1), (col_name2, col_type2), ...]`, where `col_type` is a valid Snowflake data type (e.g., 'STRING', 'NUMBER', 'TIMESTAMP_NTZ', etc.). + + :param tags: Optional overrides for the ownership/governance object tags applied to the published + table (see the table-ownership RFC). Keys may be `owner`/`team`/`domain`/`project`/`status`/`sla`/ + `contact` (optionally `TABLE_`-prefixed). OWNER/TEAM/DOMAIN/PROJECT are derived from the Metaflow + run context when not overridden, STATUS defaults to `active`, and SLA/CONTACT are only applied when + provided here. Tags are only applied to **production** tables; in non-prod runs no tags are applied. + If the tag definitions have not yet been created by a Snowflake admin, tagging is skipped with a + warning (the publish still succeeds). """ + from ds_platform_utils._snowflake.object_tags import apply_table_tags, build_table_tags + if not isinstance(df, pd.DataFrame): raise TypeError("df must be a pandas DataFrame.") if df.empty: raise ValueError("DataFrame is empty.") + # Build/validate tags up front so an invalid status/sla fails fast, before any writes. + table_tags = build_table_tags(tags_override=tags) + if add_created_date: df["created_date"] = datetime.now().astimezone(pytz.utc) @@ -136,6 +150,12 @@ def publish_pandas( # noqa: PLR0913 (too many arguments) use_logical_type=use_logical_type, ) + # Tag the published table (prod only). The S3 path has no open connection, so open one. + if current.is_production: + _tag_table_with_new_connection( + table_name=table_name, tags=table_tags, schema=schema, warehouse=warehouse, use_utc=use_utc + ) + else: conn: SnowflakeConnection = get_snowflake_connection(warehouse=warehouse, use_utc=use_utc) _execute_sql(conn, f"USE SCHEMA PATTERN_DB.{schema};") @@ -154,6 +174,10 @@ def publish_pandas( # noqa: PLR0913 (too many arguments) overwrite=overwrite, use_logical_type=use_logical_type, ) + + # Tag the published table (prod only), reusing the open connection before closing it. + if current.is_production: + apply_table_tags(conn=conn, table_name=table_name, tags=table_tags) conn.close() # Add a link to the table in Snowflake to the card @@ -165,6 +189,35 @@ def publish_pandas( # noqa: PLR0913 (too many arguments) current.card.append(Markdown(f"[View table in Snowflake]({table_url})")) +def _tag_table_with_new_connection( + table_name: str, + tags: Dict[str, str], + schema: str, + warehouse: Optional[Union[Literal["XS", "MED", "XL"], str]], + use_utc: bool, +) -> None: + """Open a short-lived connection and tag an already-published table. + + Used by the S3-stage publish path, which has no open connection. Opening the + connection happens outside ``apply_table_tags``' own error handling, so we guard it + here too: tagging must never break an already-successful publish. + """ + from ds_platform_utils._snowflake.object_tags import apply_table_tags + + tag_conn = None + try: + tag_conn = get_snowflake_connection(warehouse=warehouse, use_utc=use_utc) + apply_table_tags(conn=tag_conn, table_name=table_name, tags=tags) + except Exception as exc: # noqa: BLE001 -- tagging must never break a successful publish + print( + f"Warning: failed to open a Snowflake connection to tag PATTERN_DB.{schema}.{table_name} " + f"({exc}). The table was published successfully; tags were skipped." + ) + finally: + if tag_conn is not None: + tag_conn.close() + + def query_pandas_from_snowflake( query: Union[str, Path], warehouse: Optional[Union[Literal["XS", "MED", "XL"], str]] = None, diff --git a/src/ds_platform_utils/metaflow/registry.py b/src/ds_platform_utils/metaflow/registry.py new file mode 100644 index 0000000..a39f4ed --- /dev/null +++ b/src/ds_platform_utils/metaflow/registry.py @@ -0,0 +1,56 @@ +"""Central table-ownership registry view (RFC §6). + +Exposes the ownership tags applied by :func:`publish` / :func:`publish_pandas` as a +single queryable view. The view is *not* materialized -- it is always live at query +time. Its only staleness is the inherent ~2h lag of +``SNOWFLAKE.ACCOUNT_USAGE.TAG_REFERENCES`` (see the RFC risks section); no periodic +refresh is needed. It is adoption-based: only tables that have at least one ownership +tag appear. +""" + +from typing import Optional + +from ds_platform_utils._snowflake.run_query import _execute_sql + +REGISTRY_VIEW_NAME = "PATTERN_DB.DATA_SCIENCE.TABLE_OWNERSHIP_REGISTRY" + +OWNERSHIP_REGISTRY_VIEW_SQL = f""" +CREATE OR REPLACE VIEW {REGISTRY_VIEW_NAME} AS +SELECT + tr.object_name AS table_name, + MAX(CASE WHEN tr.tag_name = 'TABLE_OWNER' THEN tr.tag_value END) AS owner, + MAX(CASE WHEN tr.tag_name = 'TABLE_TEAM' THEN tr.tag_value END) AS team, + MAX(CASE WHEN tr.tag_name = 'TABLE_DOMAIN' THEN tr.tag_value END) AS domain, + MAX(CASE WHEN tr.tag_name = 'TABLE_PROJECT' THEN tr.tag_value END) AS project, + MAX(CASE WHEN tr.tag_name = 'TABLE_STATUS' THEN tr.tag_value END) AS status, + MAX(CASE WHEN tr.tag_name = 'TABLE_SLA' THEN tr.tag_value END) AS sla, + MAX(CASE WHEN tr.tag_name = 'TABLE_CONTACT' THEN tr.tag_value END) AS contact +FROM SNOWFLAKE.ACCOUNT_USAGE.TAG_REFERENCES tr +WHERE tr.object_database = 'PATTERN_DB' + AND tr.object_schema = 'DATA_SCIENCE' + AND tr.domain = 'TABLE' + AND tr.tag_name IN ( + 'TABLE_OWNER', 'TABLE_TEAM', 'TABLE_DOMAIN', 'TABLE_PROJECT', + 'TABLE_STATUS', 'TABLE_SLA', 'TABLE_CONTACT' + ) +GROUP BY tr.object_name; +""" + + +def create_ownership_registry_view(conn: Optional["object"] = None) -> str: + """Create (or replace) the table-ownership registry view. + + Intended as a one-time admin helper. If ``conn`` is omitted, a connection is opened + via :func:`get_snowflake_connection`. + + :param conn: Optional open Snowflake connection. If None, one is created. + :return: The executed ``CREATE OR REPLACE VIEW`` SQL. + """ + if conn is None: + from ds_platform_utils.metaflow.snowflake_connection import get_snowflake_connection + + conn = get_snowflake_connection() + _execute_sql(conn, OWNERSHIP_REGISTRY_VIEW_SQL) + conn.commit() + print(f"Created/replaced view {REGISTRY_VIEW_NAME}.") + return OWNERSHIP_REGISTRY_VIEW_SQL diff --git a/src/ds_platform_utils/metaflow/write_audit_publish.py b/src/ds_platform_utils/metaflow/write_audit_publish.py index 1f1a709..8b20fd0 100644 --- a/src/ds_platform_utils/metaflow/write_audit_publish.py +++ b/src/ds_platform_utils/metaflow/write_audit_publish.py @@ -25,6 +25,7 @@ def publish( # noqa: PLR0913, D417 ctx: Optional[Dict[str, Any]] = None, warehouse: Optional[Union[Literal["XS", "MED", "XL"], str]] = None, use_utc: bool = True, + tags: Optional[Dict[str, str]] = None, ) -> None: """Publish a Snowflake table using the write-audit-publish (WAP) pattern via Metaflow's Snowflake connection. @@ -43,6 +44,13 @@ def publish( # noqa: PLR0913, D417 when running in the Outerbounds **Default** perimeter, and to the `OUTERBOUNDS_DATA_SCIENCE_SHARED_PROD_XS_WH` warehouse, when running in the Outerbounds **PROD** perimeter. :param use_utc: Whether to use UTC timezone for the Snowflake connection (affects timestamp fields). + :param tags: Optional overrides for the ownership/governance object tags applied to the published + table (see the table-ownership RFC). Keys may be ``owner``/``team``/``domain``/``project``/ + ``status``/``sla``/``contact`` (optionally ``TABLE_``-prefixed). OWNER/TEAM/DOMAIN/PROJECT are + derived from the Metaflow run context when not overridden, STATUS defaults to ``active``, and + SLA/CONTACT are only applied when provided here. Tags are only applied to **production** tables; + in non-prod runs no tags are applied. If the tag definitions have not yet been created by a + Snowflake admin, tagging is skipped with a warning (the publish still succeeds). Returns ------- @@ -58,12 +66,17 @@ def publish( # noqa: PLR0913, D417 query="sql/create_training_data.sql", audits=["sql/validate_training_data.sql"], warehouse="OUTERBOUNDS_DATA_SCIENCE_SHARED_DEV_XL_WH", + tags={"sla": "daily", "contact": "#ds-recsys"}, ) ``` """ + from ds_platform_utils._snowflake.object_tags import apply_table_tags, build_table_tags from ds_platform_utils._snowflake.write_audit_publish import write_audit_publish + # Build/validate tags up front so an invalid status/sla fails fast, before any writes. + table_tags = build_table_tags(tags_override=tags) + conn = get_snowflake_connection(warehouse=warehouse, use_utc=use_utc) query = get_query_from_string_or_fpath(query) @@ -86,6 +99,10 @@ def publish( # noqa: PLR0913, D417 ) last_op_was_write = operation.operation_type == "write" + # Tag the final table (prod only). Done after the SWAP so tags land on the live table. + if current.is_production: + apply_table_tags(conn=cur.connection, table_name=table_name, tags=table_tags) + def update_card_with_operation_info( operation: "SQLOperation", diff --git a/tests/unit_tests/snowflake/test__object_tags.py b/tests/unit_tests/snowflake/test__object_tags.py new file mode 100644 index 0000000..2e40a6a --- /dev/null +++ b/tests/unit_tests/snowflake/test__object_tags.py @@ -0,0 +1,197 @@ +import pytest + +from ds_platform_utils._snowflake import object_tags +from ds_platform_utils._snowflake.object_tags import ( + apply_table_tags, + build_set_tag_sql, + build_table_tags, +) + + +class FakeCurrent: + """Stand-in for ``metaflow.current`` used to drive tag derivation in tests.""" + + tags = ["ds.domain:recommendations", "ds.project:two_tower_v2"] + flow_name = "MyFlow" + project_name = "recsys-proj" + step_name = "end" + run_id = "123" + username = "john_doe" + namespace = "user:john" + is_production = True + + +def test_build_table_tags_derives_all_mappings(): + """All four context-derived tags + default STATUS are present; SLA/CONTACT omitted.""" + tags = build_table_tags(current_obj=FakeCurrent()) + + assert tags["TABLE_OWNER"] == "john_doe" + assert tags["TABLE_TEAM"] == "data-science" + assert tags["TABLE_DOMAIN"] == "recommendations" + assert tags["TABLE_PROJECT"] == "two_tower_v2" + assert tags["TABLE_STATUS"] == "active" + assert "TABLE_SLA" not in tags + assert "TABLE_CONTACT" not in tags + + +def test_build_table_tags_overrides_win(): + """Overrides (incl. alias + cased keys) replace derived values and add SLA/CONTACT.""" + tags = build_table_tags( + tags_override={"owner": "jane", "SLA": "daily", "TABLE_CONTACT": "#ds-recsys"}, + current_obj=FakeCurrent(), + ) + + assert tags["TABLE_OWNER"] == "jane" + assert tags["TABLE_SLA"] == "daily" + assert tags["TABLE_CONTACT"] == "#ds-recsys" + # Non-overridden derived values still present. + assert tags["TABLE_DOMAIN"] == "recommendations" + + +@pytest.mark.parametrize("override", [{"status": "bogus"}, {"sla": "every_minute"}]) +def test_build_table_tags_invalid_constrained_value_raises(override): + """Invalid STATUS or SLA values raise ValueError (caller error).""" + with pytest.raises(ValueError): + build_table_tags(tags_override=override, current_obj=FakeCurrent()) + + +def test_build_table_tags_unknown_key_raises(): + """An unrecognized override key raises ValueError.""" + with pytest.raises(ValueError, match="Unknown tag override key"): + build_table_tags(tags_override={"foo": "bar"}, current_obj=FakeCurrent()) + + +def test_build_set_tag_sql_format_and_escaping(): + """SQL targets DATA_SCIENCE for both table and tag, upper-cases the table, escapes quotes.""" + sql = build_set_tag_sql(table_name="my_table", tags={"TABLE_OWNER": "o'brien"}) + + assert "ALTER TABLE PATTERN_DB.DATA_SCIENCE.MY_TABLE" in sql + assert "PATTERN_DB.DATA_SCIENCE.TABLE_OWNER = 'o''brien'" in sql + assert sql.strip().endswith(";") + + +def test_build_set_tag_sql_empty_raises(): + with pytest.raises(ValueError, match="No tags to apply"): + build_set_tag_sql(table_name="t", tags={}) + + +def test_build_set_tag_sql_multiple_tags_joined(): + """Multiple tags are comma-joined under a single SET TAG / single trailing semicolon.""" + sql = build_set_tag_sql( + table_name="my_table", + tags={"TABLE_OWNER": "john_doe", "TABLE_TEAM": "data-science", "TABLE_STATUS": "active"}, + ) + + assert sql.count("SET TAG") == 1 + assert "PATTERN_DB.DATA_SCIENCE.TABLE_OWNER = 'john_doe'," in sql + assert "PATTERN_DB.DATA_SCIENCE.TABLE_TEAM = 'data-science'," in sql + assert "PATTERN_DB.DATA_SCIENCE.TABLE_STATUS = 'active'" in sql + # Exactly one statement terminator, on the last assignment only. + assert sql.count(";") == 1 + assert sql.count("=") == 3 + + +def test_build_table_tags_drops_empty_override_value(): + """An empty-string override is dropped rather than emitted as TAG = ''.""" + tags = build_table_tags(tags_override={"contact": ""}, current_obj=FakeCurrent()) + + assert "TABLE_CONTACT" not in tags + + +@pytest.mark.parametrize("bad_table", ["bad; DROP TABLE x", "has space", "1leading_digit", "", "a-b"]) +def test_build_set_tag_sql_rejects_invalid_table_name(bad_table): + """Non-identifier table names are rejected before reaching SQL.""" + with pytest.raises(ValueError, match="Invalid table_name"): + build_set_tag_sql(table_name=bad_table, tags={"TABLE_OWNER": "john_doe"}) + + +def test_build_set_tag_sql_rejects_invalid_tag_name(): + """Non-identifier tag names are rejected.""" + with pytest.raises(ValueError, match="Invalid tag name"): + build_set_tag_sql(table_name="my_table", tags={"TABLE_OWNER; DROP": "x"}) + + +def test_build_set_tag_sql_rejects_invalid_schema(): + """Non-identifier schema is rejected.""" + with pytest.raises(ValueError, match="Invalid schema"): + build_set_tag_sql(table_name="my_table", tags={"TABLE_OWNER": "x"}, schema="DATA_SCIENCE; DROP") + + +class FakeConn: + def __init__(self): + self.committed = False + + def commit(self): + """Record that commit was called.""" + self.committed = True + + +def test_apply_table_tags_swallows_errors_and_warns(monkeypatch, capsys): + """A failure applying tags must not raise and must not break the publish.""" + + def _boom(*_args, **_kwargs): + raise RuntimeError("tag 'TABLE_OWNER' does not exist") + + monkeypatch.setattr(object_tags, "_execute_sql", _boom) + conn = FakeConn() + + apply_table_tags(conn=conn, table_name="my_table", tags={"TABLE_OWNER": "john_doe"}) + + assert conn.committed is False + assert "Warning: failed to apply ownership tags" in capsys.readouterr().out + + +def test_apply_table_tags_invalid_identifier_warns_not_raises(monkeypatch, capsys): + """An invalid identifier must warn-and-skip, not propagate out of apply_table_tags.""" + executed = False + + def _spy(*_args, **_kwargs): + nonlocal executed + executed = True + + monkeypatch.setattr(object_tags, "_execute_sql", _spy) + conn = FakeConn() + + # A malformed table name would otherwise build invalid/injectable SQL. + apply_table_tags(conn=conn, table_name="bad; DROP TABLE x", tags={"TABLE_OWNER": "john_doe"}) + + assert executed is False # never reached execution + assert conn.committed is False + assert "Warning: failed to apply ownership tags" in capsys.readouterr().out + + +def test_apply_table_tags_success_executes_and_commits(monkeypatch, capsys): + """Happy path: the built SQL is executed against the conn and the change is committed.""" + captured = {} + + def _capture(conn, sql): + captured["conn"] = conn + captured["sql"] = sql + + monkeypatch.setattr(object_tags, "_execute_sql", _capture) + conn = FakeConn() + + apply_table_tags(conn=conn, table_name="my_table", tags={"TABLE_OWNER": "john_doe"}) + + assert captured["conn"] is conn + assert "ALTER TABLE PATTERN_DB.DATA_SCIENCE.MY_TABLE" in captured["sql"] + assert "PATTERN_DB.DATA_SCIENCE.TABLE_OWNER = 'john_doe'" in captured["sql"] + assert conn.committed is True + assert "Applied ownership tags" in capsys.readouterr().out + + +def test_apply_table_tags_noop_on_empty(monkeypatch): + """No tags -> no execution, no commit.""" + called = False + + def _spy(*_args, **_kwargs): + nonlocal called + called = True + + monkeypatch.setattr(object_tags, "_execute_sql", _spy) + conn = FakeConn() + + apply_table_tags(conn=conn, table_name="my_table", tags={}) + + assert called is False + assert conn.committed is False diff --git a/uv.lock b/uv.lock index a099e72..37078e2 100644 --- a/uv.lock +++ b/uv.lock @@ -479,7 +479,7 @@ wheels = [ [[package]] name = "ds-platform-utils" -version = "0.4.2" +version = "0.5.0" source = { editable = "." } dependencies = [ { name = "jinja2" },