Skip to content

RuntimeError: grad_input must be contiguous when tensor size is large #175

Closed
@xingjian-zhang

Description

@xingjian-zhang

Reproduce

import intel_pytorch_extension as ipex
import torch
import torch.nn.functional as F

# Run successfully if size is small.

out = torch.randn(10, 10, requires_grad=True, device=ipex.DEVICE)
mask = torch.randint(5, (10,), dtype=torch.long, device=ipex.DEVICE)

loss =  F.cross_entropy(out, mask, ignore_index=1)
loss.backward()

# RuntimeError: grad_input must be contiguous

out = torch.randn(10, 10, 500, 1000, requires_grad=True, device=ipex.DEVICE)
mask = torch.randint(5, (10, 500, 1000,), dtype=torch.long, device=ipex.DEVICE)

loss =  F.cross_entropy(out, mask, ignore_index=1)
loss.backward()

Traceback

Traceback (most recent call last):
  File "reproduce.py", line 19, in <module>
    loss.backward()
  File "/opt/conda/envs/torch_env/lib/python3.7/site-packages/torch/tensor.py", line 245, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/opt/conda/envs/torch_env/lib/python3.7/site-packages/torch/autograd/__init__.py", line 147, in backward
    allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
RuntimeError: grad_input must be contiguous

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions