Skip to content

Commit 28e5864

Browse files
framoncgdvrogozh
andauthored
Add accelerate API support for Word Language Model example (#1345)
* Add accelerate API support for Word Language Model example * Update README for Word Language Model example * Update word_language_model/generate.py for consistency Co-authored-by: Dmitry Rogozhkin <[email protected]> * Update word_language_model/README.md Co-authored-by: Dmitry Rogozhkin <[email protected]> * Update README to change wording on acceleration devices * Remove cuda conditional * Fix flags for word_language_model in ci script --------- Co-authored-by: Dmitry Rogozhkin <[email protected]>
1 parent 5a4ca92 commit 28e5864

File tree

4 files changed

+31
-48
lines changed

4 files changed

+31
-48
lines changed

run_python_examples.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ function vision_transformer() {
155155
}
156156

157157
function word_language_model() {
158-
uv run main.py --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
159-
uv run generate.py $CUDA_FLAG --mps || error "word_language_model generate failed"
158+
uv run main.py --epochs 1 --dry-run $ACCEL_FLAG || error "word_language_model failed"
159+
uv run generate.py $ACCEL_FLAG || error "word_language_model generate failed"
160160
for model in "RNN_TANH" "RNN_RELU" "LSTM" "GRU" "Transformer"; do
161-
uv run main.py --model $model --epochs 1 --dry-run $CUDA_FLAG --mps || error "word_language_model failed"
162-
uv run generate.py $CUDA_FLAG --mps || error "word_language_model generate failed"
161+
uv run main.py --model $model --epochs 1 --dry-run $ACCEL_FLAG || error "word_language_model failed"
162+
uv run generate.py $ACCEL_FLAG || error "word_language_model generate failed"
163163
done
164164
}
165165

word_language_model/README.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@ This example trains a multi-layer RNN (Elman, GRU, or LSTM) or Transformer on a
44
The trained model can then be used by the generate script to generate new text.
55

66
```bash
7-
python main.py --cuda --epochs 6 # Train a LSTM on Wikitext-2 with CUDA.
8-
python main.py --cuda --epochs 6 --tied # Train a tied LSTM on Wikitext-2 with CUDA.
9-
python main.py --cuda --tied # Train a tied LSTM on Wikitext-2 with CUDA for 40 epochs.
10-
python main.py --cuda --epochs 6 --model Transformer --lr 5
11-
# Train a Transformer model on Wikitext-2 with CUDA.
7+
python main.py --accel --epochs 6 # Train a LSTM on Wikitext-2.
8+
python main.py --accel --epochs 6 --tied # Train a tied LSTM on Wikitext-2.
9+
python main.py --accel --tied # Train a tied LSTM on Wikitext-2for 40 epochs.
10+
python main.py --accel --epochs 6 --model Transformer --lr 5
11+
# Train a Transformer model on Wikitext-2.
1212

13-
python generate.py # Generate samples from the default model checkpoint.
13+
python generate.py --accel # Generate samples from the default model checkpoint.
1414
```
1515

16+
> [!NOTE]
17+
> Example supports running on acceleration devices (CUDA, MPS, XPU)
18+
1619
The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`) or Transformer module (`nn.TransformerEncoder` and `nn.TransformerEncoderLayer`) which will automatically use the cuDNN backend if run on CUDA with cuDNN installed.
1720

1821
During training, if a keyboard interrupt (Ctrl-C) is received, training is stopped and the current model is evaluated against the test dataset.
@@ -35,8 +38,7 @@ optional arguments:
3538
--dropout DROPOUT dropout applied to layers (0 = no dropout)
3639
--tied tie the word embedding and softmax weights
3740
--seed SEED random seed
38-
--cuda use CUDA
39-
--mps enable GPU on macOS
41+
--accel use accelerator
4042
--log-interval N report interval
4143
--save SAVE path to save the final model
4244
--onnx-export ONNX_EXPORT
@@ -49,8 +51,8 @@ With these arguments, a variety of models can be tested.
4951
As an example, the following arguments produce slower but better models:
5052

5153
```bash
52-
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40
53-
python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied
54-
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40
55-
python main.py --cuda --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied
54+
python main.py --accel --emsize 650 --nhid 650 --dropout 0.5 --epochs 40
55+
python main.py --accel --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied
56+
python main.py --accel --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40
57+
python main.py --accel --emsize 1500 --nhid 1500 --dropout 0.65 --epochs 40 --tied
5658
```

word_language_model/generate.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,20 @@
2222
help='number of words to generate')
2323
parser.add_argument('--seed', type=int, default=1111,
2424
help='random seed')
25-
parser.add_argument('--cuda', action='store_true',
26-
help='use CUDA')
27-
parser.add_argument('--mps', action='store_true', default=False,
28-
help='enables macOS GPU training')
2925
parser.add_argument('--temperature', type=float, default=1.0,
3026
help='temperature - higher will increase diversity')
3127
parser.add_argument('--log-interval', type=int, default=100,
3228
help='reporting interval')
29+
parser.add_argument('--accel', action='store_true', default=False,
30+
help='use accelerator')
3331
args = parser.parse_args()
3432

3533
# Set the random seed manually for reproducibility.
3634
torch.manual_seed(args.seed)
37-
if torch.cuda.is_available():
38-
if not args.cuda:
39-
print("WARNING: You have a CUDA device, so you should probably run with --cuda.")
40-
if torch.backends.mps.is_available():
41-
if not args.mps:
42-
print("WARNING: You have mps device, to enable macOS GPU run with --mps.")
43-
44-
use_mps = args.mps and torch.backends.mps.is_available()
45-
if args.cuda:
46-
device = torch.device("cuda")
47-
elif use_mps:
48-
device = torch.device("mps")
35+
36+
if args.accel and torch.accelerator.is_available():
37+
device = torch.accelerator.current_accelerator()
38+
4939
else:
5040
device = torch.device("cpu")
5141

word_language_model/main.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@
3737
help='tie the word embedding and softmax weights')
3838
parser.add_argument('--seed', type=int, default=1111,
3939
help='random seed')
40-
parser.add_argument('--cuda', action='store_true', default=False,
41-
help='use CUDA')
42-
parser.add_argument('--mps', action='store_true', default=False,
43-
help='enables macOS GPU training')
4440
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
4541
help='report interval')
4642
parser.add_argument('--save', type=str, default='model.pt',
@@ -51,25 +47,20 @@
5147
help='the number of heads in the encoder/decoder of the transformer model')
5248
parser.add_argument('--dry-run', action='store_true',
5349
help='verify the code and the model')
50+
parser.add_argument('--accel', action='store_true',help='Enables accelerated training')
5451
args = parser.parse_args()
5552

5653
# Set the random seed manually for reproducibility.
5754
torch.manual_seed(args.seed)
58-
if torch.cuda.is_available():
59-
if not args.cuda:
60-
print("WARNING: You have a CUDA device, so you should probably run with --cuda.")
61-
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
62-
if not args.mps:
63-
print("WARNING: You have mps device, to enable macOS GPU run with --mps.")
64-
65-
use_mps = args.mps and torch.backends.mps.is_available()
66-
if args.cuda:
67-
device = torch.device("cuda")
68-
elif use_mps:
69-
device = torch.device("mps")
55+
56+
if args.accel and torch.accelerator.is_available():
57+
device = torch.accelerator.current_accelerator()
58+
7059
else:
7160
device = torch.device("cpu")
7261

62+
print("Using device:", device)
63+
7364
###############################################################################
7465
# Load data
7566
###############################################################################

0 commit comments

Comments
 (0)