Skip to content

make rms_norm_eps a parameter #2374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static const float rms_norm_eps = 1e-6f;

float frand() {
return (float)rand()/(float)RAND_MAX;
}
Expand Down Expand Up @@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -685,7 +687,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);

// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
Expand Down Expand Up @@ -729,7 +731,7 @@ struct ggml_tensor * forward(
{

// inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
Expand Down Expand Up @@ -817,7 +819,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -981,7 +983,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1034,7 +1036,7 @@ struct ggml_tensor * forward_batch(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1104,7 +1106,7 @@ struct ggml_tensor * forward_lora(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -1251,7 +1253,7 @@ struct ggml_tensor * forward_lora(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);

// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
Expand Down Expand Up @@ -1295,7 +1297,7 @@ struct ggml_tensor * forward_lora(
{

// inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
Expand Down
8 changes: 8 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_gqa = std::stoi(argv[i]);
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -519,6 +525,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
Expand Down Expand Up @@ -615,6 +622,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa;
lparams.rms_norm_eps = params.rms_norm_eps;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
Expand Down
23 changes: 12 additions & 11 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@
int32_t get_num_physical_cores();

struct gpt_params {
uint32_t seed = -1; // RNG seed
uint32_t seed = -1; // RNG seed
int32_t n_threads = get_num_physical_cores();
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
float rms_norm_eps = 1e-6; // rms norm epsilon
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

Expand Down
32 changes: 17 additions & 15 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif

static const float rms_norm_eps = 1e-6f;

struct random_normal_distribution {
std::mt19937 gen;
std::normal_distribution<float> rd;
Expand Down Expand Up @@ -439,7 +441,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// cur = attention_norm*cur
cur = ggml_mul(ctx0,
Expand Down Expand Up @@ -562,7 +564,7 @@ struct ggml_tensor * forward(
// norm
{
// cur shape [n_embd,N,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);

// cur = ffn_norm*cur
// cur shape [n_embd,N,1,1]
Expand Down Expand Up @@ -606,7 +608,7 @@ struct ggml_tensor * forward(
{

// inpL shape [n_embd,N,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);

// inpL = norm*inpL
// inpL shape [n_embd,N,1,1]
Expand Down Expand Up @@ -694,7 +696,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -857,7 +859,7 @@ struct ggml_tensor * forward_batch(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -910,7 +912,7 @@ struct ggml_tensor * forward_batch(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -979,7 +981,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -1085,7 +1087,7 @@ struct ggml_tensor * forward_batch_wo_cache(
// norm
{
// cur shape [n_embd,N*n_batch,1,1]
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1138,7 +1140,7 @@ struct ggml_tensor * forward_batch_wo_cache(
{

// inpL shape [n_embd,N*n_batch,1,1]
inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1203,7 +1205,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(

// norm
{
cur = ggml_rms_norm(ctx0, inpL);
cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = attention_norm*cur
Expand Down Expand Up @@ -1267,7 +1269,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
{
// norm
{
cur = ggml_rms_norm(ctx0, inpFF);
cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps);
assert_shape_2d(cur, n_embd, N*n_batch);

// cur = ffn_norm*cur
Expand Down Expand Up @@ -1311,7 +1313,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
// norm
{

inpL = ggml_rms_norm(ctx0, inpL);
inpL = ggml_rms_norm(ctx0, inpL, rms_norm_eps);
assert_shape_2d(inpL, n_embd, N*n_batch);

// inpL = norm*inpL
Expand Down Expand Up @@ -1603,7 +1605,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
struct my_llama_layer & layer = model->layers[il];
// tensors with values necessary for backward pass are in persistent buf(-1)
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t02, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
Expand All @@ -1623,7 +1625,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21, rms_norm_eps)); assert_shape_2d(t22, n_embd, N*n_batch);
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
Expand Down Expand Up @@ -1666,7 +1668,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
}
clr_buf(0);
use_buf(0);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur, rms_norm_eps)); assert_shape_2d(t31, n_embd, N*n_batch);
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
use_buf(-1);
Expand Down
13 changes: 7 additions & 6 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -332,12 +332,10 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
}
}

static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;

const float eps = 1e-6f;

float tmp = 0.0f; // partial sum for thread in warp

for (int col = tid; col < ncols; col += WARP_SIZE) {
Expand Down Expand Up @@ -2122,10 +2120,10 @@ static void norm_f32_cuda(const float * x, float * dst, const int ncols, const i
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}

static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}

static void quantize_row_q8_1_cuda(const float * x, void * vy, const int ndata, const int k, cudaStream_t stream) {
Expand Down Expand Up @@ -2876,8 +2874,11 @@ inline void ggml_cuda_op_rms_norm(
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;

float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// compute
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, eps, cudaStream_main);

(void) src1;
(void) dst;
Expand Down
3 changes: 2 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,8 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}

const float eps = 1e-6f;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

const int nth = 512;

Expand Down
16 changes: 10 additions & 6 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -5781,6 +5781,7 @@ struct ggml_tensor * ggml_norm_inplace(
static struct ggml_tensor * ggml_rms_norm_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float eps,
bool inplace) {
bool is_node = false;

Expand All @@ -5790,7 +5791,7 @@ static struct ggml_tensor * ggml_rms_norm_impl(

struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

// TODO: maybe store epsilon here?
ggml_set_op_params(result, &eps, sizeof(eps));

result->op = GGML_OP_RMS_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
Expand All @@ -5801,14 +5802,16 @@ static struct ggml_tensor * ggml_rms_norm_impl(

struct ggml_tensor * ggml_rms_norm(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, false);
struct ggml_tensor * a,
float eps) {
return ggml_rms_norm_impl(ctx, a, eps, false);
}

struct ggml_tensor * ggml_rms_norm_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_rms_norm_impl(ctx, a, true);
struct ggml_tensor * a,
float eps) {
return ggml_rms_norm_impl(ctx, a, eps, true);
}

struct ggml_tensor * ggml_rms_norm_back(
Expand Down Expand Up @@ -10131,7 +10134,8 @@ static void ggml_compute_forward_rms_norm_f32(

GGML_TENSOR_UNARY_OP_LOCALS;

const float eps = 1e-6f; // TODO: make this a parameter
float eps;
memcpy(&eps, dst->op_params, sizeof(float));

// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
Expand Down
Loading