Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ dependencies = [
"deepecho>=0.8.0;python_version>='3.14'",
"rdt>=1.18.2;python_version<'3.14'",
"rdt>=1.20.0;python_version>='3.14'",
"sdmetrics>=0.21.0;python_version<'3.14'",
"sdmetrics>=0.26.0;python_version>='3.14'",
"sdmetrics>=0.28.0",
'platformdirs>=4.0',
'pyyaml>=6.0.1',
]
Expand Down
25 changes: 24 additions & 1 deletion sdv/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@
MODELABLE_SDTYPES = ['categorical', 'numerical', 'datetime', 'boolean']


def _cast_to_iterable(value):
def _cast_to_iterable(value, iterable_type=None):
"""Return a ``list`` if the input object is not a ``list`` or ``tuple``."""
if isinstance(value, (list, tuple)):
if iterable_type:
return iterable_type(value)

return value

return [value]
Expand Down Expand Up @@ -513,3 +516,23 @@ def _validate_correct_synthesizer_loading(synthesizer, cls):
f"but got '{synthesizer_name}'. Please ensure you are loading the correct "
f'synthesizer type.'
)


def _sort_keys(keys):
return sorted(keys, key=lambda key: key if isinstance(key, str) else key[0])


def _get_unreferenced_keys(parent_columns, child_columns):
indicator = _create_unique_name(
'_merge', list(child_columns.columns) + list(parent_columns.columns)
)
merged = child_columns.merge(
parent_columns,
left_on=list(child_columns.columns),
right_on=list(parent_columns.columns),
how='left',
indicator=indicator,
)
merged = merged[merged[indicator] == 'left_only'][list(child_columns.columns)]
merged = merged.dropna(how='all')
return merged.dropna(how='all')
23 changes: 21 additions & 2 deletions sdv/cag/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,30 @@
import numpy as np
import pandas as pd

from sdv._utils import _cast_to_iterable
from sdv.cag._errors import ConstraintNotMetError
from sdv.errors import RefitWarning, SynthesizerInputError, TableNameError
from sdv.metadata import Metadata


def _validate_columns_not_primary_key(table_name, columns, metadata):
"""Validate that none of the columns are in the primary key for the table."""
primary_key = metadata.tables[table_name].primary_key
if metadata.tables[table_name]._primary_key_is_composite:
key_columns = set(primary_key).intersection(set(columns))
if key_columns:
pk_columns = "', '".join(sorted(key_columns))
raise ConstraintNotMetError(
f"Cannot apply constraint because ['{pk_columns}'] are "
f"part of the primary key for table '{table_name}'."
)
elif primary_key in columns:
raise ConstraintNotMetError(
f"Cannot apply constraint because '{primary_key}' is the "
f"primary key of table '{table_name}'."
)


def _validate_columns_in_metadata(table_name, columns, metadata):
"""Validates that the columns are in the metadata.

Expand Down Expand Up @@ -137,9 +156,9 @@ def _remove_columns_from_metadata(metadata, table_name, columns_to_drop):
if isinstance(metadata, Metadata):
metadata = metadata.to_dict()
column_set = set(columns_to_drop)
primary_key = metadata['tables'][table_name].get('primary_key')
primary_key = _cast_to_iterable(metadata['tables'][table_name].get('primary_key'))
for column in column_set:
if primary_key and primary_key == column:
if primary_key and column in primary_key:
raise ValueError('Cannot remove primary key from Metadata')
del metadata['tables'][table_name]['columns'][column]

Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/fixed_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -67,6 +68,7 @@ def _validate_constraint_with_metadata(self, metadata):
"""
_validate_table_and_column_names(self.table_name, self.column_names, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, self.column_names, metadata)
for column in self.column_names:
col_sdtype = metadata.tables[table_name].columns[column]['sdtype']
if col_sdtype not in ['boolean', 'categorical']:
Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/fixed_increments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sdv.cag._utils import (
_get_is_valid_dict,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -67,6 +68,7 @@ def _validate_constraint_with_metadata(self, metadata):
self.table_name, columns=[self.column_name], metadata=metadata
)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, [self.column_name], metadata)
col_sdtype = metadata.tables[table_name].columns[self.column_name]['sdtype']
if col_sdtype != 'numerical':
raise ConstraintNotMetError(
Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/inequality.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -93,6 +94,7 @@ def _validate_constraint_with_metadata(self, metadata):
columns = [self._low_column_name, self._high_column_name]
_validate_table_and_column_names(self.table_name, columns, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, columns, metadata)
for column in columns:
col_sdtype = metadata.tables[table_name].columns[column]['sdtype']
if col_sdtype not in ['numerical', 'datetime']:
Expand Down
3 changes: 3 additions & 0 deletions sdv/cag/one_hot_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -73,6 +74,8 @@ def _validate_constraint_with_metadata(self, metadata):
If any of the validations fail.
"""
_validate_table_and_column_names(self.table_name, self._column_names, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, self._column_names, metadata)

def _get_valid_table_data(self, table_data):
one_hot_data = table_data[self._column_names]
Expand Down
2 changes: 2 additions & 0 deletions sdv/cag/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_get_is_valid_dict,
_is_list_of_type,
_remove_columns_from_metadata,
_validate_columns_not_primary_key,
_validate_table_and_column_names,
_validate_table_name_if_defined,
)
Expand Down Expand Up @@ -126,6 +127,7 @@ def _validate_constraint_with_metadata(self, metadata):
columns = [self._low_column_name, self._middle_column_name, self._high_column_name]
_validate_table_and_column_names(self.table_name, columns, metadata)
table_name = self._get_single_table_name(metadata)
_validate_columns_not_primary_key(table_name, columns, metadata)
for column in columns:
col_sdtype = metadata.tables[table_name].columns[column]['sdtype']
if col_sdtype not in ['numerical', 'datetime']:
Expand Down
9 changes: 8 additions & 1 deletion sdv/metadata/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,14 @@ def add_column_relationship(
super().add_column_relationship(table_name, relationship_type, column_names)

def set_primary_key(self, column_name, table_name=None):
"""Set the primary key of a table."""
"""Set the primary key of a table.

Args:
column_name (str, list[str]):
Name (or list of names) of the primary key column(s).
table_name (str):
Name of the table to set the primary key.
"""
table_name = self._handle_table_name(table_name)
super().set_primary_key(table_name, column_name)

Expand Down
99 changes: 61 additions & 38 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

import pandas as pd

from sdv._utils import _cast_to_iterable, _load_data_from_csv
from sdv._utils import (
_cast_to_iterable,
_format_invalid_values_string,
_get_unreferenced_keys,
_load_data_from_csv,
)
from sdv.errors import InvalidDataError
from sdv.logging import get_sdv_logger
from sdv.metadata.errors import InvalidMetadataError
Expand Down Expand Up @@ -62,27 +67,23 @@ def _validate_missing_relationship_keys(
"Please use 'set_primary_key' in order to set one."
)

missing_keys = set()
parent_primary_key = _cast_to_iterable(parent_primary_key)
table_primary_keys = set(_cast_to_iterable(parent_table.primary_key))
for key in parent_primary_key:
if key not in table_primary_keys:
missing_keys.add(key)

if missing_keys:
if set(parent_primary_key) != table_primary_keys:
raise InvalidMetadataError(
f'Relationship between tables ({parent_table_name}, {child_table_name}) contains '
f'an unknown primary key {missing_keys}.'
f'Relationship between tables ({parent_table_name}, {child_table_name}) '
f'has a mismatched primary key {sorted(parent_primary_key)}.'
)

missing_fk = set()
for key in set(_cast_to_iterable(child_foreign_key)):
if key not in child_table.columns:
missing_keys.add(key)
missing_fk.add(key)

if missing_keys:
if missing_fk:
raise InvalidMetadataError(
f'Relationship between tables ({parent_table_name}, {child_table_name}) '
f'contains an unknown foreign key {missing_keys}.'
f'contains an unknown foreign key {missing_fk}.'
)

@staticmethod
Expand Down Expand Up @@ -173,9 +174,14 @@ def _validate_new_foreign_key_is_not_reused(
and relationship['parent_primary_key'] == parent_primary_key
)
if foreign_key_already_used and not parent_matches:
child_foreign_key = (
f"('{child_foreign_key}')"
if isinstance(child_foreign_key, str)
else f'({child_foreign_key})'
)
raise InvalidMetadataError(
f'Relationship between tables ({parent_table_name}, {child_table_name}) uses '
f"a foreign key column ('{child_foreign_key}') that is already used in another "
f'a foreign key {child_foreign_key} that is already used in another '
'relationship.'
)

Expand All @@ -187,15 +193,23 @@ def _validate_foreign_key_uniqueness_across_relationships(
child_foreign_key,
seen_foreign_keys,
):
key = (child_table_name, child_foreign_key)
key = (
tuple(_cast_to_iterable(child_table_name)),
tuple(_cast_to_iterable(child_foreign_key)),
)
current_relationship = (parent_table_name, parent_primary_key)

if key in seen_foreign_keys:
existing_relationship = seen_foreign_keys[key]
if existing_relationship != current_relationship:
child_foreign_key = (
f"('{child_foreign_key}')"
if isinstance(child_foreign_key, str)
else f'({child_foreign_key})'
)
raise InvalidMetadataError(
f'Relationship between tables ({parent_table_name}, {child_table_name}) uses '
f"a foreign key column ('{child_foreign_key}') that is already used in another "
f'a foreign key {child_foreign_key} that is already used in another '
'relationship.'
)
else:
Expand Down Expand Up @@ -284,10 +298,10 @@ def add_relationship(
A string representing the name of the parent table.
child_table_name (str):
A string representing the name of the child table.
parent_primary_key (str or tuple):
A string or tuple of strings representing the primary key of the parent.
child_foreign_key (str or tuple):
A string or tuple of strings representing the foreign key of the child.
parent_primary_key (str or list[str]):
A string or list of strings representing the primary key of the parent.
child_foreign_key (str or list[str]):
A string or list of strings representing the foreign key of the child.

Raises:
- ``InvalidMetadataError`` if a table is missing.
Expand Down Expand Up @@ -675,8 +689,8 @@ def set_primary_key(self, table_name, column_name):
Args:
table_name (str):
Name of the table to set the primary key.
column_name (str, tulple[str]):
Name (or tuple of names) of the primary key column(s).
column_name (str, list[str]):
Name (or list of names) of the primary key column(s).
"""
self._validate_table_exists(table_name)
self.tables[table_name].set_primary_key(column_name)
Expand Down Expand Up @@ -903,22 +917,21 @@ def _validate_foreign_keys(self, data):
parent_table = data.get(relation['parent_table_name'])

if isinstance(child_table, pd.DataFrame) and isinstance(parent_table, pd.DataFrame):
child_column = child_table[relation['child_foreign_key']]
parent_column = parent_table[relation['parent_primary_key']]
missing_values = child_column[~child_column.isin(parent_column)].unique()
missing_values = missing_values[~pd.isna(missing_values)]

if any(missing_values):
message = ', '.join(missing_values[:5].astype(str))
if len(missing_values) > 5:
message = f'({message}, + more)'
else:
message = f'({message})'

child_columns = child_table[_cast_to_iterable(relation['child_foreign_key'])]
parent_columns = parent_table[_cast_to_iterable(relation['parent_primary_key'])]
missing_values = _get_unreferenced_keys(parent_columns, child_columns)
missing_values = missing_values.drop_duplicates()
if not missing_values.empty:
foreign_key = relation['child_foreign_key']
if not isinstance(foreign_key, list):
foreign_key = f"'{foreign_key}'"

message = f'\n{_format_invalid_values_string(missing_values, 5)}'
errors.append(
f"Error: foreign key column '{relation['child_foreign_key']}' contains "
f'unknown references: {message}. Please use the method'
" 'drop_unknown_references' from sdv.utils to clean the data."
f'Error: foreign key column {foreign_key} contains '
f'unknown references:{message}\n'
"Please use the method 'drop_unknown_references' from sdv.utils "
'to clean the data.'
)

if errors:
Expand Down Expand Up @@ -1223,9 +1236,19 @@ def _set_metadata_dict(self, metadata):
) from error

for relationship in metadata.get('relationships', []):
parent_pk = relationship.get('parent_primary_key')
child_fk = relationship.get('child_foreign_key')
type_safe_pk = (
[str(col) for col in parent_pk] if isinstance(parent_pk, list) else str(parent_pk)
)
type_safe_fk = (
[str(col) for col in child_fk] if isinstance(parent_pk, list) else str(child_fk)
)
type_safe_relationships = {
key: str(value) if not isinstance(value, str) else value
for key, value in relationship.items()
'parent_table_name': str(relationship.get('parent_table_name')),
'child_table_name': str(relationship.get('child_table_name')),
'parent_primary_key': type_safe_pk,
'child_foreign_key': type_safe_fk,
}
self.relationships.append(type_safe_relationships)

Expand Down
Loading
Loading