Skip to content

Commit 50ce052

Browse files
add kthvalue into blacklist
1 parent 0f2f7ee commit 50ce052

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

torch_ipex/csrc/autocast_mode.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,24 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m){
215215
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool, bool),
216216
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool, bool),
217217
&ADD_NS(topk)>::type::call)));
218+
218219
m.impl(TORCH_SELECTIVE_NAME("aten::sort"),
219220
TORCH_FN((&CPU_WrapFunction<DtypeCastPolicy::fp32,
220221
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, bool),
221222
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, bool),
222223
&ADD_NS(sort)>::type::call)));
223-
224+
225+
m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue"),
226+
TORCH_FN((&CPU_WrapFunction<DtypeCastPolicy::fp32,
227+
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool),
228+
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, int64_t, bool),
229+
&ADD_NS(kthvalue)>::type::call)));
230+
231+
m.impl(TORCH_SELECTIVE_NAME("aten::kthvalue.dimname"),
232+
TORCH_FN((&CPU_WrapFunction<DtypeCastPolicy::fp32,
233+
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, at::Dimname, bool),
234+
std::tuple<Tensor,Tensor> (const Tensor &, int64_t, at::Dimname, bool),
235+
&ADD_NS(kthvalue)>::type::call)));
224236
// for int8 path
225237
m.impl(TORCH_SELECTIVE_NAME("aten::conv2d"), TORCH_FN((&torch_ipex::autocast::conv2d)));
226238
m.impl(TORCH_SELECTIVE_NAME("aten::conv3d"), TORCH_FN((&torch_ipex::autocast::conv3d)));

0 commit comments

Comments
 (0)