@@ -24,12 +24,12 @@ set -u # fail and exit on any undefined variable reference
2424DIR=$( cd $( dirname " ${BASH_SOURCE[0]} " ) && pwd)
2525
2626# Make sure the environment variables are set.
27- if [ -z " ${SHARD} " ]; then
27+ if [ -z " ${SHARD+x } " ]; then
2828 echo " SHARD is unset."
2929 exit -1
3030fi
3131
32- if [ -z " ${NUM_SHARDS} " ]; then
32+ if [ -z " ${NUM_SHARDS+x } " ]; then
3333 echo " NUM_SHARDS is unset."
3434 exit -1
3535fi
@@ -59,16 +59,27 @@ install_python_packages() {
5959which bazel || install_bazel
6060install_python_packages
6161
62- # You can alter this test_target to some smaller subset of TFP tests in case
63- # you need to reproduce something on the CI workers.
64- test_target=" //tensorflow_probability/..."
65- test_tags_to_skip=" (gpu|requires-gpu-nvidia|notap|no-oss-ci|tfp_jax|tf2-broken|tf2-kokoro-broken)"
62+ changed_py_files=" $( ${DIR} /get_github_changed_py_files.sh | \
63+ sed -r ' s#(.*)/([^/]+).py#//\1:\2.py#' ) "
64+
65+ if [[ -n " ${changed_py_files} " ]]; then
66+ test_targets=$( bazel query --universe_scope=//tensorflow_probability/... \
67+ " tests(allrdeps(set(${changed_py_files} )))" )
68+ else
69+ # For pushes, test all targets.
70+ test_targets=$( bazel query ' tests(//tensorflow_probability/...)' )
71+ fi
72+
73+ test_targets=$( echo " ${test_targets} " | tr -s ' \n' ' ' )
74+ test_targets=" $( echo " ${test_targets} " | sed -r ' s#(.*) #\1#' ) "
75+ test_tags_to_skip=" (gpu|requires-gpu-nvidia|notap|no-oss-ci|tfp_jax|\
76+ tfp_numpy|tf2-broken|tf2-kokoro-broken)"
6677
6778# Given a test size (small, medium, large), a number of shards and a shard ID,
6879# query and print a list of tests of the given size to run in the given shard.
6980query_and_shard_tests_by_size () {
7081 size=$1
71- bazel_query=" attr(size, ${size} , tests (${test_target } )) \
82+ bazel_query=" attr(size, ${size} , set (${test_targets } )) \
7283 except \
7384 attr(tags, \" ${test_tags_to_skip} \" , \
7485 tests(//tensorflow_probability/...))"
0 commit comments