@@ -215,12 +215,24 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m){
215
215
std::tuple<Tensor,Tensor> (const Tensor &, int64_t , int64_t , bool , bool ),
216
216
std::tuple<Tensor,Tensor> (const Tensor &, int64_t , int64_t , bool , bool ),
217
217
&ADD_NS (topk)>::type::call)));
218
+
218
219
m.impl (TORCH_SELECTIVE_NAME (" aten::sort" ),
219
220
TORCH_FN ((&CPU_WrapFunction<DtypeCastPolicy::fp32,
220
221
std::tuple<Tensor,Tensor> (const Tensor &, int64_t , bool ),
221
222
std::tuple<Tensor,Tensor> (const Tensor &, int64_t , bool ),
222
223
&ADD_NS (sort)>::type::call)));
223
-
224
+
225
+ m.impl (TORCH_SELECTIVE_NAME (" aten::kthvalue" ),
226
+ TORCH_FN ((&CPU_WrapFunction<DtypeCastPolicy::fp32,
227
+ std::tuple<Tensor,Tensor> (const Tensor &, int64_t , int64_t , bool ),
228
+ std::tuple<Tensor,Tensor> (const Tensor &, int64_t , int64_t , bool ),
229
+ &ADD_NS (kthvalue)>::type::call)));
230
+
231
+ m.impl (TORCH_SELECTIVE_NAME (" aten::kthvalue.dimname" ),
232
+ TORCH_FN ((&CPU_WrapFunction<DtypeCastPolicy::fp32,
233
+ std::tuple<Tensor,Tensor> (const Tensor &, int64_t , at::Dimname, bool ),
234
+ std::tuple<Tensor,Tensor> (const Tensor &, int64_t , at::Dimname, bool ),
235
+ &ADD_NS (kthvalue)>::type::call)));
224
236
// for int8 path
225
237
m.impl (TORCH_SELECTIVE_NAME (" aten::conv2d" ), TORCH_FN ((&torch_ipex::autocast::conv2d)));
226
238
m.impl (TORCH_SELECTIVE_NAME (" aten::conv3d" ), TORCH_FN ((&torch_ipex::autocast::conv3d)));
0 commit comments