From 1853609b86e1a565f4a27d7e7b83950311966047 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B7=A6=E4=BA=91=E9=BE=99?= Date: Sun, 31 Mar 2024 20:07:14 +0800 Subject: [PATCH] add support for mac m1/m2 chips mps gpu need upgrade onnx and protobuf, add setup default device for torch, so user can use command params "--torch-device mps:0" to launch gpu training on mac --- ml-agents/mlagents/torch_utils/torch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml-agents/mlagents/torch_utils/torch.py b/ml-agents/mlagents/torch_utils/torch.py index 24dc45cca3..ae2752de89 100644 --- a/ml-agents/mlagents/torch_utils/torch.py +++ b/ml-agents/mlagents/torch_utils/torch.py @@ -53,6 +53,8 @@ def set_torch_config(torch_settings: TorchSettings) -> None: if _device.type == "cuda": torch.set_default_tensor_type(torch.cuda.FloatTensor) + elif _device.type == 'mps': + torch.set_default_device(device_str) else: torch.set_default_tensor_type(torch.FloatTensor) logger.debug(f"default Torch device: {_device}")