Skip to content

Commit 64f426c

Browse files
committed
fix bug
1 parent fa8b86d commit 64f426c

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

parallelformers/policies/gptj.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock
16-
from transformers.models.gptj.modeling_gptj import GPTJBlock
17-
1815
from parallelformers.policies.base import Layer, Policy
1916
from parallelformers.utils import AllReduceLinear
2017

@@ -26,7 +23,7 @@ def replace_arguments(config, world_size):
2623
# 1. reduce hidden size
2724
"attn.embed_dim": config.hidden_size // world_size,
2825
# 2. reduce number of heads
29-
"attn.num_heads": config.n_head // world_size,
26+
"attn.num_attention_heads": config.n_head // world_size,
3027
}
3128

3229
@staticmethod

0 commit comments

Comments
 (0)