Skip to content

Commit

Permalink
Revert D65103519 (#2700)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2700

This diff reverts D65103519
Depends on D68528333
Need to revert this to fix lowering import error breaking aps tests

Reviewed By: PoojaAg18

Differential Revision: D68528363

fbshipit-source-id: 111a14e8c90e27be2860c852702a8dc8bf3543ff
  • Loading branch information
kausv authored and facebook-github-bot committed Jan 23, 2025
1 parent d5a991b commit 4d7b7ff
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 16 deletions.
16 changes: 4 additions & 12 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from tensordict import TensorDict
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharded_tensor import TensorProperties
Expand Down Expand Up @@ -95,7 +94,6 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -658,7 +656,9 @@ def __init__(
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
# to support mean pooling callback hook
self._has_mean_pooling_callback: bool = (
PoolingType.MEAN.value in self._pooling_type_to_rs_features
True
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
else False
)
self._dim_per_key: Optional[torch.Tensor] = None
self._kjt_key_indices: Dict[str, int] = {}
Expand Down Expand Up @@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices(

# pyre-ignore [14]
def input_dist(
self,
ctx: EmbeddingBagCollectionContext,
features: Union[KeyedJaggedTensor, TensorDict],
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
) -> Awaitable[Awaitable[KJTList]]:
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if len(self._features_order) > 0:
feature_keys = [feature_keys[i] for i in self._features_order]
self._has_features_permute = False # feature_keys are in order
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
if self._has_uninitialized_input_dist:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(

tables = [
EmbeddingBagConfig(
num_embeddings=max(i + 1, 100) * 1000,
num_embeddings=(i + 1) * 1000,
embedding_dim=dim_emb,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
Expand All @@ -169,7 +169,7 @@ def main(
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=max(i + 1, 100) * 1000,
num_embeddings=(i + 1) * 1000,
embedding_dim=dim_emb,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
Expand Down
2 changes: 0 additions & 2 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
pooling_type_to_str,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt


@torch.fx.wrap
Expand Down Expand Up @@ -230,7 +229,6 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
KeyedTensor
"""
flat_feature_names: List[str] = []
features = maybe_td_to_kjt(features, None)
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
Expand Down

0 comments on commit 4d7b7ff

Please sign in to comment.