Skip to content

Commit 3c683af

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Include stride_per_key_per_rank in KJT's PyTree flatten/unflatten logic (#2903)
Summary: # Context * Currently torchrec IR serializer can't handle variable batch use case. * `torch.export` only captures the keys, values, lengths, weights, offsets of a KJT, however, some variable-batch related parameters like `stride_per_rank` or `inverse_indices` would be ignored. * This test case (expected failure right now) covers the vb-KJT scenario for verifying that the serialize_deserialize_ebc use case works fine with KJTs with variable batch size. # Ref Differential Revision: D73051959
1 parent f35befa commit 3c683af

File tree

4 files changed

+82
-38
lines changed

4 files changed

+82
-38
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,14 @@ def forward(
206206
num_embeddings=10,
207207
feature_names=["f2"],
208208
)
209+
config3 = EmbeddingBagConfig(
210+
name="t3",
211+
embedding_dim=5,
212+
num_embeddings=10,
213+
feature_names=["f3"],
214+
)
209215
ebc = EmbeddingBagCollection(
210-
tables=[config1, config2],
216+
tables=[config1, config2, config3],
211217
is_weighted=False,
212218
)
213219

@@ -292,15 +298,17 @@ def test_serialize_deserialize_ebc(self) -> None:
292298
self.assertEqual(deserialized.shape, orginal.shape)
293299
self.assertTrue(torch.allclose(deserialized, orginal))
294300

295-
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
296301
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
297302
model = self.generate_model_for_vbe_kjt()
298303
id_list_features = KeyedJaggedTensor(
299-
keys=["f1", "f2"],
300-
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
301-
lengths=torch.tensor([3, 3, 2]),
302-
stride_per_key_per_rank=[[2], [1]],
303-
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
304+
keys=["f1", "f2", "f3"],
305+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
306+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
307+
stride_per_key_per_rank=[[3], [2], [1]],
308+
inverse_indices=(
309+
["f1", "f2", "f3"],
310+
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
311+
),
304312
)
305313

306314
eager_out = model(id_list_features)
@@ -319,15 +327,16 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
319327
# Run forward on ExportedProgram
320328
ep_output = ep.module()(id_list_features)
321329

330+
self.assertEqual(len(ep_output), len(id_list_features.keys()))
322331
for i, tensor in enumerate(ep_output):
323-
self.assertEqual(eager_out[i].shape, tensor.shape)
332+
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
324333

325334
# Deserialize EBC
326335
unflatten_ep = torch.export.unflatten(ep)
327336
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
328337

329338
# check EBC config
330-
for i in range(5):
339+
for i in range(1):
331340
ebc_name = f"ebc{i + 1}"
332341
self.assertIsInstance(
333342
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
@@ -342,29 +351,9 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
342351
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
343352
self.assertEqual(deserialized.feature_names, orginal.feature_names)
344353

345-
# check FPEBC config
346-
for i in range(2):
347-
fpebc_name = f"fpebc{i + 1}"
348-
assert isinstance(
349-
getattr(deserialized_model, fpebc_name),
350-
FeatureProcessedEmbeddingBagCollection,
351-
)
352-
353-
for deserialized, orginal in zip(
354-
getattr(
355-
deserialized_model, fpebc_name
356-
)._embedding_bag_collection.embedding_bag_configs(),
357-
getattr(
358-
model, fpebc_name
359-
)._embedding_bag_collection.embedding_bag_configs(),
360-
):
361-
self.assertEqual(deserialized.name, orginal.name)
362-
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
363-
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
364-
self.assertEqual(deserialized.feature_names, orginal.feature_names)
365-
366354
# Run forward on deserialized model and compare the output
367355
deserialized_model.load_state_dict(model.state_dict())
356+
368357
deserialized_out = deserialized_model(id_list_features)
369358

370359
self.assertEqual(len(deserialized_out), len(eager_out))
@@ -385,6 +374,7 @@ def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
385374
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
386375
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
387376
)
377+
388378
eager_out = model(feature2)
389379

390380
# Serialize EBC

torchrec/modules/embedding_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def reorder_inverse_indices(
2727
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
2828
feature_names: List[str],
2929
) -> torch.Tensor:
30-
if inverse_indices is None:
30+
if inverse_indices is None or inverse_indices[1].numel() == 0:
3131
return torch.empty(0)
3232
index_per_name = {name: i for i, name in enumerate(inverse_indices[0])}
3333
index = torch.tensor(

torchrec/sparse/jagged_tensor.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
# pyre-strict
99

1010
import abc
11+
import dataclasses
1112
import logging
1213

1314
import operator
15+
from dataclasses import dataclass
1416

1517
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1618

@@ -1756,6 +1758,7 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17561758
"_weights",
17571759
"_lengths",
17581760
"_offsets",
1761+
"_inverse_indices_tensor",
17591762
]
17601763

17611764
def __init__(
@@ -1801,6 +1804,12 @@ def __init__(
18011804
inverse_indices
18021805
)
18031806

1807+
# Init _inverse_indices_tensor to an empty tensor so it will be exported as FakeTensor by torch.export.
1808+
# Otherwise it will be exported as ConstantArgument when it's None, which causes Unsupported data type error.
1809+
self._inverse_indices_tensor: Optional[torch.Tensor] = torch.empty(0)
1810+
if inverse_indices is not None:
1811+
self._inverse_indices_tensor = inverse_indices[1]
1812+
18041813
# legacy attribute, for backward compatabilibity
18051814
self._variable_stride_per_key: Optional[bool] = None
18061815

@@ -3030,15 +3039,32 @@ def dist_init(
30303039
return kjt.sync()
30313040

30323041

3042+
@dataclass
3043+
class KjtTreeSpecs:
3044+
keys: List[str]
3045+
stride_per_key_per_rank: Optional[List[List[int]]]
3046+
3047+
def to_dict(self) -> dict[str, Any]:
3048+
return {
3049+
field.name: getattr(self, field.name) for field in dataclasses.fields(self)
3050+
}
3051+
3052+
30333053
def _kjt_flatten(
30343054
t: KeyedJaggedTensor,
3035-
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3036-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3055+
) -> Tuple[List[Optional[torch.Tensor]], Tuple[List[str], Optional[List[List[int]]]]]:
3056+
return [getattr(t, a) for a in KeyedJaggedTensor._fields], (
3057+
t._keys,
3058+
t._stride_per_key_per_rank,
3059+
)
30373060

30383061

30393062
def _kjt_flatten_with_keys(
30403063
t: KeyedJaggedTensor,
3041-
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
3064+
) -> Tuple[
3065+
List[Tuple[KeyEntry, Optional[torch.Tensor]]],
3066+
Tuple[List[str], Optional[List[List[int]]]],
3067+
]:
30423068
values, context = _kjt_flatten(t)
30433069
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
30443070
return [ # pyre-ignore[7]
@@ -3047,9 +3073,17 @@ def _kjt_flatten_with_keys(
30473073

30483074

30493075
def _kjt_unflatten(
3050-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3076+
values: List[Optional[torch.Tensor]],
3077+
context: Tuple[
3078+
List[str], Optional[List[List[int]]]
3079+
], # context is the (_keys, _stride_per_key_per_rank, _inverse_indices) tuple
30513080
) -> KeyedJaggedTensor:
3052-
return KeyedJaggedTensor(context, *values)
3081+
return KeyedJaggedTensor(
3082+
context[0],
3083+
*values[:-1],
3084+
stride_per_key_per_rank=context[1],
3085+
inverse_indices=(context[0], values[-1]),
3086+
)
30533087

30543088

30553089
def _kjt_flatten_spec(
@@ -3070,7 +3104,9 @@ def _kjt_flatten_spec(
30703104

30713105
def flatten_kjt_list(
30723106
kjt_arr: List[KeyedJaggedTensor],
3073-
) -> Tuple[List[Optional[torch.Tensor]], List[List[str]]]:
3107+
) -> Tuple[
3108+
List[Optional[torch.Tensor]], List[Tuple[List[str], Optional[List[List[int]]]]]
3109+
]:
30743110
_flattened_data = []
30753111
_flattened_context = []
30763112
for t in kjt_arr:
@@ -3081,7 +3117,8 @@ def flatten_kjt_list(
30813117

30823118

30833119
def unflatten_kjt_list(
3084-
values: List[Optional[torch.Tensor]], contexts: List[List[str]]
3120+
values: List[Optional[torch.Tensor]],
3121+
contexts: List[Tuple[List[str], Optional[List[List[int]]]]],
30853122
) -> List[KeyedJaggedTensor]:
30863123
num_kjt_fields = len(KeyedJaggedTensor._fields)
30873124
length = len(values)

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,23 @@ def test_meta_device_compatibility(self) -> None:
10171017
lengths=torch.tensor([], device=torch.device("meta")),
10181018
)
10191019

1020+
def test_flatten_unflatten_with_vbe(self) -> None:
1021+
kjt = KeyedJaggedTensor(
1022+
keys=["f1", "f2"],
1023+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1024+
lengths=torch.tensor([3, 3, 2]),
1025+
stride_per_key_per_rank=[[2], [1]],
1026+
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
1027+
)
1028+
1029+
flat_kjt, spec = pytree.tree_flatten(kjt)
1030+
unflattened_kjt = pytree.tree_unflatten(flat_kjt, spec)
1031+
1032+
self.assertEqual(
1033+
kjt.stride_per_key_per_rank(), unflattened_kjt.stride_per_key_per_rank()
1034+
)
1035+
self.assertEqual(kjt.inverse_indices(), unflattened_kjt.inverse_indices())
1036+
10201037

10211038
class TestKeyedJaggedTensorScripting(unittest.TestCase):
10221039
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)