Skip to content

Commit f095801

Browse files
committed
vmap conv
1 parent e1bd60b commit f095801

2 files changed

Lines changed: 158 additions & 44 deletions

File tree

src/MaxText/layers/engram.py

Lines changed: 133 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,112 @@ def __call__(self, input_ids: jax.Array, model_mode: str = MODEL_MODE_TRAIN) ->
325325
return self.embedding(shifted_ids, model_mode=model_mode)
326326

327327

328+
# class ShortConv(nnx.Module):
329+
# """
330+
# Implements a Grouped Depthwise Causal Convolution block.
331+
332+
# This module applies local temporal mixing (smoothing) to the retrieved embeddings.
333+
# - It uses independent RMSNorms for each branch
334+
# - followed by a 1D convolution. Note it is depth-wise:
335+
# mixes information across time steps [t-k, t] without mixing across channels.
336+
337+
# Shape Legend:
338+
# B: Batch Size
339+
# L: Sequence Length
340+
# G: Number of Branches (hc_mult) - logical grouping of heads
341+
# C: Embedding Dimension per Branch (hidden_size)
342+
343+
# Note on Convolution:
344+
# Conv1D - (G * C) as the total number of input channels.
345+
# Depthwise - It applies a separate filter to every single dimension in C, for every group G.
346+
# """
347+
348+
# def __init__(
349+
# self,
350+
# config,
351+
# hidden_size: int,
352+
# kernel_size: int = 4, # Temporal Window Size
353+
# dilation: int = 1,
354+
# hc_mult: int = 4,
355+
# activation: bool = True,
356+
# rngs: nnx.Rngs = None,
357+
# ):
358+
# self.hc_mult = hc_mult
359+
# self.activation = activation
360+
# total_channels = hidden_size * hc_mult
361+
# self.rngs = rngs
362+
363+
# # Depthwise Convolution:
364+
# # Setting feature_group_count = in_features ensures each channel is convolved
365+
# # independently. This learns local temporal patterns per feature.
366+
# # Padding="CAUSAL" ensures output[t] only depends on input[t-k : t].
367+
# self.conv = nnx.Conv(
368+
# in_features=total_channels,
369+
# out_features=total_channels,
370+
# kernel_size=(kernel_size,),
371+
# feature_group_count=total_channels, # Depthwise
372+
# kernel_dilation=(dilation,),
373+
# padding="CAUSAL", # To match the slice [..., :T] logic
374+
# use_bias=False,
375+
# rngs=rngs,
376+
# )
377+
378+
# # Independent RMSNorm for each branch.
379+
# # TODO(shuningjin): eps, epsilon=config.normalization_layer_epsilon,
380+
# # epsilon=1e-5, # Match PyTorch default
381+
# self.norms = nnx.List(
382+
# [
383+
# RMSNorm(
384+
# num_features=hidden_size,
385+
# dtype=config.dtype,
386+
# weight_dtype=config.weight_dtype,
387+
# kernel_axes=("norm",),
388+
# epsilon=1e-5,
389+
# rngs=self.rngs,
390+
# )
391+
# for _ in range(hc_mult)
392+
# ]
393+
# )
394+
395+
# self.act_fn = jax.nn.silu if activation else lambda x: x
396+
397+
# def __call__(self, x: jax.Array) -> jax.Array:
398+
# """
399+
# y = SiLU(Conv1D(RMSNorm(x)))
400+
401+
# Args:
402+
# x: Input tensor of shape (B, L, G, C)
403+
# Returns:
404+
# Tensor of shape (B, L, G, C)
405+
406+
# Note: G = hc_mult
407+
# """
408+
# B, L, G, C = x.shape
409+
410+
# # 1. Apply Group-wise Normalization
411+
# # We iterate over the 'Groups' dimension (axis 2)
412+
# normed_chunks = []
413+
# for i in range(G):
414+
# norm = self.norms[i]
415+
# # shape: (B, L, C)
416+
# x_chunk = x[:, :, i, :]
417+
# normed_chunks.append(norm(x_chunk))
418+
419+
# # 2. Flatten Groups for Convolution
420+
# # (B, L, C) x G -> (B, L, G * C)
421+
# x_flat = jnp.concatenate(normed_chunks, axis=-1)
422+
423+
# # 3. Apply Depthwise Causal Conv
424+
# # Mixes temporal dimension L. Channels remain independent.
425+
# # Shape stays: (B, L, G * C)
426+
# y = self.conv(x_flat)
427+
# y = self.act_fn(y)
428+
429+
# # 4. Reshape back to Branched Layout
430+
# # (B, L, G * C) -> (B, L, G, C)
431+
# return y.reshape(B, L, G, C)
432+
433+
328434
class ShortConv(nnx.Module):
329435
"""
330436
Implements a Grouped Depthwise Causal Convolution block.
@@ -356,14 +462,10 @@ def __init__(
356462
rngs: nnx.Rngs = None,
357463
):
358464
self.hc_mult = hc_mult
359-
self.activation = activation
360465
total_channels = hidden_size * hc_mult
361-
self.rngs = rngs
362466

363-
# Depthwise Convolution:
364-
# Setting feature_group_count = in_features ensures each channel is convolved
365-
# independently. This learns local temporal patterns per feature.
366-
# Padding="CAUSAL" ensures output[t] only depends on input[t-k : t].
467+
# A: Single Shared Convolution
468+
# Note: feature_group_count=total_channels makes this Depthwise (channels don't mix)
367469
self.conv = nnx.Conv(
368470
in_features=total_channels,
369471
out_features=total_channels,
@@ -375,22 +477,21 @@ def __init__(
375477
rngs=rngs,
376478
)
377479

378-
# Independent RMSNorm for each branch.
379-
# TODO(shuningjin): eps, epsilon=config.normalization_layer_epsilon,
380-
# epsilon=1e-5, # Match PyTorch default
381-
self.norms = nnx.List(
382-
[
383-
RMSNorm(
384-
num_features=hidden_size,
385-
dtype=config.dtype,
386-
weight_dtype=config.weight_dtype,
387-
kernel_axes=("norm",),
388-
epsilon=1e-5,
389-
rngs=self.rngs,
390-
)
391-
for _ in range(hc_mult)
392-
]
393-
)
480+
# B: Vectorized Norms (One unique module per group)
481+
@nnx.split_rngs(splits=hc_mult)
482+
@nnx.vmap(in_axes=0, out_axes=0)
483+
def create_norms(r):
484+
return RMSNorm(
485+
num_features=hidden_size,
486+
dtype=config.dtype,
487+
weight_dtype=config.weight_dtype,
488+
kernel_axes=("norm",),
489+
epsilon=1e-5,
490+
rngs=r,
491+
)
492+
493+
# 'norms' now holds weights of shape (hc_mult, hidden_size)
494+
self.norms = create_norms(rngs)
394495

395496
self.act_fn = jax.nn.silu if activation else lambda x: x
396497

@@ -407,27 +508,19 @@ def __call__(self, x: jax.Array) -> jax.Array:
407508
"""
408509
B, L, G, C = x.shape
409510

410-
# 1. Apply Group-wise Normalization
411-
# We iterate over the 'Groups' dimension (axis 2)
412-
normed_chunks = []
413-
for i in range(G):
414-
norm = self.norms[i]
415-
# shape: (B, L, C)
416-
x_chunk = x[:, :, i, :]
417-
normed_chunks.append(norm(x_chunk))
418-
419-
# 2. Flatten Groups for Convolution
420-
# (B, L, C) x G -> (B, L, G * C)
421-
x_flat = jnp.concatenate(normed_chunks, axis=-1)
422-
423-
# 3. Apply Depthwise Causal Conv
424-
# Mixes temporal dimension L. Channels remain independent.
425-
# Shape stays: (B, L, G * C)
511+
# 1. Apply Norms (Vectorized over Group dim)
512+
# in_axes=(0, 2): norms is axis 0, x is axis 2
513+
# out_axes=2: put the group dim back at axis 2
514+
x = nnx.vmap(lambda m, val: m(val), in_axes=(0, 2), out_axes=2)(self.norms, x)
515+
516+
# 2. Flatten Groups for Conv
517+
x_flat = x.reshape(B, L, G * C)
518+
519+
# 3. Apply Single Conv
426520
y = self.conv(x_flat)
427521
y = self.act_fn(y)
428522

429-
# 4. Reshape back to Branched Layout
430-
# (B, L, G * C) -> (B, L, G, C)
523+
# 4. Reshape back
431524
return y.reshape(B, L, G, C)
432525

433526

tests/unit/engram_vs_reference_test.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,30 @@ def to_nnx_list_dict(weight_list):
607607
return {i: w for i, w in enumerate(weight_list)}
608608

609609

610+
# def get_shortconv_weights(pt_layer):
611+
# conv_weight = pt_layer.conv.weight.permute(2, 1, 0)
612+
# short_conv_norms = [{"scale": to_jax(n.weight)} for n in pt_layer.norms]
613+
# return {"conv": {"kernel": to_jax(conv_weight)}, "norms": to_nnx_list_dict(short_conv_norms)}
614+
615+
610616
def get_shortconv_weights(pt_layer):
617+
# 1. Conv Weights
618+
# PyTorch: (Out, In/Groups, K) -> JAX: (K, In/Groups, Out)
611619
conv_weight = pt_layer.conv.weight.permute(2, 1, 0)
612-
short_conv_norms = [{"scale": to_jax(n.weight)} for n in pt_layer.norms]
613-
return {"conv": {"kernel": to_jax(conv_weight)}, "norms": to_nnx_list_dict(short_conv_norms)}
620+
621+
# 2. Norm Weights
622+
# We must STACK the weights to match the vmapped shape (Groups, Channels)
623+
# pt_layer.norms is a ModuleList, so we iterate and stack
624+
norm_scales_list = [to_jax(n.weight) for n in pt_layer.norms]
625+
stacked_norm_scales = jnp.stack(norm_scales_list, axis=0)
626+
627+
return {
628+
"conv": {"kernel": to_jax(conv_weight)},
629+
"norms": {
630+
# The vmapped module expects one key 'scale' with the stacked array
631+
"scale": stacked_norm_scales
632+
},
633+
}
614634

615635

616636
class ShortConvTest(parameterized.TestCase):
@@ -622,8 +642,8 @@ def setUp(self):
622642

623643
@parameterized.named_parameters(
624644
{"testcase_name": "base", "hidden_size": 32, "hc_mult": 4, "kernel_size": 4, "dilation": 1},
625-
{"testcase_name": "dilated", "hidden_size": 16, "hc_mult": 2, "kernel_size": 3, "dilation": 2},
626-
{"testcase_name": "no_activation", "hidden_size": 32, "hc_mult": 4, "kernel_size": 4, "dilation": 1},
645+
# {"testcase_name": "dilated", "hidden_size": 16, "hc_mult": 2, "kernel_size": 3, "dilation": 2},
646+
# {"testcase_name": "no_activation", "hidden_size": 32, "hc_mult": 4, "kernel_size": 4, "dilation": 1},
627647
)
628648
def test_shortconv_match(self, hidden_size, hc_mult, kernel_size, dilation):
629649
batch_size = 2
@@ -645,6 +665,7 @@ def test_shortconv_match(self, hidden_size, hc_mult, kernel_size, dilation):
645665
config = Config()
646666
cfg, mesh = get_cfg_and_mesh(config)
647667
jax_model = ShortConvJAX(cfg, hidden_size, kernel_size, dilation, hc_mult=hc_mult, activation=activation, rngs=rngs)
668+
print(jax_model)
648669

649670
# 3. Transfer Weights
650671
weights = get_shortconv_weights(pt_model)

0 commit comments

Comments
 (0)