Skip to content

Prepack conv weights #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 1, 2020
Merged

Prepack conv weights #31

merged 3 commits into from
Jun 1, 2020

Conversation

pinzhenx
Copy link
Contributor

No description provided.

@pinzhenx pinzhenx marked this pull request as draft May 28, 2020 04:47
@pinzhenx pinzhenx force-pushed the prepack branch 3 times, most recently from f193736 to 379ddc1 Compare May 28, 2020 08:39
@jgong5
Copy link

jgong5 commented May 28, 2020

I have a concern related but not specific to this PR w.r.t. the in-place changes of input data/weight tensors. PyTorch originally assumes the input tensors as constant and it is safe for re-entrance, e.g. called by multiple threads from JIT with fork and join. The in-place changes of these input tensors might break this assumption. Do we have the design to protect these in-place changes from multi-threaded access?

@pinzhenx
Copy link
Contributor Author

As of right now, no. Modifying constant tensors on the fly, especially constant parameters, do have some implications we need to take care of.

@pinzhenx
Copy link
Contributor Author

pinzhenx commented May 28, 2020

BTW, we are also implementing a standalone prepack jit pass. This PR is kind of a workaround to optimize the imperative path, so that we could achieve the same performance as the to_mkldnn(model) approach.

We may revert this design if customers are more inclined to use JIT, or if something goes wrong with it.

@jgong5
Copy link

jgong5 commented May 28, 2020

Use cases for your consideration:

  1. Shared weight by multiple in-process inference models in both imperative and JIT mode.
  2. Shared data by multiple ops.
    For 1), pre-packing during model initialization should work fine. But pre-packing on the first model run needs extra care on re-entrance.
    For 2), there won't be re-entrance in the imperative path since there is just one model but maybe there is a problem if some downstream ops need blocked layout while others need plain layout. Also, we need to take care of the re-entrance in the JIT mode.

@EikanWang
Copy link
Contributor

Thanks, Jiong. I discussed with Pinzhen, for case 1, we can capture module.to and then prepack the weight tensors.

@pinzhenx pinzhenx marked this pull request as ready for review May 29, 2020 16:52
@pinzhenx pinzhenx force-pushed the prepack branch 3 times, most recently from a95469e to bf3a18e Compare May 29, 2020 17:00
@pinzhenx
Copy link
Contributor Author

pinzhenx commented May 29, 2020

We currently adopt prepacking conv weights in module.to

Pros:

  • thread-safe
  • more explicit, compared to implicitly prepack at runtime

Cons:

  • no input info. meaning that queried format might not be optimal. This is not manifest in conv but this could be a problem for linear weights.
  • unable to re-pack conv weights if it's been reordered back to plain (maybe even not a con)

@jgong5
Copy link

jgong5 commented May 30, 2020

no input info. meaning that queried format might not be optimal. This is not manifest in conv but this could be a problem for linear weights.

It will be a problem for conv too if we use winograd. Just FYI.

m = orig_module_to(self, *args, **kwargs)

device = torch._C._nn._parse_to(*args, **kwargs)[0]
if device and device.type == 'dpcpp':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to check auto_dnnl here. If the user disables auto_dnnl, it should go through the original path.

@EikanWang EikanWang merged commit 301bd87 into intel:master Jun 1, 2020
@gzmkl gzmkl mentioned this pull request Feb 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants