We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 9de3197 + 8abf6c4 commit d2f7bafCopy full SHA for d2f7baf
MaxText/train.py
@@ -948,7 +948,6 @@ def train_loop(config, state=None):
948
def main(argv: Sequence[str]) -> None:
949
pathwaysutils.initialize()
950
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
951
- jax.config.update("jax_enable_compilation_cache", os.environ.get("JAX_ENABLE_COMPILATION_CACHE", True))
952
# TF allocates extraneous GPU memory when using TFDS data
953
# this leads to CUDA OOMs. WAR for now is to hide GPUs from TF
954
tf.config.set_visible_devices([], "GPU")
0 commit comments