diff --git a/tfx/orchestration/kubeflow/v2/kubeflow_v2_dag_runner_test.py b/tfx/orchestration/kubeflow/v2/kubeflow_v2_dag_runner_test.py index a789e14c3e..4f6dc814de 100644 --- a/tfx/orchestration/kubeflow/v2/kubeflow_v2_dag_runner_test.py +++ b/tfx/orchestration/kubeflow/v2/kubeflow_v2_dag_runner_test.py @@ -25,6 +25,7 @@ from tfx.orchestration import pipeline as tfx_pipeline from tfx.orchestration.kubeflow.v2 import kubeflow_v2_dag_runner from tfx.orchestration.kubeflow.v2 import test_utils +from tfx.utils import version_utils from tfx.utils import telemetry_utils from tfx.utils import test_case_utils import yaml @@ -66,6 +67,16 @@ def _compare_against_testdata( ) expected_json['pipelineSpec']['sdkVersion'] = 'tfx-{}'.format( version.__version__) + # Update expected test data image tags to match the TFX version under test. + executors = expected_json['pipelineSpec']['deploymentSpec']['executors'] + statistics_gen_container = executors['StatisticsGen_executor']['container'] + if (statistics_gen_container['image'] == + 'gcr.io/tfx-oss-public/tfx:latest'): + expected_json['pipelineSpec']['deploymentSpec'][ + 'executors']['StatisticsGen_executor'][ + 'container']['image'] = ( + f'gcr.io/tfx-oss-public/tfx:{version_utils.get_image_version()}' + ) if 'labels' in expected_json: expected_json['labels']['tfx_version'] = telemetry_utils._normalize_label( version.__version__)