Skip to content

Commit bd40678

Browse files
authored
HIP: Add support for RDNA4 targets (#12372)
1 parent b3298fa commit bd40678

File tree

7 files changed

+25
-17
lines changed

7 files changed

+25
-17
lines changed

docs/build.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ The following compilation options are also available to tweak performance:
191191

192192
| Option | Legal values | Default | Description |
193193
|-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
194-
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
194+
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, CDNA and RDNA3+). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
195195
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models |
196196
| GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
197197
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |

ggml/src/ggml-cuda/common.cuh

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,26 @@
5252
#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
5353

5454
// AMD
55-
// GCN/CNDA, wave size is 64
55+
// GCN/CDNA, wave size is 64
5656
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
5757
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
5858
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
5959
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
6060
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
6161
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
6262

63-
// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
63+
// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
6464
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
6565
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
6666
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
67+
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
6768

6869
#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
6970
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
7071
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
7172
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
72-
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
73+
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
74+
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
7375
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
7476
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7577

@@ -209,9 +211,9 @@ typedef float2 dfloat2;
209211
#define FP16_MMA_AVAILABLE
210212
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
211213

212-
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
214+
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
213215
#define FP16_MMA_AVAILABLE
214-
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
216+
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
215217

216218
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
217219
#define NEW_MMA_AVAILABLE
@@ -244,14 +246,14 @@ static bool fp16_mma_available(const int cc) {
244246
return false;
245247
#else
246248
return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
247-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
249+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
248250
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
249251
}
250252

251253
// To be used for feature selection of external libraries, e.g. cuBLAS.
252254
static bool fp16_mma_hardware_available(const int cc) {
253255
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
254-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc);
256+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
255257
}
256258

257259
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
@@ -409,7 +411,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
409411
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
410412
#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
411413
c = __builtin_amdgcn_sdot4(a, b, c, false);
412-
#elif defined(RDNA3)
414+
#elif defined(RDNA3) || defined(RDNA4)
413415
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
414416
#elif defined(RDNA1) || defined(__gfx900__)
415417
int tmp1;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ static void ggml_cuda_op_mul_mat_cublas(
12161216

12171217
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
12181218

1219-
if (GGML_CUDA_CC_IS_CDNA(cc)) {
1219+
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
12201220
const float alpha = 1.0f;
12211221
const float beta = 0.0f;
12221222
CUBLAS_CHECK(
@@ -1759,7 +1759,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
17591759
beta = &beta_f32;
17601760
}
17611761

1762-
if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
1762+
int id = ggml_cuda_get_device();
1763+
const int cc = ggml_cuda_info().devices[id].cc;
1764+
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
17631765
cu_compute_type = CUBLAS_COMPUTE_32F;
17641766
alpha = &alpha_f32;
17651767
beta = &beta_f32;
@@ -1836,7 +1838,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
18361838
}
18371839
#endif
18381840

1839-
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1841+
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
18401842
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
18411843
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
18421844
}

ggml/src/ggml-cuda/mmq.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
149149
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
150150
}
151151

152-
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
152+
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
153153
}

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
25772577

25782578
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
25792579
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2580-
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2580+
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
25812581
__launch_bounds__(WARP_SIZE*nwarps, 2)
2582-
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
2582+
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
25832583
#else
25842584
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
25852585
__launch_bounds__(WARP_SIZE*nwarps, 1)

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ enum mmvq_parameter_table_id {
5454
};
5555

5656
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
57-
#if defined(RDNA2) || defined(RDNA3)
57+
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
5858
return MMVQ_PARAMETERS_RDNA2;
5959
#elif defined(GCN) || defined(CDNA)
6060
return MMVQ_PARAMETERS_GCN;
@@ -64,7 +64,7 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
6464
}
6565

6666
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
67-
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
67+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
6868
return MMVQ_PARAMETERS_RDNA2;
6969
}
7070
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,10 @@
151151
#define CDNA
152152
#endif
153153

154+
#if defined(__GFX12__)
155+
#define RDNA4
156+
#endif
157+
154158
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
155159
defined(__gfx1150__) || defined(__gfx1151__)
156160
#define RDNA3

0 commit comments

Comments
 (0)