Skip to content

Add hashing for Topology #3045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions torchrec/distributed/planner/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParameterConstraints,
Shard,
ShardingOption,
Topology,
)
from torchrec.distributed.types import (
BoundsCheckMode,
Expand Down Expand Up @@ -214,6 +215,54 @@ def test_module_pooled_mch_ec(self) -> None:
self.assertEqual(sharding_option.is_pooled, False)


class TestTopologyHash(unittest.TestCase):
def test_hash_equality(self) -> None:
# Create two identical Topology instances
topology1 = Topology(
world_size=2,
compute_device="cuda",
hbm_cap=1024 * 1024 * 2,
local_world_size=2,
)

topology2 = Topology(
world_size=2,
compute_device="cuda",
hbm_cap=1024 * 1024 * 2,
local_world_size=2,
)

# Verify that the hash values are equal
self.assertEqual(
topology1._hash(),
topology2._hash(),
"Hashes should be equal for identical Topology instances",
)

def test_hash_inequality(self) -> None:
# Create two different Topology instances
topology1 = Topology(
world_size=2,
compute_device="cuda",
hbm_cap=1024 * 1024 * 2,
local_world_size=2,
)

topology2 = Topology(
world_size=4, # Different world_size
compute_device="cuda",
hbm_cap=1024 * 1024 * 2,
local_world_size=2,
)

# Verify that the hash values are different
self.assertNotEqual(
topology1._hash(),
topology2._hash(),
"Hashes should be different for different Topology instances",
)


class TestParameterConstraintsHash(unittest.TestCase):

def test_hash_equality(self) -> None:
Expand Down
39 changes: 39 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import abc
import hashlib
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
Expand Down Expand Up @@ -248,6 +249,10 @@ def get_bw(


class Topology:
"""
Representation of a network of devices in a cluster.
"""

def __init__(
self,
world_size: int,
Expand Down Expand Up @@ -396,6 +401,40 @@ def __repr__(self) -> str:
topology_repr += str(self._comms_bandwidths) + "\n"
return topology_repr

def _hash(self) -> str:
"""
Compute a consistent hash value for this Topology instance.

Returns:
str: A hash value for this Topology instance.
"""

# Compute hbms and ddrs from the decives
hbms = [device.storage.hbm for device in self._devices]
ddrs = [device.storage.ddr for device in self._devices]

# Combine all attributes into a hashable tuple
hashable_list = [
self._world_size,
self._compute_device,
hbms,
ddrs,
self._local_world_size,
self._hbm_mem_bw,
self._ddr_mem_bw,
self._hbm_to_ddr_mem_bw,
self._comms_bandwidths.intra_host_bw,
self._comms_bandwidths.inter_host_bw,
self._bwd_compute_multiplier,
self._weighted_feature_bwd_compute_multiplier,
self._uneven_sharding_perf_multiplier,
]

serialized_list = str(hashable_list).encode("utf-8")
hash_object = hashlib.sha256(serialized_list)
hash_digest = hash_object.hexdigest()
return hash_digest


# ---- INPUT / OUTPUT ----- #

Expand Down
Loading