Skip to content

Commit

Permalink
2025-01-23 nightly release (4d7b7ff)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Jan 23, 2025
1 parent 61a849d commit b227f54
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 127 deletions.
4 changes: 4 additions & 0 deletions .github/scripts/install_fbgemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ if [[ $CU_VERSION = cu* ]]; then
echo "[NOVA] Setting LD_LIBRARY_PATH ..."
conda env config vars set -p ${CONDA_ENV} \
LD_LIBRARY_PATH="/usr/local/lib:${CUDA_HOME}/lib64:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
else
echo "[NOVA] Setting LD_LIBRARY_PATH ..."
conda env config vars set -p ${CONDA_ENV} \
LD_LIBRARY_PATH="/usr/local/lib:${CONDA_ENV}/lib:${LD_LIBRARY_PATH}"
fi

if [ "$CHANNEL" = "nightly" ]; then
Expand Down
13 changes: 2 additions & 11 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
)

import torch
from tensordict import TensorDict
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
Expand Down Expand Up @@ -91,7 +90,6 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -1200,15 +1198,8 @@ def _compute_sequence_vbe_context(
def input_dist(
self,
ctx: EmbeddingCollectionContext,
features: TypeUnion[KeyedJaggedTensor, TensorDict],
features: KeyedJaggedTensor,
) -> Awaitable[Awaitable[KJTList]]:
need_permute: bool = True
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if self._features_order:
feature_keys = [feature_keys[i] for i in self._features_order]
need_permute = False
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
Expand All @@ -1218,7 +1209,7 @@ def input_dist(
unpadded_features = features
features = pad_vbe_kjt_lengths(unpadded_features)

if need_permute and self._features_order:
if self._features_order:
features = features.permute(
self._features_order,
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
Expand Down
16 changes: 4 additions & 12 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from tensordict import TensorDict
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharded_tensor import TensorProperties
Expand Down Expand Up @@ -95,7 +94,6 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -658,7 +656,9 @@ def __init__(
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
# to support mean pooling callback hook
self._has_mean_pooling_callback: bool = (
PoolingType.MEAN.value in self._pooling_type_to_rs_features
True
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
else False
)
self._dim_per_key: Optional[torch.Tensor] = None
self._kjt_key_indices: Dict[str, int] = {}
Expand Down Expand Up @@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices(

# pyre-ignore [14]
def input_dist(
self,
ctx: EmbeddingBagCollectionContext,
features: Union[KeyedJaggedTensor, TensorDict],
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
) -> Awaitable[Awaitable[KJTList]]:
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if len(self._features_order) > 0:
feature_keys = [feature_keys[i] for i in self._features_order]
self._has_features_permute = False # feature_keys are in order
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
if self._has_uninitialized_input_dist:
Expand Down
35 changes: 9 additions & 26 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def _create_process_groups(
) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]:
"""
Creates process groups for sharding and replication, the process groups
are created in the same exact order on all ranks as per `dist.new_group` API.
are created using the DeviceMesh API.
Args:
global_rank (int): The global rank of the current process.
Expand All @@ -781,44 +781,27 @@ def _create_process_groups(
Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh,
replication process group, and allreduce process group.
"""
# TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a
peer_matrix = []
sharding_pg, replica_pg = None, None
step = world_size // local_size

my_group_rank = global_rank % step
for group_rank in range(world_size // local_size):
peers = [step * r + group_rank for r in range(local_size)]
backend = dist.get_backend(self._pg)
curr_pg = dist.new_group(backend=backend, ranks=peers)
peer_matrix.append(peers)
if my_group_rank == group_rank:
logger.warning(
f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]"
)
sharding_pg = curr_pg
assert sharding_pg is not None, "sharding_pg is not initialized!"
dist.barrier()

my_inter_rank = global_rank // step
for inter_rank in range(local_size):
peers = [inter_rank * step + r for r in range(step)]
backend = dist.get_backend(self._pg)
curr_pg = dist.new_group(backend=backend, ranks=peers)
if my_inter_rank == inter_rank:
logger.warning(
f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]"
)
replica_pg = curr_pg
assert replica_pg is not None, "replica_pg is not initialized!"
dist.barrier()

mesh = DeviceMesh(
device_type=self._device.type,
mesh=peer_matrix,
mesh_dim_names=("replicate", "shard"),
)
logger.warning(f"[Connection] 2D Device Mesh created: {mesh}")
sharding_pg = mesh.get_group(mesh_dim="shard")
logger.warning(
f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]"
)
replica_pg = mesh.get_group(mesh_dim="replicate")
logger.warning(
f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]"
)

return mesh, sharding_pg, replica_pg

Expand Down
32 changes: 6 additions & 26 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def gen_model_and_input(
long_indices: bool = True,
global_constant_batch: bool = False,
num_inputs: int = 1,
input_type: str = "kjt", # "kjt" or "td"
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -178,9 +177,9 @@ def gen_model_and_input(
feature_processor_modules=feature_processor_modules,
)
inputs = []
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
for _ in range(num_inputs):
inputs.append(
for _ in range(num_inputs):
inputs.append(
(
cast(VariableBatchModelInputCallable, generate)(
average_batch_size=batch_size,
world_size=world_size,
Expand All @@ -189,26 +188,8 @@ def gen_model_and_input(
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
)
)
elif generate == ModelInput.generate:
for _ in range(num_inputs):
inputs.append(
ModelInput.generate(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
input_type=input_type,
)
)
else:
for _ in range(num_inputs):
inputs.append(
cast(ModelInputCallable, generate)(
if generate == ModelInput.generate_variable_batch_input
else cast(ModelInputCallable, generate)(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
Expand All @@ -219,6 +200,7 @@ def gen_model_and_input(
long_indices=long_indices,
)
)
)
return (model, inputs)


Expand Down Expand Up @@ -315,7 +297,6 @@ def sharding_single_rank_test(
global_constant_batch: bool = False,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
input_type: str = "kjt", # "kjt" or "td"
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Generate model & inputs.
Expand All @@ -338,7 +319,6 @@ def sharding_single_rank_test(
batch_size=batch_size,
feature_processor_modules=feature_processor_modules,
global_constant_batch=global_constant_batch,
input_type=input_type,
)
global_model = global_model.to(ctx.device)
global_input = inputs[0][0].to(ctx.device)
Expand Down
41 changes: 0 additions & 41 deletions torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,44 +376,3 @@ def _test_sharding(
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
)


@skip_if_asan_class
class TDSequenceModelParallelTest(SequenceModelParallelTest):

def test_sharding_variable_batch(self) -> None:
pass

def _test_sharding(
self,
sharders: List[TestEmbeddingCollectionSharder],
backend: str = "gloo",
world_size: int = 2,
local_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
qcomms_config: Optional[QCommsConfig] = None,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
) -> None:
self._run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
local_size=local_size,
model_class=model_class,
tables=self.tables,
embedding_groups=self.embedding_groups,
sharders=sharders,
optim=EmbOptimType.EXACT_SGD,
backend=backend,
constraints=constraints,
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
input_type="td",
)
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(

tables = [
EmbeddingBagConfig(
num_embeddings=max(i + 1, 100) * 1000,
num_embeddings=(i + 1) * 1000,
embedding_dim=dim_emb,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
Expand All @@ -169,7 +169,7 @@ def main(
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=max(i + 1, 100) * 1000,
num_embeddings=(i + 1) * 1000,
embedding_dim=dim_emb,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
Expand Down
10 changes: 2 additions & 8 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
pooling_type_to_str,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt


@torch.fx.wrap
Expand Down Expand Up @@ -219,10 +218,7 @@ def __init__(
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
self.reset_parameters()

def forward(
self,
features: KeyedJaggedTensor, # can also take TensorDict as input
) -> KeyedTensor:
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Expand All @@ -233,7 +229,6 @@ def forward(
KeyedTensor
"""
flat_feature_names: List[str] = []
features = maybe_td_to_kjt(features, None)
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
Expand Down Expand Up @@ -453,7 +448,7 @@ def __init__( # noqa C901

def forward(
self,
features: KeyedJaggedTensor, # can also take TensorDict as input
features: KeyedJaggedTensor,
) -> Dict[str, JaggedTensor]:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
Expand All @@ -466,7 +461,6 @@ def forward(
Dict[str, JaggedTensor]
"""

features = maybe_td_to_kjt(features, None)
feature_embeddings: Dict[str, JaggedTensor] = {}
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
for i, emb_module in enumerate(self.embeddings.values()):
Expand Down
30 changes: 29 additions & 1 deletion torchrec/schema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@
# pyre-strict

import inspect
import typing
from typing import Any


def _is_annot_compatible(prev: object, curr: object) -> bool:
if prev == curr:
return True

if not (prev_origin := typing.get_origin(prev)):
return False
if not (curr_origin := typing.get_origin(curr)):
return False

if prev_origin != curr_origin:
return False

prev_args = typing.get_args(prev)
curr_args = typing.get_args(curr)
if len(prev_args) != len(curr_args):
return False

for prev_arg, curr_arg in zip(prev_args, curr_args):
if not _is_annot_compatible(prev_arg, curr_arg):
return False

return True


def is_signature_compatible(
Expand Down Expand Up @@ -84,6 +110,8 @@ def is_signature_compatible(
return False

# TODO: Account for Union Types?
if current_signature.return_annotation != previous_signature.return_annotation:
if not _is_annot_compatible(
previous_signature.return_annotation, current_signature.return_annotation
):
return False
return True

0 comments on commit b227f54

Please sign in to comment.