We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 804df7e commit 8c65687Copy full SHA for 8c65687
1 file changed
meshmode/array_context.py
@@ -602,6 +602,27 @@ def _get_fake_numpy_namespace(self):
602
def transform_dag(self, dag):
603
from pytato.array import Einsum
604
605
+ # {{{ face_mass: materialize einsum args
606
+
607
+ def materialize_face_mass_vec(expr):
608
+ if isinstance(expr, pt.Einsum):
609
+ my_tag, = expr.tags_of_type(pt.tags.EinsumInfo)
610
+ if my_tag.spec == "ifj,fej,fej->ei":
611
+ mat, jac, vec = expr.args
612
+ return pt.einsum("ifj,fej,fej->ei",
613
+ mat,
614
+ jac,
615
+ vec.tagged(pt.tags
616
+ .ImplementAs(pt.tags.ImplStored())))
617
+ else:
618
+ return expr
619
620
621
622
+ dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec)
623
624
+ # }}}
625
626
# {{{ materialize
627
628
nusers = pt.analysis.get_nusers(dag)
0 commit comments