Skip to content

Commit b942406

Browse files
authored
Add Intel Gaudi as a supported device. (#2888)
1 parent b22a3ae commit b942406

File tree

6 files changed

+36
-7
lines changed

6 files changed

+36
-7
lines changed

recipes/eleuther_eval.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ def __init__(
318318
self._batch_size = batch_size
319319
self._dtype = dtype
320320
self._enable_kv_cache = enable_kv_cache
321+
# Set device explicitely here since HPU is not included in
322+
# `device_list` in `HFLM` class
323+
self._device = torch.device(device)
321324

322325
@property
323326
def model(self):

torchtune/training/_distributed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
_DISTRIBUTED_STATE_DICT_API_IS_AVAILABLE = False
5656

5757
# Valid backends for logging memory stats
58-
VALID_BACKENDS_FOR_MEMORY_STATS = ("cuda", "xpu", "npu")
58+
VALID_BACKENDS_FOR_MEMORY_STATS = ("cuda", "xpu", "npu", "hpu")
5959

6060

6161
@dataclass
@@ -221,7 +221,9 @@ def _broadcast_tensor(tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
221221
elif dist.get_backend() == "xccl":
222222
tensor = tensor.to(get_device("xpu"))
223223
elif dist.get_backend() == "hccl":
224-
tensor = tensor.to(get_device("npu"))
224+
# Since NPU and HPU both have same backend names
225+
# infer device based on environment here.
226+
tensor = tensor.to(get_device())
225227
dist.broadcast(tensor, src=src, group=None)
226228
return tensor.to(device)
227229
else:

torchtune/training/_profiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def setup_torch_profiler(
180180
cpu: bool = True,
181181
cuda: bool = True,
182182
xpu: bool = True,
183+
hpu: bool = False,
183184
profile_memory: bool = DEFAULT_TRACE_OPTS["profile_memory"],
184185
with_stack: bool = DEFAULT_TRACE_OPTS["with_stack"],
185186
record_shapes: bool = DEFAULT_TRACE_OPTS["record_shapes"],
@@ -248,6 +249,7 @@ def setup_torch_profiler(
248249
cpu (bool): Enable cpu profiling. Default is True.
249250
cuda (bool): Enable cuda profiling. Default is True.
250251
xpu (bool): Enable xpu profiling. Default is True.
252+
hpu (bool): Enable hpu profiling. Default is False.
251253
profile_memory (bool): Profile memory usage. Default is False.
252254
with_stack (bool): Profile stack. Default is False.
253255
record_shapes (bool): Record shapes. Default is True.
@@ -274,6 +276,8 @@ def setup_torch_profiler(
274276
activities.append(torch.profiler.ProfilerActivity.CUDA)
275277
if xpu:
276278
activities.append(torch.profiler.ProfilerActivity.XPU)
279+
if hpu:
280+
activities.append(torch.profiler.ProfilerActivity.HPU)
277281
if len(activities) == 0:
278282
_warn("No activities specified, defaulting to CPU + CUDA")
279283
activities = DEFAULT_PROFILER_ACTIVITIES
@@ -371,6 +375,7 @@ def setup_torch_profiler(
371375
"cpu": cpu,
372376
"cuda": cuda,
373377
"xpu": xpu,
378+
"hpu": hpu,
374379
"profile_memory": profile_memory,
375380
"with_stack": with_stack,
376381
"record_shapes": record_shapes,

torchtune/training/memory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ def cleanup_before_training() -> None:
4848
Call gc collect, empty device cache, and reset peak memory stats.
4949
"""
5050
gc.collect()
51-
get_torch_device_namespace().empty_cache()
51+
from torchtune.utils._device import is_hpu_available
52+
53+
if not is_hpu_available:
54+
get_torch_device_namespace().empty_cache()
5255
get_torch_device_namespace().reset_peak_memory_stats()
5356

5457

torchtune/training/precision.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111

1212
from torchtune.utils import get_logger
13-
from torchtune.utils._device import is_npu_available
13+
from torchtune.utils._device import is_hpu_available, is_npu_available
1414

1515
log = get_logger()
1616

@@ -69,7 +69,8 @@ def verify_bf16_support() -> bool:
6969
mps_support = torch.backends.mps.is_available() and torch.backends.mps.is_built()
7070
npu_support = is_npu_available and torch.npu.is_bf16_supported()
7171
xpu_support = torch.xpu.is_available() and torch.xpu.is_bf16_supported()
72-
return cuda_support or mps_support or npu_support or xpu_support
72+
hpu_support = is_hpu_available and torch.hpu.is_bf16_supported()
73+
return cuda_support or mps_support or npu_support or xpu_support or hpu_support
7374

7475

7576
def get_dtype(

torchtune/utils/_device.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def is_torch_npu_available() -> bool:
4747
is_npu_available = is_torch_npu_available()
4848

4949

50+
def is_torch_hpu_available() -> bool:
51+
"""Check the availability of HPU"""
52+
try:
53+
import habana_frameworks.torch # noqa: F401
54+
55+
return torch.hpu.is_available()
56+
except ImportError:
57+
return False
58+
59+
60+
is_hpu_available = is_torch_hpu_available()
61+
62+
5063
def _get_local_rank() -> Optional[int]:
5164
"""Function that gets the local rank from the environment.
5265
@@ -78,7 +91,6 @@ def _setup_device(device: torch.device) -> torch.device:
7891
device_type = device_support.device_type
7992
device_name = device_support.device_name
8093
torch_device = get_torch_device_namespace()
81-
8294
if device.index is None:
8395
device = torch.device(type=device_type, index=local_rank)
8496

@@ -107,6 +119,8 @@ def _get_device_type_from_env() -> str:
107119
device = "cuda"
108120
elif is_npu_available:
109121
device = "npu"
122+
elif is_hpu_available:
123+
device = "hpu"
110124
elif torch.xpu.is_available():
111125
device = "xpu"
112126
elif torch.mps.is_available():
@@ -171,7 +185,7 @@ def get_device(device: Optional[str] = None) -> torch.device:
171185
if device is None:
172186
device = _get_device_type_from_env()
173187
device = torch.device(device)
174-
if device.type in ["cuda", "npu", "xpu"]:
188+
if device.type in ["cuda", "npu", "xpu", "hpu"]:
175189
device = _setup_device(device)
176190
_validate_device_from_env(device)
177191
return device
@@ -220,6 +234,7 @@ class DeviceSupport(Enum):
220234
NPU = ("npu", "NPU", "hccl")
221235
XPU = ("xpu", "XPU", "ccl")
222236
MPS = ("mps", "MPS", "gloo")
237+
HPU = ("hpu", "HPU", "hccl")
223238

224239
def __init__(
225240
self,

0 commit comments

Comments
 (0)