diff --git a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json index c537844dc84a..e0266d62f2e0 100644 --- a/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Python_ValidatesRunner_Dataflow.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run", - "modification": 3 + "modification": 4 } diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index fdf2e0ea03e3..4755aea0bc2c 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -519,12 +519,14 @@ def __init__(self, options, root_staging_location=None): client_options = None transport = None if self.google_cloud_options.dataflow_endpoint: - endpoint = self.google_cloud_options.dataflow_endpoint - if 'localhost' in endpoint or 'sandbox' in endpoint: - transport = 'rest' - else: - endpoint = re.sub('^https?://', '', endpoint) - client_options = client_options_lib.ClientOptions(api_endpoint=endpoint) + endpoint = self.google_cloud_options.dataflow_endpoint.strip() + if endpoint: + if 'localhost' in endpoint or 'sandbox' in endpoint: + transport = 'rest' + else: + endpoint = re.sub('^https?://', '', endpoint) + endpoint = endpoint.rstrip('/') + client_options = client_options_lib.ClientOptions(api_endpoint=endpoint) self._jobs_client = dataflow.JobsV1Beta3Client( credentials=gapic_credentials, diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index b34f06b64df4..12c51d305145 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -1298,6 +1298,42 @@ def test_template_file_generation_with_upload_graph(self): self.assertFalse(template_obj.get('steps')) self.assertTrue(template_obj['stepsLocation']) + def test_dataflow_endpoint_clean(self): + endpoints_and_expectations = [ + # (input_endpoint, expected_endpoint, expected_transport) + ('https://dataflow.googleapis.com/', 'dataflow.googleapis.com', None), + ('https://dataflow.googleapis.com ', 'dataflow.googleapis.com', None), + ('dataflow.googleapis.com/', 'dataflow.googleapis.com', None), + ('http://localhost:8080/', 'http://localhost:8080', 'rest'), + ('localhost:8080/', 'localhost:8080', 'rest'), + ] + + for input_ep, expected_ep, expected_transport in endpoints_and_expectations: + pipeline_options = PipelineOptions([ + '--project', + 'test-project', + '--temp_location', + 'gs://test-location/temp', + '--dataflow_endpoint', + input_ep, + '--no_auth', + ]) + with mock.patch('apache_beam.runners.dataflow.internal.apiclient.dataflow' + '.JobsV1Beta3Client') as mock_jobs: + with mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.dataflow' + '.MessagesV1Beta3Client'): + with mock.patch( + 'apache_beam.runners.dataflow.internal.apiclient.dataflow' + '.MetricsV1Beta3Client'): + apiclient.DataflowApplicationClient(pipeline_options) + mock_jobs.assert_called_once() + called_kwargs = mock_jobs.call_args.kwargs + client_opts = called_kwargs.get('client_options') + self.assertIsNotNone(client_opts) + self.assertEqual(client_opts.api_endpoint, expected_ep) + self.assertEqual(called_kwargs.get('transport'), expected_transport) + def test_stage_resources(self): pipeline_options = PipelineOptions([ '--temp_location',