Added fp16/bf16 based export and compile support for VLMs#819
Added fp16/bf16 based export and compile support for VLMs#819quic-hemagnih merged 54 commits intoquic:mainfrom
Conversation
| retained_state=True, | ||
| specializations=specializations["lang"], | ||
| convert_to_fp16=True, | ||
| convert_to_fp16=(DTYPE_TO_STRING_MAP[needed_dtype] == "float16"), |
There was a problem hiding this comment.
nit: why is this condition? required for AI200? @quic-rishinr
There was a problem hiding this comment.
This condition is required in case user wants bf16 support which will come in AI200, I have updated the code to convert_to_fp16 = True when passed dtype is either fp16 or fp32.
| self.model, transformed = transform.apply(self.model) | ||
| any_transformed = any_transformed or transformed | ||
|
|
||
| self._normalize_torch_dtype() |
There was a problem hiding this comment.
nit: does this take care of embedding and ASR models too?
| "allenai/Molmo-7B-D-0924", | ||
| "meta-llama/Llama-3.2-11B-Vision-Instruct", | ||
| ]: | ||
| pytest.skip("Test skipped for this model due to some issues.") |
There was a problem hiding this comment.
nit: with our dummy configs, can we run all sample lm models w/this test quickly?
d84668e to
c658e0f
Compare
| torch.nn.functional.pad(inputs["input_values"], (0, self.seq_len - input_ids_len), "constant", 0) | ||
| ) | ||
| needed_dtype = self.model.config.torch_dtype | ||
| input_values = input_values.astype(CUSTOM_IO_DTYPE_MAP[needed_dtype]) |
There was a problem hiding this comment.
Since inputs are in numpy format we should be using Torch_to_numpy_map right?
| router_logits = self.gate(hidden_states) | ||
|
|
||
| routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||
| routing_weights = F.softmax(router_logits, dim=1, dtype=self.gate.weight.dtype) |
There was a problem hiding this comment.
Can we update the softmax to be in original percision?
1e15c1a to
9df3b31
Compare
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
… and inference. Almost all LLMs can now be compiled and infered in fp16, test_causal_lm_models script has the following notion regarding how the tests happened : # means the model wasn't tested due to the size, not sure if it'll run through or have an accuracy mismatch. ## means the ouputs match for fp16 and things worked fine. ### means, outputs come but don't match properly with HF tokens. #### means they're quantized model and additional effort is needed to enable these. These commits cover almost all LLMs currently supported. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com> Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
…both in bfloat16. Added a patch incloud infer to map bfloat16 or 11 key type to np.float16 for AI200 inference. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com> Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
…e to False for _compile when appropriate params are missing. Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: Dhiraj Kumar Sah <dhirajku@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: asmigosw <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
| CUSTOM_IO_DTYPE_MAP = { | ||
| torch.float16: "float16", | ||
| torch.bfloat16: "bfloat16", | ||
| torch.float32: "float16", # Since compiler doesn't support fp32 |
There was a problem hiding this comment.
nit: did not understand this? torch.float32: "float16", # Since compiler doesn't support fp32
is it to avoid compiling for fp32?
There was a problem hiding this comment.
Yes, if in model's config, torch_dtype=torch.float32 is there, it should compile in float16 as compiler doesn't support fp32 precision. Earlier we were hardcoding it to fp16 everywhere so it wouldn't created a problem, but now since bf16 support is also there hence to make it generic.
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Signed-off-by: Asmita Goswami <asmigosw@qti.qualcomm.com>
Added fp16/bf16 based export and compile support for VLMs