@@ -832,9 +832,14 @@ Status DataServiceDispatcherImpl::GetOrCreateJob(
832832 GetOrCreateJobRequest::kNumConsumers ) {
833833 num_consumers = request->num_consumers ();
834834 }
835+
836+ absl::flat_hash_set<std::string> local_workers;
837+ local_workers.insert (request->local_workers ().cbegin (),
838+ request->local_workers ().cend ());
839+
835840 TF_RETURN_IF_ERROR (CreateJob (request->dataset_id (),
836841 requested_processing_mode, key, num_consumers,
837- job));
842+ job, local_workers ));
838843 int64 job_client_id;
839844 TF_RETURN_IF_ERROR (AcquireJobClientId (job, job_client_id));
840845 response->set_job_client_id (job_client_id);
@@ -943,7 +948,9 @@ Status DataServiceDispatcherImpl::ValidateMatchingJob(
943948Status DataServiceDispatcherImpl::CreateJob (
944949 int64 dataset_id, ProcessingMode processing_mode,
945950 absl::optional<NamedJobKey> named_job_key,
946- absl::optional<int64> num_consumers, std::shared_ptr<const Job>& job)
951+ absl::optional<int64> num_consumers,
952+ std::shared_ptr<const Job>& job,
953+ absl::flat_hash_set<std::string> local_workers)
947954 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
948955 switch (processing_mode) {
949956 case ProcessingMode::PARALLEL_EPOCHS:
@@ -1005,24 +1012,23 @@ Status DataServiceDispatcherImpl::CreateJob(
10051012
10061013 bool should_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth?
10071014 TF_RETURN_IF_ERROR (service::easl::local_workers_utils::ShouldUseLocalWorkers (
1008- config_, metadata_store_, compute_dataset_key , should_use_local_workers
1015+ config_, metadata_store_, dataset_fingerprint , should_use_local_workers
10091016 ));
10101017
1011- if (should_use_local_workers && request. local_workers () .size () >= 1 ) {
1018+ if (should_use_local_workers && local_workers.size () >= 1 ) {
10121019 target_remote_workers = suggested_worker_count - 1 ;
10131020 target_local_workers = 1 ;
10141021 } else {
10151022 target_remote_workers = suggested_worker_count;
10161023 target_local_workers = 0 ;
10171024 }
10181025 } else if (config_.scaling_policy () == 2 ) { // Use all available workers
1019- target_remote_workers = total_workers - request. local_workers () .size ();
1020- target_local_workers = request. local_workers () .size ();
1026+ target_remote_workers = total_workers - local_workers.size ();
1027+ target_local_workers = local_workers.size ();
10211028 } else if (config_.scaling_policy () == 3 ) { // Grid search over local and remote workers
10221029 TF_RETURN_IF_ERROR (service::easl::local_workers_utils::DecideTargetWorkersGridSearch (
1023- config_, metadata_store_, compute_dataset_key,
1024- total_workers - request.local_workers ().size (), request.local_workers ().size (),
1025- target_remote_workers, target_local_workers
1030+ total_workers - local_workers.size (), local_workers.size (),
1031+ target_remote_workers, target_local_workers // passed by reference
10261032 ));
10271033 }
10281034
@@ -1061,7 +1067,7 @@ Status DataServiceDispatcherImpl::CreateJob(
10611067 create_job->set_target_worker_count (suggested_worker_count);
10621068 create_job->set_target_local_workers (target_local_workers);
10631069 create_job->set_target_remote_workers (target_remote_workers);
1064- *create_job->mutable_local_workers () = {request. local_workers () .begin (), request. local_workers () .end ()};
1070+ *create_job->mutable_local_workers () = {local_workers.begin (), local_workers.end ()};
10651071 if (named_job_key.has_value ()) {
10661072 NamedJobKeyDef* key = create_job->mutable_named_job_key ();
10671073 key->set_name (named_job_key->name );
@@ -1120,7 +1126,7 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
11201126 std::vector<std::shared_ptr<const Task>>& tasks)
11211127 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
11221128 std::vector<std::shared_ptr<const Worker>> workers = state_.ReserveWorkers (
1123- job->job_id , job->target_worker_count , job-> target_remote_workers , job->target_local_workers , job->local_workers );
1129+ job->job_id , job->target_remote_workers , job->target_local_workers , job->local_workers );
11241130 if (workers.size () < job->target_worker_count ){
11251131 VLOG (0 )
11261132 << " EASL - Not enough workers for job. Elasticity policy requires "
@@ -1415,7 +1421,6 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
14151421 }
14161422 response->set_job_finished (job->finished );
14171423 response->set_target_local_workers (job->target_local_workers );
1418- response->set_target_remote_workers (job->target_remote_workers );
14191424 VLOG (4 ) << " Found " << response->task_info_size ()
14201425 << " tasks for job client id " << request->job_client_id ();
14211426
0 commit comments