Currently, if you try to create a prediction trajectory from a model and lens loaded in bfloat16 error.
294 traj_log_probs.append(
--> 295 logits.log_softmax(dim=-1).squeeze().detach().cpu().numpy()
296 )
298 # Add model predictions
299 traj_log_probs.append(model_log_probs)
TypeError: Got unsupported ScalarType BFloat16
Currently, if you try to create a prediction trajectory from a model and lens loaded in
bfloat16error.