@@ -3809,11 +3809,43 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc
3809
3809
}
3810
3810
}
3811
3811
3812
+ #ifdef GGML_SYCL_GRAPH
3813
+ static bool check_graph_compatibility (ggml_cgraph * cgraph) {
3814
+ if (ggml_sycl_info ().device_count > 1 ) {
3815
+ // A sycl_ex::command_graph object can only be created for a single device
3816
+ GGML_LOG_INFO (" %s: disabling SYCL graphs due to multiple devices\n " , __func__);
3817
+ return false ;
3818
+ }
3819
+
3820
+ for (int i = 0 ; i < cgraph->n_nodes ; i++) {
3821
+ const ggml_op node_op = cgraph->nodes [i]->op ;
3822
+ switch (node_op) {
3823
+ default :
3824
+ break ;
3825
+ case GGML_OP_CONCAT:
3826
+ // ggml_sycl_op_concat() does a blocking host wait after memcpy operations,
3827
+ // but wait() can't be called on the events returned by a queue recording
3828
+ // to a graph.
3829
+ [[fallthrough]];
3830
+ case GGML_OP_MUL_MAT_ID:
3831
+ // ggml_sycl_mul_mat_id() does a blocking host wait on the sycl queue after
3832
+ // submitting a memcpy operation, but wait() can't be called on a queue that
3833
+ // is recording to a graph.
3834
+ GGML_LOG_INFO (" %s: disabling SYCL graphs due to unsupported node type %s\n " , __func__,
3835
+ ggml_op_name (node_op));
3836
+ return false ;
3837
+ }
3838
+ }
3839
+ return true ;
3840
+ }
3841
+ #endif
3842
+
3812
3843
static ggml_status ggml_backend_sycl_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
3813
3844
auto * sycl_ctx = static_cast <ggml_backend_sycl_context *>(backend->context );
3814
3845
3815
3846
#ifdef GGML_SYCL_GRAPH
3816
- if (!g_ggml_sycl_disable_graph) {
3847
+ bool use_sycl_graph = !g_ggml_sycl_disable_graph && check_graph_compatibility (cgraph);
3848
+ if (use_sycl_graph) {
3817
3849
const bool graph_support = dpct::get_device (sycl_ctx->device ).has (sycl::aspect::ext_oneapi_limited_graph);
3818
3850
if (!graph_support) {
3819
3851
GGML_SYCL_DEBUG (" [SYCL-GRAPH] can not use graphs on device:%d\n " , sycl_ctx->device );
0 commit comments