Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@
_SLEEP_DURATION_BETWEEN_POLLS = 10


def _has_partitioning_load_parameters(additional_parameters):
return (
'timePartitioning' in additional_parameters or
'rangePartitioning' in additional_parameters)


def _add_destination_partitioning_load_parameters(
additional_parameters, destination_table):
if destination_table is None:
return additional_parameters

additional_parameters = dict(additional_parameters)
time_partitioning = getattr(destination_table, 'timePartitioning', None)
range_partitioning = getattr(destination_table, 'rangePartitioning', None)

if ('timePartitioning' not in additional_parameters and
isinstance(time_partitioning, bigquery_tools.bigquery.TimePartitioning)):
additional_parameters['timePartitioning'] = time_partitioning

if ('rangePartitioning' not in additional_parameters and isinstance(
range_partitioning, bigquery_tools.bigquery.RangePartitioning)):
additional_parameters['rangePartitioning'] = range_partitioning

return additional_parameters
Comment on lines +91 to +108


def _generate_job_name(job_name, job_type, step_name):
return bigquery_tools.generate_bq_job_name(
job_name=job_name,
Expand Down Expand Up @@ -688,6 +714,7 @@ def start_bundle(self):
self.bq_io_metadata = create_bigquery_io_metadata(self._step_name)
self.pending_jobs = []
self.schema_cache = {}
self.destination_table_cache = {}

def process(
self,
Expand Down Expand Up @@ -716,6 +743,7 @@ def process(
additional_parameters = self.additional_bq_parameters.get()
else:
additional_parameters = self.additional_bq_parameters
additional_parameters = dict(additional_parameters or {})

table_reference = bigquery_tools.parse_table_reference(destination)
if table_reference.projectId is None:
Expand All @@ -735,19 +763,41 @@ def process(

create_disposition = self.create_disposition
if self.temporary_tables:
destination_table = None
hashed_dest = bigquery_tools.get_hashable_destination(table_reference)
need_schema = schema is None and hashed_dest not in self.schema_cache
need_partitioning = not _has_partitioning_load_parameters(
additional_parameters)
if need_schema or need_partitioning:
try:
if hashed_dest in self.destination_table_cache:
destination_table = self.destination_table_cache[hashed_dest]
else:
destination_table = self.bq_wrapper.get_table(
project_id=table_reference.projectId,
dataset_id=table_reference.datasetId,
table_id=table_reference.tableId)
self.destination_table_cache[hashed_dest] = destination_table
except Exception as e:
Comment on lines +766 to +781
if need_schema:
_LOGGER.warning(
"Input schema is absent and could not fetch the final "
"destination table's schema [%s]. Creating temp table [%s] "
"will likely fail: %s",
hashed_dest,
job_name,
e)
destination_table = None
Comment on lines +766 to +790

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This implementation introduces a significant performance regression and redundant BigQuery API calls:

  1. Bypassing the Schema Cache: Since should_lookup_destination_table is evaluated before checking self.schema_cache, if schema is None initially, should_lookup_destination_table will always be True. This forces a synchronous get_table API call on every single bundle/partition even if the schema is already cached in self.schema_cache.
  2. Redundant Calls per Bundle: If schema is provided but partitioning parameters are not (the default case), get_table is called on every single bundle/partition without any caching.

We can resolve both issues by caching the fetched destination_table in a local cache (e.g., self._destination_table_cache) and only calling get_table if we actually need the schema (and it's not in self.schema_cache) or if we need the partitioning parameters.

      hashed_dest = bigquery_tools.get_hashable_destination(table_reference)
      if not hasattr(self, '_destination_table_cache'):
        self._destination_table_cache = {}
      destination_table = self._destination_table_cache.get(hashed_dest)
      if destination_table is None:
        need_schema = schema is None and hashed_dest not in self.schema_cache
        need_partitioning = not _has_partitioning_load_parameters(additional_parameters)
        if need_schema or need_partitioning:
          try:
            destination_table = self.bq_wrapper.get_table(
                project_id=table_reference.projectId,
                dataset_id=table_reference.datasetId,
                table_id=table_reference.tableId)
            self._destination_table_cache[hashed_dest] = destination_table
          except Exception as e:
            if schema is None and hashed_dest not in self.schema_cache:
              _LOGGER.warning(
                  "Input schema is absent and could not fetch the final "
                  "destination table's schema [%s]. Creating temp table [%s] "
                  "will likely fail: %s",
                  hashed_dest,
                  job_name,
                  e)
            destination_table = None


# we need to create temp tables, so we need a schema.
# if there is no input schema, fetch the destination table's schema
if schema is None:
hashed_dest = bigquery_tools.get_hashable_destination(table_reference)
if hashed_dest in self.schema_cache:
schema = self.schema_cache[hashed_dest]
else:
elif destination_table is not None:
try:
schema = bigquery_tools.table_schema_to_dict(
bigquery_tools.BigQueryWrapper().get_table(
project_id=table_reference.projectId,
dataset_id=table_reference.datasetId,
table_id=table_reference.tableId).schema)
destination_table.schema)
self.schema_cache[hashed_dest] = schema
except Exception as e:
_LOGGER.warning(
Expand All @@ -758,6 +808,11 @@ def process(
job_name,
e)

if (destination_table is not None and
not _has_partitioning_load_parameters(additional_parameters)):
additional_parameters = _add_destination_partitioning_load_parameters(
additional_parameters, destination_table)

# If we are using temporary tables, then we must always create the
# temporary tables, so we replace the create_disposition.
create_disposition = 'CREATE_IF_NEEDED'
Expand Down
151 changes: 151 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,157 @@ def test_one_load_job_failed_after_waiting(self, sleep_mock):

sleep_mock.assert_called_once()

def test_temporary_table_load_inherits_destination_time_partitioning(self):
destination = 'project1:dataset1.table1'
partition = (destination, (0, ['gs://bucket/file1']))
job_reference = bigquery_api.JobReference(
projectId='project1', jobId='job_name1')
destination_table = bigquery_api.Table(
timePartitioning=bigquery_api.TimePartitioning(type='DAY'))

dofn = bqfl.TriggerLoadJobs(
schema=_ELEMENTS_SCHEMA, test_client=mock.Mock(), temporary_tables=True)
dofn.start_bundle()
dofn.bq_wrapper.get_table = mock.Mock(return_value=destination_table)
dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference)

list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0)))

load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs
self.assertEqual(
load_call['additional_load_parameters']['timePartitioning'],
destination_table.timePartitioning)
dofn.bq_wrapper.get_table.assert_called_once_with(
project_id='project1', dataset_id='dataset1', table_id='table1')

def test_temporary_table_load_inherits_destination_range_partitioning(self):
destination = 'project1:dataset1.table1'
partition = (destination, (0, ['gs://bucket/file1']))
job_reference = bigquery_api.JobReference(
projectId='project1', jobId='job_name1')
destination_table = bigquery_api.Table(
rangePartitioning=bigquery_api.RangePartitioning())

dofn = bqfl.TriggerLoadJobs(
schema=_ELEMENTS_SCHEMA, test_client=mock.Mock(), temporary_tables=True)
dofn.start_bundle()
dofn.bq_wrapper.get_table = mock.Mock(return_value=destination_table)
dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference)

list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0)))

load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs
self.assertEqual(
load_call['additional_load_parameters']['rangePartitioning'],
destination_table.rangePartitioning)
dofn.bq_wrapper.get_table.assert_called_once_with(
project_id='project1', dataset_id='dataset1', table_id='table1')

def test_temporary_table_load_keeps_explicit_partitioning_parameters(self):
destination = 'project1:dataset1.table1'
partition = (destination, (0, ['gs://bucket/file1']))
explicit_partitioning = {'timePartitioning': {'type': 'DAY'}}
job_reference = bigquery_api.JobReference(
projectId='project1', jobId='job_name1')

dofn = bqfl.TriggerLoadJobs(
schema=_ELEMENTS_SCHEMA,
test_client=mock.Mock(),
temporary_tables=True,
additional_bq_parameters=explicit_partitioning)
dofn.start_bundle()
dofn.bq_wrapper.get_table = mock.Mock()
dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference)

list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0)))

load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs
self.assertEqual(
load_call['additional_load_parameters'], explicit_partitioning)
dofn.bq_wrapper.get_table.assert_not_called()

def test_temporary_table_load_uses_cached_schema_with_explicit_partitioning(
self):
destination = 'project1:dataset1.table1'
partition = (destination, (0, ['gs://bucket/file1']))
explicit_partitioning = {'timePartitioning': {'type': 'DAY'}}
job_reference = bigquery_api.JobReference(
projectId='project1', jobId='job_name1')
table_reference = bigquery_tools.parse_table_reference(destination)
hashed_dest = bigquery_tools.get_hashable_destination(table_reference)

dofn = bqfl.TriggerLoadJobs(
schema=None,
test_client=mock.Mock(),
temporary_tables=True,
additional_bq_parameters=explicit_partitioning)
dofn.start_bundle()
dofn.schema_cache[hashed_dest] = _ELEMENTS_SCHEMA
dofn.bq_wrapper.get_table = mock.Mock()
dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference)

list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0)))

load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs
self.assertEqual(load_call['schema'], _ELEMENTS_SCHEMA)
self.assertEqual(
load_call['additional_load_parameters'], explicit_partitioning)
dofn.bq_wrapper.get_table.assert_not_called()

def test_temporary_table_load_caches_destination_table_per_bundle(self):
destination = 'project1:dataset1.table1'
first_partition = (destination, (0, ['gs://bucket/file1']))
second_partition = (destination, (1, ['gs://bucket/file2']))
job_reference = bigquery_api.JobReference(
projectId='project1', jobId='job_name1')
destination_table = bigquery_api.Table(
timePartitioning=bigquery_api.TimePartitioning(type='DAY'))

dofn = bqfl.TriggerLoadJobs(
schema=_ELEMENTS_SCHEMA, test_client=mock.Mock(), temporary_tables=True)
dofn.start_bundle()
dofn.bq_wrapper.get_table = mock.Mock(return_value=destination_table)
dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference)

list(
dofn.process(first_partition, 'test_job', pane_info=mock.Mock(index=0)))
list(
dofn.process(
second_partition, 'test_job', pane_info=mock.Mock(index=1)))

dofn.bq_wrapper.get_table.assert_called_once_with(
project_id='project1', dataset_id='dataset1', table_id='table1')
load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs
self.assertEqual(
load_call['additional_load_parameters']['timePartitioning'],
destination_table.timePartitioning)

def test_temporary_table_load_ignores_invalid_mock_partitioning_metadata(
self):
destination = 'project1:dataset1.table1'
partition = (destination, (0, ['gs://bucket/file1']))
job_reference = bigquery_api.JobReference(
projectId='project1', jobId='job_name1')
destination_table = mock.Mock()
destination_table.timePartitioning = mock.Mock()
destination_table.rangePartitioning = mock.Mock()

dofn = bqfl.TriggerLoadJobs(
schema=_ELEMENTS_SCHEMA, test_client=mock.Mock(), temporary_tables=True)
dofn.start_bundle()
dofn.bq_wrapper.get_table = mock.Mock(return_value=destination_table)
dofn.bq_wrapper.perform_load_job = mock.Mock(return_value=job_reference)

list(dofn.process(partition, 'test_job', pane_info=mock.Mock(index=0)))

load_call = dofn.bq_wrapper.perform_load_job.call_args.kwargs
self.assertNotIn(
'timePartitioning', load_call['additional_load_parameters'])
self.assertNotIn(
'rangePartitioning', load_call['additional_load_parameters'])
dofn.bq_wrapper.get_table.assert_called_once_with(
project_id='project1', dataset_id='dataset1', table_id='table1')

def test_multiple_partition_files(self):
destination = 'project1:dataset1.table1'

Expand Down
Loading