You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add accelerate API support for Word Language Model example (#1345)
* Add accelerate API support for Word Language Model example
* Update README for Word Language Model example
* Update word_language_model/generate.py for consistency
Co-authored-by: Dmitry Rogozhkin <[email protected]>
* Update word_language_model/README.md
Co-authored-by: Dmitry Rogozhkin <[email protected]>
* Update README to change wording on acceleration devices
* Remove cuda conditional
* Fix flags for word_language_model in ci script
---------
Co-authored-by: Dmitry Rogozhkin <[email protected]>
python generate.py # Generate samples from the default model checkpoint.
13
+
python generate.py --accel# Generate samples from the default model checkpoint.
14
14
```
15
15
16
+
> [!NOTE]
17
+
> Example supports running on acceleration devices (CUDA, MPS, XPU)
18
+
16
19
The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`) or Transformer module (`nn.TransformerEncoder` and `nn.TransformerEncoderLayer`) which will automatically use the cuDNN backend if run on CUDA with cuDNN installed.
17
20
18
21
During training, if a keyboard interrupt (Ctrl-C) is received, training is stopped and the current model is evaluated against the test dataset.
@@ -35,8 +38,7 @@ optional arguments:
35
38
--dropout DROPOUT dropout applied to layers (0 = no dropout)
36
39
--tied tie the word embedding and softmax weights
37
40
--seed SEED random seed
38
-
--cuda use CUDA
39
-
--mps enable GPU on macOS
41
+
--accel use accelerator
40
42
--log-interval N report interval
41
43
--save SAVE path to save the final model
42
44
--onnx-export ONNX_EXPORT
@@ -49,8 +51,8 @@ With these arguments, a variety of models can be tested.
49
51
As an example, the following arguments produce slower but better models:
0 commit comments