2
2
from torch .autograd import Function
3
3
import torch .nn .functional as F
4
4
import _torch_ipex as core
5
+ from torch .nn .modules .utils import _single
5
6
6
- F_adaptive_avg_pool2d = F .adaptive_avg_pool2d
7
+ torch_adaptive_avg_pool2d = torch . _C . _nn .adaptive_avg_pool2d
7
8
torch_max_pool2d = torch .max_pool2d
8
9
torch_max_pool3d = torch .max_pool3d
9
10
10
11
class AdaptiveAvgPool2dFunction (Function ):
11
12
@staticmethod
12
13
def forward (ctx , input , output_size ):
13
- _output_size = _list_with_default (output_size , input .size ())
14
- output = core .adaptive_avg_pool2d (input , _output_size )
14
+ output = core .adaptive_avg_pool2d (input , _single (output_size ))
15
15
ctx .save_for_backward (input )
16
16
return output
17
17
@@ -25,44 +25,46 @@ def backward(ctx, grad_output):
25
25
class MaxPoolingFunction (Function ):
26
26
@staticmethod
27
27
def forward (ctx , input , kernel_size , stride , padding , dilation , ceil_mode ):
28
- output = core .max_pooling (input , (kernel_size ,), (stride ,), (padding ,), (dilation ,), ceil_mode )
29
- ctx .save_for_backward (output , input )
30
- ctx .kernel_size = kernel_size
31
- ctx .stride = stride
32
- ctx .padding = padding
33
- ctx .dilation = dilation
28
+ ctx .kernel_size = _single (kernel_size )
29
+ ctx .stride = _single (stride )
30
+ ctx .padding = _single (padding )
31
+ ctx .dilation = _single (dilation )
34
32
ctx .ceil_mode = ceil_mode
33
+ output = core .max_pooling (input , ctx .kernel_size , ctx .stride , ctx .padding , ctx .dilation , ctx .ceil_mode )
34
+ ctx .save_for_backward (output , input )
35
35
return output
36
36
37
37
@staticmethod
38
38
def backward (ctx , grad_output ):
39
39
output , input = ctx .saved_tensors
40
40
grad_output = grad_output .contiguous ()
41
- grad_input = core .max_pooling_backward (grad_output , output , input , ( ctx .kernel_size ,), ( ctx .stride ,), ( ctx .padding ,), ( ctx .dilation ,) , ctx .ceil_mode )
41
+ grad_input = core .max_pooling_backward (grad_output , output , input , ctx .kernel_size , ctx .stride , ctx .padding , ctx .dilation , ctx .ceil_mode )
42
42
return (grad_input , None , None , None , None , None )
43
43
44
- def _list_with_default (out_size , defaults ):
45
- if isinstance (out_size , int ):
46
- return (out_size ,)
47
- if len (defaults ) <= len (out_size ):
48
- raise ValueError ('Input dimension should be at least {}' .format (len (out_size ) + 1 ))
49
- return [v if v is not None else d for v , d in zip (out_size , defaults [- len (out_size ):])]
50
-
51
44
def adaptive_avg_pool2d (input , output_size ):
52
- if input .device .type == 'dpcpp' :
53
- return AdaptiveAvgPool2dFunction .apply (input , output_size )
54
- return F_adaptive_avg_pool2d (input , output_size )
45
+ try :
46
+ if input .device .type == 'dpcpp' and core .get_auto_dnnl ():
47
+ return AdaptiveAvgPool2dFunction .apply (input , output_size )
48
+ except RuntimeError :
49
+ pass
50
+ return torch_adaptive_avg_pool2d (input , output_size )
55
51
56
52
def max_pool2d (input , kernel_size , stride , padding , dilation , ceil_mode ):
57
- if input .device .type == 'dpcpp' :
58
- return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
53
+ try :
54
+ if input .device .type == 'dpcpp' and core .get_auto_dnnl ():
55
+ return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
56
+ except RuntimeError :
57
+ pass
59
58
return torch_max_pool2d (input , kernel_size , stride , padding , dilation , ceil_mode )
60
59
61
60
def max_pool3d (input , kernel_size , stride , padding , dilation , ceil_mode ):
62
- if input .device .type == 'dpcpp' :
63
- return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
61
+ try :
62
+ if input .device .type == 'dpcpp' and core .get_auto_dnnl ():
63
+ return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
64
+ except RuntimeError :
65
+ pass
64
66
return torch_max_pool3d (input , kernel_size , stride , padding , dilation , ceil_mode )
65
67
66
- F .adaptive_avg_pool2d = adaptive_avg_pool2d
68
+ torch . _C . _nn .adaptive_avg_pool2d = adaptive_avg_pool2d
67
69
torch .max_pool2d = max_pool2d
68
70
torch .max_pool3d = max_pool3d
0 commit comments