diff --git a/.gitignore b/.gitignore index 581e7a7..a30eb31 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ package/ __pycache__/ *.py[cod] *$py.class +.vscode/ diff --git a/devel_requirements.txt b/devel_requirements.txt new file mode 100644 index 0000000..1ae9e07 --- /dev/null +++ b/devel_requirements.txt @@ -0,0 +1,3 @@ +black +pytest +pytest-mock \ No newline at end of file diff --git a/lambda_function.py b/lambda_function.py index 4ad9fa7..f4adf02 100644 --- a/lambda_function.py +++ b/lambda_function.py @@ -27,6 +27,13 @@ ON t.patron_id = {staging_table}.patron_id WHERE patron_count > 1;""" +_COLUMNS_QUERY = """ + SELECT column_name + FROM INFORMATION_SCHEMA.COLUMNS + WHERE table_name = '{table}' + ORDER BY ordinal_position; +""" + def lambda_handler(event, context): logger.info("Starting lambda processing") @@ -40,6 +47,12 @@ def lambda_handler(event, context): kms_client.close() redshift_client.connect() + # Determine all non-id columns in the staging table to construct accurate insert/update queries + staging_columns_response = redshift_client.execute_query( + _COLUMNS_QUERY.format(table=os.environ["STAGING_TABLE"]) + ) + staging_columns = [c[0] for c in staging_columns_response if c[0] != "id"] + logger.info("Checking for duplicate records") raw_duplicates = redshift_client.execute_query( _DUPLICATES_QUERY.format(staging_table=os.environ["STAGING_TABLE"]) @@ -72,8 +85,12 @@ def lambda_handler(event, context): # len(row)-2 because the row contains two extra fields from the join. placeholder_length = len(next(iter(unique_map.values()))) - 2 placeholder = ", ".join(["%s"] * placeholder_length) - insert_query = "INSERT INTO {staging_table} VALUES ({placeholder});".format( - staging_table=os.environ["STAGING_TABLE"], placeholder=placeholder + insert_query = ( + "INSERT INTO {staging_table} ({columns}) VALUES ({placeholder});".format( + staging_table=os.environ["STAGING_TABLE"], + columns=", ".join(staging_columns), + placeholder=placeholder, + ) ) queries.append((insert_query, [v[:-2] for v in unique_map.values()])) redshift_client.execute_transaction(queries) @@ -88,8 +105,9 @@ def lambda_handler(event, context): None, ), ( - "INSERT INTO {main_table} SELECT * FROM {staging_table};".format( + "INSERT INTO {main_table} ({columns}) SELECT {columns} FROM {staging_table};".format( main_table=os.environ["MAIN_TABLE"], + columns=", ".join(staging_columns), staging_table=os.environ["STAGING_TABLE"], ), None, diff --git a/requirements.txt b/requirements.txt index 89e7503..35e7f59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -nypl-py-utils==1.1.5 +nypl-py-utils[redshift-client,config-helper,kms-client,log_helper]==1.8.0 redshift_connector \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..64abf97 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,23 @@ +import os +import pytest + +TEST_ENV_VARS = { + "REDSHIFT_DB_NAME": "test_db", + "REDSHIFT_DB_HOST": "test_redshift_host", + "REDSHIFT_DB_USER": "test_redshift_user", + "REDSHIFT_DB_PASSWORD": "test_redshift_password", + "STAGING_TABLE": "test_staging_table", + "MAIN_TABLE": "test_main_table", +} + + +@pytest.fixture(scope="session", autouse=True) +def tests_setup_and_teardown(): + # Will be executed before the first test + os.environ.update(TEST_ENV_VARS) + + yield + + # Will execute after final test + for os_config in TEST_ENV_VARS.keys(): + del os.environ[os_config] \ No newline at end of file diff --git a/tests/test_helpers.py b/tests/test_helpers.py deleted file mode 100644 index 0c8105e..0000000 --- a/tests/test_helpers.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - - -class TestHelpers: - ENV_VARS = { - "REDSHIFT_DB_NAME": "test_db", - "REDSHIFT_DB_HOST": "test_redshift_host", - "REDSHIFT_DB_USER": "test_redshift_user", - "REDSHIFT_DB_PASSWORD": "test_redshift_password", - "STAGING_TABLE": "test_staging_table", - "MAIN_TABLE": "test_main_table", - } - - @classmethod - def set_env_vars(cls): - for key, value in cls.ENV_VARS.items(): - os.environ[key] = value - - @classmethod - def clear_env_vars(cls): - for key in cls.ENV_VARS.keys(): - if key in os.environ: - del os.environ[key] diff --git a/tests/test_lambda_function.py b/tests/test_lambda_function.py index f5c18d4..0383c1a 100644 --- a/tests/test_lambda_function.py +++ b/tests/test_lambda_function.py @@ -3,13 +3,13 @@ from copy import deepcopy from lambda_function import ( + _COLUMNS_QUERY, _DUPLICATE_DELETION_QUERY, _DUPLICATES_QUERY, _MAIN_DELETION_QUERY, lambda_handler, ReplaceRedshiftDataError, ) -from tests.test_helpers import TestHelpers _PLACEHOLDER = "%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s" @@ -20,7 +20,7 @@ ), None, ), - ("INSERT INTO test_main_table SELECT * FROM test_staging_table;", None), + ("INSERT INTO test_main_table (patron_id, address_hash) SELECT patron_id, address_hash FROM test_staging_table;", None), ("DELETE FROM test_staging_table;", None), ] @@ -33,15 +33,6 @@ class TestLambdaFunction: - - @classmethod - def setup_class(cls): - TestHelpers.set_env_vars() - - @classmethod - def teardown_class(cls): - TestHelpers.clear_env_vars() - @pytest.fixture def test_instance(self, mocker): mocker.patch("lambda_function.create_log") @@ -49,16 +40,19 @@ def test_instance(self, mocker): mock_kms_client.decrypt.return_value = "decrypted" mocker.patch("lambda_function.KmsClient", return_value=mock_kms_client) - def get_mock_redshift_client(self, mocker, response): + @pytest.fixture + def mock_redshift_client(self, mocker): mock_redshift_client = mocker.MagicMock() - mock_redshift_client.execute_query.return_value = response mocker.patch( "lambda_function.RedshiftClient", return_value=mock_redshift_client ) return mock_redshift_client - def test_lambda_handler_no_duplicates(self, test_instance, mocker): - mock_redshift_client = self.get_mock_redshift_client(mocker, []) + def test_lambda_handler_no_duplicates(self, test_instance, mock_redshift_client, mocker): + mock_redshift_client.execute_query.side_effect = [ + (['id'], ['patron_id'], ['address_hash']), + () + ] assert lambda_handler(None, None) == { "statusCode": 200, @@ -66,19 +60,25 @@ def test_lambda_handler_no_duplicates(self, test_instance, mocker): } mock_redshift_client.connect.assert_called_once() - mock_redshift_client.execute_query.assert_called_once_with( - _DUPLICATES_QUERY.format(staging_table="test_staging_table") + mock_redshift_client.execute_query.assert_has_calls( + [ + mocker.call(_COLUMNS_QUERY.format(table="test_staging_table")), + mocker.call( + _DUPLICATES_QUERY.format(staging_table="test_staging_table") + ) + ] ) mock_redshift_client.execute_transaction.assert_called_once_with( _PRIMARY_REDSHIFT_QUERIES ) mock_redshift_client.close_connection.assert_called_once() - def test_lambda_handler_exact_duplicates(self, test_instance, mocker): + def test_lambda_handler_exact_duplicates(self, test_instance, mock_redshift_client, mocker): EXACT_DUPLICATE_PATRONS = _TEST_PATRONS + _TEST_PATRONS - mock_redshift_client = self.get_mock_redshift_client( - mocker, EXACT_DUPLICATE_PATRONS - ) + mock_redshift_client.execute_query.side_effect = [ + (['id'], ['patron_id'], ['address_hash']), + EXACT_DUPLICATE_PATRONS + ] assert lambda_handler(None, None) == { "statusCode": 200, @@ -86,8 +86,13 @@ def test_lambda_handler_exact_duplicates(self, test_instance, mocker): } mock_redshift_client.connect.assert_called_once() - mock_redshift_client.execute_query.assert_called_once_with( - _DUPLICATES_QUERY.format(staging_table="test_staging_table") + mock_redshift_client.execute_query.assert_has_calls( + [ + mocker.call(_COLUMNS_QUERY.format(table="test_staging_table")), + mocker.call( + _DUPLICATES_QUERY.format(staging_table="test_staging_table") + ) + ] ) mock_redshift_client.execute_transaction.assert_has_calls( [ @@ -101,7 +106,7 @@ def test_lambda_handler_exact_duplicates(self, test_instance, mocker): None, ), ( - f"INSERT INTO test_staging_table VALUES ({_PLACEHOLDER});", + f"INSERT INTO test_staging_table (patron_id, address_hash) VALUES ({_PLACEHOLDER});", [v[:-2] for v in _TEST_PATRONS], ), ] @@ -111,19 +116,25 @@ def test_lambda_handler_exact_duplicates(self, test_instance, mocker): ) mock_redshift_client.close_connection.assert_called_once() - def test_lambda_handler_inexact_duplicates(self, test_instance, mocker): + def test_lambda_handler_inexact_duplicates(self, test_instance, mock_redshift_client, mocker): INEXACT_DUPLICATE_PATRONS = deepcopy(_TEST_PATRONS) + deepcopy(_TEST_PATRONS) INEXACT_DUPLICATE_PATRONS[-1][1] = "different_address" - mock_redshift_client = self.get_mock_redshift_client( - mocker, INEXACT_DUPLICATE_PATRONS - ) + mock_redshift_client.execute_query.side_effect = [ + (['id'], ['patron_id'], ['address_hash']), + INEXACT_DUPLICATE_PATRONS + ] with pytest.raises(ReplaceRedshiftDataError) as e: lambda_handler(None, None) assert "Duplicate patron ids with different values found" in e.value.message mock_redshift_client.connect.assert_called_once() - mock_redshift_client.execute_query.assert_called_once_with( - _DUPLICATES_QUERY.format(staging_table="test_staging_table") + mock_redshift_client.execute_query.assert_has_calls( + [ + mocker.call(_COLUMNS_QUERY.format(table="test_staging_table")), + mocker.call( + _DUPLICATES_QUERY.format(staging_table="test_staging_table") + ) + ] ) mock_redshift_client.execute_transaction.assert_not_called()