Skip to content

Commit f0cea9e

Browse files
authored
Merge branch 'main' into 2.6-RC-TEST
2 parents 1a3d0e0 + 7262c21 commit f0cea9e

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

advanced_source/cpp_frontend.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -969,15 +969,15 @@ the data loader every epoch and then write the GAN training code:
969969
discriminator->zero_grad();
970970
torch::Tensor real_images = batch.data;
971971
torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
972-
torch::Tensor real_output = discriminator->forward(real_images);
972+
torch::Tensor real_output = discriminator->forward(real_images).reshape(real_labels.sizes());
973973
torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
974974
d_loss_real.backward();
975975
976976
// Train discriminator with fake images.
977977
torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
978978
torch::Tensor fake_images = generator->forward(noise);
979979
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
980-
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
980+
torch::Tensor fake_output = discriminator->forward(fake_images.detach()).reshape(fake_labels.sizes());
981981
torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
982982
d_loss_fake.backward();
983983
@@ -987,7 +987,7 @@ the data loader every epoch and then write the GAN training code:
987987
// Train generator.
988988
generator->zero_grad();
989989
fake_labels.fill_(1);
990-
fake_output = discriminator->forward(fake_images);
990+
fake_output = discriminator->forward(fake_images).reshape(fake_labels.sizes());
991991
torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
992992
g_loss.backward();
993993
generator_optimizer.step();

prototype_source/flight_recorder_tutorial.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ For demonstration purposes, we named this program ``crash.py``.
214214
complexities.
215215

216216
.. code:: python
217-
:caption: A crashing example
218217
219218
import torch
220219
import torch.distributed as dist

0 commit comments

Comments
 (0)