Skip to content

Commit 0338183

Browse files
committed
train_iter.py: Added condition to allow for fixed total weight-combination
1 parent cb85ec9 commit 0338183

1 file changed

Lines changed: 9 additions & 1 deletion

File tree

train_iter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,15 @@ def generate_weight_combinations(step=0.1, fixed_weights=None):
1515
variable_weights = [w for w in weight_names if w not in fixed_weights]
1616
fixed_sum = sum(fixed_weights.values())
1717

18-
if len(variable_weights) == 1:
18+
if len(variable_weights) == 0:
19+
# If all weights are fixed, return that single combination
20+
if abs(fixed_sum - 1.0) < 1e-9: # Allow for floating point rounding
21+
combo = [0, 0, 0, 0, 0]
22+
for weight_name, value in fixed_weights.items():
23+
combo[weight_names.index(weight_name)] = value
24+
combinations.append(tuple(combo))
25+
26+
elif len(variable_weights) == 1:
1927
# If all but one weight is fixed, there's only one possible value
2028
remaining = round(1 - fixed_sum, 2)
2129
if 0 <= remaining <= 1:

0 commit comments

Comments
 (0)