Skip to content

Commit 0f630fb

Browse files
authored
cuda : ROCm AMD Unified Memory Architecture (UMA) handling (#4449)
* AMD ROCm: handle UMA memory VRAM expansions This resolves #2797 by allowing ROCm AMD GPU users with a UMA to dynamically expand the VRAM allocated to the GPU. Without this, AMD ROCm users with shared CPU/GPU memory usually are stuck with the BIOS-set (or fixed) framebuffer VRAM, making it impossible to load more than 1-2 layers. Note that the model is duplicated in RAM because it's loaded once for the CPU and then copied into a second set of allocations that are managed by the HIP UMA system. We can fix this later. * clarify build process for ROCm on linux with cmake * avoid using deprecated ROCm hipMallocHost * keep simplifying the change required for UMA * cmake: enable UMA-compatible allocation when LLAMA_HIP_UMA=ON
1 parent 562cf22 commit 0f630fb

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
9191
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
9292
"llama: max. batch size for using peer access")
9393
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
94+
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
9495
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
9596
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
9697
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
@@ -377,6 +378,9 @@ if (LLAMA_HIPBLAS)
377378
if (${hipblas_FOUND} AND ${hip_FOUND})
378379
message(STATUS "HIP and hipBLAS found")
379380
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
381+
if (LLAMA_HIP_UMA)
382+
add_compile_definitions(GGML_HIP_UMA)
383+
endif()
380384
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
381385
if (BUILD_SHARED_LIBS)
382386
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)

README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,15 @@ Building the program with BLAS support may lead to some performance improvements
432432
```bash
433433
make LLAMA_HIPBLAS=1
434434
```
435-
- Using `CMake` for Linux:
435+
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
436436
```bash
437-
mkdir build
438-
cd build
439-
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON
440-
cmake --build .
437+
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ \
438+
cmake -H. -Bbuild -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
439+
&& cmake --build build -- -j 16
441440
```
442-
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS):
441+
On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DLLAMA_HIP_UMA=ON"`.
442+
However, this hurts performance for non-integrated GPUs.
443+
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
443444
```bash
444445
set PATH=%HIP_PATH%\bin;%PATH%
445446
mkdir build
@@ -448,10 +449,11 @@ Building the program with BLAS support may lead to some performance improvements
448449
cmake --build .
449450
```
450451
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
452+
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
451453
452454
453455
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
454-
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
456+
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
455457
The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
456458
457459
| Option | Legal values | Default | Description |

ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,13 @@
6060
#define cudaGetDeviceProperties hipGetDeviceProperties
6161
#define cudaGetErrorString hipGetErrorString
6262
#define cudaGetLastError hipGetLastError
63+
#ifdef GGML_HIP_UMA
64+
#define cudaMalloc hipMallocManaged
65+
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
66+
#else
6367
#define cudaMalloc hipMalloc
6468
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
69+
#endif
6570
#define cudaMemcpy hipMemcpy
6671
#define cudaMemcpy2DAsync hipMemcpy2DAsync
6772
#define cudaMemcpyAsync hipMemcpyAsync

0 commit comments

Comments
 (0)