Skip to content

Commit 9ab4c21

Browse files
committed
Add Megatron bert
1 parent 6353c5f commit 9ab4c21

File tree

2 files changed

+121
-1
lines changed

2 files changed

+121
-1
lines changed

parallelformers/policies/base/auto.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from torch import nn
1919

2020
from parallelformers.policies.base import Policy
21-
from parallelformers.policies.gptj import GPTJPolicy
2221

2322

2423
class AutoPolicy:
@@ -641,11 +640,20 @@ def __init__(self):
641640

642641
with suppress(Exception):
643642
from transformers.models.gptj.modeling_gptj import GPTJPreTrainedModel
643+
from parallelformers.policies.gptj import GPTJPolicy
644644

645645
self.builtin_policies[GPTJPreTrainedModel] = [
646646
GPTJPolicy,
647647
]
648648

649+
with suppress(Exception):
650+
from transformers.models.megatron_bert import MegatronBertPreTrainedModel
651+
from parallelformers.policies.megtron_bert import MegatronBertPolicy
652+
653+
self.builtin_policies[MegatronBertPreTrainedModel] = [
654+
MegatronBertPolicy,
655+
]
656+
649657
def get_policy(self, model: nn.Module) -> Union[List[Policy], None]:
650658
"""
651659
Find appropriate policies for the current model
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2021 TUNiB inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from transformers.models.megatron_bert.modeling_megatron_bert import \
16+
MegatronBertLayer
17+
18+
from parallelformers.policies.base import Layer, Policy
19+
from parallelformers.transformers.modeling_bert import BertEmbeddings_
20+
from parallelformers.utils import AllReduceLinear
21+
22+
23+
class MegatronBertPolicy(Policy):
24+
@staticmethod
25+
def replace_arguments(config, world_size):
26+
return {
27+
# 1. reduce hidden size
28+
"attention.self.all_head_size": config.hidden_size // world_size,
29+
"crossattention.self.all_head_size": config.hidden_size // world_size,
30+
# 2. reduce number of heads
31+
"attention.self.num_attention_heads": config.num_attention_heads
32+
// world_size,
33+
"crossattention.self.num_attention_heads": config.num_attention_heads
34+
// world_size,
35+
}
36+
37+
@staticmethod
38+
def replace_modules():
39+
return {
40+
"BertEmbeddings": BertEmbeddings_,
41+
}
42+
43+
@staticmethod
44+
def attn_qkv():
45+
return [
46+
Layer(
47+
weight="attention.self.query.weight",
48+
bias="attention.self.query.bias",
49+
),
50+
Layer(
51+
weight="attention.self.key.weight",
52+
bias="attention.self.key.bias",
53+
),
54+
Layer(
55+
weight="attention.self.value.weight",
56+
bias="attention.self.value.bias",
57+
),
58+
Layer(
59+
weight="crossattention.self.query.weight",
60+
bias="crossattention.self.query.bias",
61+
ignore_checker=True,
62+
),
63+
Layer(
64+
weight="crossattention.self.key.weight",
65+
bias="crossattention.self.key.bias",
66+
ignore_checker=True,
67+
),
68+
Layer(
69+
weight="crossattention.self.value.weight",
70+
bias="crossattention.self.value.bias",
71+
ignore_checker=True,
72+
),
73+
]
74+
75+
@staticmethod
76+
def attn_out():
77+
return [
78+
Layer(
79+
weight="attention.output.dense.weight",
80+
bias="attention.output.dense.bias",
81+
replace=AllReduceLinear,
82+
),
83+
Layer(
84+
weight="crossattention.output.dense.weight",
85+
bias="crossattention.output.dense.bias",
86+
replace=AllReduceLinear,
87+
ignore_checker=True,
88+
),
89+
]
90+
91+
@staticmethod
92+
def mlp_in():
93+
return [
94+
Layer(
95+
weight="intermediate.dense.weight",
96+
bias="intermediate.dense.bias",
97+
),
98+
]
99+
100+
@staticmethod
101+
def mlp_out():
102+
return [
103+
Layer(
104+
weight="output.dense.weight",
105+
bias="output.dense.bias",
106+
replace=AllReduceLinear,
107+
),
108+
]
109+
110+
@staticmethod
111+
def original_layer_class():
112+
return MegatronBertLayer

0 commit comments

Comments
 (0)