diff --git a/torchrec/distributed/planner/tests/test_types.py b/torchrec/distributed/planner/tests/test_types.py index ce84b3721..cd9d2a650 100644 --- a/torchrec/distributed/planner/tests/test_types.py +++ b/torchrec/distributed/planner/tests/test_types.py @@ -18,6 +18,7 @@ ParameterConstraints, Shard, ShardingOption, + Topology, ) from torchrec.distributed.types import ( BoundsCheckMode, @@ -214,6 +215,54 @@ def test_module_pooled_mch_ec(self) -> None: self.assertEqual(sharding_option.is_pooled, False) +class TestTopologyHash(unittest.TestCase): + def test_hash_equality(self) -> None: + # Create two identical Topology instances + topology1 = Topology( + world_size=2, + compute_device="cuda", + hbm_cap=1024 * 1024 * 2, + local_world_size=2, + ) + + topology2 = Topology( + world_size=2, + compute_device="cuda", + hbm_cap=1024 * 1024 * 2, + local_world_size=2, + ) + + # Verify that the hash values are equal + self.assertEqual( + topology1._hash(), + topology2._hash(), + "Hashes should be equal for identical Topology instances", + ) + + def test_hash_inequality(self) -> None: + # Create two different Topology instances + topology1 = Topology( + world_size=2, + compute_device="cuda", + hbm_cap=1024 * 1024 * 2, + local_world_size=2, + ) + + topology2 = Topology( + world_size=4, # Different world_size + compute_device="cuda", + hbm_cap=1024 * 1024 * 2, + local_world_size=2, + ) + + # Verify that the hash values are different + self.assertNotEqual( + topology1._hash(), + topology2._hash(), + "Hashes should be different for different Topology instances", + ) + + class TestParameterConstraintsHash(unittest.TestCase): def test_hash_equality(self) -> None: diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 3689c5724..ba8ea59c8 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -8,6 +8,7 @@ # pyre-strict import abc +import hashlib from copy import deepcopy from dataclasses import dataclass, field from enum import Enum @@ -248,6 +249,10 @@ def get_bw( class Topology: + """ + Representation of a network of devices in a cluster. + """ + def __init__( self, world_size: int, @@ -396,6 +401,40 @@ def __repr__(self) -> str: topology_repr += str(self._comms_bandwidths) + "\n" return topology_repr + def _hash(self) -> str: + """ + Compute a consistent hash value for this Topology instance. + + Returns: + str: A hash value for this Topology instance. + """ + + # Compute hbms and ddrs from the decives + hbms = [device.storage.hbm for device in self._devices] + ddrs = [device.storage.ddr for device in self._devices] + + # Combine all attributes into a hashable tuple + hashable_list = [ + self._world_size, + self._compute_device, + hbms, + ddrs, + self._local_world_size, + self._hbm_mem_bw, + self._ddr_mem_bw, + self._hbm_to_ddr_mem_bw, + self._comms_bandwidths.intra_host_bw, + self._comms_bandwidths.inter_host_bw, + self._bwd_compute_multiplier, + self._weighted_feature_bwd_compute_multiplier, + self._uneven_sharding_perf_multiplier, + ] + + serialized_list = str(hashable_list).encode("utf-8") + hash_object = hashlib.sha256(serialized_list) + hash_digest = hash_object.hexdigest() + return hash_digest + # ---- INPUT / OUTPUT ----- #