Skip to content

Commit 335ce65

Browse files
Added DSL contributions
1 parent e5036a6 commit 335ce65

15 files changed

Lines changed: 377 additions & 44 deletions

File tree

tensorflow/core/data/service/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ cc_library(
276276
deps = [
277277
":cache_utils",
278278
":scaling_utils",
279+
":local_workers_utils",
279280
":common_proto_cc",
280281
":credentials_factory",
281282
":data_service",
@@ -901,6 +902,21 @@ cc_library(
901902
],
902903
)
903904

905+
cc_library(
906+
name = "local_workers_utils",
907+
srcs = ["easl/local_workers_utils.cc"],
908+
hdrs = [
909+
"easl/local_workers_utils.h",
910+
],
911+
deps = [
912+
":common_proto_cc",
913+
":dispatcher_state",
914+
":metadata_store",
915+
"//tensorflow/core:lib",
916+
"@com_google_absl//absl/strings",
917+
],
918+
)
919+
904920
cc_library(
905921
name = "cache_model",
906922
srcs = ["easl/cache_model.cc"],

tensorflow/core/data/service/dispatcher.proto

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ message JobKey {
119119
int64 job_name_index = 2;
120120
}
121121

122-
// Next tag: 8
122+
// Next tag: 9
123123
message GetOrCreateJobRequest {
124124
reserved 3, 4;
125125
// The id of the dataset to create a job for.
@@ -134,6 +134,8 @@ message GetOrCreateJobRequest {
134134
oneof optional_num_consumers {
135135
int64 num_consumers = 7;
136136
}
137+
// DSL - to pass the array of available local workers to the dispatcher's job creation logic
138+
repeated string local_workers = 8;
137139
}
138140

139141
// Next tag: 2
@@ -188,7 +190,7 @@ message ClientHeartbeatRequest {
188190
double result_queue_size = 10;
189191
}
190192

191-
// Next tag: 4
193+
// Next tag: 5
192194
message ClientHeartbeatResponse {
193195
// A list of all tasks that the client should read from.
194196
repeated TaskInfo task_info = 1;
@@ -198,6 +200,8 @@ message ClientHeartbeatResponse {
198200
}
199201
// Whether the job has finished.
200202
bool job_finished = 2;
203+
// DSL: to check whether we should use local workers (based on last epoch's metrics)
204+
bool target_local_workers = 4;
201205
}
202206

203207
// Next tag: 3

tensorflow/core/data/service/dispatcher_client.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ Status DataServiceDispatcherClient::RegisterDataset(
198198
Status DataServiceDispatcherClient::GetOrCreateJob(
199199
int64 dataset_id, ProcessingMode processing_mode,
200200
const absl::optional<JobKey>& job_key, absl::optional<int64> num_consumers,
201-
int64& job_client_id) {
201+
int64& job_client_id,
202+
std::vector<std::string> local_workers) {
202203
TF_RETURN_IF_ERROR(EnsureInitialized());
203204
GetOrCreateJobRequest req;
204205
req.set_dataset_id(dataset_id);
@@ -209,6 +210,9 @@ Status DataServiceDispatcherClient::GetOrCreateJob(
209210
if (num_consumers.has_value()) {
210211
req.set_num_consumers(num_consumers.value());
211212
}
213+
// DSL - client sends a list of its local workers to the dispatcher
214+
*req.mutable_local_workers() = {local_workers.begin(), local_workers.end()};
215+
212216
GetOrCreateJobResponse resp;
213217
grpc::ClientContext client_ctx;
214218
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);

tensorflow/core/data/service/dispatcher_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
7878
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
7979
const absl::optional<JobKey>& job_key,
8080
absl::optional<int64> num_consumers,
81-
int64& job_client_id);
81+
int64& job_client_id,
82+
std::vector<std::string> local_workers);
8283

8384
// Releases a job client id, indicating that the id will no longer be used to
8485
// read from the job.

tensorflow/core/data/service/dispatcher_impl.cc

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ limitations under the License.
4444
#include "tensorflow/core/data/service/easl/cache_utils.h"
4545
#include "tensorflow/core/data/service/easl/scaling_utils.h"
4646
#include "tensorflow/core/data/service/easl/metadata_store.h"
47+
#include "tensorflow/core/data/service/easl/local_workers_utils.h"
4748
#include "tensorflow/core/data/standalone.h"
4849
#include "tensorflow/core/framework/dataset.h"
4950
#include "tensorflow/core/framework/graph.pb.h"
@@ -955,7 +956,6 @@ Status DataServiceDispatcherImpl::CreateJob(
955956
int64 job_id = state_.NextAvailableJobId();
956957

957958
// EASL - Caching decision: should the job compute, write or read from cache?
958-
int64 worker_count;
959959
std::string job_type;
960960
string job_name = "named_job_key.value().name";
961961
std::shared_ptr<const Dataset> dataset;
@@ -993,28 +993,55 @@ Status DataServiceDispatcherImpl::CreateJob(
993993

994994
std::shared_ptr<easl::JobMetrics> job_metrics;
995995
s = metadata_store_.GetJobMetrics(job_id, job_metrics);
996-
worker_count = job_metrics->target_worker_count_;
997996

998-
if (config_.scaling_policy() == 2) {
999-
worker_count = 100;
997+
// EASL - Scaling decision: how many workers (remote/local) should the job assign?
998+
int64 total_workers = state_.ListWorkers().size();
999+
int64 suggested_worker_count = job_metrics->target_worker_count_;
1000+
1001+
int64 num_worker_remote_target, num_worker_local_target;
1002+
if(config_.scaling_policy() == 1) { // Paper autoscaling, except a discrimination between local and remote workers is now made
1003+
VLOG(0) << "EASL - Scalability decision for dataset_key "
1004+
<< compute_dataset_key << ": " << suggested_worker_count;
1005+
1006+
bool should_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth?
1007+
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::ShouldUseLocalWorkers(
1008+
config_, metadata_store_, compute_dataset_key, should_use_local_workers
1009+
));
1010+
1011+
if(should_use_local_workers && local_workers.size() >= 1) {
1012+
num_worker_remote_target = suggested_worker_count - 1;
1013+
num_worker_local_target = 1;
1014+
} else {
1015+
num_worker_remote_target = suggested_worker_count;
1016+
num_worker_local_target = 0;
1017+
}
1018+
} else if(config_.scaling_policy() == 2) { // Use all available workers
1019+
num_worker_remote_target = total_workers - local_workers.size();
1020+
num_worker_local_target = local_workers.size();
1021+
} else if(config_.scaling_policy() == 3) { // Grid search over local and remote workers
1022+
TF_RETURN_IF_ERROR(service::easl::local_workers_utils::DecideTargetWorkersGridSearch(
1023+
config_, metadata_store_, compute_dataset_key,
1024+
total_workers - local_workers.size(), local_workers.size(),
1025+
num_worker_remote_target, num_worker_local_target
1026+
));
10001027
}
10011028

10021029
if (job_type == "PUT" || job_type == "PUT_SOURCE") {
10031030
std::shared_ptr<easl::JobMetrics> dataset_fingerprint_metrics;
10041031
s = metadata_store_.GetJobMetricsByDatasetFingerprintAndName(
10051032
dataset_fingerprint, job_name, dataset_fingerprint_metrics);
10061033
if (s.ok()) {
1007-
worker_count = std::ceil(std::max(1.0,
1034+
suggested_worker_count = std::ceil(std::max(1.0,
10081035
dataset_fingerprint_metrics->target_worker_count_ * 1.5));
10091036
}
1010-
job_metrics->target_worker_count_ = worker_count;
1037+
job_metrics->target_worker_count_ = suggested_worker_count;
10111038
}
10121039

10131040
// EASL: Logging stuff
10141041
if (kEnableEventLogging) {
1015-
last_scale_[job_name] = worker_count;
1042+
last_scale_[job_name] = suggested_worker_count;
10161043
RecordEvent(dataset_fingerprint, dataset_id, job_name, job_id,
1017-
"starting_worker_count", std::to_string(worker_count));
1044+
"starting_worker_count", std::to_string(suggested_worker_count));
10181045
}
10191046

10201047
int64 num_split_providers = 0;
@@ -1031,7 +1058,10 @@ Status DataServiceDispatcherImpl::CreateJob(
10311058
create_job->set_processing_mode(ProcessingModeDef(processing_mode));
10321059
create_job->set_job_type(job_type);
10331060
create_job->set_num_split_providers(num_split_providers);
1034-
create_job->set_target_worker_count(worker_count);
1061+
create_job->set_target_worker_count(suggested_worker_count);
1062+
create_job->set_target_local_workers(target_local_workers);
1063+
create_job->set_target_remote_workers(target_remote_workers);
1064+
*create_job->mutable_local_workers() = {request.local_workers().begin(), request.local_workers().end()};
10351065
if (named_job_key.has_value()) {
10361066
NamedJobKeyDef* key = create_job->mutable_named_job_key();
10371067
key->set_name(named_job_key->name);
@@ -1090,7 +1120,7 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
10901120
std::vector<std::shared_ptr<const Task>>& tasks)
10911121
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
10921122
std::vector<std::shared_ptr<const Worker>> workers = state_.ReserveWorkers(
1093-
job->job_id, job->target_worker_count);
1123+
job->job_id, job->target_worker_count, job->target_remote_workers, job->target_local_workers, job->local_workers);
10941124
if (workers.size() < job->target_worker_count){
10951125
VLOG(0)
10961126
<< "EASL - Not enough workers for job. Elasticity policy requires "
@@ -1384,6 +1414,8 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
13841414
task_info->set_starting_round(task->starting_round);
13851415
}
13861416
response->set_job_finished(job->finished);
1417+
response->set_target_local_workers(job->target_local_workers);
1418+
response->set_target_remote_workers(job->target_remote_workers);
13871419
VLOG(4) << "Found " << response->task_info_size()
13881420
<< " tasks for job client id " << request->job_client_id();
13891421

tensorflow/core/data/service/dispatcher_state.cc

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,20 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
130130
ProcessingMode(create_job.processing_mode()),
131131
create_job.num_split_providers(),
132132
named_job_key, num_consumers,
133-
create_job.job_type(), create_job.target_worker_count());
133+
create_job.job_type(),
134+
create_job.target_worker_count(),
135+
create_job.target_remote_workers(),
136+
create_job.target_local_workers());
134137
DCHECK(!jobs_.contains(job_id));
135138
jobs_[job_id] = job;
136139
tasks_by_job_[job_id] = TasksById();
137140
ending_tasks_by_job_[job_id] = TasksById();
138141

142+
for (auto worker: create_job.local_workers()) {
143+
VLOG(1) << "EASL-DSL (DispatcherState::CreateJob): adding local worker to dispatcher's state: " << worker;
144+
job->local_workers.insert(worker);
145+
}
146+
139147
if (named_job_key.has_value()) {
140148
DCHECK(!named_jobs_.contains(named_job_key.value()) ||
141149
named_jobs_[named_job_key.value()]->garbage_collected);
@@ -373,29 +381,60 @@ DispatcherState::ListAvailableWorkers() const {
373381
return workers;
374382
}
375383

384+
// Reserves a number of available workers for a particular job. If num_workers
385+
// is lower than or equal to 0, then the reserved number of workers is equal
386+
// to all the available workers.
376387
std::vector<std::shared_ptr<const DispatcherState::Worker>>
377388
DispatcherState::ReserveWorkers(
378-
int64 job_id, int64 target_num_workers) {
379-
// DCHECK(num_workers <= avail_workers_.size());
380-
jobs_[job_id]->target_worker_count = target_num_workers;
389+
int64 job_id,
390+
int64 target_remote_workers,
391+
int64 target_local_workers,
392+
const absl::flat_hash_set<std::string> local_workers) {
381393
// If the number of required workers is below those available, we just assign
382394
// as many as there are available at this epoch's scheduling time.
383-
int64 num_workers = target_num_workers <= 0
384-
|| target_num_workers > avail_workers_.size() ? avail_workers_.size()
385-
: target_num_workers;
395+
int64 target_worker_count = target_remote_workers + target_local_workers;
396+
if(target_worker_count <= 0 || target_worker_count > avail_workers_.size()) {
397+
target_remote_workers = avail_workers_.size();
398+
target_local_workers = avail_workers_.size();
399+
}
400+
401+
jobs_[job_id]->target_worker_count = target_worker_count;
402+
386403
std::vector<std::shared_ptr<const Worker>> workers;
387-
workers.reserve(num_workers);
388-
VLOG(0) << "(ReserveWorkers) User got " << num_workers << " workers from "
389-
<< "target " << target_num_workers << " workers";
404+
workers.reserve(avail_workers_.size());
405+
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers)" << "\n"
406+
<< "Available remote: " << avail_workers_.size() << "\n"
407+
<< "Available local: " << local_workers.size() << "\n"
408+
<< "Target remote: " << target_remote_workers << "\n"
409+
<< "Target local: " << target_local_workers << "\n";
410+
390411
for (auto it = avail_workers_.begin(); it != avail_workers_.end(); ) {
391-
num_workers--;
412+
bool is_local = local_workers.count(it->first);
413+
if (is_local) {
414+
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found local worker " << it->first;
415+
if (target_local_workers <= 0) { // No additional local workers needed
416+
it++;
417+
continue;
418+
} else {
419+
target_local_workers--;
420+
}
421+
} else {
422+
VLOG(0) << "EASL-DSL (DispatcherState::ReserveWorkers) found remote worker " << it->first;
423+
if (target_remote_workers <= 0) { // No additional remote workers needed
424+
it++;
425+
continue;
426+
} else {
427+
target_remote_workers--;
428+
}
429+
}
430+
392431
workers.push_back(it->second);
393432
VLOG(0) << "(ReserveWorkers) Assigning worker at address "
394433
<< it->second->address << " to job " << job_id;
395-
workers_by_job_[job_id][it->second->address] = it->second;
434+
workers_by_job_[job_id].push_back(it->second);
396435
jobs_by_worker_[it->second->address][job_id] = jobs_[job_id];
397436
avail_workers_.erase(it++);
398-
if (num_workers == 0)
437+
if (target_worker_count == 0)
399438
break;
400439
}
401440
VLOG(0) << "(ReserveWorkers) Number of workers for job " << job_id << " is: "

tensorflow/core/data/service/dispatcher_state.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,21 @@ class DispatcherState {
137137
int64 num_split_providers,
138138
absl::optional<NamedJobKey> named_job_key,
139139
absl::optional<int64> num_consumers, const std::string& job_type,
140-
int64 target_worker_count)
140+
int64 target_worker_count,
141+
int64 target_remote_workers,
142+
int64 target_local_workers,
143+
absl::flat_hash_set<std::string> local_workers = {}
144+
)
141145
: job_id(job_id),
142146
dataset_id(dataset_id),
143147
processing_mode(processing_mode),
144148
named_job_key(named_job_key),
145149
num_consumers(num_consumers),
146150
job_type(job_type),
147-
target_worker_count(target_worker_count){
151+
target_worker_count(target_worker_count),
152+
target_remote_workers(target_remote_workers),
153+
target_local_workers(target_local_workers),
154+
local_workers(local_workers){
148155
if (processing_mode == ProcessingMode::DISTRIBUTED_EPOCH) {
149156
distributed_epoch_state = DistributedEpochState(num_split_providers);
150157
}
@@ -176,6 +183,10 @@ class DispatcherState {
176183
const std::string job_type;
177184
int64 target_worker_count; // Non-constant, can be dynamically adjusted.
178185
int64 current_worker_count = 0;
186+
// EASL - DSL
187+
const int64 target_remote_workers; // replaces worker_count as there is a distinction now
188+
const int64 target_local_workers; // replaces worker_count as there is a distinction now
189+
absl::flat_hash_set<std::string> local_workers; // list of local workers in the client
179190
};
180191

181192
struct Task {
@@ -228,7 +239,9 @@ class DispatcherState {
228239
// is lower than or equal to 0, then the reserved number of workers is equal
229240
// to all the available workers.
230241
std::vector<std::shared_ptr<const Worker>> ReserveWorkers(int64 job_id,
231-
int64 num_workers = 0);
242+
int64 target_remote_workers = 0,
243+
int64 target_local_workers = 0,
244+
const absl::flat_hash_set<std::string> local_workers = {});
232245

233246
// Returns the next available job id.
234247
int64 NextAvailableJobId() const;

0 commit comments

Comments
 (0)