diff --git a/airbyte_cdk/sources/file_based/file_types/csv_parser.py b/airbyte_cdk/sources/file_based/file_types/csv_parser.py index 3b3dc4a0f..ca900b38e 100644 --- a/airbyte_cdk/sources/file_based/file_types/csv_parser.py +++ b/airbyte_cdk/sources/file_based/file_types/csv_parser.py @@ -70,7 +70,9 @@ def read_data( file, file_read_mode, config_format.encoding, logger ) as fp: try: - headers = self._get_headers(fp, config_format, dialect_name) + headers, n_cols_stripped = self._get_headers( + fp, config_format, dialect_name, logger + ) except UnicodeError: raise AirbyteTracedException( message=f"{FileBasedSourceError.ENCODING_ERROR.value} Expected encoding: {config_format.encoding}", @@ -88,6 +90,9 @@ def read_data( for row in reader: lineno += 1 + if n_cols_stripped and None in row: + row.pop(None) + # The row was not properly parsed if any of the values are None. This will most likely occur if there are more columns # than headers or more headers dans columns if None in row: @@ -116,11 +121,18 @@ def read_data( finally: csv.unregister_dialect(dialect_name) - def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]: + def _get_headers( + self, + fp: IOBase, + config_format: CsvFormat, + dialect_name: str, + logger: logging.Logger, + ) -> Tuple[List[str], int]: """Assumes the fp is pointing to the beginning of the files and will reset it as such.""" # Note that this method assumes the dialect has already been registered if we're parsing the headers if isinstance(config_format.header_definition, CsvHeaderUserProvided): - return config_format.header_definition.column_names + fp.seek(0) + return config_format.header_definition.column_names, 0 if isinstance(config_format.header_definition, CsvHeaderAutogenerated): self._skip_rows( @@ -132,17 +144,38 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) self._skip_rows(fp, config_format.skip_rows_before_header) reader = csv.reader(fp, dialect=dialect_name) # type: ignore headers = list(next(reader)) + headers, n_stripped = self._strip_trailing_empty_headers(headers, logger) - empty_count = sum(1 for h in headers if not h or h.isspace()) - if empty_count: + empty_interior = sum(1 for h in headers if not h or h.isspace()) + if empty_interior: raise AirbyteTracedException( - message="CSV header row contains empty column name(s). Remove trailing delimiters or empty columns from the header row.", - internal_message=f"Found {empty_count} empty/whitespace-only column name(s) in header: {headers}", + message="CSV header row contains empty column name(s) in non-trailing positions.", + internal_message=f"Found {empty_interior} empty/whitespace-only column name(s) in non-trailing positions of header: {headers}", failure_type=FailureType.config_error, ) + fp.seek(0) + return headers, n_stripped + fp.seek(0) - return headers + return headers, 0 + + @staticmethod + def _strip_trailing_empty_headers( + headers: List[str], logger: logging.Logger + ) -> Tuple[List[str], int]: + """Strip trailing empty/whitespace-only column names caused by trailing delimiters.""" + original_count = len(headers) + while headers and (not headers[-1] or headers[-1].isspace()): + headers.pop() + n_stripped = original_count - len(headers) + if n_stripped: + logger.warning( + "Ignoring %d trailing empty column name(s) in CSV header. " + "This commonly occurs when the file has trailing delimiters.", + n_stripped, + ) + return headers, n_stripped def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> List[str]: """ diff --git a/unit_tests/sources/file_based/file_types/test_csv_parser.py b/unit_tests/sources/file_based/file_types/test_csv_parser.py index 3238a96bb..7fa678a59 100644 --- a/unit_tests/sources/file_based/file_types/test_csv_parser.py +++ b/unit_tests/sources/file_based/file_types/test_csv_parser.py @@ -791,17 +791,17 @@ def test_encoding_is_passed_to_stream_reader() -> None: @pytest.mark.parametrize( "header_row, expected_empty_count", [ - pytest.param("col1,col2,col3,,,", 3, id="trailing_empty_columns"), pytest.param("col1,,col3", 1, id="middle_empty_column"), pytest.param(",col2,col3", 1, id="leading_empty_column"), - pytest.param("col1,col2, ", 1, id="whitespace_only_column"), + pytest.param("col1,,col3, , ", 1, id="interior_empty_after_trailing_strip"), ], ) -def test_get_headers_raises_on_empty_column_names( +def test_get_headers_raises_on_non_trailing_empty_column_names( header_row: str, expected_empty_count: int ) -> None: csv_reader = _CsvReader() config_format = CsvFormat() + test_logger = logging.getLogger("test") fp = io.StringIO(header_row) dialect_name = f"test_{uuid4()}" @@ -816,19 +816,34 @@ def test_get_headers_raises_on_empty_column_names( try: with pytest.raises(AirbyteTracedException) as exc_info: - csv_reader._get_headers(fp, config_format, dialect_name) + csv_reader._get_headers(fp, config_format, dialect_name, test_logger) assert exc_info.value.failure_type == FailureType.config_error assert "empty column name" in exc_info.value.message + assert "non-trailing" in exc_info.value.message assert f"{expected_empty_count} empty" in exc_info.value.internal_message finally: csv.unregister_dialect(dialect_name) -def test_get_headers_accepts_valid_headers() -> None: +@pytest.mark.parametrize( + "header_row, expected_headers, expected_stripped", + [ + pytest.param("col1,col2,col3", ["col1", "col2", "col3"], 0, id="no_empty_columns"), + pytest.param("col1,col2,col3,,,", ["col1", "col2", "col3"], 3, id="trailing_empty_columns"), + pytest.param("col1,col2, ", ["col1", "col2"], 1, id="trailing_whitespace_column"), + pytest.param( + "col1,col2,col3,, , ", ["col1", "col2", "col3"], 3, id="trailing_mixed_empty_whitespace" + ), + ], +) +def test_get_headers_strips_trailing_empty_columns( + header_row: str, expected_headers: List[str], expected_stripped: int +) -> None: csv_reader = _CsvReader() config_format = CsvFormat() - fp = io.StringIO("col1,col2,col3") + test_logger = logging.getLogger("test") + fp = io.StringIO(header_row) dialect_name = f"test_{uuid4()}" csv.register_dialect( @@ -841,13 +856,14 @@ def test_get_headers_accepts_valid_headers() -> None: ) try: - headers = csv_reader._get_headers(fp, config_format, dialect_name) - assert headers == ["col1", "col2", "col3"] + headers, n_stripped = csv_reader._get_headers(fp, config_format, dialect_name, test_logger) + assert headers == expected_headers + assert n_stripped == expected_stripped finally: csv.unregister_dialect(dialect_name) -def test_read_data_raises_on_empty_column_names() -> None: +def test_read_data_strips_trailing_empty_columns() -> None: config_format = CsvFormat() config = Mock() config.name = "config_name" @@ -855,26 +871,57 @@ def test_read_data_raises_on_empty_column_names() -> None: file = RemoteFile(uri="test.csv", last_modified=datetime.now()) stream_reader = Mock(spec=AbstractFileBasedStreamReader) - logger = Mock(spec=logging.Logger) + test_logger = Mock(spec=logging.Logger) csv_reader = _CsvReader() stream_reader.open_file.return_value = ( CsvFileBuilder().with_data(["col1,col2,col3,,,", "v1,v2,v3,v4,v5,v6"]).build() ) + rows = list( + csv_reader.read_data( + config, + file, + stream_reader, + test_logger, + FileReadMode.READ, + ) + ) + + assert len(rows) == 1 + assert rows[0] == {"col1": "v1", "col2": "v2", "col3": "v3"} + test_logger.warning.assert_called_once() + assert "trailing empty" in test_logger.warning.call_args[0][0] + + +def test_read_data_raises_on_non_trailing_empty_column_names() -> None: + config_format = CsvFormat() + config = Mock() + config.name = "config_name" + config.format = config_format + + file = RemoteFile(uri="test.csv", last_modified=datetime.now()) + stream_reader = Mock(spec=AbstractFileBasedStreamReader) + test_logger = Mock(spec=logging.Logger) + csv_reader = _CsvReader() + + stream_reader.open_file.return_value = ( + CsvFileBuilder().with_data(["col1,,col3", "v1,v2,v3"]).build() + ) + with pytest.raises(AirbyteTracedException) as exc_info: list( csv_reader.read_data( config, file, stream_reader, - logger, + test_logger, FileReadMode.READ, ) ) assert exc_info.value.failure_type == FailureType.config_error - assert "empty column name" in exc_info.value.message + assert "non-trailing" in exc_info.value.message def test_parse_records_preserves_mismatch_error_detail() -> None: