@@ -749,36 +749,74 @@ static const std::string gpu_pipeline =
749
749
" func.func(convert-parallel-loops-to-gpu),"
750
750
// insert-gpu-allocs pass can have client-api = opencl or vulkan args
751
751
" func.func(insert-gpu-allocs{in-regions=1}),"
752
- // ** imex GPU passes
753
- // "drop-regions,"
754
- // "canonicalize,"
755
- // // "normalize-memrefs,"
756
- // // "gpu-decompose-memrefs,"
757
- // "func.func(lower-affine),"
758
- // "gpu-kernel-outlining,"
759
- // "canonicalize,"
760
- // "cse,"
761
- // // The following set-spirv-* passes can have client-api = opencl or
762
- // vulkan
763
- // // args
764
- // "set-spirv-capabilities{client-api=opencl},"
765
- // "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
766
- // "canonicalize,"
767
- // "fold-memref-alias-ops,"
768
- // "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
769
- // "spirv.module(spirv-lower-abi-attrs),"
770
- // "spirv.module(spirv-update-vce),"
771
- // // "func.func(llvm-request-c-wrappers),"
772
- // "serialize-spirv,"
773
- // "expand-strided-metadata,"
774
- // "lower-affine,"
775
- // "convert-gpu-to-gpux,"
776
- // "convert-func-to-llvm,"
777
- // "convert-math-to-llvm,"
778
- // "convert-gpux-to-llvm,"
779
- // "finalize-memref-to-llvm,"
780
- // "reconcile-unrealized-casts";
781
- // ** nv GPU passes
752
+ " drop-regions,"
753
+ " canonicalize,"
754
+ // "normalize-memrefs,"
755
+ // "gpu-decompose-memrefs,"
756
+ " func.func(lower-affine),"
757
+ " gpu-kernel-outlining,"
758
+ " canonicalize,"
759
+ " cse,"
760
+ // The following set-spirv-* passes can have client-api = opencl or vulkan
761
+ // args
762
+ " set-spirv-capabilities{client-api=opencl},"
763
+ " gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
764
+ " canonicalize,"
765
+ " fold-memref-alias-ops,"
766
+ " imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
767
+ " spirv.module(spirv-lower-abi-attrs),"
768
+ " spirv.module(spirv-update-vce),"
769
+ // "func.func(llvm-request-c-wrappers),"
770
+ " serialize-spirv,"
771
+ " expand-strided-metadata,"
772
+ " lower-affine,"
773
+ " convert-gpu-to-gpux,"
774
+ " convert-func-to-llvm,"
775
+ " convert-math-to-llvm,"
776
+ " convert-gpux-to-llvm,"
777
+ " finalize-memref-to-llvm,"
778
+ " reconcile-unrealized-casts" ;
779
+
780
+ static const std::string cuda_pipeline =
781
+ " add-gpu-regions,"
782
+ " canonicalize,"
783
+ " ndarray-dist,"
784
+ " func.func(dist-coalesce),"
785
+ " func.func(dist-infer-elementwise-cores),"
786
+ " convert-dist-to-standard,"
787
+ " canonicalize,"
788
+ " overlap-comm-and-compute,"
789
+ " add-comm-cache-keys,"
790
+ " lower-distruntime-to-idtr,"
791
+ " convert-ndarray-to-linalg,"
792
+ " canonicalize,"
793
+ " func.func(tosa-make-broadcastable),"
794
+ " func.func(tosa-to-linalg),"
795
+ " func.func(tosa-to-tensor),"
796
+ " canonicalize,"
797
+ " linalg-fuse-elementwise-ops,"
798
+ " arith-expand,"
799
+ " memref-expand,"
800
+ " arith-bufferize,"
801
+ " func-bufferize,"
802
+ " func.func(empty-tensor-to-alloc-tensor),"
803
+ " func.func(scf-bufferize),"
804
+ " func.func(tensor-bufferize),"
805
+ " func.func(bufferization-bufferize),"
806
+ " func.func(linalg-bufferize),"
807
+ " func.func(linalg-detensorize),"
808
+ " func.func(tensor-bufferize),"
809
+ " region-bufferize,"
810
+ " canonicalize,"
811
+ " func.func(finalizing-bufferize),"
812
+ " imex-remove-temporaries,"
813
+ " func.func(convert-linalg-to-parallel-loops),"
814
+ " func.func(scf-parallel-loop-fusion),"
815
+ // is add-outer-parallel-loop needed?
816
+ " func.func(imex-add-outer-parallel-loop),"
817
+ " func.func(gpu-map-parallel-loops),"
818
+ " func.func(convert-parallel-loops-to-gpu),"
819
+ " func.func(insert-gpu-allocs{in-regions=1}),"
782
820
" func.func(insert-gpu-copy),"
783
821
" drop-regions,"
784
822
" canonicalize,"
@@ -800,7 +838,9 @@ static const std::string gpu_pipeline =
800
838
801
839
const std::string _passes (get_text_env (" SHARPY_PASSES" ));
802
840
static const std::string &pass_pipeline =
803
- _passes != " " ? _passes : (useGPU () ? gpu_pipeline : cpu_pipeline);
841
+ _passes != " " ? _passes
842
+ : (useGPU () ? (useCUDA () ? cuda_pipeline : gpu_pipeline)
843
+ : cpu_pipeline);
804
844
805
845
JIT::JIT (const std::string &libidtr)
806
846
: _context (::mlir::MLIRContext::Threading::DISABLED), _pm (&_context),
@@ -852,23 +892,24 @@ JIT::JIT(const std::string &libidtr)
852
892
_crunnerlib = mlirRoot + " /lib/libmlir_c_runner_utils.so" ;
853
893
_runnerlib = mlirRoot + " /lib/libmlir_runner_utils.so" ;
854
894
if (!std::ifstream (_crunnerlib)) {
855
- throw std::runtime_error (" Cannot find libmlir_c_runner_utils.so " );
895
+ throw std::runtime_error (" Cannot find lib: " + _crunnerlib );
856
896
}
857
897
if (!std::ifstream (_runnerlib)) {
858
- throw std::runtime_error (" Cannot find libmlir_runner_utils.so " );
898
+ throw std::runtime_error (" Cannot find lib: " + _runnerlib );
859
899
}
860
900
861
901
if (useGPU ()) {
862
902
auto gpuxlibstr = get_text_env (" SHARPY_GPUX_SO" );
863
903
if (!gpuxlibstr.empty ()) {
864
904
_gpulib = std::string (gpuxlibstr);
865
905
} else {
866
- // auto imexRoot = get_text_env("IMEXROOT");
867
- // imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
868
- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
869
- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
870
- // for nv gpu
871
- _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
906
+ if (useCUDA ()) {
907
+ _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
908
+ } else {
909
+ auto imexRoot = get_text_env (" IMEXROOT" );
910
+ imexRoot = !imexRoot.empty () ? imexRoot : std::string (CMAKE_IMEX_ROOT);
911
+ _gpulib = imexRoot + " /lib/liblevel-zero-runtime.so" ;
912
+ }
872
913
if (!std::ifstream (_gpulib)) {
873
914
throw std::runtime_error (" Cannot find lib: " + _gpulib);
874
915
}
0 commit comments