-
Notifications
You must be signed in to change notification settings - Fork 282
jit: enable conv_relu fusion #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
89541bc
jit: enable conv_relu fusion
XiaobingSuper bbd70f8
only jit fusion for extension path
XiaobingSuper c08064a
jit: enable conv_sum and conc_sum_relu fusion
XiaobingSuper e7ee4b3
make rewrited linear op can be traced
XiaobingSuper 85e7170
make rewrited max_pool2d op can be traced
XiaobingSuper 2cca38f
fix max_pool2d backward floating point exception issue
XiaobingSuper 2cfb394
make rewrited AdaptiveAvgPool2d op can be traced
XiaobingSuper f402c30
fix linear issue when bias is None
XiaobingSuper 746fe44
fix max_pool2d issue with stride=None case
XiaobingSuper 0fb1060
add prepack_weight API
XiaobingSuper File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,25 +2,12 @@ | |
from torch.autograd import Function | ||
import torch.nn.functional as F | ||
import _torch_ipex as core | ||
from torch.nn.modules.utils import _single | ||
from torch.nn.modules.utils import _single, _pair | ||
from typing import List | ||
|
||
torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d | ||
torch_max_pool2d = torch.max_pool2d | ||
torch_max_pool3d = torch.max_pool3d | ||
Vector = List[int] | ||
|
||
class AdaptiveAvgPool2dFunction(Function): | ||
@staticmethod | ||
def forward(ctx, input, output_size): | ||
output = core.adaptive_avg_pool2d(input, _single(output_size)) | ||
ctx.save_for_backward(input) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
(input,) = ctx.saved_tensors | ||
grad_output = grad_output.contiguous() | ||
grad_input = core.adaptive_avg_pool2d_backward(grad_output, input) | ||
return (grad_input, None) | ||
torch_max_pool3d = torch.max_pool3d | ||
|
||
class MaxPoolingFunction(Function): | ||
@staticmethod | ||
|
@@ -41,21 +28,8 @@ def backward(ctx, grad_output): | |
grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode) | ||
return (grad_input, None, None, None, None, None) | ||
|
||
def adaptive_avg_pool2d(input, output_size): | ||
try: | ||
if input.device.type == 'dpcpp' and core.get_auto_dnnl(): | ||
return AdaptiveAvgPool2dFunction.apply(input, output_size) | ||
except RuntimeError: | ||
pass | ||
return torch_adaptive_avg_pool2d(input, output_size) | ||
|
||
def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode): | ||
try: | ||
if input.device.type == 'dpcpp' and core.get_auto_dnnl(): | ||
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode) | ||
except RuntimeError: | ||
pass | ||
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) | ||
def adaptive_avg_pool2d(input, output_size: Vector): | ||
return torch.ops.torch_ipex.adaptive_avg_pool2d(input, _pair(output_size)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
||
def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode): | ||
try: | ||
|
@@ -65,6 +39,11 @@ def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode): | |
pass | ||
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) | ||
|
||
def max_pool2d(input, kernel_size: Vector, stride: Vector, padding: Vector, dilation: Vector, ceil_mode: bool): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
if not stride: | ||
stride = kernel_size | ||
return torch.ops.torch_ipex.max_pool2d(input, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation), ceil_mode) | ||
|
||
torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d | ||
torch.max_pool2d = max_pool2d | ||
torch.max_pool3d = max_pool3d | ||
torch.max_pool3d = max_pool3d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
''' | ||
From PyTorch: | ||
|
||
Copyright (c) 2016- Facebook, Inc (Adam Paszke) | ||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala) | ||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) | ||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) | ||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) | ||
Copyright (c) 2011-2013 NYU (Clement Farabet) | ||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) | ||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio) | ||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) | ||
|
||
From Caffe2: | ||
|
||
Copyright (c) 2016-present, Facebook Inc. All rights reserved. | ||
|
||
All contributions by Facebook: | ||
Copyright (c) 2016 Facebook Inc. | ||
|
||
All contributions by Google: | ||
Copyright (c) 2015 Google Inc. | ||
All rights reserved. | ||
|
||
All contributions by Yangqing Jia: | ||
Copyright (c) 2015 Yangqing Jia | ||
All rights reserved. | ||
|
||
All contributions from Caffe: | ||
Copyright(c) 2013, 2014, 2015, the respective contributors | ||
All rights reserved. | ||
|
||
All other contributions: | ||
Copyright(c) 2015, 2016 the respective contributors | ||
All rights reserved. | ||
|
||
Caffe2 uses a copyright model similar to Caffe: each contributor holds | ||
copyright over their contributions to Caffe2. The project versioning records | ||
all such contribution and copyright details. If a contributor wants to further | ||
mark their specific copyright on a particular contribution, they should | ||
indicate their copyright solely in the commit message of the change when it is | ||
committed. | ||
|
||
All rights reserved. | ||
''' | ||
|
||
"""Tests for rn50.""" | ||
|
||
import math | ||
import random | ||
import unittest | ||
from functools import reduce | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.jit._recursive import wrap_cpp_module | ||
import copy | ||
|
||
import intel_pytorch_extension | ||
from intel_pytorch_extension import core | ||
|
||
import torch.nn as nn | ||
import torch.backends.cudnn as cudnn | ||
from torch.nn import Parameter | ||
import torch.nn.functional as F | ||
from torch.autograd import gradcheck | ||
from torch.autograd.gradcheck import gradgradcheck | ||
from torch._six import inf, nan | ||
|
||
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \ | ||
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \ | ||
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \ | ||
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \ | ||
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf | ||
|
||
device = 'dpcpp:0' | ||
#device = 'cpu:0' | ||
SIZE = 100 | ||
|
||
torch._C._jit_set_profiling_mode(False) | ||
torch._C._jit_set_profiling_executor(False) | ||
|
||
def test_output(model, x): | ||
modelName = model.__class__.__name__ | ||
core.disable_jit() | ||
|
||
model = model.to('dpcpp').eval() | ||
x = x.to('dpcpp') | ||
with torch.no_grad(): | ||
result = model(x) | ||
|
||
smodel = torch.jit.script(model) | ||
smodel.eval() | ||
with torch.no_grad(): | ||
sresult = smodel(x) | ||
|
||
print(f'\nAre {modelName} and Scripted{modelName} outputs the same: ', | ||
torch.allclose( | ||
sresult, result, rtol=1e-05, atol=1e-06, equal_nan=False)) | ||
|
||
core.enable_jit() | ||
pmodel = torch.jit.script(model) | ||
# bn folding | ||
pmodel = wrap_cpp_module(torch._C._jit_pass_fold_convbn(pmodel._c)) | ||
with torch.no_grad(): | ||
# conv relu fusion, conv sum fusion or conv sum relu fusion | ||
print(pmodel.graph_for(x)) | ||
presult = pmodel(x) | ||
|
||
# print(result) | ||
# print(sresult) | ||
# print(presult) | ||
|
||
print(f'\nWith or without pyrys, are Scripted{modelName} outputs the same: ', | ||
torch.allclose( | ||
sresult, presult, rtol=1e-05, atol=1e-06, equal_nan=False)) | ||
|
||
class Conv2dRelu_Fixed(nn.Module): | ||
def __init__(self, in_channels, out_channels, **kwargs): | ||
super(Conv2dRelu_Fixed, self).__init__() | ||
seed = 2018 | ||
torch.manual_seed(seed) | ||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | ||
|
||
def forward(self, x): | ||
return F.relu(self.conv(x), inplace=True) | ||
|
||
class CascadedConv2dBnSumRelu(nn.Module): | ||
def __init__(self, in_channels, mid_channels, out_channels, **kwargs): | ||
super(CascadedConv2dBnSumRelu, self).__init__() | ||
torch.manual_seed(2018) | ||
self.conv = nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs) | ||
self.conv1 = nn.Conv2d( | ||
mid_channels, out_channels, bias=False, padding=1, **kwargs) | ||
self.conv2 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | ||
self.bn = nn.BatchNorm2d(mid_channels, eps=0.001) | ||
self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001) | ||
self.bn2 = nn.BatchNorm2d(out_channels, eps=0.001) | ||
|
||
def forward(self, x): | ||
a = self.conv(x) | ||
a = self.bn(a) | ||
a = F.relu(a, inplace=True) | ||
a = self.conv1(a) | ||
a = self.bn1(a) | ||
b = self.conv2(x) | ||
b = self.bn2(b) | ||
return F.relu(a.add_(b), inplace=True) | ||
|
||
class Tester(TestCase): | ||
n = 32 | ||
c = 3 | ||
h = 224 | ||
w = 224 | ||
print('input size: (%d, %d, %d, %d)' % (n, c, h, w)) | ||
|
||
def test_output_conv_relu(self): | ||
test_output( | ||
Conv2dRelu_Fixed(self.c, 32, kernel_size=3, stride=1), | ||
torch.rand(self.n, self.c, self.h, self.w)) | ||
|
||
def test_output_cascaded_conv2d_bn_sum_relu(self): | ||
test_output( | ||
CascadedConv2dBnSumRelu(self.c, 64, 32, kernel_size=3, stride=1), | ||
torch.rand(self.n, self.c, self.h, self.w)) | ||
|
||
if __name__ == '__main__': | ||
core.enable_auto_dnnl() | ||
test = unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@XiaobingSuper there's no way to fallback to aten linear if there's any exception happens. I suggest we should add try-catch here? or in the
NewLinearOp
?