We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent cb85ec9 commit 0338183Copy full SHA for 0338183
1 file changed
train_iter.py
@@ -15,7 +15,15 @@ def generate_weight_combinations(step=0.1, fixed_weights=None):
15
variable_weights = [w for w in weight_names if w not in fixed_weights]
16
fixed_sum = sum(fixed_weights.values())
17
18
- if len(variable_weights) == 1:
+ if len(variable_weights) == 0:
19
+ # If all weights are fixed, return that single combination
20
+ if abs(fixed_sum - 1.0) < 1e-9: # Allow for floating point rounding
21
+ combo = [0, 0, 0, 0, 0]
22
+ for weight_name, value in fixed_weights.items():
23
+ combo[weight_names.index(weight_name)] = value
24
+ combinations.append(tuple(combo))
25
+
26
+ elif len(variable_weights) == 1:
27
# If all but one weight is fixed, there's only one possible value
28
remaining = round(1 - fixed_sum, 2)
29
if 0 <= remaining <= 1:
0 commit comments