diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 354f831..31fce1e 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -248,6 +248,12 @@ def test_fp8_mlp_tensor_parallelism_base( x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)]) tp_out = tp_model(x_fp32_tp_input) + assert ( + tp_model.ffn.w1.weight.requires_grad + ), "Expecting gradients to be enabled for TP model." + assert tp_out.requires_grad, "Expecting gradients to be enabled for TP model." + awaited_out = tp_out.wait() + assert awaited_out.requires_grad, "Expecting awaited out to require gradients" tp_out.sum().backward() sp_out = sp_model(x_fp32_sp_input) sp_out.sum().backward()