diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 02dc8c12dbd8c..c442a64924303 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2192,6 +2192,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_SILU: ggml_cuda_op_silu(ctx, dst); break; + case GGML_UNARY_OP_GELU_ERF: + ggml_cuda_op_gelu_erf(ctx, dst); + break; case GGML_UNARY_OP_GELU_QUICK: ggml_cuda_op_gelu_quick(ctx, dst); break; @@ -2977,6 +2980,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index ec5773e01637e..2c0375fbe3cf6 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) { return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); } +static __device__ __forceinline__ float op_gelu_erf(float x) { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); +} + static __device__ __forceinline__ float op_gelu_quick(float x) { const float GELU_QUICK_COEF = -1.702f; @@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 940a1feed9a9c..6686fc17e9193 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);