@@ -325,6 +325,112 @@ def __call__(self, input_ids: jax.Array, model_mode: str = MODEL_MODE_TRAIN) ->
325325 return self .embedding (shifted_ids , model_mode = model_mode )
326326
327327
328+ # class ShortConv(nnx.Module):
329+ # """
330+ # Implements a Grouped Depthwise Causal Convolution block.
331+
332+ # This module applies local temporal mixing (smoothing) to the retrieved embeddings.
333+ # - It uses independent RMSNorms for each branch
334+ # - followed by a 1D convolution. Note it is depth-wise:
335+ # mixes information across time steps [t-k, t] without mixing across channels.
336+
337+ # Shape Legend:
338+ # B: Batch Size
339+ # L: Sequence Length
340+ # G: Number of Branches (hc_mult) - logical grouping of heads
341+ # C: Embedding Dimension per Branch (hidden_size)
342+
343+ # Note on Convolution:
344+ # Conv1D - (G * C) as the total number of input channels.
345+ # Depthwise - It applies a separate filter to every single dimension in C, for every group G.
346+ # """
347+
348+ # def __init__(
349+ # self,
350+ # config,
351+ # hidden_size: int,
352+ # kernel_size: int = 4, # Temporal Window Size
353+ # dilation: int = 1,
354+ # hc_mult: int = 4,
355+ # activation: bool = True,
356+ # rngs: nnx.Rngs = None,
357+ # ):
358+ # self.hc_mult = hc_mult
359+ # self.activation = activation
360+ # total_channels = hidden_size * hc_mult
361+ # self.rngs = rngs
362+
363+ # # Depthwise Convolution:
364+ # # Setting feature_group_count = in_features ensures each channel is convolved
365+ # # independently. This learns local temporal patterns per feature.
366+ # # Padding="CAUSAL" ensures output[t] only depends on input[t-k : t].
367+ # self.conv = nnx.Conv(
368+ # in_features=total_channels,
369+ # out_features=total_channels,
370+ # kernel_size=(kernel_size,),
371+ # feature_group_count=total_channels, # Depthwise
372+ # kernel_dilation=(dilation,),
373+ # padding="CAUSAL", # To match the slice [..., :T] logic
374+ # use_bias=False,
375+ # rngs=rngs,
376+ # )
377+
378+ # # Independent RMSNorm for each branch.
379+ # # TODO(shuningjin): eps, epsilon=config.normalization_layer_epsilon,
380+ # # epsilon=1e-5, # Match PyTorch default
381+ # self.norms = nnx.List(
382+ # [
383+ # RMSNorm(
384+ # num_features=hidden_size,
385+ # dtype=config.dtype,
386+ # weight_dtype=config.weight_dtype,
387+ # kernel_axes=("norm",),
388+ # epsilon=1e-5,
389+ # rngs=self.rngs,
390+ # )
391+ # for _ in range(hc_mult)
392+ # ]
393+ # )
394+
395+ # self.act_fn = jax.nn.silu if activation else lambda x: x
396+
397+ # def __call__(self, x: jax.Array) -> jax.Array:
398+ # """
399+ # y = SiLU(Conv1D(RMSNorm(x)))
400+
401+ # Args:
402+ # x: Input tensor of shape (B, L, G, C)
403+ # Returns:
404+ # Tensor of shape (B, L, G, C)
405+
406+ # Note: G = hc_mult
407+ # """
408+ # B, L, G, C = x.shape
409+
410+ # # 1. Apply Group-wise Normalization
411+ # # We iterate over the 'Groups' dimension (axis 2)
412+ # normed_chunks = []
413+ # for i in range(G):
414+ # norm = self.norms[i]
415+ # # shape: (B, L, C)
416+ # x_chunk = x[:, :, i, :]
417+ # normed_chunks.append(norm(x_chunk))
418+
419+ # # 2. Flatten Groups for Convolution
420+ # # (B, L, C) x G -> (B, L, G * C)
421+ # x_flat = jnp.concatenate(normed_chunks, axis=-1)
422+
423+ # # 3. Apply Depthwise Causal Conv
424+ # # Mixes temporal dimension L. Channels remain independent.
425+ # # Shape stays: (B, L, G * C)
426+ # y = self.conv(x_flat)
427+ # y = self.act_fn(y)
428+
429+ # # 4. Reshape back to Branched Layout
430+ # # (B, L, G * C) -> (B, L, G, C)
431+ # return y.reshape(B, L, G, C)
432+
433+
328434class ShortConv (nnx .Module ):
329435 """
330436 Implements a Grouped Depthwise Causal Convolution block.
@@ -356,14 +462,10 @@ def __init__(
356462 rngs : nnx .Rngs = None ,
357463 ):
358464 self .hc_mult = hc_mult
359- self .activation = activation
360465 total_channels = hidden_size * hc_mult
361- self .rngs = rngs
362466
363- # Depthwise Convolution:
364- # Setting feature_group_count = in_features ensures each channel is convolved
365- # independently. This learns local temporal patterns per feature.
366- # Padding="CAUSAL" ensures output[t] only depends on input[t-k : t].
467+ # A: Single Shared Convolution
468+ # Note: feature_group_count=total_channels makes this Depthwise (channels don't mix)
367469 self .conv = nnx .Conv (
368470 in_features = total_channels ,
369471 out_features = total_channels ,
@@ -375,22 +477,21 @@ def __init__(
375477 rngs = rngs ,
376478 )
377479
378- # Independent RMSNorm for each branch.
379- # TODO(shuningjin): eps, epsilon=config.normalization_layer_epsilon,
380- # epsilon=1e-5, # Match PyTorch default
381- self .norms = nnx .List (
382- [
383- RMSNorm (
384- num_features = hidden_size ,
385- dtype = config .dtype ,
386- weight_dtype = config .weight_dtype ,
387- kernel_axes = ("norm" ,),
388- epsilon = 1e-5 ,
389- rngs = self .rngs ,
390- )
391- for _ in range (hc_mult )
392- ]
393- )
480+ # B: Vectorized Norms (One unique module per group)
481+ @nnx .split_rngs (splits = hc_mult )
482+ @nnx .vmap (in_axes = 0 , out_axes = 0 )
483+ def create_norms (r ):
484+ return RMSNorm (
485+ num_features = hidden_size ,
486+ dtype = config .dtype ,
487+ weight_dtype = config .weight_dtype ,
488+ kernel_axes = ("norm" ,),
489+ epsilon = 1e-5 ,
490+ rngs = r ,
491+ )
492+
493+ # 'norms' now holds weights of shape (hc_mult, hidden_size)
494+ self .norms = create_norms (rngs )
394495
395496 self .act_fn = jax .nn .silu if activation else lambda x : x
396497
@@ -407,27 +508,19 @@ def __call__(self, x: jax.Array) -> jax.Array:
407508 """
408509 B , L , G , C = x .shape
409510
410- # 1. Apply Group-wise Normalization
411- # We iterate over the 'Groups' dimension (axis 2)
412- normed_chunks = []
413- for i in range (G ):
414- norm = self .norms [i ]
415- # shape: (B, L, C)
416- x_chunk = x [:, :, i , :]
417- normed_chunks .append (norm (x_chunk ))
418-
419- # 2. Flatten Groups for Convolution
420- # (B, L, C) x G -> (B, L, G * C)
421- x_flat = jnp .concatenate (normed_chunks , axis = - 1 )
422-
423- # 3. Apply Depthwise Causal Conv
424- # Mixes temporal dimension L. Channels remain independent.
425- # Shape stays: (B, L, G * C)
511+ # 1. Apply Norms (Vectorized over Group dim)
512+ # in_axes=(0, 2): norms is axis 0, x is axis 2
513+ # out_axes=2: put the group dim back at axis 2
514+ x = nnx .vmap (lambda m , val : m (val ), in_axes = (0 , 2 ), out_axes = 2 )(self .norms , x )
515+
516+ # 2. Flatten Groups for Conv
517+ x_flat = x .reshape (B , L , G * C )
518+
519+ # 3. Apply Single Conv
426520 y = self .conv (x_flat )
427521 y = self .act_fn (y )
428522
429- # 4. Reshape back to Branched Layout
430- # (B, L, G * C) -> (B, L, G, C)
523+ # 4. Reshape back
431524 return y .reshape (B , L , G , C )
432525
433526
0 commit comments