From 815ca788cddb9d427cadb2656bc05e24d48935ee Mon Sep 17 00:00:00 2001 From: Raphael Tony Date: Wed, 6 May 2026 13:17:40 +0530 Subject: [PATCH] fix: harden SQL identifier handling across adapters --- .../adapters/types/bigquery/bigquery.py | 72 +++++++----- .../adapters/types/databricks/databricks.py | 94 +++++++++------- src/intugle/adapters/types/oracle/oracle.py | 32 +++--- .../adapters/types/postgres/postgres.py | 42 ++++--- .../adapters/types/snowflake/snowflake.py | 66 +++++++---- .../adapters/types/sqlserver/sqlserver.py | 50 ++++++--- src/intugle/adapters/utils.py | 102 ++++++++++++++++- src/intugle/data_product.py | 21 ++-- tests/adapters/test_databricks_adapter.py | 5 +- tests/adapters/test_sql_identifier_safety.py | 104 ++++++++++++++++++ 10 files changed, 437 insertions(+), 151 deletions(-) create mode 100644 tests/adapters/test_sql_identifier_safety.py diff --git a/src/intugle/adapters/types/bigquery/bigquery.py b/src/intugle/adapters/types/bigquery/bigquery.py index f121a05..1c84009 100644 --- a/src/intugle/adapters/types/bigquery/bigquery.py +++ b/src/intugle/adapters/types/bigquery/bigquery.py @@ -1,4 +1,5 @@ import time + from typing import TYPE_CHECKING, Any, Optional import numpy as np @@ -8,7 +9,12 @@ from intugle.adapters.factory import AdapterFactory from intugle.adapters.models import ColumnProfile, DataSetData, ProfilingOutput from intugle.adapters.types.bigquery.models import BigQueryConfig, BigQueryConnectionConfig -from intugle.adapters.utils import convert_to_native +from intugle.adapters.utils import ( + convert_to_native, + quote_identifier, + quote_identifier_parts, + split_identifier_path, +) from intugle.core import settings from intugle.core.utilities.processing import string_standardization @@ -97,16 +103,13 @@ def connect(self): def _get_fqn(self, identifier: str) -> str: """Gets the fully qualified name for a table identifier.""" - if "." in identifier: - # Already has project or dataset prefix - parts = identifier.split(".") - if len(parts) == 2: - # dataset.table format - return f"`{self._project_id}.{identifier}`" - elif len(parts) == 3: - # project.dataset.table format - return f"`{identifier}`" - return f"`{self._project_id}.{self._dataset_id}.{identifier}`" + parts = split_identifier_path(identifier, max_parts=3) + if len(parts) == 1: + parts = [self._project_id, self._dataset_id, parts[0]] + elif len(parts) == 2: + parts = [self._project_id, parts[0], parts[1]] + + return quote_identifier_parts(parts, quote_char="`", compound=True) @staticmethod def check_data(data: Any) -> BigQueryConfig: @@ -133,20 +136,32 @@ def profile(self, data: BigQueryConfig, table_name: str) -> ProfilingOutput: """Profile a BigQuery table.""" data = self.check_data(data) fqn = self._get_fqn(data.identifier) + identifier_parts = split_identifier_path(data.identifier, max_parts=3) + if len(identifier_parts) == 1: + project_id, dataset_id, table_identifier = self._project_id, self._dataset_id, identifier_parts[0] + elif len(identifier_parts) == 2: + project_id, dataset_id, table_identifier = self._project_id, identifier_parts[0], identifier_parts[1] + else: + project_id, dataset_id, table_identifier = identifier_parts # Get total count count_query = f"SELECT COUNT(*) as count FROM {fqn}" total_count = self._execute_sql(count_query)[0]["count"] # Get column information from INFORMATION_SCHEMA + information_schema = quote_identifier_parts( + [project_id, dataset_id, "INFORMATION_SCHEMA", "COLUMNS"], + quote_char="`", + compound=True, + ) schema_query = f""" SELECT column_name, data_type - FROM `{self._project_id}.{self._dataset_id}.INFORMATION_SCHEMA.COLUMNS` + FROM {information_schema} WHERE table_name = @table_name ORDER BY ordinal_position """ job_config = bigquery.QueryJobConfig( - query_parameters=[bigquery.ScalarQueryParameter("table_name", "STRING", data.identifier)] + query_parameters=[bigquery.ScalarQueryParameter("table_name", "STRING", table_identifier)] ) query_job = self.client.query(schema_query, job_config=job_config) rows = [dict(row) for row in query_job.result()] @@ -172,13 +187,14 @@ def column_profile( """Profile a specific column in a BigQuery table.""" data = self.check_data(data) fqn = self._get_fqn(data.identifier) + safe_column_name = quote_identifier(column_name, quote_char="`") start_ts = time.time() # Null and distinct counts query = f""" SELECT - COUNTIF(`{column_name}` IS NULL) as null_count, - COUNT(DISTINCT `{column_name}`) as distinct_count + COUNTIF({safe_column_name} IS NULL) as null_count, + COUNT(DISTINCT {safe_column_name}) as distinct_count FROM {fqn} """ result = self._execute_sql(query)[0] @@ -188,9 +204,9 @@ def column_profile( # Sampling for distinct values sample_query = f""" - SELECT DISTINCT CAST(`{column_name}` AS STRING) as value + SELECT DISTINCT CAST({safe_column_name} AS STRING) as value FROM {fqn} - WHERE `{column_name}` IS NOT NULL + WHERE {safe_column_name} IS NOT NULL LIMIT {dtype_sample_limit} """ distinct_values_result = self._execute_sql(sample_query) @@ -209,9 +225,9 @@ def column_profile( remaining_sample_size = dtype_sample_limit - len(distinct_values) if remaining_sample_size > 0: additional_samples_query = f""" - SELECT CAST(`{column_name}` AS STRING) as value + SELECT CAST({safe_column_name} AS STRING) as value FROM {fqn} - WHERE `{column_name}` IS NOT NULL + WHERE {safe_column_name} IS NOT NULL ORDER BY RAND() LIMIT {remaining_sample_size} """ @@ -295,18 +311,20 @@ def intersect_count( data2 = self.check_data(table2.data) fqn1 = self._get_fqn(data1.identifier) fqn2 = self._get_fqn(data2.identifier) + col1 = quote_identifier(column1_name, quote_char="`") + col2 = quote_identifier(column2_name, quote_char="`") query = f""" SELECT COUNT(*) as count FROM ( - SELECT DISTINCT `{column1_name}` as key + SELECT DISTINCT {col1} as key FROM {fqn1} - WHERE `{column1_name}` IS NOT NULL + WHERE {col1} IS NOT NULL ) t1 INNER JOIN ( - SELECT DISTINCT `{column2_name}` as key + SELECT DISTINCT {col2} as key FROM {fqn2} - WHERE `{column2_name}` IS NOT NULL + WHERE {col2} IS NOT NULL ) t2 ON t1.key = t2.key """ @@ -319,8 +337,7 @@ def get_composite_key_uniqueness( data = self.check_data(dataset_data) fqn = self._get_fqn(data.identifier) - # Build column list with backticks - safe_columns = [f"`{col}`" for col in columns] + safe_columns = [quote_identifier(col, quote_char="`") for col in columns] columns_str = ", ".join(safe_columns) # Build null filter @@ -352,9 +369,8 @@ def intersect_composite_keys_count( fqn1 = self._get_fqn(data1.identifier) fqn2 = self._get_fqn(data2.identifier) - # Build column lists with backticks - safe_columns1 = [f"`{col}`" for col in columns1] - safe_columns2 = [f"`{col}`" for col in columns2] + safe_columns1 = [quote_identifier(col, quote_char="`") for col in columns1] + safe_columns2 = [quote_identifier(col, quote_char="`") for col in columns2] # Subquery for distinct keys from table 1 distinct_cols1 = ", ".join(safe_columns1) diff --git a/src/intugle/adapters/types/databricks/databricks.py b/src/intugle/adapters/types/databricks/databricks.py index 23e48b9..1d554c0 100644 --- a/src/intugle/adapters/types/databricks/databricks.py +++ b/src/intugle/adapters/types/databricks/databricks.py @@ -15,7 +15,13 @@ DatabricksNotebookConfig, DatabricksSQLConnectorConfig, ) -from intugle.adapters.utils import convert_to_native +from intugle.adapters.utils import ( + convert_to_native, + escape_sql_literal, + quote_identifier, + quote_identifier_parts, + split_identifier_path, +) from intugle.core import settings from intugle.core.utilities.processing import string_standardization @@ -136,19 +142,17 @@ def connect(self): def _get_fqn(self, identifier: str) -> str: """Gets the fully qualified name for a table identifier.""" - # An identifier is already fully qualified if it contains a dot. - if "." in identifier: - return identifier - - # Backticks are used to handle reserved keywords and special characters. - safe_schema = f"`{self._schema}`" - safe_identifier = f"`{identifier}`" + parts = split_identifier_path(identifier, max_parts=3) + if len(parts) > 1: + return quote_identifier_parts(parts, quote_char="`") + path_parts = [parts[0]] + if self._schema: + path_parts.insert(0, self._schema) if self.catalog: - safe_catalog = f"`{self.catalog}`" - return f"{safe_catalog}.{safe_schema}.{safe_identifier}" - - return f"{safe_schema}.{safe_identifier}" + path_parts.insert(0, self.catalog) + + return quote_identifier_parts(path_parts, quote_char="`") @staticmethod def check_data(data: Any) -> DatabricksConfig: @@ -161,16 +165,16 @@ def check_data(data: Any) -> DatabricksConfig: def _execute_sql(self, query: str) -> list[Any]: if self.spark: if self.catalog: - self.spark.sql(f"USE CATALOG `{self.catalog}`") + self.spark.sql(f"USE CATALOG {quote_identifier(self.catalog, quote_char='`')}") if self._schema: - self.spark.sql(f"USE `{self._schema}`") + self.spark.sql(f"USE {quote_identifier(self._schema, quote_char='`')}") return self.spark.sql(query).collect() elif self.connection: with self.connection.cursor() as cursor: if self.catalog: - cursor.execute(f"USE CATALOG `{self.catalog}`") + cursor.execute(f"USE CATALOG {quote_identifier(self.catalog, quote_char='`')}") if self._schema: - cursor.execute(f"USE `{self._schema}`") + cursor.execute(f"USE {quote_identifier(self._schema, quote_char='`')}") cursor.execute(query) try: return cursor.fetchall() @@ -181,16 +185,16 @@ def _execute_sql(self, query: str) -> list[Any]: def _get_pandas_df(self, query: str) -> pd.DataFrame: if self.spark: if self.catalog: - self.spark.sql(f"USE CATALOG `{self.catalog}`") + self.spark.sql(f"USE CATALOG {quote_identifier(self.catalog, quote_char='`')}") if self._schema: - self.spark.sql(f"USE `{self._schema}`") + self.spark.sql(f"USE {quote_identifier(self._schema, quote_char='`')}") return self.spark.sql(query).toPandas() elif self.connection: with self.connection.cursor() as cursor: if self.catalog: - cursor.execute(f"USE CATALOG `{self.catalog}`") + cursor.execute(f"USE CATALOG {quote_identifier(self.catalog, quote_char='`')}") if self._schema: - cursor.execute(f"USE `{self._schema}`") + cursor.execute(f"USE {quote_identifier(self._schema, quote_char='`')}") cursor.execute(query) data = cursor.fetchall() columns = [column[0] for column in cursor.description] @@ -228,13 +232,14 @@ def column_profile( ) -> Optional[ColumnProfile]: data = self.check_data(data) fqn = self._get_fqn(data.identifier) + safe_column_name = quote_identifier(column_name, quote_char="`") start_ts = time.time() # Null and distinct counts query = f""" SELECT - COUNT(CASE WHEN `{column_name}` IS NULL THEN 1 END) as null_count, - COUNT(DISTINCT `{column_name}`) as distinct_count + COUNT(CASE WHEN {safe_column_name} IS NULL THEN 1 END) as null_count, + COUNT(DISTINCT {safe_column_name}) as distinct_count FROM {fqn} """ result = self._execute_sql(query)[0] @@ -244,7 +249,7 @@ def column_profile( # Sampling sample_query = f""" - SELECT DISTINCT CAST(`{column_name}` AS STRING) FROM {fqn} WHERE `{column_name}` IS NOT NULL LIMIT {dtype_sample_limit} + SELECT DISTINCT CAST({safe_column_name} AS STRING) FROM {fqn} WHERE {safe_column_name} IS NOT NULL LIMIT {dtype_sample_limit} """ distinct_values_result = self._execute_sql(sample_query) distinct_values = [row[0] for row in distinct_values_result] @@ -261,7 +266,7 @@ def column_profile( elif distinct_count > 0 and not_null_count > 0: remaining_sample_size = dtype_sample_limit - distinct_count additional_samples_query = f""" - SELECT CAST(`{column_name}` AS STRING) FROM {fqn} WHERE `{column_name}` IS NOT NULL ORDER BY RAND() LIMIT {remaining_sample_size} + SELECT CAST({safe_column_name} AS STRING) FROM {fqn} WHERE {safe_column_name} IS NOT NULL ORDER BY RAND() LIMIT {remaining_sample_size} """ additional_samples_result = self._execute_sql(additional_samples_query) additional_samples = [row[0] for row in additional_samples_result] @@ -346,14 +351,15 @@ def _sync_metadata(self, manifest: "Manifest", sync_glossary: bool, sync_tags: b # Set table comment if sync_glossary and source.table.description: - table_comment = source.table.description.replace("'", "\\'") + table_comment = escape_sql_literal(source.table.description) self._execute_sql(f"COMMENT ON TABLE {fqn} IS '{table_comment}'") # Works for views too # Set column comments and tags for column in source.table.columns: + safe_column_name = quote_identifier(column.name, quote_char="`") if sync_glossary and column.description: - col_comment = column.description.replace("'", "\\'") - self._execute_sql(f"COMMENT ON COLUMN {fqn}.`{column.name}` IS '{col_comment}'") + col_comment = escape_sql_literal(column.description) + self._execute_sql(f"COMMENT ON COLUMN {fqn}.{safe_column_name} IS '{col_comment}'") if sync_tags and column.tags: cleaned_tags = [clean_tag(tag) for tag in column.tags] @@ -361,12 +367,12 @@ def _sync_metadata(self, manifest: "Manifest", sync_glossary: bool, sync_tags: b # FIXME: Need to differentiate between TABLES and VIEWS for setting tags try: - self._execute_sql(f"ALTER TABLE {fqn} ALTER COLUMN `{column.name}` SET TAGS ({tag_assignments})") + self._execute_sql(f"ALTER TABLE {fqn} ALTER COLUMN {safe_column_name} SET TAGS ({tag_assignments})") except Exception: try: - self._execute_sql(f"ALTER VIEW {fqn} ALTER COLUMN `{column.name}` SET TAGS ({tag_assignments})") + self._execute_sql(f"ALTER VIEW {fqn} ALTER COLUMN {safe_column_name} SET TAGS ({tag_assignments})") except Exception as e: - print(f"Could not set tags '{tag_assignments}' on {fqn}.`{column.name}`: {e}") + print(f"Could not set tags '{tag_assignments}' on {fqn}.{safe_column_name}: {e}") print("Metadata sync complete.") @@ -382,13 +388,15 @@ def _set_primary_keys(self, manifest: "Manifest"): fqn = self._get_fqn(source.table.name) pk_columns = source.table.key.columns - constraint_name = f"pk_{source.table.name}" + constraint_name = quote_identifier(clean_name(f"pk_{source.table.name}"), quote_char="`") try: for col in pk_columns: + safe_col = quote_identifier(col, quote_char="`") # First, ensure the column is not nullable - self._execute_sql(f"ALTER TABLE {fqn} ALTER COLUMN `{col}` SET NOT NULL") + self._execute_sql(f"ALTER TABLE {fqn} ALTER COLUMN {safe_col} SET NOT NULL") # Then, add the primary key constraint - self._execute_sql(f"ALTER TABLE {fqn} ADD CONSTRAINT {constraint_name} PRIMARY KEY (`" + "`, `".join(pk_columns) + "`)") + pk_columns_sql = ", ".join(quote_identifier(col, quote_char="`") for col in pk_columns) + self._execute_sql(f"ALTER TABLE {fqn} ADD CONSTRAINT {constraint_name} PRIMARY KEY ({pk_columns_sql})") print(f"Set primary key on {fqn} (`{pk_columns}`)") except Exception as e: print(f"Could not set primary key for {fqn}: {e}") @@ -409,11 +417,13 @@ def _set_foreign_keys(self, manifest: "Manifest"): child_fqn = self._get_fqn(rel.target.table) parent_fqn = self._get_fqn(rel.source.table) constraint_name = f"fk_{rel.name}" - cleaned_constraint_name = clean_name(constraint_name) + cleaned_constraint_name = quote_identifier(clean_name(constraint_name), quote_char="`") + child_columns = ", ".join(quote_identifier(col, quote_char="`") for col in rel.target.columns) + parent_columns = ", ".join(quote_identifier(col, quote_char="`") for col in rel.source.columns) self._execute_sql( f"ALTER TABLE {child_fqn} ADD CONSTRAINT {cleaned_constraint_name} " - f"FOREIGN KEY (`{'`, '.join(rel.target.columns)}`) REFERENCES {parent_fqn} (`{'`, '.join(rel.source.columns)}`)" + f"FOREIGN KEY ({child_columns}) REFERENCES {parent_fqn} ({parent_columns})" ) except Exception as e: print(f"Could not set foreign key for relationship {rel.name}: {e}") @@ -425,12 +435,14 @@ def intersect_count(self, table1: "DataSet", column1_name: str, table2: "DataSet fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) + col1 = quote_identifier(column1_name, quote_char="`") + col2 = quote_identifier(column2_name, quote_char="`") query = f""" SELECT COUNT(*) FROM ( - SELECT DISTINCT `{column1_name}` FROM {fqn1} WHERE `{column1_name}` IS NOT NULL + SELECT DISTINCT {col1} FROM {fqn1} WHERE {col1} IS NOT NULL INTERSECT - SELECT DISTINCT `{column2_name}` FROM {fqn2} WHERE `{column2_name}` IS NOT NULL + SELECT DISTINCT {col2} FROM {fqn2} WHERE {col2} IS NOT NULL ) """ return self._execute_sql(query)[0][0] @@ -438,7 +450,7 @@ def intersect_count(self, table1: "DataSet", column1_name: str, table2: "DataSet def get_composite_key_uniqueness(self, table_name: str, columns: list[str], dataset_data: DataSetData) -> int: data = self.check_data(dataset_data) fqn = self._get_fqn(data.identifier) - safe_columns = [f"`{col}`" for col in columns] + safe_columns = [quote_identifier(col, quote_char="`") for col in columns] column_list = ", ".join(safe_columns) null_cols_filter = " AND ".join(f"{c} IS NOT NULL" for c in safe_columns) @@ -463,8 +475,8 @@ def intersect_composite_keys_count( fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) - safe_columns1 = [f"`{col}`" for col in columns1] - safe_columns2 = [f"`{col}`" for col in columns2] + safe_columns1 = [quote_identifier(col, quote_char="`") for col in columns1] + safe_columns2 = [quote_identifier(col, quote_char="`") for col in columns2] # Subquery for distinct keys from table 1 distinct_cols1 = ", ".join(safe_columns1) @@ -503,4 +515,4 @@ def can_handle_databricks(df: Any) -> bool: def register(factory: AdapterFactory): if DATABRICKS_AVAILABLE: - factory.register("databricks", can_handle_databricks, DatabricksAdapter, DatabricksConfig) \ No newline at end of file + factory.register("databricks", can_handle_databricks, DatabricksAdapter, DatabricksConfig) diff --git a/src/intugle/adapters/types/oracle/oracle.py b/src/intugle/adapters/types/oracle/oracle.py index d3e9ce1..8d8ffae 100644 --- a/src/intugle/adapters/types/oracle/oracle.py +++ b/src/intugle/adapters/types/oracle/oracle.py @@ -9,7 +9,7 @@ from intugle.adapters.factory import AdapterFactory from intugle.adapters.models import ColumnProfile, DataSetData, ProfilingOutput from intugle.adapters.types.oracle.models import OracleConfig, OracleConnectionConfig -from intugle.adapters.utils import convert_to_native +from intugle.adapters.utils import convert_to_native, quote_identifier, quote_identifier_parts, split_identifier_path from intugle.core import settings from intugle.core.utilities.processing import string_standardization @@ -100,12 +100,13 @@ def connect(self): # Set current schema if different from user if params.schema_: with self.connection.cursor() as cursor: - cursor.execute(f"ALTER SESSION SET CURRENT_SCHEMA = {params.schema_}") + cursor.execute(f"ALTER SESSION SET CURRENT_SCHEMA = {quote_identifier(params.schema_)}") def _get_fqn(self, identifier: str) -> str: - if "." in identifier: - return identifier.upper() # Oracle identifiers are case-insensitive/upper by default unless quoted - return f'"{self._schema}"."{identifier}"' + parts = split_identifier_path(identifier, max_parts=2) + if len(parts) == 2: + return quote_identifier_parts(parts) + return quote_identifier_parts([self._schema, parts[0]]) @staticmethod def check_data(data: Any) -> OracleConfig: @@ -140,8 +141,9 @@ def _get_pandas_df(self, query: str, params: Optional[list | dict] = None) -> pd def profile(self, data: OracleConfig, table_name: str) -> ProfilingOutput: data = self.check_data(data) - # Assuming identifier is table name - table_upper = data.identifier.upper() + identifier_parts = split_identifier_path(data.identifier, max_parts=2) + schema_name = identifier_parts[0].upper() if len(identifier_parts) == 2 else self._schema + table_upper = identifier_parts[-1].upper() # Count fqn = self._get_fqn(data.identifier) @@ -155,7 +157,7 @@ def profile(self, data: OracleConfig, table_name: str) -> ProfilingOutput: FROM ALL_TAB_COLUMNS WHERE OWNER = :owner AND TABLE_NAME = :table_name """ - rows = self._execute_sql(query, {"owner": self._schema, "table_name": table_upper}) + rows = self._execute_sql(query, {"owner": schema_name, "table_name": table_upper}) columns = [row["COLUMN_NAME"] for row in rows] dtypes = {row["COLUMN_NAME"]: row["DATA_TYPE"] for row in rows} @@ -180,7 +182,7 @@ def column_profile( start_ts = time.time() # Oracle treats empty strings as NULLs sometimes, but we'll stick to standard SQL - safe_col = f'"{column_name}"' + safe_col = quote_identifier(column_name) query = f""" SELECT @@ -312,13 +314,15 @@ def intersect_count(self, table1: "DataSet", column1_name: str, table2: "DataSet fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) + col1 = quote_identifier(column1_name) + col2 = quote_identifier(column2_name) # Use INTERSECT query = f""" SELECT COUNT(*) as CNT FROM ( - SELECT DISTINCT "{column1_name}" FROM {fqn1} WHERE "{column1_name}" IS NOT NULL + SELECT DISTINCT {col1} FROM {fqn1} WHERE {col1} IS NOT NULL INTERSECT - SELECT DISTINCT "{column2_name}" FROM {fqn2} WHERE "{column2_name}" IS NOT NULL + SELECT DISTINCT {col2} FROM {fqn2} WHERE {col2} IS NOT NULL ) """ return self._execute_sql(query)[0]["CNT"] @@ -326,7 +330,7 @@ def intersect_count(self, table1: "DataSet", column1_name: str, table2: "DataSet def get_composite_key_uniqueness(self, table_name: str, columns: list[str], dataset_data: DataSetData) -> int: data = self.check_data(dataset_data) fqn = self._get_fqn(data.identifier) - safe_columns = [f'"{col}"' for col in columns] + safe_columns = [quote_identifier(col) for col in columns] column_list = ", ".join(safe_columns) null_cols_filter = " AND ".join(f"{c} IS NOT NULL" for c in safe_columns) @@ -351,8 +355,8 @@ def intersect_composite_keys_count( fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) - safe_columns1 = [f'"{col}"' for col in columns1] - safe_columns2 = [f'"{col}"' for col in columns2] + safe_columns1 = [quote_identifier(col) for col in columns1] + safe_columns2 = [quote_identifier(col) for col in columns2] # Subquery for distinct keys from table 1 distinct_cols1 = ", ".join(safe_columns1) diff --git a/src/intugle/adapters/types/postgres/postgres.py b/src/intugle/adapters/types/postgres/postgres.py index 57ee0eb..b010823 100644 --- a/src/intugle/adapters/types/postgres/postgres.py +++ b/src/intugle/adapters/types/postgres/postgres.py @@ -11,7 +11,12 @@ from intugle.adapters.factory import AdapterFactory from intugle.adapters.models import ColumnProfile, DataSetData, ProfilingOutput from intugle.adapters.types.postgres.models import PostgresConfig, PostgresConnectionConfig -from intugle.adapters.utils import convert_to_native +from intugle.adapters.utils import ( + convert_to_native, + quote_identifier, + quote_identifier_parts, + split_identifier_path, +) from intugle.core import settings from intugle.core.utilities.processing import string_standardization @@ -145,13 +150,14 @@ async def _connect_async(self): port=params.port, database=params.database, ) - await self.connection.execute(f"SET search_path TO {self._schema}") + await self.connection.execute(f"SET search_path TO {quote_identifier(self._schema)}") def _get_fqn(self, identifier: str) -> str: """Gets the fully qualified name for a table identifier.""" - if "." in identifier: - return identifier - return f'"{self._schema}"."{identifier}"' + parts = split_identifier_path(identifier, max_parts=2) + if len(parts) == 2: + return quote_identifier_parts(parts) + return quote_identifier_parts([self._schema, parts[0]]) @staticmethod def check_data(data: Any) -> PostgresConfig: @@ -174,6 +180,9 @@ def _get_pandas_df(self, query: str, *args) -> pd.DataFrame: def profile(self, data: PostgresConfig, table_name: str) -> ProfilingOutput: data = self.check_data(data) fqn = self._get_fqn(data.identifier) + identifier_parts = split_identifier_path(data.identifier, max_parts=2) + schema_name = identifier_parts[0] if len(identifier_parts) == 2 else self._schema + table_identifier = identifier_parts[-1] total_count = self._execute_sql(f"SELECT COUNT(*) FROM {fqn}")[0][0] @@ -182,7 +191,7 @@ def profile(self, data: PostgresConfig, table_name: str) -> ProfilingOutput: FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2 """ - rows = self._execute_sql(query, self._schema, data.identifier) + rows = self._execute_sql(query, schema_name, table_identifier) columns = [row["column_name"] for row in rows] dtypes = {row["column_name"]: row["data_type"] for row in rows} @@ -203,13 +212,14 @@ def column_profile( ) -> Optional[ColumnProfile]: data = self.check_data(data) fqn = self._get_fqn(data.identifier) + safe_column_name = quote_identifier(column_name) start_ts = time.time() # Null and distinct counts query = f""" SELECT - COUNT(*) FILTER (WHERE "{column_name}" IS NULL) as null_count, - COUNT(DISTINCT "{column_name}") as distinct_count + COUNT(*) FILTER (WHERE {safe_column_name} IS NULL) as null_count, + COUNT(DISTINCT {safe_column_name}) as distinct_count FROM {fqn} """ result = self._execute_sql(query)[0] @@ -219,7 +229,7 @@ def column_profile( # Sampling sample_query = f""" - SELECT DISTINCT CAST("{column_name}" AS VARCHAR) FROM {fqn} WHERE "{column_name}" IS NOT NULL LIMIT {dtype_sample_limit} + SELECT DISTINCT CAST({safe_column_name} AS VARCHAR) FROM {fqn} WHERE {safe_column_name} IS NOT NULL LIMIT {dtype_sample_limit} """ distinct_values_result = self._execute_sql(sample_query) distinct_values = [row[0] for row in distinct_values_result] @@ -236,7 +246,7 @@ def column_profile( elif distinct_count > 0 and not_null_count > 0: remaining_sample_size = dtype_sample_limit - distinct_count additional_samples_query = f""" - SELECT CAST("{column_name}" AS VARCHAR) FROM {fqn} WHERE "{column_name}" IS NOT NULL ORDER BY RANDOM() LIMIT {remaining_sample_size} + SELECT CAST({safe_column_name} AS VARCHAR) FROM {fqn} WHERE {safe_column_name} IS NOT NULL ORDER BY RANDOM() LIMIT {remaining_sample_size} """ additional_samples_result = self._execute_sql(additional_samples_query) additional_samples = [row[0] for row in additional_samples_result] @@ -301,12 +311,14 @@ def intersect_count(self, table1: "DataSet", column1_name: str, table2: "DataSet fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) + col1 = quote_identifier(column1_name) + col2 = quote_identifier(column2_name) query = f""" SELECT COUNT(*) FROM ( - SELECT DISTINCT "{column1_name}" FROM {fqn1} WHERE "{column1_name}" IS NOT NULL + SELECT DISTINCT {col1} FROM {fqn1} WHERE {col1} IS NOT NULL INTERSECT - SELECT DISTINCT "{column2_name}" FROM {fqn2} WHERE "{column2_name}" IS NOT NULL + SELECT DISTINCT {col2} FROM {fqn2} WHERE {col2} IS NOT NULL ) as t """ return self._execute_sql(query)[0][0] @@ -314,7 +326,7 @@ def intersect_count(self, table1: "DataSet", column1_name: str, table2: "DataSet def get_composite_key_uniqueness(self, table_name: str, columns: list[str], dataset_data: DataSetData) -> int: data = self.check_data(dataset_data) fqn = self._get_fqn(data.identifier) - safe_columns = [f'"{col}"' for col in columns] + safe_columns = [quote_identifier(col) for col in columns] column_list = ", ".join(safe_columns) null_cols_filter = " AND ".join(f"{c} IS NOT NULL" for c in safe_columns) @@ -339,8 +351,8 @@ def intersect_composite_keys_count( fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) - safe_columns1 = [f'"{col}"' for col in columns1] - safe_columns2 = [f'"{col}"' for col in columns2] + safe_columns1 = [quote_identifier(col) for col in columns1] + safe_columns2 = [quote_identifier(col) for col in columns2] # Subquery for distinct keys from table 1 distinct_cols1 = ", ".join(safe_columns1) diff --git a/src/intugle/adapters/types/snowflake/snowflake.py b/src/intugle/adapters/types/snowflake/snowflake.py index e01c43e..033a81c 100644 --- a/src/intugle/adapters/types/snowflake/snowflake.py +++ b/src/intugle/adapters/types/snowflake/snowflake.py @@ -18,9 +18,16 @@ ProfilingOutput, ) from intugle.adapters.types.snowflake.models import SnowflakeConfig, SnowflakeConnectionConfig -from intugle.adapters.utils import convert_to_native +from intugle.adapters.utils import ( + convert_to_native, + escape_sql_literal, + quote_identifier, + quote_identifier_parts, + split_identifier_path, +) from intugle.core import settings -from intugle.exporters.snowflake import clean_name, quote_identifier +from intugle.exporters.snowflake import clean_name +from intugle.exporters.snowflake import quote_identifier as snowflake_quote_identifier try: import snowflake.snowpark.functions as F @@ -107,9 +114,21 @@ def check_data(data: Any) -> SnowflakeConfig: raise TypeError("Input must be a snowflake config.") return data + def _get_fqn(self, identifier: str) -> str: + parts = split_identifier_path(identifier, max_parts=3) + if len(parts) > 1: + return quote_identifier_parts(parts) + + path_parts = [parts[0]] + if self._schema: + path_parts.insert(0, self._schema) + if self._database: + path_parts.insert(0, self._database) + return quote_identifier_parts(path_parts) + def profile(self, data: SnowflakeConfig, table_name: str) -> ProfilingOutput: data = self.check_data(data) - table = self.session.table(data.identifier) + table = self.session.table(self._get_fqn(data.identifier)) total_count = table.count() columns = table.columns dtypes = {field.name: str(field.datatype) for field in table.schema.fields} @@ -129,7 +148,7 @@ def column_profile( dtype_sample_limit: int = 10000, ) -> Optional[ColumnProfile]: data = self.check_data(data) - table = self.session.table(data.identifier) + table = self.session.table(self._get_fqn(data.identifier)) start_ts = time.time() @@ -198,7 +217,7 @@ def execute(self, query: str): def to_df(self, data: SnowflakeConfig, table_name: str): data = self.check_data(data) - df = self.session.table(data.identifier).to_pandas() + df = self.session.table(self._get_fqn(data.identifier)).to_pandas() df.columns = [col.strip('"') for col in df.columns] return df @@ -213,13 +232,14 @@ def _clean_column_quotes(sql: str) -> str: return re.sub(r'""(.*?)""', r'"\1"', sql) query = _clean_column_quotes(query) + fqn = self._get_fqn(table_name) if materialize == "table": self.session.sql( - f"CREATE OR REPLACE TABLE {table_name} AS {query}" + f"CREATE OR REPLACE TABLE {fqn} AS {query}" ).collect() else: self.session.sql( - f"CREATE OR REPLACE VIEW {table_name} AS {query}" + f"CREATE OR REPLACE VIEW {fqn} AS {query}" ).collect() return query @@ -253,20 +273,20 @@ def _sync_metadata(self, manifest: "Manifest"): # Apply comments and tags to tables and columns for source in manifest.sources.values(): # Construct the fully qualified table name using details from profiles.yml - full_table_name = f"{database}.{schema}.{source.table.name}" + full_table_name = quote_identifier_parts([database, schema, source.table.name]) # Set table comment if source.table.description: - table_comment = source.table.description.replace("'", "''") + table_comment = escape_sql_literal(source.table.description) self.session.sql(f"ALTER TABLE {full_table_name} SET COMMENT = '{table_comment}'").collect() # Set column comments and tags for column in source.table.columns: - comment = (column.description or "").replace("'", "''") + comment = escape_sql_literal(column.description or "") # Set column comment self.session.sql( - f"ALTER TABLE {full_table_name} MODIFY COLUMN {quote_identifier(column.name)} COMMENT '{comment}'" + f"ALTER TABLE {full_table_name} MODIFY COLUMN {snowflake_quote_identifier(column.name)} COMMENT '{comment}'" ).collect() # Set column tags @@ -298,13 +318,13 @@ def deploy_semantic_model(self, manifest: "Manifest", **kwargs): table_clauses = [] for source in manifest.sources.values(): table_alias = clean_name(source.table.name) - full_table_name = f"{database}.{schema}.{source.table.name}" + full_table_name = quote_identifier_parts([database, schema, source.table.name]) clause = f"{table_alias} AS {full_table_name}" if source.table.key: clause += ' PRIMARY KEY ("' + '", "'.join(source.table.key.columns) + '")' if source.table.description: - comment = source.table.description.replace("'", "''") + comment = escape_sql_literal(source.table.description) clause += f" COMMENT = '{comment}'" table_clauses.append(clause) @@ -313,10 +333,10 @@ def deploy_semantic_model(self, manifest: "Manifest", **kwargs): for rel in manifest.relationships.values(): # The table with the FK is the "referencing" table - table_alias = rel.target.table + table_alias = clean_name(rel.target.table) column = '"' + '", "'.join(rel.target.columns) + '"' # The table with the PK is the "referenced" table - ref_table_alias = rel.source.table + ref_table_alias = clean_name(rel.source.table) ref_column = '"' + '", "'.join(rel.source.columns) + '"' clause = f"{clean_name(rel.name)} AS {table_alias}({column}) REFERENCES {ref_table_alias}({ref_column})" @@ -329,9 +349,9 @@ def deploy_semantic_model(self, manifest: "Manifest", **kwargs): table_alias = clean_name(source.table.name) for column in source.table.columns: col_alias = clean_name(column.name) - expr = f"{table_alias}.{col_alias} AS {quote_identifier(column.name)}" + expr = f"{table_alias}.{col_alias} AS {snowflake_quote_identifier(column.name)}" if column.description: - comment = column.description.replace("'", "''") + comment = escape_sql_literal(column.description) expr += f" COMMENT = '{comment}'" if column.category == "measure": @@ -340,7 +360,7 @@ def deploy_semantic_model(self, manifest: "Manifest", **kwargs): dimension_clauses.append(expr) # -- Assemble the final SQL statement -- - sql = f"CREATE OR REPLACE SEMANTIC VIEW {model_name}\n" + sql = f"CREATE OR REPLACE SEMANTIC VIEW {quote_identifier(model_name)}\n" sql += f" TABLES ({', '.join(table_clauses)})\n" if relationship_clauses: sql += f" RELATIONSHIPS ({', '.join(relationship_clauses)})\n" @@ -358,15 +378,15 @@ def intersect_count(self, table1: "DataSet", column1_name: str, table2: "DataSet table1_adapter = self.check_data(table1.data) table2_adapter = self.check_data(table2.data) - table1_df = self.session.table(table1_adapter.identifier) - table2_df = self.session.table(table2_adapter.identifier) + table1_df = self.session.table(self._get_fqn(table1_adapter.identifier)) + table2_df = self.session.table(self._get_fqn(table2_adapter.identifier)) intersect_df = table1_df.select(column1_name).intersect(table2_df.select(column2_name)) return intersect_df.count() def get_composite_key_uniqueness(self, table_name: str, columns: list[str], dataset_data: DataSetData) -> int: data = self.check_data(dataset_data) - table = self.session.table(data.identifier) + table = self.session.table(self._get_fqn(data.identifier)) # Drop rows where any of the key columns have null values and count distinct distinct_count = table.dropna(subset=columns).select(columns).distinct().count() @@ -382,8 +402,8 @@ def intersect_composite_keys_count( table1_adapter = self.check_data(table1.data) table2_adapter = self.check_data(table2.data) - df1 = self.session.table(table1_adapter.identifier) - df2 = self.session.table(table2_adapter.identifier) + df1 = self.session.table(self._get_fqn(table1_adapter.identifier)) + df2 = self.session.table(self._get_fqn(table2_adapter.identifier)) # Get unique combinations of composite keys, dropping nulls df1_unique_keys = df1.dropna(subset=columns1).select(columns1).distinct() diff --git a/src/intugle/adapters/types/sqlserver/sqlserver.py b/src/intugle/adapters/types/sqlserver/sqlserver.py index 71816a9..3de926c 100644 --- a/src/intugle/adapters/types/sqlserver/sqlserver.py +++ b/src/intugle/adapters/types/sqlserver/sqlserver.py @@ -12,7 +12,13 @@ SQLServerConfig, SQLServerConnectionConfig, ) -from intugle.adapters.utils import convert_to_native +from intugle.adapters.utils import ( + convert_to_native, + escape_sql_literal, + quote_identifier, + quote_identifier_parts, + split_identifier_path, +) from intugle.core import settings from intugle.core.utilities.processing import string_standardization @@ -111,9 +117,10 @@ def connect(self): def _get_fqn(self, identifier: str) -> str: """Gets the fully qualified name for a table identifier.""" - if "." in identifier: - return identifier - return f'[{self._schema}].[{identifier}]' + parts = split_identifier_path(identifier, max_parts=2) + if len(parts) == 2: + return quote_identifier_parts(parts, quote_char="[") + return quote_identifier_parts([self._schema, parts[0]], quote_char="[") @staticmethod def check_data(data: Any) -> SQLServerConfig: @@ -141,6 +148,9 @@ def _get_pandas_df(self, query: str, *args) -> pd.DataFrame: def profile(self, data: SQLServerConfig, table_name: str) -> ProfilingOutput: data = self.check_data(data) fqn = self._get_fqn(data.identifier) + identifier_parts = split_identifier_path(data.identifier, max_parts=2) + schema_name = identifier_parts[0] if len(identifier_parts) == 2 else self._schema + table_identifier = identifier_parts[-1] total_count = self._execute_sql(f"SELECT COUNT(*) FROM {fqn}")[0][0] @@ -149,7 +159,7 @@ def profile(self, data: SQLServerConfig, table_name: str) -> ProfilingOutput: FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ? """ - rows = self._execute_sql(query, self._schema, data.identifier) + rows = self._execute_sql(query, schema_name, table_identifier) columns = [row.COLUMN_NAME for row in rows] dtypes = {row.COLUMN_NAME: row.DATA_TYPE for row in rows} @@ -170,13 +180,14 @@ def column_profile( ) -> Optional[ColumnProfile]: data = self.check_data(data) fqn = self._get_fqn(data.identifier) + safe_column_name = quote_identifier(column_name, quote_char="[") start_ts = time.time() # Null and distinct counts query = f""" SELECT - SUM(CASE WHEN [{column_name}] IS NULL THEN 1 ELSE 0 END) as null_count, - COUNT(DISTINCT [{column_name}]) as distinct_count + SUM(CASE WHEN {safe_column_name} IS NULL THEN 1 ELSE 0 END) as null_count, + COUNT(DISTINCT {safe_column_name}) as distinct_count FROM {fqn} """ result = self._execute_sql(query)[0] @@ -186,7 +197,7 @@ def column_profile( # Sampling sample_query = f""" - SELECT DISTINCT TOP ({dtype_sample_limit}) [{column_name}] FROM {fqn} WHERE [{column_name}] IS NOT NULL + SELECT DISTINCT TOP ({dtype_sample_limit}) {safe_column_name} FROM {fqn} WHERE {safe_column_name} IS NOT NULL """ distinct_values_result = self._execute_sql(sample_query) distinct_values = [str(row[0]) for row in distinct_values_result] @@ -205,9 +216,9 @@ def column_profile( elif distinct_count > 0 and not_null_count > 0: remaining_sample_size = dtype_sample_limit - distinct_count additional_samples_query = f""" - SELECT TOP {remaining_sample_size} [{column_name}] + SELECT TOP {remaining_sample_size} {safe_column_name} FROM {fqn} - WHERE [{column_name}] IS NOT NULL + WHERE {safe_column_name} IS NOT NULL ORDER BY NEWID() """ additional_samples_result = self._execute_sql(additional_samples_query) @@ -254,13 +265,14 @@ def create_table_from_query( ) -> str: fqn = self._get_fqn(table_name) transpiled_sql = transpile(query, write="tsql")[0] + escaped_fqn_literal = escape_sql_literal(fqn) # Drop existing object if materialize == "view": - self._execute_sql(f"IF OBJECT_ID('{fqn}', 'V') IS NOT NULL DROP VIEW {fqn}") + self._execute_sql(f"IF OBJECT_ID('{escaped_fqn_literal}', 'V') IS NOT NULL DROP VIEW {fqn}") self._execute_sql(f"CREATE VIEW {fqn} AS {transpiled_sql}") else: # table - self._execute_sql(f"IF OBJECT_ID('{fqn}', 'U') IS NOT NULL DROP TABLE {fqn}") + self._execute_sql(f"IF OBJECT_ID('{escaped_fqn_literal}', 'U') IS NOT NULL DROP TABLE {fqn}") self._execute_sql(f"SELECT * INTO {fqn} FROM ({transpiled_sql}) as tmp") self.connection.commit() @@ -277,18 +289,20 @@ def intersect_count( fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) + col1 = quote_identifier(column1_name, quote_char="[") + col2 = quote_identifier(column2_name, quote_char="[") query = f""" - SELECT COUNT(DISTINCT t1.[{column1_name}]) + SELECT COUNT(DISTINCT t1.{col1}) FROM {fqn1} AS t1 - INNER JOIN {fqn2} AS t2 ON t1.[{column1_name}] = t2.[{column2_name}] + INNER JOIN {fqn2} AS t2 ON t1.{col1} = t2.{col2} """ return self._execute_sql(query)[0][0] def get_composite_key_uniqueness(self, table_name: str, columns: list[str], dataset_data: DataSetData) -> int: data = self.check_data(dataset_data) fqn = self._get_fqn(data.identifier) - safe_columns = [f"[{col}]" for col in columns] + safe_columns = [quote_identifier(col, quote_char="[") for col in columns] column_list = ", ".join(safe_columns) null_cols_filter = " AND ".join(f"{c} IS NOT NULL" for c in safe_columns) @@ -313,8 +327,8 @@ def intersect_composite_keys_count( fqn1 = self._get_fqn(table1_adapter.identifier) fqn2 = self._get_fqn(table2_adapter.identifier) - safe_columns1 = [f"[{col}]" for col in columns1] - safe_columns2 = [f"[{col}]" for col in columns2] + safe_columns1 = [quote_identifier(col, quote_char="[") for col in columns1] + safe_columns2 = [quote_identifier(col, quote_char="[") for col in columns2] # Subquery for distinct keys from table 1 distinct_cols1 = ", ".join(safe_columns1) @@ -355,4 +369,4 @@ def register(factory: AdapterFactory): if SQLSERVER_AVAILABLE: factory.register( "sqlserver", can_handle_sqlserver, SQLServerAdapter, SQLServerConfig - ) \ No newline at end of file + ) diff --git a/src/intugle/adapters/utils.py b/src/intugle/adapters/utils.py index f11673b..f7b3118 100644 --- a/src/intugle/adapters/utils.py +++ b/src/intugle/adapters/utils.py @@ -1,4 +1,7 @@ +import re + +from collections.abc import Sequence from typing import Any import numpy as np @@ -10,4 +13,101 @@ def convert_to_native(value: Any) -> Any: return value.item() if isinstance(value, (list, tuple)): return [convert_to_native(v) for v in value] - return value \ No newline at end of file + return value + + +def split_identifier_path(identifier: str, max_parts: int | None = None) -> list[str]: + """Split a dotted SQL identifier path into validated parts.""" + if not isinstance(identifier, str): + raise TypeError("SQL identifier must be a string.") + + parts = [part.strip() for part in identifier.split(".")] + if not parts or any(not part for part in parts): + raise ValueError(f"Invalid SQL identifier: {identifier!r}") + + if max_parts is not None and len(parts) > max_parts: + raise ValueError( + f"Invalid SQL identifier {identifier!r}: expected at most {max_parts} part(s)." + ) + + return parts + + +def quote_identifier(identifier: str, quote_char: str = '"') -> str: + """Quote a single SQL identifier safely for the target dialect.""" + if not isinstance(identifier, str): + raise TypeError("SQL identifier must be a string.") + if identifier == "": + raise ValueError("SQL identifier cannot be empty.") + + if quote_char == "[": + return f"[{identifier.replace(']', ']]')}]" + + escaped = identifier.replace(quote_char, quote_char * 2) + return f"{quote_char}{escaped}{quote_char}" + + +def quote_identifier_parts( + parts: Sequence[str], quote_char: str = '"', compound: bool = False +) -> str: + """Quote an already split SQL identifier path.""" + if not parts: + raise ValueError("SQL identifier path cannot be empty.") + + normalized_parts = [part.strip() for part in parts] + if any(not part for part in normalized_parts): + raise ValueError("SQL identifier path contains an empty part.") + + if quote_char == "[": + return ".".join(quote_identifier(part, quote_char) for part in normalized_parts) + + if compound: + escaped = ".".join(part.replace(quote_char, quote_char * 2) for part in normalized_parts) + return f"{quote_char}{escaped}{quote_char}" + + return ".".join(quote_identifier(part, quote_char) for part in normalized_parts) + + +def quote_identifier_path( + identifier: str, quote_char: str = '"', max_parts: int | None = None, compound: bool = False +) -> str: + """Quote a dotted SQL identifier path safely for the target dialect.""" + parts = split_identifier_path(identifier, max_parts=max_parts) + return quote_identifier_parts(parts, quote_char=quote_char, compound=compound) + + +def escape_sql_literal(value: str) -> str: + """Escape a string for inclusion inside a single-quoted SQL literal.""" + if not isinstance(value, str): + raise TypeError("SQL literal value must be a string.") + return value.replace("'", "''") + + +_PII_TAG_PATTERN = re.compile(r"[\s\-]+") +_PII_TAGS = { + "pii", + "phi", + "sensitive", + "sensitive_data", + "personal_data", + "personal_information", + "confidential", + "restricted", +} + + +def has_pii_tags(tags: Sequence[str] | None) -> bool: + """Return whether column tags indicate PII/sensitive data.""" + if not tags: + return False + + normalized_tags = { + _PII_TAG_PATTERN.sub("_", tag.strip().lower()) + for tag in tags + if isinstance(tag, str) and tag.strip() + } + + return any( + tag in _PII_TAGS or "pii" in tag or "phi" in tag or "sensitive" in tag + for tag in normalized_tags + ) diff --git a/src/intugle/data_product.py b/src/intugle/data_product.py index 71f136b..0d63bc5 100644 --- a/src/intugle/data_product.py +++ b/src/intugle/data_product.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Optional from intugle.adapters.factory import AdapterFactory +from intugle.adapters.utils import has_pii_tags, quote_identifier from intugle.analysis.models import DataSet from intugle.core import settings from intugle.core.conceptual_search.plan import DataProductPlan @@ -180,18 +181,20 @@ def get_all_field_details(self) -> dict[str, FieldDetailsModel]: # iterate through each source and get the field details (all fields / columns) for source in sources.values(): for column in source.table.columns: + table_details = source.table.details or {} + connection_source_name = table_details.get("type", "unknown") field_detail: FieldDetailsModel = FieldDetailsModel( id=f"{source.table.name}.{column.name}", name=column.name, datatype_l1=column.type, datatype_l2=column.category, - sql_code=f"\"{source.table.name}\".\"{column.name}\"", - is_pii=False, + sql_code=f"{quote_identifier(source.table.name)}.{quote_identifier(column.name)}", + is_pii=has_pii_tags(column.tags), asset_id=source.table.name, asset_name=source.table.name, - asset_details={}, + asset_details=table_details, connection_id=source.schema_, - connection_source_name="postgresql", + connection_source_name=connection_source_name, connection_credentials={}, ) field_details[field_detail.id] = field_detail @@ -229,18 +232,20 @@ def field_details_fetcher(ids: list[str]): for column in columns: column_detail = column_details[column] + table_details = table_detail.table.details or {} + connection_source_name = table_details.get("type", "unknown") field_detail: FieldDetailsModel = FieldDetailsModel( id=f"{table}.{column}", name=column_detail.name, datatype_l1=column_detail.type, datatype_l2=column_detail.category, - sql_code=f"\"{table}\".\"{column}\"", - is_pii=False, + sql_code=f"{quote_identifier(table)}.{quote_identifier(column)}", + is_pii=has_pii_tags(column_detail.tags), asset_id=table, asset_name=table, - asset_details={}, + asset_details=table_details, connection_id=table_detail.schema, - connection_source_name="postgresql", + connection_source_name=connection_source_name, connection_credentials={}, ) field_details[field_detail.id] = field_detail diff --git a/tests/adapters/test_databricks_adapter.py b/tests/adapters/test_databricks_adapter.py index 0c54ac1..6f507dc 100644 --- a/tests/adapters/test_databricks_adapter.py +++ b/tests/adapters/test_databricks_adapter.py @@ -178,10 +178,10 @@ def test_get_fqn_creates_fully_qualified_name(self, mock_adapter): fqn = mock_adapter._get_fqn("my_table") assert fqn == "`test_catalog`.`test_schema`.`my_table`" - # Already qualified (contains dots) - should return as-is + # Already qualified (contains dots) - should still be safely quoted already_qualified = "custom_cat.custom_schema.my_table" fqn_already = mock_adapter._get_fqn(already_qualified) - assert fqn_already == already_qualified + assert fqn_already == "`custom_cat`.`custom_schema`.`my_table`" def test_connection_property_exists(self, mock_adapter): """Verify adapter has connection for SQL execution.""" @@ -273,4 +273,3 @@ def test_connection_error_provides_context(self, mocker): with pytest.raises(ValueError, match="Could not create Databricks connection"): DatabricksAdapter() - diff --git a/tests/adapters/test_sql_identifier_safety.py b/tests/adapters/test_sql_identifier_safety.py new file mode 100644 index 0000000..ae3caad --- /dev/null +++ b/tests/adapters/test_sql_identifier_safety.py @@ -0,0 +1,104 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +from intugle.adapters.types.bigquery.bigquery import BigQueryAdapter +from intugle.adapters.types.databricks.databricks import DatabricksAdapter +from intugle.adapters.types.postgres.postgres import PostgresAdapter +from intugle.adapters.types.sqlserver import sqlserver as sqlserver_module +from intugle.adapters.types.sqlserver.sqlserver import SQLServerAdapter +from intugle.data_product import DataProduct + + +def test_bigquery_get_fqn_escapes_compound_identifier(): + adapter = BigQueryAdapter.__new__(BigQueryAdapter) + adapter._project_id = "test-project" + adapter._dataset_id = "analytics" + + fqn = adapter._get_fqn("users`; DROP TABLE accounts; --") + + assert fqn == "`test-project.analytics.users``; DROP TABLE accounts; --`" + + +def test_databricks_get_fqn_quotes_each_identifier_segment(): + adapter = DatabricksAdapter.__new__(DatabricksAdapter) + adapter.catalog = "main" + adapter._schema = "analytics" + + fqn = adapter._get_fqn("sales.orders`; DROP TABLE users; --") + + assert fqn == "`sales`.`orders``; DROP TABLE users; --`" + + +def test_postgres_column_profile_escapes_column_identifier(): + adapter = PostgresAdapter.__new__(PostgresAdapter) + adapter._schema = "public" + + captured_queries: list[str] = [] + + def fake_execute(query: str, *args): + captured_queries.append(query) + if "COUNT(*) FILTER" in query: + return [{"null_count": 0, "distinct_count": 1}] + return [("sample@example.com",)] + + adapter._execute_sql = fake_execute + + adapter.column_profile( + data={"identifier": 'users"; DROP TABLE accounts; --', "type": "postgres"}, + table_name="users", + column_name='email"; SELECT pg_sleep(1); --', + total_count=1, + sample_limit=1, + dtype_sample_limit=1, + ) + + assert captured_queries + assert '"email""; SELECT pg_sleep(1); --"' in captured_queries[0] + assert 'DROP TABLE accounts; --"' in captured_queries[0] + + +def test_sqlserver_create_table_from_query_escapes_object_id_literal(monkeypatch): + monkeypatch.setattr(sqlserver_module, "transpile", lambda query, write: [query], raising=False) + + adapter = SQLServerAdapter.__new__(SQLServerAdapter) + adapter._schema = "dbo" + adapter.connection = MagicMock() + + executed_queries: list[str] = [] + adapter._execute_sql = lambda query, *args: executed_queries.append(query) or [] + + adapter.create_table_from_query("orders'name", "SELECT 1", materialize="view") + + assert "OBJECT_ID('[dbo].[orders''name]', 'V')" in executed_queries[0] + assert "CREATE VIEW [dbo].[orders'name] AS SELECT 1" == executed_queries[1] + + +def test_data_product_marks_pii_and_escapes_sql_code(): + source = SimpleNamespace( + schema_="public", + schema="public", + table=SimpleNamespace( + name='users"; DROP TABLE accounts; --', + details={"type": "postgres"}, + columns=[ + SimpleNamespace( + name='email"; SELECT 1; --', + type="string", + category="dimension", + tags=["PII", "customer"], + ) + ], + ), + ) + + dp = DataProduct.__new__(DataProduct) + dp.manifest = SimpleNamespace(sources={"users": source}, relationships={}) + + field_details = dp.get_all_field_details() + field_detail = next(iter(field_details.values())) + + assert field_detail.is_pii is True + assert ( + field_detail.sql_code + == '"users""; DROP TABLE accounts; --"."email""; SELECT 1; --"' + )