@@ -30,16 +30,42 @@ def from_dict(cls, config: dict[str, Any] | None) -> "Controller":
3030 merged = {key : raw .get (key , default ) for key , default in defaults .items ()}
3131 return cls (** merged )
3232
33- def prepare_block (self , * , block_idx : int , n_trials : int , conditions : list [str ] | None ) -> list [str ]:
33+ def prepare_block (
34+ self ,
35+ * ,
36+ block_idx : int ,
37+ n_trials : int ,
38+ conditions : list [str ] | None ,
39+ condition_weights : list [float ] | None = None ,
40+ ) -> list [str ]:
3441 labels = [str (c ) for c in (conditions or []) if str (c ).strip ()]
3542 if not labels :
3643 labels = ["default" ]
3744
3845 trial_count = max (1 , int (n_trials ))
39- schedule = [labels [i % len (labels )] for i in range (trial_count )]
46+ rng = random .Random (self .seed + int (block_idx ) * 1009 )
47+
48+ if condition_weights is None :
49+ schedule = [labels [i % len (labels )] for i in range (trial_count )]
50+ else :
51+ if len (condition_weights ) != len (labels ):
52+ raise ValueError (
53+ "condition_weights length mismatch for labels "
54+ f"{ labels } : expected { len (labels )} , got { len (condition_weights )} "
55+ )
56+ total_w = sum (float (w ) for w in condition_weights )
57+ raw = [trial_count * float (w ) / total_w for w in condition_weights ]
58+ counts = [int (x ) for x in raw ]
59+ rem = trial_count - sum (counts )
60+ if rem > 0 :
61+ extra = rng .choices (labels , weights = condition_weights , k = rem )
62+ for lbl in extra :
63+ counts [labels .index (lbl )] += 1
64+ schedule = []
65+ for lbl , cnt in zip (labels , counts ):
66+ schedule .extend ([lbl ] * cnt )
4067
4168 if self .shuffle and len (schedule ) > 1 :
42- rng = random .Random (self .seed + int (block_idx ) * 1009 )
4369 rng .shuffle (schedule )
4470
4571 if self .enable_logging :
0 commit comments