-
Notifications
You must be signed in to change notification settings - Fork 282
jit: enable conv_relu fusion #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
If the code is invalid, could we remove it directly? |
ed5e9dd
to
0052bd9
Compare
6f0987c
to
f2253a8
Compare
@EikanWang , please help merge it, thanks! |
def linear(input, weight, bias: Optional[torch.Tensor] = None): | ||
if bias is None: | ||
bias = torch.zeros(weight.size(0)) | ||
return torch.ops.torch_ipex.linear(input, weight, bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@XiaobingSuper there's no way to fallback to aten linear if there's any exception happens. I suggest we should add try-catch here? or in the NewLinearOp
?
@@ -65,6 +39,11 @@ def max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode): | |||
pass | |||
return torch_max_pool3d(input, kernel_size, stride, padding, dilation, ceil_mode) | |||
|
|||
def max_pool2d(input, kernel_size: Vector, stride: Vector, padding: Vector, dilation: Vector, ceil_mode: bool): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
pass | ||
return torch_max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode) | ||
def adaptive_avg_pool2d(input, output_size: Vector): | ||
return torch.ops.torch_ipex.adaptive_avg_pool2d(input, _pair(output_size)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
optimize sparse adagrad compute and update
Now, only enable conv_relu fusion for jit path.