Skip to content

jit: enable conv_relu fusion #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 29, 2020
Merged
4 changes: 3 additions & 1 deletion cmake/CPU.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,11 @@ include_directories(${DPCPP_THIRD_PARTY_ROOT}/xsmm/include)
set(DPCPP_SRCS)
set(DPCPP_COMMON_SRCS)
set(DPCPP_CPU_SRCS)
set(DPCPP_JIT_SRCS)

add_subdirectory(${DPCPP_ROOT})
add_subdirectory(${DPCPP_ROOT}/cpu)
add_subdirectory(${DPCPP_ROOT}/jit)

# libxsmm
include(${CMAKE_ROOT}/Modules/ExternalProject.cmake)
Expand All @@ -153,7 +155,7 @@ ExternalProject_Add(xsmm
INSTALL_COMMAND ""
)
# Compile code with pybind11
set(DPCPP_SRCS ${DPCPP_ATEN_SRCS} ${DPCPP_COMMON_SRCS} ${DPCPP_CPU_SRCS})
set(DPCPP_SRCS ${DPCPP_ATEN_SRCS} ${DPCPP_COMMON_SRCS} ${DPCPP_CPU_SRCS} ${DPCPP_JIT_SRCS})
pybind11_add_module(${PLUGIN_NAME} SHARED ${DPCPP_SRCS})
target_link_libraries(${PLUGIN_NAME} PRIVATE ${DPCPP_THIRD_PARTY_ROOT}/xsmm/lib/libxsmm.a)

Expand Down
29 changes: 5 additions & 24 deletions intel_pytorch_extension_py/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,11 @@
from torch.autograd import Function
import torch.nn.functional as F
import _torch_ipex as core
from typing import Optional

F_linear = F.linear

class LinearFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias):
output = core.linear(input, weight, bias)
ctx.save_for_backward(input, weight, bias)
return output

@staticmethod
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
grad_output = grad_output.contiguous()
if bias == None:
output_mask = (input.requires_grad, weight.requires_grad, 0)
else:
output_mask = (input.requires_grad, weight.requires_grad, bias.requires_grad)
grad_input, grad_weight, grad_bias = core.linear_backward(input, grad_output, weight, output_mask)
return (grad_input, grad_weight, grad_bias)

def linear(input, weight, bias=None):
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
return LinearFunction.apply(input, weight, bias)
return F_linear(input, weight, bias)
def linear(input, weight, bias: Optional[torch.Tensor] = None):
if bias is None:
bias = torch.zeros(weight.size(0))
return torch.ops.torch_ipex.linear(input, weight, bias)
Copy link
Contributor

@pinzhenx pinzhenx Jun 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XiaobingSuper there's no way to fallback to aten linear if there's any exception happens. I suggest we should add try-catch here? or in the NewLinearOp?


F.linear = linear
45 changes: 12 additions & 33 deletions intel_pytorch_extension_py/ops/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,12 @@
from torch.autograd import Function
import torch.nn.functional as F
import _torch_ipex as core
from torch.nn.modules.utils import _single
from torch.nn.modules.utils import _single, _pair
from typing import List

torch_adaptive_avg_pool2d = torch._C._nn.adaptive_avg_pool2d
torch_max_pool2d = torch.max_pool2d
torch_max_pool3d = torch.max_pool3d
Vector = List[int]

class AdaptiveAvgPool2dFunction(Function):
@staticmethod
def forward(ctx, input, output_size):
output = core.adaptive_avg_pool2d(input, _single(output_size))
ctx.save_for_backward(input)
return output

@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad_output = grad_output.contiguous()
grad_input = core.adaptive_avg_pool2d_backward(grad_output, input)
return (grad_input, None)
torch_max_pool3d = torch.max_pool3d

class MaxPoolingFunction(Function):
@staticmethod
Expand All @@ -41,21 +28,8 @@ def backward(ctx, grad_output):
grad_input = core.max_pooling_backward(grad_output, output, input, ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation, ctx.ceil_mode)
return (grad_input, None, None, None, None, None)

def adaptive_avg_pool2d(input, output_size):
try:
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
return AdaptiveAvgPool2dFunction.apply(input, output_size)
except RuntimeError:
pass
return torch_adaptive_avg_pool2d(input, output_size)

def max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode):
try:
if input.device.type == 'dpcpp' and core.get_auto_dnnl():
return MaxPoolingFunction.apply(input, kernel_size, stride, padding, dilation, ceil_mode)
except RuntimeError:
pass
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
def adaptive_avg_pool2d(input, output_size: Vector):
return torch.ops.torch_ipex.adaptive_avg_pool2d(input, _pair(output_size))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
try:
Expand All @@ -65,6 +39,11 @@ def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode):
pass
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode)

def max_pool2d(input, kernel_size: Vector, stride: Vector, padding: Vector, dilation: Vector, ceil_mode: bool):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

if not stride:
stride = kernel_size
return torch.ops.torch_ipex.max_pool2d(input, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation), ceil_mode)

torch._C._nn.adaptive_avg_pool2d = adaptive_avg_pool2d
torch.max_pool2d = max_pool2d
torch.max_pool3d = max_pool3d
torch.max_pool3d = max_pool3d
172 changes: 172 additions & 0 deletions tests/cpu/test_jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import division
from __future__ import print_function

'''
From PyTorch:

Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)

From Caffe2:

Copyright (c) 2016-present, Facebook Inc. All rights reserved.

All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.

All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.

All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.

All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.

All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.

Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.

All rights reserved.
'''

"""Tests for rn50."""

import math
import random
import unittest
from functools import reduce

import torch
import torch.nn as nn
from torch.jit._recursive import wrap_cpp_module
import copy

import intel_pytorch_extension
from intel_pytorch_extension import core

import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd import gradcheck
from torch.autograd.gradcheck import gradgradcheck
from torch._six import inf, nan

from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, do_test_dtypes, do_test_empty_full, \
IS_SANDCASTLE, load_tests, brute_pdist, brute_cdist, slowTest, \
skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf

device = 'dpcpp:0'
#device = 'cpu:0'
SIZE = 100

torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)

def test_output(model, x):
modelName = model.__class__.__name__
core.disable_jit()

model = model.to('dpcpp').eval()
x = x.to('dpcpp')
with torch.no_grad():
result = model(x)

smodel = torch.jit.script(model)
smodel.eval()
with torch.no_grad():
sresult = smodel(x)

print(f'\nAre {modelName} and Scripted{modelName} outputs the same: ',
torch.allclose(
sresult, result, rtol=1e-05, atol=1e-06, equal_nan=False))

core.enable_jit()
pmodel = torch.jit.script(model)
# bn folding
pmodel = wrap_cpp_module(torch._C._jit_pass_fold_convbn(pmodel._c))
with torch.no_grad():
# conv relu fusion, conv sum fusion or conv sum relu fusion
print(pmodel.graph_for(x))
presult = pmodel(x)

# print(result)
# print(sresult)
# print(presult)

print(f'\nWith or without pyrys, are Scripted{modelName} outputs the same: ',
torch.allclose(
sresult, presult, rtol=1e-05, atol=1e-06, equal_nan=False))

class Conv2dRelu_Fixed(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(Conv2dRelu_Fixed, self).__init__()
seed = 2018
torch.manual_seed(seed)
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)

def forward(self, x):
return F.relu(self.conv(x), inplace=True)

class CascadedConv2dBnSumRelu(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels, **kwargs):
super(CascadedConv2dBnSumRelu, self).__init__()
torch.manual_seed(2018)
self.conv = nn.Conv2d(in_channels, mid_channels, bias=False, **kwargs)
self.conv1 = nn.Conv2d(
mid_channels, out_channels, bias=False, padding=1, **kwargs)
self.conv2 = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(mid_channels, eps=0.001)
self.bn1 = nn.BatchNorm2d(out_channels, eps=0.001)
self.bn2 = nn.BatchNorm2d(out_channels, eps=0.001)

def forward(self, x):
a = self.conv(x)
a = self.bn(a)
a = F.relu(a, inplace=True)
a = self.conv1(a)
a = self.bn1(a)
b = self.conv2(x)
b = self.bn2(b)
return F.relu(a.add_(b), inplace=True)

class Tester(TestCase):
n = 32
c = 3
h = 224
w = 224
print('input size: (%d, %d, %d, %d)' % (n, c, h, w))

def test_output_conv_relu(self):
test_output(
Conv2dRelu_Fixed(self.c, 32, kernel_size=3, stride=1),
torch.rand(self.n, self.c, self.h, self.w))

def test_output_cascaded_conv2d_bn_sum_relu(self):
test_output(
CascadedConv2dBnSumRelu(self.c, 64, 32, kernel_size=3, stride=1),
torch.rand(self.n, self.c, self.h, self.w))

if __name__ == '__main__':
core.enable_auto_dnnl()
test = unittest.main()
9 changes: 9 additions & 0 deletions torch_ipex/csrc/auto_opt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ class AutoOptConfig {
return auto_dnnl_;
}

inline void set_jit_fuse(bool jit_fuse) {
jit_fuse_ = jit_fuse;
}

inline bool get_jit_fuse() {
return jit_fuse_;
}

inline void set_mix_bf16_fp32(bool value) {
mix_bf16_fp32_ = value;
}
Expand All @@ -39,6 +47,7 @@ class AutoOptConfig {

private:
bool auto_dnnl_;
bool jit_fuse_;
bool mix_bf16_fp32_;
bool pure_bf16_;
};
Expand Down
Loading