Skip to content

[BUG] - <title>RuntimeError: CUDA error: an illegal memory access was encountered using vmap and model ensembling call for cuda system #2721

Open
@wuyingxiong

Description

@wuyingxiong

Add Link

https://pytorch.org/tutorials/intermediate/ensembling.html
https://pytorch.org/docs/stable/notes/extending.func.html#defining-the-vmap-staticmethod

Describe the bug

🐛 Describe the bug

I want to use vmap to vectorize the ensemble models inherited from torch.autograd.Function. And torch.autograd.Function’s forward/backward calls into functions from cuda. etc,

Firstly, I set generate_vmap_rule=True ,which means calling the system's vmap function directly.
error: RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Becaue model calls for cuda system,I need to write the own vmap,

def vmap(info,in_dims,input):
        if in_dims[0] is not None:
            input_B = input.shape[0]
            input = einops.rearrange(input,'B N C -> (B N) C')   
        outputs,_,_ = model.apply(input)
        if in_dims[0] is not None:
            outputs = einops.rearrange(input,'(B N) C -> B N C',B = input_B)
        return outputs,(0)

error: RuntimeError: CUDA error: an illegal memory access was encountered,CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.

How can I write the vmap.py to deal the Multiple models process multiple batches of data and models call for cuda to process data?

code follows,I simplify the model class.

def model(torch.autograd.Function):
      def foward():
            calls for cuda forward
      def backward():
            calls for cuda backward
      def setup_context():
      @staticmethod
      def vmap():

from torch.func import stack_module_state
b_p = torch.randn([10,100,3]).cuda() 
     
objs = [model() for i in range(10)]
pe_models = []
for obj in  objs:
    pe_models.append(obj.pe)
pe_param, pe_buffer = stack_module_state(pe_models)
base_model = copy.deepcopy(pe_models[0])
def fmodel(params,buffers,x):
    return functional_call(base_model,(params,buffers),x)
out = vmap(fmodel)(pe_param,pe_buffer,b_p)

Describe your environment

Versions

pytorch2.0
cuda11.7
python 3.8
ubuntu20.4
collect_env.py error update later

cc @albanD

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugcoreTutorials of any level of difficulty related to the core pytorch functionality

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions