We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fa8b86d commit 64f426cCopy full SHA for 64f426c
parallelformers/policies/gptj.py
@@ -12,9 +12,6 @@
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
-from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoBlock
16
-from transformers.models.gptj.modeling_gptj import GPTJBlock
17
-
18
from parallelformers.policies.base import Layer, Policy
19
from parallelformers.utils import AllReduceLinear
20
@@ -26,7 +23,7 @@ def replace_arguments(config, world_size):
26
23
# 1. reduce hidden size
27
24
"attn.embed_dim": config.hidden_size // world_size,
28
25
# 2. reduce number of heads
29
- "attn.num_heads": config.n_head // world_size,
+ "attn.num_attention_heads": config.n_head // world_size,
30
}
31
32
@staticmethod
0 commit comments