Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _export(

# Setup temporary paths
tmp_onnx_dir = export_dir / "onnx_tmp"
tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx"
# tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx"
tmp_onnx_dir.mkdir(parents=True, exist_ok=True)

# Create input_names from example_inputs
Expand Down Expand Up @@ -278,7 +278,7 @@ def _export(
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
str(onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
Expand All @@ -287,11 +287,11 @@ def _export(
)
logger.info("PyTorch export successful")
_ = self._offload_model_weights(offload_pt_weights)
model = onnx.load(tmp_onnx_path, load_external_data=False)
model = onnx.load(onnx_path, load_external_data=False)

# Clear temporary references
transform_kwargs = {
"onnx_base_dir": str(tmp_onnx_dir),
"onnx_base_dir": None,
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
Expand Down
16 changes: 14 additions & 2 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,27 @@ class CustomOpTransform(BaseOnnxTransform):
@classmethod
def apply(cls, model: ModelProto) -> bool:
op_applied = False

# Register with PyTorch ONNX exporter (for export time)
for op_name, (func_class, _) in cls._custom_ops.items():
if hasattr(func_class, "symbolic"):
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, ONNX_EXPORT_OPSET)

used_op_types = {node.op_type for node in model.graph.node}
for function_proto in model.functions:
used_op_types.update(node.op_type for node in function_proto.node)

# Add function prototypes to model
existing = {f.name for f in model.functions}
for _, onnxscript_func in cls._custom_ops.values():

for func_name, onnxscript_func in cls._custom_ops.values():
proto = onnxscript_func.to_function_proto()
if proto.name not in used_op_types:
continue
if proto.name not in existing:
model.functions.append(proto)
op_applied = True

return op_applied


Expand Down Expand Up @@ -228,7 +239,8 @@ def apply(
do_split = SplitTensorsTransform in requested
fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
file_num_tracker = {"num": 0, "size": 0}
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
if onnx_base_dir is not None:
external_data_helper.load_external_data_for_model(model, onnx_base_dir)

if do_fp16 or do_split:
for tensor in external_data_helper._get_all_tensors(model):
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _attn(
head_mask=None,
):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
query = query.to(torch.float16)
key = key.to(torch.float16)

attn_weights = torch.matmul(query, key.transpose(-1, -2))

Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def forward(
attention_scores = query_layer @ key_layer.transpose(-1, -2)
attention_scores /= math.sqrt(self.head_dim)
attention_scores = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attention_scores
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=self.config.torch_dtype), attention_scores
)
attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype)
# It is unclear why neither dropout nor head_mask is applied here (while it is with alibi).
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@ def eager_attention_forward(
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling

if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down
21 changes: 10 additions & 11 deletions QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def forward(self, hidden: torch.Tensor):
up = (hidden @ W_u) + b_u # [T, I]

# Apply GptOss activation with clamping
gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
gate = gate.clamp(min=torch.finfo(self.config.torch_dtype).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)

# GLU activation
Expand Down Expand Up @@ -135,7 +135,7 @@ def forward(self, hidden: torch.Tensor):
up = (hidden @ W_u) + b_u # [T, I]

# Apply GptOss activation with clamping
gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
gate = gate.clamp(min=torch.finfo(self.config.torch_dtype).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)

# GLU activation
Expand Down Expand Up @@ -203,7 +203,7 @@ def blocked_ffn_forward(self, hidden: torch.Tensor):
up = (tgb @ W_u) + b_u # [T, I]

# Apply GptOss activation with clamping
gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
gate = gate.clamp(min=torch.finfo(self.config.torch_dtype).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)

# GLU activation
Expand Down Expand Up @@ -284,7 +284,7 @@ def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor):
cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128]
cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128]

cur_gate = cur_gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
cur_gate = cur_gate.clamp(min=torch.finfo(self.config.torch_dtype).min, max=self.experts.limit)
cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit)
cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha)
cur_intermediate = (cur_up + 1) * cur_glu
Expand Down Expand Up @@ -367,7 +367,6 @@ def forward(self, hidden_states):
router_logits = F.linear(hidden_states, self.router.weight, self.router.bias)
router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)

# GATHER - collect weights for selected experts (separate gate and up projections)
gate_proj = self.experts.gate_proj[router_indices.flatten()]
gate_proj_bias = self.experts.gate_proj_bias[router_indices.flatten()]
Expand All @@ -389,7 +388,7 @@ def forward(self, hidden_states):
up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1)

# Apply activation with clamping
gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit)
gate = gate.clamp(min=torch.finfo(self.config.torch_dtype).min, max=self.experts.limit)
up = up.clamp(min=-self.experts.limit, max=self.experts.limit)

# GLU activation
Expand Down Expand Up @@ -603,7 +602,7 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
Expand Down Expand Up @@ -654,14 +653,14 @@ def eager_attention_forward_blocked(
scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling
attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :]
curr_attn_weights = torch.where(
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), scores
)
sinks = module.sinks.reshape(1, -1, 1, 1).expand(
curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
)
combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=module.config.torch_dtype)
curr_attn_weights = curr_attn_weights[..., :-1]
out_block = torch.matmul(curr_attn_weights, value_states)
outs.append(out_block)
Expand Down Expand Up @@ -717,14 +716,14 @@ def opt_eager_attention_forward_blocked(

scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling
curr_attn_weights = torch.where(
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores
attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), scores
)
sinks = module.sinks.reshape(1, -1, 1, 1).expand(
curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1
)
combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1)
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32)
curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=module.config.torch_dtype)
curr_attn_weights = curr_attn_weights[..., :-1]
out_block = torch.matmul(curr_attn_weights, v_block)
outs.append(out_block)
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ def _attn(
head_mask=None,
):
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
key = key.to(torch.float32)
query = query.to(torch.float16)
key = key.to(torch.float16)

attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / self.scale_attn

if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float16), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
Expand Down
5 changes: 2 additions & 3 deletions QEfficient/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down Expand Up @@ -145,7 +145,6 @@ def forward(
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.dtype), attn_weights
)

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/grok_1/modeling_grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def forward(

if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=self.config.torch_dtype), attn_weights
)

attn_weights = F.softmax(attn_weights, dim=-1).to(query_states.dtype)
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down Expand Up @@ -147,7 +147,7 @@ def eager_attention_forward_blockedKV(
past_seen_tokens = cache_kwargs.get("past_seen_tokens")
position_ids = cache_kwargs.get("position_ids")
block_size = -(-past_seen_tokens // num_kv_blocks)
masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype)

for j in range(num_kv_blocks):
start_index = j * block_size
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
Expand Down
Loading
Loading