diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 164e99803..12cd4163c 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -138,19 +138,21 @@ def __init__( else: # Hugging Face defaults to use_fast to True use_fast = True - # Phi model's fast tokenizer does not support adding a BOS token, use_fast + # Phi & Baichuan model's fast tokenizer does not support adding a BOS token, use_fast # should be False - if "phi" in self.cfg.tokenizer_name.lower(): + tokenizer_name = self.cfg.tokenizer_name.lower() + if "phi" in tokenizer_name: use_fast = False huggingface_token = os.environ.get("HF_TOKEN", None) + tokenizer = AutoTokenizer.from_pretrained( + self.cfg.tokenizer_name, + # add_bos_token=True, + trust_remote_code=self.cfg.trust_remote_code, + use_fast=use_fast, + token=huggingface_token, + ) self.set_tokenizer( - AutoTokenizer.from_pretrained( - self.cfg.tokenizer_name, - add_bos_token=True, - trust_remote_code=self.cfg.trust_remote_code, - use_fast=use_fast, - token=huggingface_token, - ), + tokenizer, default_padding_side=default_padding_side, ) else: diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 130c20743..2d63e2435 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -23,6 +23,7 @@ import transformer_lens.utils as utils from transformer_lens.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.pretrained.weight_conversions import ( + convert_baichuan_weights, convert_bert_weights, convert_bloom_weights, convert_coder_weights, @@ -218,6 +219,9 @@ "google-t5/t5-base", "google-t5/t5-large", "ai-forever/mGPT", + "baichuan-inc/Baichuan-7B", + "baichuan-inc/Baichuan-13B-Base", + "baichuan-inc/Baichuan-13B-Chat", ] """Official model names for models on HuggingFace.""" @@ -640,6 +644,9 @@ "google-t5/t5-base": ["t5-base"], "google-t5/t5-large": ["t5-large"], "ai-forever/mGPT": ["mGPT"], + "baichuan-inc/Baichuan-7B": ["Baichuan-7B"], + "baichuan-inc/Baichuan-13B-Base": ["Baichuan-13B-Base"], + "baichuan-inc/Baichuan-13B-Chat": ["Baichuan-13B-Chat"], } """Model aliases for models on HuggingFace.""" @@ -1293,6 +1300,24 @@ def convert_hf_model_config(model_name: str, **kwargs): "use_attn_scale": False, "tie_word_embeddings": hf_config.tie_word_embeddings, } + elif architecture.startswith("Bai"): + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": 2048, # Capped due to HF Tokenizer Constraints + "d_vocab": hf_config.vocab_size, + "eps": hf_config.rms_norm_eps, + "trust_remote_code": True, + "act_fn": hf_config.hidden_act, + "initializer_range": hf_config.initializer_range, + "normalization_type": "RMS", + "post_embedding_ln": True, + "positional_embedding_type": "alibi", + "tie_word_embeddings": hf_config.tie_word_embeddings, + } else: raise NotImplementedError(f"{architecture} is not currently supported.") # All of these models use LayerNorm @@ -1654,6 +1679,8 @@ def get_pretrained_state_dict( state_dict = convert_neox_weights(hf_model, cfg) elif cfg.original_architecture == "LlamaForCausalLM": state_dict = convert_llama_weights(hf_model, cfg) + elif cfg.original_architecture.startswith("Bai"): + state_dict = convert_baichuan_weights(hf_model, cfg) elif cfg.original_architecture == "BertForMaskedLM": state_dict = convert_bert_weights(hf_model, cfg) elif cfg.original_architecture == "T5ForConditionalGeneration": diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index b13850ee0..6541fde6c 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -7,6 +7,7 @@ from .bert import convert_bert_weights from .mistral import convert_mistral_weights from .mixtral import convert_mixtral_weights +from .baichuan import convert_baichuan_weights from .bloom import convert_bloom_weights from .coder import convert_coder_weights from .qwen import convert_qwen_weights diff --git a/transformer_lens/pretrained/weight_conversions/baichuan.py b/transformer_lens/pretrained/weight_conversions/baichuan.py new file mode 100644 index 000000000..12637c748 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/baichuan.py @@ -0,0 +1,66 @@ +import einops +import torch + +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def convert_baichuan_weights(baichuan, cfg: HookedTransformerConfig): + state_dict = {} + + state_dict["embed.W_E"] = baichuan.model.embed_tokens.weight + + assert cfg.d_mlp is not None # keep mypy happy + + for l in range(cfg.n_layers): + state_dict[f"blocks.{l}.ln1.w"] = baichuan.model.layers[l].input_layernorm.weight + + W = baichuan.model.layers[l].self_attn.W_pack.weight + + W_split = W.T.reshape(cfg.d_model, cfg.n_heads, 3, cfg.d_head) + + W_Q, W_K, W_V = W_split[..., 0, :], W_split[..., 1, :], W_split[..., 2, :] + W_Q = einops.rearrange(W_Q, "m n h ->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "m n h ->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "m n h ->n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype, device=W_Q.device + ) + state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros( + cfg.n_heads, + cfg.d_head, + dtype=cfg.dtype, + device=W_Q.device, + ) + state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros( + cfg.n_heads, + cfg.d_head, + dtype=cfg.dtype, + device=W_Q.device, + ) + + W_O = baichuan.model.layers[l].self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros( + cfg.d_model, dtype=cfg.dtype, device=W_O.device + ) + + state_dict[f"blocks.{l}.ln2.w"] = baichuan.model.layers[l].post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = baichuan.model.layers[l].mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = baichuan.model.layers[l].mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=W_O.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = baichuan.model.layers[l].mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=W_O.dtype) + + state_dict["ln_final.w"] = baichuan.model.norm.weight + state_dict["pos_embed.W_pos"] = baichuan.model.transformer.wpe.weight + state_dict["unembed.W_U"] = baichuan.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=W_O.dtype) + + return state_dict