3
3
import torch .nn .functional as F
4
4
import _torch_ipex as core
5
5
6
- F_adaptive_avg_pool2d = F .adaptive_avg_pool2d
6
+ torch_adaptive_avg_pool2d = torch . _C . _nn .adaptive_avg_pool2d
7
7
torch_max_pool2d = torch .max_pool2d
8
8
torch_max_pool3d = torch .max_pool3d
9
9
10
10
class AdaptiveAvgPool2dFunction (Function ):
11
11
@staticmethod
12
12
def forward (ctx , input , output_size ):
13
- _output_size = _list_with_default (output_size , input .size ())
14
- output = core .adaptive_avg_pool2d (input , _output_size )
13
+ if type (output_size ) is int :
14
+ output_size = (output_size ,)
15
+ output = core .adaptive_avg_pool2d (input , output_size )
15
16
ctx .save_for_backward (input )
16
17
return output
17
18
@@ -25,7 +26,15 @@ def backward(ctx, grad_output):
25
26
class MaxPoolingFunction (Function ):
26
27
@staticmethod
27
28
def forward (ctx , input , kernel_size , stride , padding , dilation , ceil_mode ):
28
- output = core .max_pooling (input , (kernel_size ,), (stride ,), (padding ,), (dilation ,), ceil_mode )
29
+ if type (kernel_size ) is int :
30
+ kernel_size = (kernel_size ,)
31
+ if type (stride ) is int :
32
+ stride = (stride ,)
33
+ if type (padding ) is int :
34
+ padding = (padding ,)
35
+ if type (dilation ) is int :
36
+ dilation = (dilation ,)
37
+ output = core .max_pooling (input , kernel_size , stride , padding , dilation , ceil_mode )
29
38
ctx .save_for_backward (output , input )
30
39
ctx .kernel_size = kernel_size
31
40
ctx .stride = stride
@@ -38,31 +47,33 @@ def forward(ctx, input, kernel_size, stride, padding, dilation, ceil_mode):
38
47
def backward (ctx , grad_output ):
39
48
output , input = ctx .saved_tensors
40
49
grad_output = grad_output .contiguous ()
41
- grad_input = core .max_pooling_backward (grad_output , output , input , ( ctx .kernel_size ,), ( ctx .stride ,), ( ctx .padding ,), ( ctx .dilation ,) , ctx .ceil_mode )
50
+ grad_input = core .max_pooling_backward (grad_output , output , input , ctx .kernel_size , ctx .stride , ctx .padding , ctx .dilation , ctx .ceil_mode )
42
51
return (grad_input , None , None , None , None , None )
43
52
44
- def _list_with_default (out_size , defaults ):
45
- if isinstance (out_size , int ):
46
- return (out_size ,)
47
- if len (defaults ) <= len (out_size ):
48
- raise ValueError ('Input dimension should be at least {}' .format (len (out_size ) + 1 ))
49
- return [v if v is not None else d for v , d in zip (out_size , defaults [- len (out_size ):])]
50
-
51
53
def adaptive_avg_pool2d (input , output_size ):
52
- if input .device .type == 'dpcpp' :
53
- return AdaptiveAvgPool2dFunction .apply (input , output_size )
54
- return F_adaptive_avg_pool2d (input , output_size )
54
+ try :
55
+ if input .device .type == 'dpcpp' and core .get_auto_dnnl ():
56
+ return AdaptiveAvgPool2dFunction .apply (input , output_size )
57
+ except RuntimeError :
58
+ return torch_adaptive_avg_pool2d (input , output_size )
59
+ return torch_adaptive_avg_pool2d (input , output_size )
55
60
56
61
def max_pool2d (input , kernel_size , stride , padding , dilation , ceil_mode ):
57
- if input .device .type == 'dpcpp' :
58
- return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
62
+ try :
63
+ if input .device .type == 'dpcpp' and core .get_auto_dnnl ():
64
+ return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
65
+ except RuntimeError :
66
+ return torch_max_pool2d (input , kernel_size , stride , padding , dilation , ceil_mode )
59
67
return torch_max_pool2d (input , kernel_size , stride , padding , dilation , ceil_mode )
60
68
61
69
def max_pool3d (input , kernel_size , stride , padding , dilation , ceil_mode ):
62
- if input .device .type == 'dpcpp' :
63
- return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
70
+ try :
71
+ if input .device .type == 'dpcpp' and core .get_auto_dnnl ():
72
+ return MaxPoolingFunction .apply (input , kernel_size , stride , padding , dilation , ceil_mode )
73
+ except RuntimeError :
74
+ return torch_max_pool3d (input , kernel_size , stride , padding , dilation , ceil_mode )
64
75
return torch_max_pool3d (input , kernel_size , stride , padding , dilation , ceil_mode )
65
76
66
- F .adaptive_avg_pool2d = adaptive_avg_pool2d
77
+ torch . _C . _nn .adaptive_avg_pool2d = adaptive_avg_pool2d
67
78
torch .max_pool2d = max_pool2d
68
79
torch .max_pool3d = max_pool3d
0 commit comments