@@ -44,6 +44,7 @@ limitations under the License.
4444#include " tensorflow/core/data/service/easl/cache_utils.h"
4545#include " tensorflow/core/data/service/easl/scaling_utils.h"
4646#include " tensorflow/core/data/service/easl/metadata_store.h"
47+ #include " tensorflow/core/data/service/easl/local_workers_utils.h"
4748#include " tensorflow/core/data/standalone.h"
4849#include " tensorflow/core/framework/dataset.h"
4950#include " tensorflow/core/framework/graph.pb.h"
@@ -955,7 +956,6 @@ Status DataServiceDispatcherImpl::CreateJob(
955956 int64 job_id = state_.NextAvailableJobId ();
956957
957958 // EASL - Caching decision: should the job compute, write or read from cache?
958- int64 worker_count;
959959 std::string job_type;
960960 string job_name = " named_job_key.value().name" ;
961961 std::shared_ptr<const Dataset> dataset;
@@ -993,28 +993,55 @@ Status DataServiceDispatcherImpl::CreateJob(
993993
994994 std::shared_ptr<easl::JobMetrics> job_metrics;
995995 s = metadata_store_.GetJobMetrics (job_id, job_metrics);
996- worker_count = job_metrics->target_worker_count_ ;
997996
998- if (config_.scaling_policy () == 2 ) {
999- worker_count = 100 ;
997+ // EASL - Scaling decision: how many workers (remote/local) should the job assign?
998+ int64 total_workers = state_.ListWorkers ().size ();
999+ int64 suggested_worker_count = job_metrics->target_worker_count_ ;
1000+
1001+ int64 num_worker_remote_target, num_worker_local_target;
1002+ if (config_.scaling_policy () == 1 ) { // Paper autoscaling, except a discrimination between local and remote workers is now made
1003+ VLOG (0 ) << " EASL - Scalability decision for dataset_key "
1004+ << compute_dataset_key << " : " << suggested_worker_count;
1005+
1006+ bool should_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth?
1007+ TF_RETURN_IF_ERROR (service::easl::local_workers_utils::ShouldUseLocalWorkers (
1008+ config_, metadata_store_, compute_dataset_key, should_use_local_workers
1009+ ));
1010+
1011+ if (should_use_local_workers && local_workers.size () >= 1 ) {
1012+ num_worker_remote_target = suggested_worker_count - 1 ;
1013+ num_worker_local_target = 1 ;
1014+ } else {
1015+ num_worker_remote_target = suggested_worker_count;
1016+ num_worker_local_target = 0 ;
1017+ }
1018+ } else if (config_.scaling_policy () == 2 ) { // Use all available workers
1019+ num_worker_remote_target = total_workers - local_workers.size ();
1020+ num_worker_local_target = local_workers.size ();
1021+ } else if (config_.scaling_policy () == 3 ) { // Grid search over local and remote workers
1022+ TF_RETURN_IF_ERROR (service::easl::local_workers_utils::DecideTargetWorkersGridSearch (
1023+ config_, metadata_store_, compute_dataset_key,
1024+ total_workers - local_workers.size (), local_workers.size (),
1025+ num_worker_remote_target, num_worker_local_target
1026+ ));
10001027 }
10011028
10021029 if (job_type == " PUT" || job_type == " PUT_SOURCE" ) {
10031030 std::shared_ptr<easl::JobMetrics> dataset_fingerprint_metrics;
10041031 s = metadata_store_.GetJobMetricsByDatasetFingerprintAndName (
10051032 dataset_fingerprint, job_name, dataset_fingerprint_metrics);
10061033 if (s.ok ()) {
1007- worker_count = std::ceil (std::max (1.0 ,
1034+ suggested_worker_count = std::ceil (std::max (1.0 ,
10081035 dataset_fingerprint_metrics->target_worker_count_ * 1.5 ));
10091036 }
1010- job_metrics->target_worker_count_ = worker_count ;
1037+ job_metrics->target_worker_count_ = suggested_worker_count ;
10111038 }
10121039
10131040 // EASL: Logging stuff
10141041 if (kEnableEventLogging ) {
1015- last_scale_[job_name] = worker_count ;
1042+ last_scale_[job_name] = suggested_worker_count ;
10161043 RecordEvent (dataset_fingerprint, dataset_id, job_name, job_id,
1017- " starting_worker_count" , std::to_string (worker_count ));
1044+ " starting_worker_count" , std::to_string (suggested_worker_count ));
10181045 }
10191046
10201047 int64 num_split_providers = 0 ;
@@ -1031,7 +1058,10 @@ Status DataServiceDispatcherImpl::CreateJob(
10311058 create_job->set_processing_mode (ProcessingModeDef (processing_mode));
10321059 create_job->set_job_type (job_type);
10331060 create_job->set_num_split_providers (num_split_providers);
1034- create_job->set_target_worker_count (worker_count);
1061+ create_job->set_target_worker_count (suggested_worker_count);
1062+ create_job->set_target_local_workers (target_local_workers);
1063+ create_job->set_target_remote_workers (target_remote_workers);
1064+ *create_job->mutable_local_workers () = {request.local_workers ().begin (), request.local_workers ().end ()};
10351065 if (named_job_key.has_value ()) {
10361066 NamedJobKeyDef* key = create_job->mutable_named_job_key ();
10371067 key->set_name (named_job_key->name );
@@ -1090,7 +1120,7 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
10901120 std::vector<std::shared_ptr<const Task>>& tasks)
10911121 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
10921122 std::vector<std::shared_ptr<const Worker>> workers = state_.ReserveWorkers (
1093- job->job_id , job->target_worker_count );
1123+ job->job_id , job->target_worker_count , job-> target_remote_workers , job-> target_local_workers , job-> local_workers );
10941124 if (workers.size () < job->target_worker_count ){
10951125 VLOG (0 )
10961126 << " EASL - Not enough workers for job. Elasticity policy requires "
@@ -1384,6 +1414,8 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
13841414 task_info->set_starting_round (task->starting_round );
13851415 }
13861416 response->set_job_finished (job->finished );
1417+ response->set_target_local_workers (job->target_local_workers );
1418+ response->set_target_remote_workers (job->target_remote_workers );
13871419 VLOG (4 ) << " Found " << response->task_info_size ()
13881420 << " tasks for job client id " << request->job_client_id ();
13891421
0 commit comments