Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions penzai/models/transformer/variants/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ def gpt_neox_from_huggingface_model(
"eos_token_id",
"_attn_implementation_autoset",
"head_dim",
"is_decoder",
"attention_probs_dropout_prob",
"hidden_dropout_prob",
"type_vocab_size",
"_name_or_path",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
2 changes: 2 additions & 0 deletions penzai/models/transformer/variants/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ def llama_from_huggingface_model(
"architectures",
"bos_token_id",
"eos_token_id",
"pad_token_id",
"_attn_implementation_autoset",
"head_dim",
"_name_or_path",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
7 changes: 7 additions & 0 deletions penzai/models/transformer/variants/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ def mistral_from_huggingface_model(
"architectures",
"_attn_implementation_autoset",
"head_dim",
"hidden_act",
"is_decoder",
"pad_token_id",
"attention_probs_dropout_prob",
"hidden_dropout_prob",
"type_vocab_size",
"_name_or_path",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
79 changes: 79 additions & 0 deletions tests/models/transformer_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,32 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads):
pz_out, hf_out.order_like(pz_out), atol=1e-6
)

def test_llama_consistency_from_pretrainsed(self):
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
hf_model = transformers.LlamaForCausalLM.from_pretrained(model_name)

tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq")

hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq")))
hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag(
"batch", "seq", "vocabulary"
)

for layer_stack in (False, True):
with self.subTest(f"layer_stack={layer_stack}"):
pz_model = llama.llama_from_huggingface_model(
hf_model, use_layer_stack=layer_stack
)

pz_out = pz_model(
tokens,
token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]),
)

chex.assert_trees_all_close(
pz_out, hf_out.order_like(pz_out), atol=1e-6
)

@parameterized.named_parameters(
dict(testcase_name="full", num_attention_heads=4, num_key_value_heads=4),
dict(testcase_name="mqa", num_attention_heads=4, num_key_value_heads=1),
Expand Down Expand Up @@ -108,6 +134,32 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads):
pz_out, hf_out.order_like(pz_out), atol=1e-6
)


def test_mistral_consistency_from_pretrained(self):
model_name = "hf-internal-testing/tiny-random-MistralForCausalLM"
hf_model = transformers.MistralForCausalLM.from_pretrained(model_name)

tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq")

hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq")))
hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag(
"batch", "seq", "vocabulary"
)

for layer_stack in (False, True):
with self.subTest(f"layer_stack={layer_stack}"):
pz_model = mistral.mistral_from_huggingface_model(
hf_model, use_layer_stack=layer_stack
)
pz_out = pz_model(
tokens,
token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]),
)

chex.assert_trees_all_close(
pz_out, hf_out.order_like(pz_out), atol=6e-3
)

def test_gpt_neox_consistency(self):
cfg = transformers.GPTNeoXConfig(
vocab_size=11,
Expand Down Expand Up @@ -144,6 +196,33 @@ def test_gpt_neox_consistency(self):
pz_out, hf_out.order_like(pz_out), rtol=3e-3
)

def test_gpt_neox_consistency_from_pretrained(self):
model_name = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM"
hf_model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name)

tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq")

hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq")))
hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag(
"batch", "seq", "vocabulary"
)

for layer_stack in (False, True):
with self.subTest(f"layer_stack={layer_stack}"):
pz_model = gpt_neox.gpt_neox_from_huggingface_model(
hf_model, use_layer_stack=layer_stack
)
pz_out = pz_model(
tokens,
token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]),
)

chex.assert_trees_all_close(
pz_out, hf_out.order_like(pz_out), atol=4e-3
)
chex.assert_trees_all_close(
pz_out, hf_out.order_like(pz_out), rtol=9e-3
)

if __name__ == "__main__":
absltest.main()
Loading