Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f310a3e

Browse files
committed
Add sanity checks to dtensor tests
ghstack-source-id: 0f3887b Pull Request resolved: #302
1 parent 36405a7 commit f310a3e

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

test/test_dtensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ def test_fp8_mlp_tensor_parallelism_base(
248248
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
249249

250250
tp_out = tp_model(x_fp32_tp_input)
251+
assert (
252+
tp_model.ffn.w1.weight.requires_grad
253+
), "Expecting gradients to be enabled for TP model."
254+
assert tp_out.requires_grad, "Expecting gradients to be enabled for TP model."
255+
awaited_out = tp_out.wait()
256+
assert awaited_out.requires_grad, "Expecting awaited out to require gradients"
251257
tp_out.sum().backward()
252258
sp_out = sp_model(x_fp32_sp_input)
253259
sp_out.sum().backward()

0 commit comments

Comments
 (0)