Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions airbyte_cdk/sources/file_based/file_types/csv_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def read_data(
file, file_read_mode, config_format.encoding, logger
) as fp:
try:
headers = self._get_headers(fp, config_format, dialect_name)
headers, n_cols_stripped = self._get_headers(
fp, config_format, dialect_name, logger
)
except UnicodeError:
raise AirbyteTracedException(
message=f"{FileBasedSourceError.ENCODING_ERROR.value} Expected encoding: {config_format.encoding}",
Expand All @@ -88,6 +90,9 @@ def read_data(
for row in reader:
lineno += 1

if n_cols_stripped and None in row:
row.pop(None)

# The row was not properly parsed if any of the values are None. This will most likely occur if there are more columns
# than headers or more headers dans columns
if None in row:
Expand Down Expand Up @@ -116,11 +121,18 @@ def read_data(
finally:
csv.unregister_dialect(dialect_name)

def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]:
def _get_headers(
self,
fp: IOBase,
config_format: CsvFormat,
dialect_name: str,
logger: logging.Logger,
) -> Tuple[List[str], int]:
"""Assumes the fp is pointing to the beginning of the files and will reset it as such."""
# Note that this method assumes the dialect has already been registered if we're parsing the headers
if isinstance(config_format.header_definition, CsvHeaderUserProvided):
return config_format.header_definition.column_names
fp.seek(0)
return config_format.header_definition.column_names, 0

if isinstance(config_format.header_definition, CsvHeaderAutogenerated):
self._skip_rows(
Expand All @@ -132,17 +144,38 @@ def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str)
self._skip_rows(fp, config_format.skip_rows_before_header)
reader = csv.reader(fp, dialect=dialect_name) # type: ignore
headers = list(next(reader))
headers, n_stripped = self._strip_trailing_empty_headers(headers, logger)

empty_count = sum(1 for h in headers if not h or h.isspace())
if empty_count:
empty_interior = sum(1 for h in headers if not h or h.isspace())
if empty_interior:
raise AirbyteTracedException(
message="CSV header row contains empty column name(s). Remove trailing delimiters or empty columns from the header row.",
internal_message=f"Found {empty_count} empty/whitespace-only column name(s) in header: {headers}",
message="CSV header row contains empty column name(s) in non-trailing positions.",
internal_message=f"Found {empty_interior} empty/whitespace-only column name(s) in non-trailing positions of header: {headers}",
failure_type=FailureType.config_error,
)

fp.seek(0)
return headers, n_stripped

fp.seek(0)
return headers
return headers, 0

@staticmethod
def _strip_trailing_empty_headers(
headers: List[str], logger: logging.Logger
) -> Tuple[List[str], int]:
"""Strip trailing empty/whitespace-only column names caused by trailing delimiters."""
original_count = len(headers)
while headers and (not headers[-1] or headers[-1].isspace()):
headers.pop()
n_stripped = original_count - len(headers)
if n_stripped:
logger.warning(
"Ignoring %d trailing empty column name(s) in CSV header. "
"This commonly occurs when the file has trailing delimiters.",
n_stripped,
)
return headers, n_stripped

def _auto_generate_headers(self, fp: IOBase, dialect_name: str) -> List[str]:
"""
Expand Down
71 changes: 59 additions & 12 deletions unit_tests/sources/file_based/file_types/test_csv_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,17 +791,17 @@ def test_encoding_is_passed_to_stream_reader() -> None:
@pytest.mark.parametrize(
"header_row, expected_empty_count",
[
pytest.param("col1,col2,col3,,,", 3, id="trailing_empty_columns"),
pytest.param("col1,,col3", 1, id="middle_empty_column"),
pytest.param(",col2,col3", 1, id="leading_empty_column"),
pytest.param("col1,col2, ", 1, id="whitespace_only_column"),
pytest.param("col1,,col3, , ", 1, id="interior_empty_after_trailing_strip"),
],
)
def test_get_headers_raises_on_empty_column_names(
def test_get_headers_raises_on_non_trailing_empty_column_names(
header_row: str, expected_empty_count: int
) -> None:
csv_reader = _CsvReader()
config_format = CsvFormat()
test_logger = logging.getLogger("test")
fp = io.StringIO(header_row)

dialect_name = f"test_{uuid4()}"
Expand All @@ -816,19 +816,34 @@ def test_get_headers_raises_on_empty_column_names(

try:
with pytest.raises(AirbyteTracedException) as exc_info:
csv_reader._get_headers(fp, config_format, dialect_name)
csv_reader._get_headers(fp, config_format, dialect_name, test_logger)

assert exc_info.value.failure_type == FailureType.config_error
assert "empty column name" in exc_info.value.message
assert "non-trailing" in exc_info.value.message
assert f"{expected_empty_count} empty" in exc_info.value.internal_message
finally:
csv.unregister_dialect(dialect_name)


def test_get_headers_accepts_valid_headers() -> None:
@pytest.mark.parametrize(
"header_row, expected_headers, expected_stripped",
[
pytest.param("col1,col2,col3", ["col1", "col2", "col3"], 0, id="no_empty_columns"),
pytest.param("col1,col2,col3,,,", ["col1", "col2", "col3"], 3, id="trailing_empty_columns"),
pytest.param("col1,col2, ", ["col1", "col2"], 1, id="trailing_whitespace_column"),
pytest.param(
"col1,col2,col3,, , ", ["col1", "col2", "col3"], 3, id="trailing_mixed_empty_whitespace"
),
],
)
def test_get_headers_strips_trailing_empty_columns(
header_row: str, expected_headers: List[str], expected_stripped: int
) -> None:
csv_reader = _CsvReader()
config_format = CsvFormat()
fp = io.StringIO("col1,col2,col3")
test_logger = logging.getLogger("test")
fp = io.StringIO(header_row)

dialect_name = f"test_{uuid4()}"
csv.register_dialect(
Expand All @@ -841,40 +856,72 @@ def test_get_headers_accepts_valid_headers() -> None:
)

try:
headers = csv_reader._get_headers(fp, config_format, dialect_name)
assert headers == ["col1", "col2", "col3"]
headers, n_stripped = csv_reader._get_headers(fp, config_format, dialect_name, test_logger)
assert headers == expected_headers
assert n_stripped == expected_stripped
finally:
csv.unregister_dialect(dialect_name)


def test_read_data_raises_on_empty_column_names() -> None:
def test_read_data_strips_trailing_empty_columns() -> None:
config_format = CsvFormat()
config = Mock()
config.name = "config_name"
config.format = config_format

file = RemoteFile(uri="test.csv", last_modified=datetime.now())
stream_reader = Mock(spec=AbstractFileBasedStreamReader)
logger = Mock(spec=logging.Logger)
test_logger = Mock(spec=logging.Logger)
csv_reader = _CsvReader()

stream_reader.open_file.return_value = (
CsvFileBuilder().with_data(["col1,col2,col3,,,", "v1,v2,v3,v4,v5,v6"]).build()
)

rows = list(
csv_reader.read_data(
config,
file,
stream_reader,
test_logger,
FileReadMode.READ,
)
)

assert len(rows) == 1
assert rows[0] == {"col1": "v1", "col2": "v2", "col3": "v3"}
test_logger.warning.assert_called_once()
assert "trailing empty" in test_logger.warning.call_args[0][0]


def test_read_data_raises_on_non_trailing_empty_column_names() -> None:
config_format = CsvFormat()
config = Mock()
config.name = "config_name"
config.format = config_format

file = RemoteFile(uri="test.csv", last_modified=datetime.now())
stream_reader = Mock(spec=AbstractFileBasedStreamReader)
test_logger = Mock(spec=logging.Logger)
csv_reader = _CsvReader()

stream_reader.open_file.return_value = (
CsvFileBuilder().with_data(["col1,,col3", "v1,v2,v3"]).build()
)

with pytest.raises(AirbyteTracedException) as exc_info:
list(
csv_reader.read_data(
config,
file,
stream_reader,
logger,
test_logger,
FileReadMode.READ,
)
)

assert exc_info.value.failure_type == FailureType.config_error
assert "empty column name" in exc_info.value.message
assert "non-trailing" in exc_info.value.message


def test_parse_records_preserves_mismatch_error_detail() -> None:
Expand Down
Loading