Skip to content

Commit cf4cb59

Browse files
authored
ggml : add ggml_gelu_erf() (#13667)
* ggml : add ggml_gelu_na (not approximated) * fix naming order * rename na --> erf * apply review suggesions * revert naming order
1 parent 0d5c742 commit cf4cb59

File tree

7 files changed

+213
-2
lines changed

7 files changed

+213
-2
lines changed

ggml/include/ggml.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,14 +528,15 @@ extern "C" {
528528
GGML_UNARY_OP_STEP,
529529
GGML_UNARY_OP_TANH,
530530
GGML_UNARY_OP_ELU,
531-
GGML_UNARY_OP_RELU,
532531
GGML_UNARY_OP_SIGMOID,
533532
GGML_UNARY_OP_GELU,
533+
GGML_UNARY_OP_GELU_ERF,
534534
GGML_UNARY_OP_GELU_QUICK,
535535
GGML_UNARY_OP_SILU,
536536
GGML_UNARY_OP_HARDSWISH,
537537
GGML_UNARY_OP_HARDSIGMOID,
538538
GGML_UNARY_OP_EXP,
539+
GGML_UNARY_OP_RELU,
539540

540541
GGML_UNARY_OP_COUNT,
541542
};
@@ -1024,6 +1025,16 @@ extern "C" {
10241025
struct ggml_context * ctx,
10251026
struct ggml_tensor * a);
10261027

1028+
// GELU using erf (error function) when possible
1029+
// some backends may fallback to approximation based on Abramowitz and Stegun formula
1030+
GGML_API struct ggml_tensor * ggml_gelu_erf(
1031+
struct ggml_context * ctx,
1032+
struct ggml_tensor * a);
1033+
1034+
GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
1035+
struct ggml_context * ctx,
1036+
struct ggml_tensor * a);
1037+
10271038
GGML_API struct ggml_tensor * ggml_gelu_quick(
10281039
struct ggml_context * ctx,
10291040
struct ggml_tensor * a);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2202,6 +2202,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22022202
} break;
22032203

22042204
case GGML_UNARY_OP_GELU:
2205+
case GGML_UNARY_OP_GELU_ERF:
22052206
case GGML_UNARY_OP_GELU_QUICK:
22062207
case GGML_UNARY_OP_SILU:
22072208
{

ggml/src/ggml-cpu/ops.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2691,6 +2691,109 @@ static void ggml_compute_forward_gelu(
26912691
}
26922692
}
26932693

2694+
// ggml_compute_forward_gelu_erf
2695+
2696+
static void ggml_compute_forward_gelu_erf_f32(
2697+
const ggml_compute_params * params,
2698+
ggml_tensor * dst) {
2699+
2700+
const ggml_tensor * src0 = dst->src[0];
2701+
2702+
assert(ggml_is_contiguous_1(src0));
2703+
assert(ggml_is_contiguous_1(dst));
2704+
assert(ggml_are_same_shape(src0, dst));
2705+
2706+
const int ith = params->ith;
2707+
const int nth = params->nth;
2708+
2709+
const int nc = src0->ne[0];
2710+
const int nr = ggml_nrows(src0);
2711+
2712+
// rows per thread
2713+
const int dr = (nr + nth - 1)/nth;
2714+
2715+
// row range for this thread
2716+
const int ir0 = dr*ith;
2717+
const int ir1 = MIN(ir0 + dr, nr);
2718+
2719+
for (int i1 = ir0; i1 < ir1; i1++) {
2720+
ggml_vec_gelu_erf_f32(nc,
2721+
(float *) ((char *) dst->data + i1*( dst->nb[1])),
2722+
(float *) ((char *) src0->data + i1*(src0->nb[1])));
2723+
2724+
#ifndef NDEBUG
2725+
for (int k = 0; k < nc; k++) {
2726+
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2727+
GGML_UNUSED(x);
2728+
assert(!isnan(x));
2729+
assert(!isinf(x));
2730+
}
2731+
#endif
2732+
}
2733+
}
2734+
2735+
static void ggml_compute_forward_gelu_erf_f16(
2736+
const ggml_compute_params * params,
2737+
ggml_tensor * dst) {
2738+
2739+
const ggml_tensor * src0 = dst->src[0];
2740+
2741+
assert(ggml_is_contiguous_1(src0));
2742+
assert(ggml_is_contiguous_1(dst));
2743+
assert(ggml_are_same_shape(src0, dst));
2744+
2745+
const int ith = params->ith;
2746+
const int nth = params->nth;
2747+
2748+
const int nc = src0->ne[0];
2749+
const int nr = ggml_nrows(src0);
2750+
2751+
// rows per thread
2752+
const int dr = (nr + nth - 1)/nth;
2753+
2754+
// row range for this thread
2755+
const int ir0 = dr*ith;
2756+
const int ir1 = MIN(ir0 + dr, nr);
2757+
2758+
for (int i1 = ir0; i1 < ir1; i1++) {
2759+
ggml_vec_gelu_erf_f16(nc,
2760+
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2761+
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
2762+
2763+
#ifndef NDEBUG
2764+
for (int k = 0; k < nc; k++) {
2765+
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2766+
const float v = GGML_FP16_TO_FP32(x);
2767+
GGML_UNUSED(v);
2768+
assert(!isnan(v));
2769+
assert(!isinf(v));
2770+
}
2771+
#endif
2772+
}
2773+
}
2774+
2775+
static void ggml_compute_forward_gelu_erf(
2776+
const ggml_compute_params * params,
2777+
ggml_tensor * dst) {
2778+
2779+
const ggml_tensor * src0 = dst->src[0];
2780+
2781+
switch (src0->type) {
2782+
case GGML_TYPE_F32:
2783+
{
2784+
ggml_compute_forward_gelu_erf_f32(params, dst);
2785+
} break;
2786+
case GGML_TYPE_F16:
2787+
{
2788+
ggml_compute_forward_gelu_erf_f16(params, dst);
2789+
} break;
2790+
default:
2791+
{
2792+
GGML_ABORT("fatal error");
2793+
}
2794+
}
2795+
}
2796+
26942797
// ggml_compute_forward_gelu_quick
26952798

26962799
static void ggml_compute_forward_gelu_quick_f32(
@@ -7749,6 +7852,10 @@ void ggml_compute_forward_unary(
77497852
{
77507853
ggml_compute_forward_gelu(params, dst);
77517854
} break;
7855+
case GGML_UNARY_OP_GELU_ERF:
7856+
{
7857+
ggml_compute_forward_gelu_erf(params, dst);
7858+
} break;
77527859
case GGML_UNARY_OP_GELU_QUICK:
77537860
{
77547861
ggml_compute_forward_gelu_quick(params, dst);

ggml/src/ggml-cpu/vec.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp
428428
static const float GELU_COEF_A = 0.044715f;
429429
static const float GELU_QUICK_COEF = -1.702f;
430430
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
431+
static const float SQRT_2_INV = 0.70710678118654752440084436210484f;
431432

432433
inline static float ggml_gelu_f32(float x) {
433434
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
@@ -440,6 +441,14 @@ inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp
440441
}
441442
}
442443

444+
inline static void ggml_vec_gelu_erf_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
445+
for (int i = 0; i < n; ++i) {
446+
float xi = GGML_FP16_TO_FP32(x[i]);
447+
float res = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
448+
y[i] = GGML_FP32_TO_FP16(res);
449+
}
450+
}
451+
443452
#ifdef GGML_GELU_FP16
444453
inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
445454
uint16_t t;
@@ -463,6 +472,13 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) {
463472
}
464473
#endif
465474

475+
inline static void ggml_vec_gelu_erf_f32(const int n, float * y, const float * x) {
476+
for (int i = 0; i < n; ++i) {
477+
float xi = x[i];
478+
y[i] = 0.5f*xi*(1.0f + erff(xi*SQRT_2_INV));
479+
}
480+
}
481+
466482
inline static float ggml_gelu_quick_f32(float x) {
467483
return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x)));
468484
}

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
149149
GGML_METAL_KERNEL_TYPE_SIGMOID,
150150
GGML_METAL_KERNEL_TYPE_GELU,
151151
GGML_METAL_KERNEL_TYPE_GELU_4,
152+
GGML_METAL_KERNEL_TYPE_GELU_ERF,
153+
GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
152154
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
153155
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
154156
GGML_METAL_KERNEL_TYPE_SILU,
@@ -1103,6 +1105,8 @@ @implementation GGMLMetalClass
11031105
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
11041106
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
11051107
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
1108+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
1109+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
11061110
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
11071111
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
11081112
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
@@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16131617
case GGML_UNARY_OP_RELU:
16141618
case GGML_UNARY_OP_SIGMOID:
16151619
case GGML_UNARY_OP_GELU:
1620+
case GGML_UNARY_OP_GELU_ERF:
16161621
case GGML_UNARY_OP_GELU_QUICK:
16171622
case GGML_UNARY_OP_SILU:
16181623
case GGML_UNARY_OP_ELU:
@@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node(
22512256

22522257
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
22532258
} break;
2259+
case GGML_UNARY_OP_GELU_ERF:
2260+
{
2261+
int64_t n = ggml_nelements(dst);
2262+
2263+
id<MTLComputePipelineState> pipeline = nil;
2264+
2265+
if (n % 4 == 0) {
2266+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
2267+
n /= 4;
2268+
} else {
2269+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
2270+
}
2271+
2272+
[encoder setComputePipelineState:pipeline];
2273+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2274+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2275+
2276+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
2277+
} break;
22542278
case GGML_UNARY_OP_GELU_QUICK:
22552279
{
22562280
int64_t n = ggml_nelements(dst);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,7 @@ kernel void kernel_tanh(
856856
constant float GELU_COEF_A = 0.044715f;
857857
constant float GELU_QUICK_COEF = -1.702f;
858858
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
859+
constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
859860

860861
kernel void kernel_gelu(
861862
device const float * src0,
@@ -897,6 +898,42 @@ kernel void kernel_gelu_quick_4(
897898
dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
898899
}
899900

901+
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
902+
// ref: https://www.johndcook.com/blog/python_erf/
903+
constant float p_erf = 0.3275911f;
904+
constant float a1_erf = 0.254829592f;
905+
constant float a2_erf = -0.284496736f;
906+
constant float a3_erf = 1.421413741f;
907+
constant float a4_erf = -1.453152027f;
908+
constant float a5_erf = 1.061405429f;
909+
910+
template<typename T>
911+
T erf_approx(T x) {
912+
T sign_x = sign(x);
913+
x = fabs(x);
914+
T t = 1.0f / (1.0f + p_erf * x);
915+
T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
916+
return sign_x * y;
917+
}
918+
919+
kernel void kernel_gelu_erf(
920+
device const float * src0,
921+
device float * dst,
922+
uint tpig[[thread_position_in_grid]]) {
923+
device const float & x = src0[tpig];
924+
925+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV));
926+
}
927+
928+
kernel void kernel_gelu_erf_4(
929+
device const float4 * src0,
930+
device float4 * dst,
931+
uint tpig[[thread_position_in_grid]]) {
932+
device const float4 & x = src0[tpig];
933+
934+
dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV));
935+
}
936+
900937
kernel void kernel_silu(
901938
device const float * src0,
902939
device float * dst,

ggml/src/ggml.c

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,9 +1099,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
10991099
"HARDSWISH",
11001100
"HARDSIGMOID",
11011101
"EXP",
1102+
"GELU_ERF",
11021103
};
11031104

1104-
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
1105+
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
11051106

11061107

11071108
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -2501,6 +2502,20 @@ struct ggml_tensor * ggml_gelu_inplace(
25012502
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU);
25022503
}
25032504

2505+
// ggml_gelu_erf
2506+
2507+
struct ggml_tensor * ggml_gelu_erf(
2508+
struct ggml_context * ctx,
2509+
struct ggml_tensor * a) {
2510+
return ggml_unary(ctx, a, GGML_UNARY_OP_GELU_ERF);
2511+
}
2512+
2513+
struct ggml_tensor * ggml_gelu_erf_inplace(
2514+
struct ggml_context * ctx,
2515+
struct ggml_tensor * a) {
2516+
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_GELU_ERF);
2517+
}
2518+
25042519
// ggml_gelu_quick
25052520

25062521
struct ggml_tensor * ggml_gelu_quick(

0 commit comments

Comments
 (0)