Skip to content

Commit b7e7135

Browse files
committed
fix bug + style
1 parent 64f426c commit b7e7135

File tree

3 files changed

+15
-5
lines changed

3 files changed

+15
-5
lines changed

parallelformers/policies/base/auto.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,16 +639,24 @@ def __init__(self):
639639
]
640640

641641
with suppress(Exception):
642-
from transformers.models.gptj.modeling_gptj import GPTJPreTrainedModel
642+
from transformers.models.gptj.modeling_gptj import (
643+
GPTJPreTrainedModel,
644+
)
645+
643646
from parallelformers.policies.gptj import GPTJPolicy
644647

645648
self.builtin_policies[GPTJPreTrainedModel] = [
646649
GPTJPolicy,
647650
]
648651

649652
with suppress(Exception):
650-
from transformers.models.megatron_bert import MegatronBertPreTrainedModel
651-
from parallelformers.policies.megtron_bert import MegatronBertPolicy
653+
from transformers.models.megatron_bert import (
654+
MegatronBertPreTrainedModel,
655+
)
656+
657+
from parallelformers.policies.megtron_bert import (
658+
MegatronBertPolicy,
659+
)
652660

653661
self.builtin_policies[MegatronBertPreTrainedModel] = [
654662
MegatronBertPolicy,

parallelformers/policies/gptj.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from transformers.models.gptj.modeling_gptj import GPTJBlock
1415

1516
from parallelformers.policies.base import Layer, Policy
1617
from parallelformers.utils import AllReduceLinear

parallelformers/policies/megtron_bert.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from transformers.models.megatron_bert.modeling_megatron_bert import \
16-
MegatronBertLayer
15+
from transformers.models.megatron_bert.modeling_megatron_bert import (
16+
MegatronBertLayer,
17+
)
1718

1819
from parallelformers.policies.base import Layer, Policy
1920
from parallelformers.transformers.modeling_bert import BertEmbeddings_

0 commit comments

Comments
 (0)