Skip to content

Commit 2e689d0

Browse files
committed
final
1 parent c599323 commit 2e689d0

2 files changed

Lines changed: 13 additions & 5 deletions

File tree

madgraph_plugin/template_files/matrix_method_python.inc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ class Matrix_%(process_string)s(object):
7777
def __str__(self):
7878
return "%(process_string)s"
7979

80+
@tf.function(input_signature=[tf.TensorSpec(shape=[None, %(nexternal)d], dtype=DTYPE)]+smatrix_signature)
81+
def wrapper(self, all_hel, all_ps, %(params)s):
82+
nevts = tf.shape(all_ps, out_type=DTYPEINT)[0]
83+
ans = tf.zeros(nevts, dtype=DTYPE)
84+
for hel in self.helicities:
85+
ans += self.matrix(all_ps,hel,%(params)s)
86+
return ans
87+
8088
@tf.function(input_signature=smatrix_signature)
8189
def smatrix(self,all_ps,%(params)s):
8290
#
@@ -96,12 +104,8 @@ class Matrix_%(process_string)s(object):
96104
# ----------
97105
# BEGIN CODE
98106
# ----------
99-
nevts = tf.shape(all_ps, out_type=DTYPEINT)[0]
100-
ans = tf.zeros(nevts, dtype=DTYPECOMPLEX)
101-
for hel in self.helicities:
102-
ans += self.matrix(all_ps,hel,%(params)s)
103107

104-
return ans/self.denominator
108+
return self.wrapper(self.helicities, all_ps, %(params)s)/self.denominator
105109

106110
@tf.function(input_signature=matrix_signature)
107111
def matrix(self,all_ps,hel,%(params)s):

python_package/madflow/phasespace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# Helpers for rambo
2121
events_signature = tf.TensorSpec(shape=[None, 1], dtype=DTYPE)
22+
events_signature_clean = tf.TensorSpec(shape=[None], dtype=DTYPE)
2223
p_signature = tf.TensorSpec(shape=[None, 4], dtype=DTYPE)
2324
ps_signature = tf.TensorSpec(shape=[None, None, 4], dtype=DTYPE)
2425

@@ -118,6 +119,7 @@ def _conformal_transformation(input_q, bquad):
118119
return tf.concat([pnrg, pvec], axis=1) # (n_events, 4)
119120

120121

122+
@tf.function(input_signature=[p_signature])
121123
def _gen_unconstrained_momenta(xrand):
122124
"""
123125
Generates unconstrained 4-momenta
@@ -212,6 +214,7 @@ def rambo(xrand, n_particles, sqrts, masses=None, check_physical=False):
212214
return massive_p, wt
213215

214216

217+
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 2], dtype=DTYPE)] + 2*[tf.TensorSpec(shape=[], dtype=DTYPE)])
215218
def _get_x1x2(xarr, shat_min, s_in):
216219
"""Receives two random numbers and return the
217220
value of the invariant mass of the center of mass
@@ -291,6 +294,7 @@ def ramboflow(xrand, nparticles, com_sqrts, masses=None):
291294
return final_p, wgt, x1, x2
292295

293296

297+
@tf.function(input_signature=[ps_signature] + [events_signature_clean]*2)
294298
def _boost_to_lab(p_com, x1, x2):
295299
"""Boost the momenta back from the COM frame of the initial partons
296300
to the lab frame

0 commit comments

Comments
 (0)