|
19 | 19 |
|
20 | 20 | # Helpers for rambo |
21 | 21 | events_signature = tf.TensorSpec(shape=[None, 1], dtype=DTYPE) |
| 22 | +events_signature_clean = tf.TensorSpec(shape=[None], dtype=DTYPE) |
22 | 23 | p_signature = tf.TensorSpec(shape=[None, 4], dtype=DTYPE) |
23 | 24 | ps_signature = tf.TensorSpec(shape=[None, None, 4], dtype=DTYPE) |
24 | 25 |
|
@@ -118,6 +119,7 @@ def _conformal_transformation(input_q, bquad): |
118 | 119 | return tf.concat([pnrg, pvec], axis=1) # (n_events, 4) |
119 | 120 |
|
120 | 121 |
|
| 122 | +@tf.function(input_signature=[p_signature]) |
121 | 123 | def _gen_unconstrained_momenta(xrand): |
122 | 124 | """ |
123 | 125 | Generates unconstrained 4-momenta |
@@ -212,6 +214,7 @@ def rambo(xrand, n_particles, sqrts, masses=None, check_physical=False): |
212 | 214 | return massive_p, wt |
213 | 215 |
|
214 | 216 |
|
| 217 | +@tf.function(input_signature=[tf.TensorSpec(shape=[None, 2], dtype=DTYPE)] + 2*[tf.TensorSpec(shape=[], dtype=DTYPE)]) |
215 | 218 | def _get_x1x2(xarr, shat_min, s_in): |
216 | 219 | """Receives two random numbers and return the |
217 | 220 | value of the invariant mass of the center of mass |
@@ -291,6 +294,7 @@ def ramboflow(xrand, nparticles, com_sqrts, masses=None): |
291 | 294 | return final_p, wgt, x1, x2 |
292 | 295 |
|
293 | 296 |
|
| 297 | +@tf.function(input_signature=[ps_signature] + [events_signature_clean]*2) |
294 | 298 | def _boost_to_lab(p_com, x1, x2): |
295 | 299 | """Boost the momenta back from the COM frame of the initial partons |
296 | 300 | to the lab frame |
|
0 commit comments