Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions src/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,22 +203,35 @@ def _reset_timeline_state(self, start_index):
self.next_empty_slot = 0
self.baseline_next_empty_slot = 0

# Versioned cache invalidation for pending job stats.
self._queue_backlog_version = 0
self._cached_queue_backlog_version = -1

def _mark_queue_backlog_mutation(self):
"""Invalidate pending-job stats cache after queue/backlog content changes."""
self._queue_backlog_version += 1

def _update_pending_job_stats(self, job_queue_2d):
"""Update summary statistics for all outstanding jobs (queue + backlog)."""
# Fast path: skip recalculation if queue/backlog version is unchanged.
if self._cached_queue_backlog_version == self._queue_backlog_version:
return # Stats unchanged from last step

# Slow path: recalculate pending stats after queue/backlog mutations.
# Collect stats from the main queue
current_backlog_size = len(self.backlog_queue)
active_jobs_mask = job_queue_2d[:, 0] > 0
queue_durations = job_queue_2d[active_jobs_mask, 0]
queue_nodes = job_queue_2d[active_jobs_mask, 2]
queue_cores = job_queue_2d[active_jobs_mask, 3]
queue_count = len(queue_durations)

# Collect stats from the backlog
backlog_count = len(self.backlog_queue)
backlog_count = current_backlog_size
if backlog_count > 0:
backlog_arr = np.array(list(self.backlog_queue))
backlog_durations = backlog_arr[:, 0]
backlog_nodes = backlog_arr[:, 2]
backlog_cores = backlog_arr[:, 3]
backlog_durations = np.array([job[0] for job in self.backlog_queue], dtype=np.int32)
backlog_nodes = np.array([job[2] for job in self.backlog_queue], dtype=np.int32)
backlog_cores = np.array([job[3] for job in self.backlog_queue], dtype=np.int32)
else:
backlog_durations = np.array([], dtype=np.int32)
backlog_nodes = np.array([], dtype=np.int32)
Expand Down Expand Up @@ -247,6 +260,9 @@ def _update_pending_job_stats(self, job_queue_2d):
self.state['pending_max_nodes'][0] = max_nodes
self.state['backlog_size'][0] = backlog_count

# Cache the queue/backlog version for next step.
self._cached_queue_backlog_version = self._queue_backlog_version

def reset(self, seed=None, options=None):
if options is None:
options = {}
Expand Down Expand Up @@ -305,17 +321,22 @@ def step(self, action):

# reshape the 1d job_queue array into 2d for cleaner code
job_queue_2d = self.state['job_queue'].reshape(-1, 4)
queue_backlog_mutated = False

# Decrement booked time for nodes and complete running jobs
self.env_print("[1] Processing ongoing jobs...")
completed_jobs = process_ongoing_jobs(self.state['nodes'], self.cores_available, self.running_jobs)
self.env_print(f"{len(completed_jobs)} jobs completed: [{' '.join(['#' + str(job_id) for job_id in completed_jobs]) if len(completed_jobs) > 0 else ''}]")

# Age helper queues (jobs waiting outside the fixed queue)
age_backlog_queue(self.backlog_queue, self.metrics, _is_baseline=False)
backlog_aged_dropped = age_backlog_queue(self.backlog_queue, self.metrics, _is_baseline=False)
if backlog_aged_dropped > 0:
queue_backlog_mutated = True

# Fill real queue from helper before accepting new jobs
self.next_empty_slot, _ = fill_queue_from_backlog(job_queue_2d, self.backlog_queue, self.next_empty_slot)
self.next_empty_slot, moved_from_backlog = fill_queue_from_backlog(job_queue_2d, self.backlog_queue, self.next_empty_slot)
if moved_from_backlog > 0:
queue_backlog_mutated = True

# Generate new jobs
self.env_print(f"[2] Generating new jobs...")
Expand All @@ -332,6 +353,8 @@ def step(self, action):
job_queue_2d, new_jobs_count, new_jobs_durations,
new_jobs_nodes, new_jobs_cores, self.next_empty_slot, self.backlog_queue
)
if len(new_jobs) > 0:
queue_backlog_mutated = True
if backlog_dropped > 0:
self.metrics.jobs_dropped += backlog_dropped
self.metrics.episode_jobs_dropped += backlog_dropped
Expand All @@ -358,6 +381,8 @@ def step(self, action):
job_queue_2d, self.state['nodes'], self.cores_available, self.running_jobs,
self.next_empty_slot, self.next_job_id, self.metrics, is_baseline=False
)
if num_launched_jobs > 0 or queue_dropped > 0:
queue_backlog_mutated = True
num_dropped_this_step += queue_dropped

self.env_print(f" {num_launched_jobs} jobs launched")
Expand All @@ -366,18 +391,23 @@ def step(self, action):
if do_refill == 1 and len(self.backlog_queue) > 0:
self.next_empty_slot, moved = fill_queue_from_backlog(job_queue_2d, self.backlog_queue, self.next_empty_slot)
if moved > 0:
queue_backlog_mutated = True
self.env_print(f" {moved} jobs moved from backlog to queue")
# Try to assign the newly queued jobs
extra_launched, self.next_empty_slot, extra_dropped, self.next_job_id = assign_jobs_to_available_nodes(
job_queue_2d, self.state['nodes'], self.cores_available, self.running_jobs,
self.next_empty_slot, self.next_job_id, self.metrics, is_baseline=False
)
if extra_launched > 0 or extra_dropped > 0:
queue_backlog_mutated = True
num_launched_jobs += extra_launched
num_dropped_this_step += extra_dropped
if extra_launched > 0:
self.env_print(f" {extra_launched} additional jobs launched from backlog")

# Update summary statistics for all outstanding jobs (queue + backlog)
if queue_backlog_mutated:
self._mark_queue_backlog_mutation()
self._update_pending_job_stats(job_queue_2d)

# Calculate node utilization stats
Expand Down
2 changes: 1 addition & 1 deletion src/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def plot(env, num_hours, max_nodes, save=True, show=True, suffix=""):
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines + lines2, labels + labels2, loc='upper left')

prefix = f"e{env.weights.efficiency_weight}_p{env.weights.price_weight}_i{env.weights.idle_weight}_d{env.weights.job_age_weight}"
prefix = f"e{env.weights.efficiency_weight}_p{env.weights.price_weight}_i{env.weights.idle_weight}_a{env.weights.job_age_weight}_d{env.weights.drop_weight}"

if save:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
Expand Down
10 changes: 9 additions & 1 deletion train_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,15 @@ def generate_weight_combinations(step=0.1, fixed_weights=None):
variable_weights = [w for w in weight_names if w not in fixed_weights]
fixed_sum = sum(fixed_weights.values())

if len(variable_weights) == 1:
if len(variable_weights) == 0:
# If all weights are fixed, return that single combination
if abs(fixed_sum - 1.0) < 1e-9: # Allow for floating point rounding
combo = [0, 0, 0, 0, 0]
for weight_name, value in fixed_weights.items():
combo[weight_names.index(weight_name)] = value
combinations.append(tuple(combo))

elif len(variable_weights) == 1:
# If all but one weight is fixed, there's only one possible value
remaining = round(1 - fixed_sum, 2)
if 0 <= remaining <= 1:
Expand Down