Skip to content

Commit 733cd26

Browse files
committed
add SHARPY_USE_CUDA boolean to activate cuda pipeline
1 parent 9ba5d4f commit 733cd26

File tree

2 files changed

+82
-39
lines changed

2 files changed

+82
-39
lines changed

src/include/sharpy/UtilsAndTypes.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,5 @@ inline bool useGPU() {
7070
auto device = get_text_env("SHARPY_DEVICE");
7171
return !(device.empty() || device == "host" || device == "cpu");
7272
}
73+
74+
inline bool useCUDA() { return get_bool_env("SHARPY_USE_CUDA"); }

src/jit/mlir.cpp

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -749,36 +749,74 @@ static const std::string gpu_pipeline =
749749
"func.func(convert-parallel-loops-to-gpu),"
750750
// insert-gpu-allocs pass can have client-api = opencl or vulkan args
751751
"func.func(insert-gpu-allocs{in-regions=1}),"
752-
// ** imex GPU passes
753-
// "drop-regions,"
754-
// "canonicalize,"
755-
// // "normalize-memrefs,"
756-
// // "gpu-decompose-memrefs,"
757-
// "func.func(lower-affine),"
758-
// "gpu-kernel-outlining,"
759-
// "canonicalize,"
760-
// "cse,"
761-
// // The following set-spirv-* passes can have client-api = opencl or
762-
// vulkan
763-
// // args
764-
// "set-spirv-capabilities{client-api=opencl},"
765-
// "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
766-
// "canonicalize,"
767-
// "fold-memref-alias-ops,"
768-
// "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
769-
// "spirv.module(spirv-lower-abi-attrs),"
770-
// "spirv.module(spirv-update-vce),"
771-
// // "func.func(llvm-request-c-wrappers),"
772-
// "serialize-spirv,"
773-
// "expand-strided-metadata,"
774-
// "lower-affine,"
775-
// "convert-gpu-to-gpux,"
776-
// "convert-func-to-llvm,"
777-
// "convert-math-to-llvm,"
778-
// "convert-gpux-to-llvm,"
779-
// "finalize-memref-to-llvm,"
780-
// "reconcile-unrealized-casts";
781-
// ** nv GPU passes
752+
"drop-regions,"
753+
"canonicalize,"
754+
// "normalize-memrefs,"
755+
// "gpu-decompose-memrefs,"
756+
"func.func(lower-affine),"
757+
"gpu-kernel-outlining,"
758+
"canonicalize,"
759+
"cse,"
760+
// The following set-spirv-* passes can have client-api = opencl or vulkan
761+
// args
762+
"set-spirv-capabilities{client-api=opencl},"
763+
"gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
764+
"canonicalize,"
765+
"fold-memref-alias-ops,"
766+
"imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
767+
"spirv.module(spirv-lower-abi-attrs),"
768+
"spirv.module(spirv-update-vce),"
769+
// "func.func(llvm-request-c-wrappers),"
770+
"serialize-spirv,"
771+
"expand-strided-metadata,"
772+
"lower-affine,"
773+
"convert-gpu-to-gpux,"
774+
"convert-func-to-llvm,"
775+
"convert-math-to-llvm,"
776+
"convert-gpux-to-llvm,"
777+
"finalize-memref-to-llvm,"
778+
"reconcile-unrealized-casts";
779+
780+
static const std::string cuda_pipeline =
781+
"add-gpu-regions,"
782+
"canonicalize,"
783+
"ndarray-dist,"
784+
"func.func(dist-coalesce),"
785+
"func.func(dist-infer-elementwise-cores),"
786+
"convert-dist-to-standard,"
787+
"canonicalize,"
788+
"overlap-comm-and-compute,"
789+
"add-comm-cache-keys,"
790+
"lower-distruntime-to-idtr,"
791+
"convert-ndarray-to-linalg,"
792+
"canonicalize,"
793+
"func.func(tosa-make-broadcastable),"
794+
"func.func(tosa-to-linalg),"
795+
"func.func(tosa-to-tensor),"
796+
"canonicalize,"
797+
"linalg-fuse-elementwise-ops,"
798+
"arith-expand,"
799+
"memref-expand,"
800+
"arith-bufferize,"
801+
"func-bufferize,"
802+
"func.func(empty-tensor-to-alloc-tensor),"
803+
"func.func(scf-bufferize),"
804+
"func.func(tensor-bufferize),"
805+
"func.func(bufferization-bufferize),"
806+
"func.func(linalg-bufferize),"
807+
"func.func(linalg-detensorize),"
808+
"func.func(tensor-bufferize),"
809+
"region-bufferize,"
810+
"canonicalize,"
811+
"func.func(finalizing-bufferize),"
812+
"imex-remove-temporaries,"
813+
"func.func(convert-linalg-to-parallel-loops),"
814+
"func.func(scf-parallel-loop-fusion),"
815+
// is add-outer-parallel-loop needed?
816+
"func.func(imex-add-outer-parallel-loop),"
817+
"func.func(gpu-map-parallel-loops),"
818+
"func.func(convert-parallel-loops-to-gpu),"
819+
"func.func(insert-gpu-allocs{in-regions=1}),"
782820
"func.func(insert-gpu-copy),"
783821
"drop-regions,"
784822
"canonicalize,"
@@ -800,7 +838,9 @@ static const std::string gpu_pipeline =
800838

801839
const std::string _passes(get_text_env("SHARPY_PASSES"));
802840
static const std::string &pass_pipeline =
803-
_passes != "" ? _passes : (useGPU() ? gpu_pipeline : cpu_pipeline);
841+
_passes != "" ? _passes
842+
: (useGPU() ? (useCUDA() ? cuda_pipeline : gpu_pipeline)
843+
: cpu_pipeline);
804844

805845
JIT::JIT(const std::string &libidtr)
806846
: _context(::mlir::MLIRContext::Threading::DISABLED), _pm(&_context),
@@ -852,23 +892,24 @@ JIT::JIT(const std::string &libidtr)
852892
_crunnerlib = mlirRoot + "/lib/libmlir_c_runner_utils.so";
853893
_runnerlib = mlirRoot + "/lib/libmlir_runner_utils.so";
854894
if (!std::ifstream(_crunnerlib)) {
855-
throw std::runtime_error("Cannot find libmlir_c_runner_utils.so");
895+
throw std::runtime_error("Cannot find lib: " + _crunnerlib);
856896
}
857897
if (!std::ifstream(_runnerlib)) {
858-
throw std::runtime_error("Cannot find libmlir_runner_utils.so");
898+
throw std::runtime_error("Cannot find lib: " + _runnerlib);
859899
}
860900

861901
if (useGPU()) {
862902
auto gpuxlibstr = get_text_env("SHARPY_GPUX_SO");
863903
if (!gpuxlibstr.empty()) {
864904
_gpulib = std::string(gpuxlibstr);
865905
} else {
866-
// auto imexRoot = get_text_env("IMEXROOT");
867-
// imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
868-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
869-
// _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
870-
// for nv gpu
871-
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
906+
if (useCUDA()) {
907+
_gpulib = mlirRoot + "/lib/libmlir_cuda_runtime.so";
908+
} else {
909+
auto imexRoot = get_text_env("IMEXROOT");
910+
imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
911+
_gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
912+
}
872913
if (!std::ifstream(_gpulib)) {
873914
throw std::runtime_error("Cannot find lib: " + _gpulib);
874915
}

0 commit comments

Comments
 (0)