Skip to content

Commit d2b9572

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Include stride_per_key_per_rank in KJT's PyTree flatten/unflatten logic (#2903)
Summary: Pull Request resolved: #2903 # Context * Currently torchrec IR serializer can't handle variable batch use case. * `torch.export` only captures the keys, values, lengths, weights, offsets of a KJT, however, some variable-batch related parameters like `stride_per_rank` or `inverse_indices` would be ignored. * This test case (expected failure right now) covers the vb-KJT scenario for verifying that the serialize_deserialize_ebc use case works fine with KJTs with variable batch size. # Ref Differential Revision: D73051959
1 parent 4eca985 commit d2b9572

File tree

4 files changed

+81
-39
lines changed

4 files changed

+81
-39
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,14 @@ def forward(
206206
num_embeddings=10,
207207
feature_names=["f2"],
208208
)
209+
config3 = EmbeddingBagConfig(
210+
name="t3",
211+
embedding_dim=5,
212+
num_embeddings=10,
213+
feature_names=["f3"],
214+
)
209215
ebc = EmbeddingBagCollection(
210-
tables=[config1, config2],
216+
tables=[config1, config2, config3],
211217
is_weighted=False,
212218
)
213219

@@ -292,15 +298,17 @@ def test_serialize_deserialize_ebc(self) -> None:
292298
self.assertEqual(deserialized.shape, orginal.shape)
293299
self.assertTrue(torch.allclose(deserialized, orginal))
294300

295-
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
296301
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
297302
model = self.generate_model_for_vbe_kjt()
298303
id_list_features = KeyedJaggedTensor(
299-
keys=["f1", "f2"],
300-
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
301-
lengths=torch.tensor([3, 3, 2]),
302-
stride_per_key_per_rank=[[2], [1]],
303-
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
304+
keys=["f1", "f2", "f3"],
305+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
306+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
307+
stride_per_key_per_rank=[[3], [2], [1]],
308+
inverse_indices=(
309+
["f1", "f2", "f3"],
310+
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
311+
),
304312
)
305313

306314
eager_out = model(id_list_features)
@@ -319,15 +327,16 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
319327
# Run forward on ExportedProgram
320328
ep_output = ep.module()(id_list_features)
321329

330+
self.assertEqual(len(ep_output), len(id_list_features.keys()))
322331
for i, tensor in enumerate(ep_output):
323-
self.assertEqual(eager_out[i].shape, tensor.shape)
332+
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
324333

325334
# Deserialize EBC
326335
unflatten_ep = torch.export.unflatten(ep)
327336
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
328337

329338
# check EBC config
330-
for i in range(5):
339+
for i in range(1):
331340
ebc_name = f"ebc{i + 1}"
332341
self.assertIsInstance(
333342
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
@@ -342,29 +351,9 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
342351
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
343352
self.assertEqual(deserialized.feature_names, orginal.feature_names)
344353

345-
# check FPEBC config
346-
for i in range(2):
347-
fpebc_name = f"fpebc{i + 1}"
348-
assert isinstance(
349-
getattr(deserialized_model, fpebc_name),
350-
FeatureProcessedEmbeddingBagCollection,
351-
)
352-
353-
for deserialized, orginal in zip(
354-
getattr(
355-
deserialized_model, fpebc_name
356-
)._embedding_bag_collection.embedding_bag_configs(),
357-
getattr(
358-
model, fpebc_name
359-
)._embedding_bag_collection.embedding_bag_configs(),
360-
):
361-
self.assertEqual(deserialized.name, orginal.name)
362-
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
363-
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
364-
self.assertEqual(deserialized.feature_names, orginal.feature_names)
365-
366354
# Run forward on deserialized model and compare the output
367355
deserialized_model.load_state_dict(model.state_dict())
356+
368357
deserialized_out = deserialized_model(id_list_features)
369358

370359
self.assertEqual(len(deserialized_out), len(eager_out))
@@ -385,6 +374,7 @@ def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
385374
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
386375
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
387376
)
377+
388378
eager_out = model(feature2)
389379

390380
# Serialize EBC

torchrec/modules/embedding_modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def reorder_inverse_indices(
2727
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
2828
feature_names: List[str],
2929
) -> torch.Tensor:
30-
if inverse_indices is None:
30+
if inverse_indices is None or inverse_indices[1].numel() == 0:
3131
return torch.empty(0)
3232
index_per_name = {name: i for i, name in enumerate(inverse_indices[0])}
3333
index = torch.tensor(

torchrec/sparse/jagged_tensor.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,6 +1756,7 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17561756
"_weights",
17571757
"_lengths",
17581758
"_offsets",
1759+
"_inverse_indices",
17591760
]
17601761

17611762
def __init__(
@@ -1800,7 +1801,6 @@ def __init__(
18001801
self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = (
18011802
inverse_indices
18021803
)
1803-
18041804
# legacy attribute, for backward compatabilibity
18051805
self._variable_stride_per_key: Optional[bool] = None
18061806

@@ -3032,13 +3032,36 @@ def dist_init(
30323032

30333033
def _kjt_flatten(
30343034
t: KeyedJaggedTensor,
3035-
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3036-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3035+
) -> Tuple[
3036+
List[Optional[torch.Tensor]],
3037+
Tuple[List[str], Optional[List[str]], Optional[List[List[int]]]],
3038+
]:
3039+
field_values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3040+
inverse_index_keys: Optional[List[str]] = None
3041+
# Init to an empty tensor so it will be exported as FakeTensor by torch.export.
3042+
# Otherwise it will be exported as a ConstantArgument when the KJT's _inverse_indices None, which causes Unsupported data type error.
3043+
inverse_indices: torch.Tensor = torch.empty(0)
3044+
3045+
if t._inverse_indices is not None:
3046+
inverse_indices = t._inverse_indices[1]
3047+
# pyre-fixme: [16]: `Optional` has no attribute `__getitem__`.
3048+
inverse_index_keys = t._inverse_indices[0]
3049+
3050+
field_values.append(inverse_indices)
3051+
3052+
return field_values, (
3053+
t._keys,
3054+
inverse_index_keys,
3055+
t._stride_per_key_per_rank,
3056+
)
30373057

30383058

30393059
def _kjt_flatten_with_keys(
30403060
t: KeyedJaggedTensor,
3041-
) -> Tuple[List[Tuple[KeyEntry, Optional[torch.Tensor]]], List[str]]:
3061+
) -> Tuple[
3062+
List[Tuple[KeyEntry, Optional[torch.Tensor]]],
3063+
Tuple[List[str], Optional[List[str]], Optional[List[List[int]]]],
3064+
]:
30423065
values, context = _kjt_flatten(t)
30433066
# pyre can't tell that GetAttrKey implements the KeyEntry protocol
30443067
return [ # pyre-ignore[7]
@@ -3047,9 +3070,17 @@ def _kjt_flatten_with_keys(
30473070

30483071

30493072
def _kjt_unflatten(
3050-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3073+
values: List[Optional[torch.Tensor]],
3074+
context: Tuple[
3075+
List[str], Optional[List[str]], Optional[List[List[int]]]
3076+
], # context is the (_keys, inverse_index_keys, _stride_per_key_per_rank) tuple
30513077
) -> KeyedJaggedTensor:
3052-
return KeyedJaggedTensor(context, *values)
3078+
return KeyedJaggedTensor(
3079+
context[0],
3080+
*values[:-1],
3081+
stride_per_key_per_rank=context[2],
3082+
inverse_indices=(context[1], values[-1]) if context[1] is not None else None,
3083+
)
30533084

30543085

30553086
def _kjt_flatten_spec(
@@ -3070,7 +3101,10 @@ def _kjt_flatten_spec(
30703101

30713102
def flatten_kjt_list(
30723103
kjt_arr: List[KeyedJaggedTensor],
3073-
) -> Tuple[List[Optional[torch.Tensor]], List[List[str]]]:
3104+
) -> Tuple[
3105+
List[Optional[torch.Tensor]],
3106+
List[Tuple[List[str], Optional[List[str]], Optional[List[List[int]]]]],
3107+
]:
30743108
_flattened_data = []
30753109
_flattened_context = []
30763110
for t in kjt_arr:
@@ -3081,7 +3115,8 @@ def flatten_kjt_list(
30813115

30823116

30833117
def unflatten_kjt_list(
3084-
values: List[Optional[torch.Tensor]], contexts: List[List[str]]
3118+
values: List[Optional[torch.Tensor]],
3119+
contexts: List[Tuple[List[str], Optional[List[str]], Optional[List[List[int]]]]],
30853120
) -> List[KeyedJaggedTensor]:
30863121
num_kjt_fields = len(KeyedJaggedTensor._fields)
30873122
length = len(values)

torchrec/sparse/tests/test_keyed_jagged_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,23 @@ def test_meta_device_compatibility(self) -> None:
10171017
lengths=torch.tensor([], device=torch.device("meta")),
10181018
)
10191019

1020+
def test_flatten_unflatten_with_vbe(self) -> None:
1021+
kjt = KeyedJaggedTensor(
1022+
keys=["f1", "f2"],
1023+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
1024+
lengths=torch.tensor([3, 3, 2]),
1025+
stride_per_key_per_rank=[[2], [1]],
1026+
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
1027+
)
1028+
1029+
flat_kjt, spec = pytree.tree_flatten(kjt)
1030+
unflattened_kjt = pytree.tree_unflatten(flat_kjt, spec)
1031+
1032+
self.assertEqual(
1033+
kjt.stride_per_key_per_rank(), unflattened_kjt.stride_per_key_per_rank()
1034+
)
1035+
self.assertEqual(kjt.inverse_indices(), unflattened_kjt.inverse_indices())
1036+
10201037

10211038
class TestKeyedJaggedTensorScripting(unittest.TestCase):
10221039
def test_scriptable_forward(self) -> None:

0 commit comments

Comments
 (0)