From fc775366f17a97ee5b5c7f48af2736971361a572 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 Feb 2024 12:18:30 +0200 Subject: [PATCH 1/4] llama : switch to floating-point token positions ggml-ci --- examples/baby-llama/baby-llama.cpp | 4 +-- examples/finetune/finetune.cpp | 4 +-- examples/llava/llava.cpp | 2 +- examples/server/server.cpp | 2 +- .../train-text-from-scratch.cpp | 4 +-- ggml-cuda.cu | 30 +++++++++---------- ggml-metal.m | 8 ++++- ggml-metal.metal | 10 +++---- ggml.c | 12 ++++---- llama.cpp | 20 ++++++------- llama.h | 4 +-- tests/test-backend-ops.cpp | 9 +++--- tests/test-grad0.cpp | 8 ++--- tests/test-rope.cpp | 12 ++++---- 14 files changed, 68 insertions(+), 61 deletions(-) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 65bb238a0d565..c3e483292c6c0 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -1015,9 +1015,9 @@ static struct ggml_tensor * forward_lora( struct ggml_tensor * kc = kv_self.k; struct ggml_tensor * vc = kv_self.v; - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, N); { - int * data = (int *) KQ_pos->data; + float * data = (float *) KQ_pos->data; for (int i = 0; i < N; ++i) { data[i] = n_past + i; } diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 98bf5a07a7ed1..2944ce8bd1bcd 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -554,7 +554,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( }; // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N); ggml_set_input(KQ_pos); // rope has so much parameters that we make a custom function for it @@ -743,7 +743,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs( // set KQ_pos { - int * data = (int *) KQ_pos->data; + float * data = (float *) KQ_pos->data; for (int i = 0; i < N; ++i) { data[i] = n_past + i; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 1a1cf7c78bf34..8f2cf29a5e17b 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -338,7 +338,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, }; + llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, (float) *n_past, 1, 0, }; if (llama_decode(ctx_llama, batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 369121e885b27..a12e39bd0e7f0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1281,7 +1281,7 @@ struct llama_server_context } const int n_embd = llama_n_embd(model); - llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, }; + llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, (float) slot.n_past, 1, 0, }; if (llama_decode(ctx, batch_img)) { LOG_TEE("%s : failed to eval image\n", __func__); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index e78ab185d89f3..4de3d9123cafe 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -291,7 +291,7 @@ static struct ggml_tensor * llama_build_train_graphs( }; // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N); + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N); ggml_set_input(KQ_pos); // rope has so much parameters that we make a custom function for it @@ -419,7 +419,7 @@ static struct ggml_tensor * llama_build_train_graphs( ggml_gallocr_alloc_graph(alloc, gb); if (!measure_only) { - int * data = (int *) KQ_pos->data; + float * data = (float *) KQ_pos->data; for (int i = 0; i < N; ++i) { data[i] = n_past + i; } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b0e454e025ec4..37bfb39cb3955 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6040,7 +6040,7 @@ static __device__ void rope_yarn( // rope == RoPE == rotary positional embedding template static __global__ void rope( - const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + const T * x, T * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -6053,7 +6053,7 @@ static __global__ void rope( const int i = row*ncols + col; const int i2 = row/p_delta_rows; - const int p = has_pos ? pos[i2] : 0; + const float p = has_pos ? pos[i2] : 0.0f; const float theta_base = p*powf(freq_base, -float(col)/ncols); float cos_theta, sin_theta; @@ -6068,7 +6068,7 @@ static __global__ void rope( template static __global__ void rope_neox( - const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, + const T * x, T * dst, int ncols, int n_dims, const float * pos, float freq_scale, int p_delta_rows, float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims ) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); @@ -6095,7 +6095,7 @@ static __global__ void rope_neox( float cur_rot = inv_ndims * ic - ib; - const int p = has_pos ? pos[i2] : 0; + const float p = has_pos ? pos[i2] : 0.0f; const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f); float cos_theta, sin_theta; @@ -6109,7 +6109,7 @@ static __global__ void rope_neox( } static __global__ void rope_glm_f32( - const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, + const float * x, float * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base, int n_ctx ) { const int col = blockDim.x*blockIdx.x + threadIdx.x; @@ -6124,10 +6124,10 @@ static __global__ void rope_glm_f32( const int i2 = row/p_delta_rows; const float col_theta_scale = powf(freq_base, -2.0f*col/ncols); - // FIXME: this is likely wrong - const int p = pos != nullptr ? pos[i2] : 0; - const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale; + const float p = pos != nullptr ? pos[i2] : 0.0f; + + const float theta = min(p, (float) n_ctx - 2)*freq_scale*col_theta_scale; const float sin_theta = sinf(theta); const float cos_theta = cosf(theta); @@ -6137,7 +6137,7 @@ static __global__ void rope_glm_f32( dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; - const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale; + const float block_theta = max(p - n_ctx - 2, 0.0f)*col_theta_scale; const float sin_block_theta = sinf(block_theta); const float cos_block_theta = cosf(block_theta); @@ -7688,7 +7688,7 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const template static void rope_cuda( - const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const T * x, T * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); @@ -7708,7 +7708,7 @@ static void rope_cuda( template static void rope_neox_cuda( - const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const T * x, T * dst, int ncols, int n_dims, int nrows, const float * pos, float freq_scale, int p_delta_rows, float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream ) { GGML_ASSERT(ncols % 2 == 0); @@ -7733,7 +7733,7 @@ static void rope_neox_cuda( } static void rope_glm_f32_cuda( - const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, + const float * x, float * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows, float freq_base, int n_ctx, cudaStream_t stream ) { GGML_ASSERT(ncols % 4 == 0); @@ -9035,11 +9035,11 @@ static void ggml_cuda_op_rope( memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - const int32_t * pos = nullptr; + const float * pos = nullptr; if ((mode & 1) == 0) { - GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->ne[0] == ne2); - pos = (const int32_t *) src1_dd; + pos = (const float *) src1_dd; } const bool is_neox = mode & 2; diff --git a/ggml-metal.m b/ggml-metal.m index 0d4aa43093739..e183a56c20562 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2057,7 +2057,13 @@ static bool ggml_metal_graph_compute( // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); diff --git a/ggml-metal.metal b/ggml-metal.metal index c223a981c246a..09ec33d06173e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1674,7 +1674,7 @@ static void rope_yarn_corr_dims( typedef void (rope_t)( device const void * src0, - device const int32_t * src1, + device const float * src1, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1709,7 +1709,7 @@ typedef void (rope_t)( template kernel void kernel_rope( device const void * src0, - device const int32_t * src1, + device const float * src1, device float * dst, constant int64_t & ne00, constant int64_t & ne01, @@ -1749,11 +1749,11 @@ kernel void kernel_rope( float corr_dims[2]; rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); - device const int32_t * pos = src1; + device const float * pos = src1; - const int64_t p = pos[i2]; + const float p = pos[i2]; - const float theta_0 = (float)p; + const float theta_0 = p; const float inv_ndims = -1.f/n_dims; if (!is_neox) { diff --git a/ggml.c b/ggml.c index d710fe702ddbd..6fc1fc1aa412d 100644 --- a/ggml.c +++ b/ggml.c @@ -5254,7 +5254,7 @@ static struct ggml_tensor * ggml_rope_impl( bool xpos_down, bool inplace) { GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(b->type == GGML_TYPE_F32); GGML_ASSERT(a->ne[2] == b->ne[0]); bool is_node = false; @@ -5377,7 +5377,7 @@ struct ggml_tensor * ggml_rope_back( float xpos_base, bool xpos_down) { GGML_ASSERT(ggml_is_vector(b)); - GGML_ASSERT(b->type == GGML_TYPE_I32); + GGML_ASSERT(b->type == GGML_TYPE_F32); GGML_ASSERT(a->ne[2] == b->ne[0]); GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet"); @@ -12352,11 +12352,11 @@ static void ggml_compute_forward_rope_f32( // this essentially just switches the sign of sin. const float sin_sign = forward ? 1.0f : -1.0f; - const int32_t * pos = (const int32_t *) src1->data; + const float * pos = (const float *) src1->data; for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; + const float p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox @@ -12523,11 +12523,11 @@ static void ggml_compute_forward_rope_f16( // this essentially just switches the sign of sin. const float sin_sign = forward ? 1.0f : -1.0f; - const int32_t * pos = (const int32_t *) src1->data; + const float * pos = (const float *) src1->data; for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; + const float p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; if (!is_glm && !is_neox) { // TODO: cache sin/cos for glm, neox diff --git a/llama.cpp b/llama.cpp index 37477e6ef3c44..788bf3fbc2bd2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1699,8 +1699,8 @@ struct llama_layer { }; struct llama_kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; + float pos = -1.0f; + float delta = 0.0f; std::set seq_id; @@ -1939,10 +1939,10 @@ struct llama_context { ggml_context * ctx_input = nullptr; struct ggml_tensor * inp_tokens; // I32 [n_batch] struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch] - struct ggml_tensor * inp_pos; // I32 [n_batch] + struct ggml_tensor * inp_pos; // F32 [n_batch] struct ggml_tensor * inp_KQ_mask; // F32 [n_ctx, n_batch] struct ggml_tensor * inp_KQ_pos; // F32 [n_ctx] - struct ggml_tensor * inp_K_shift; // I32 [n_ctx] + struct ggml_tensor * inp_K_shift; // F32 [n_ctx] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] @@ -2222,7 +2222,7 @@ static void llama_kv_cache_seq_div( llama_seq_id seq_id, llama_pos p0, llama_pos p1, - int d) { + float d) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); @@ -7744,7 +7744,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); - int32_t * data = (int32_t *) lctx.inp_K_shift->data; + float * data = (float *) lctx.inp_K_shift->data; for (int i = 0; i < n_ctx; ++i) { data[i] = lctx.kv_self.cells[i].delta; @@ -11690,10 +11690,10 @@ struct llama_context * llama_new_context_with_model( ctx->inp_tokens = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); ctx->inp_embd = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch); - ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); + ctx->inp_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch); ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch); ctx->inp_KQ_pos = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx); - ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx); + ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx); ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch); ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch); @@ -12046,7 +12046,7 @@ void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, l llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } -void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, float d) { if (d == 1) { return; } @@ -12461,7 +12461,7 @@ int llama_eval_embd( int32_t n_past) { llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, (float) n_past, 1, 0, }; const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { diff --git a/llama.h b/llama.h index 84f196b3bb625..b8c97c0891784 100644 --- a/llama.h +++ b/llama.h @@ -54,7 +54,7 @@ extern "C" { struct llama_model; struct llama_context; - typedef int32_t llama_pos; + typedef float llama_pos; typedef int32_t llama_token; typedef int32_t llama_seq_id; @@ -531,7 +531,7 @@ extern "C" { llama_seq_id seq_id, llama_pos p0, llama_pos p1, - int d); + float d); // // State / sessions diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 55db42bf6e851..3d50ca12a2547 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1134,14 +1134,15 @@ struct test_rope : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]); + ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]); + ggml_set_name(pos, "pos"); ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx); return out; } void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (t->type == GGML_TYPE_I32) { + if (strcmp(ggml_get_name(t), "pos") == 0) { // pos std::vector data(ne[2]); for (int i = 0; i < ne[2]; i++) { @@ -1703,7 +1704,7 @@ struct test_llama : public test_llm { inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens); // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); @@ -1825,7 +1826,7 @@ struct test_falcon : public test_llm { inpL = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hp.n_embd, hp.n_tokens); // inp_pos - contains the positions - struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 8ff76c8910c49..fcce73d704f72 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -1449,9 +1449,9 @@ int main(int argc, const char ** argv) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); + struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne2[2]); for (int i = 0; i < ne2[2]; ++i) { - ((int32_t *) p->data)[i] = n_past + i; + ((float *) p->data)[i] = n_past + i; } ggml_set_param(ctx0, x[0]); @@ -1489,9 +1489,9 @@ int main(int argc, const char ** argv) { for (int n_past = 1; n_past < ne2[2]; ++n_past) { x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f); - struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); + struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne2[2]); for (int i = 0; i < ne2[2]; ++i) { - ((int32_t *) p->data)[i] = n_past + i; + ((float *) p->data)[i] = n_past + i; } ggml_set_param(ctx0, x[0]); diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp index 26c1f42dc0e95..6d8c2704cdb8d 100644 --- a/tests/test-rope.cpp +++ b/tests/test-rope.cpp @@ -146,14 +146,14 @@ int main(int /*argc*/, const char ** /*argv*/) { const int n_past_0 = 100; const int n_past_2 = 33; - struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); - struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); - struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]); + struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]); + struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]); + struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ne[2]); for (int i = 0; i < ne[2]; ++i) { - ((int32_t *) p0->data)[i] = n_past_0 + i; - ((int32_t *) p1->data)[i] = n_past_2 - n_past_0; - ((int32_t *) p2->data)[i] = n_past_2 + i; + ((float *) p0->data)[i] = n_past_0 + i; + ((float *) p1->data)[i] = n_past_2 - n_past_0; + ((float *) p2->data)[i] = n_past_2 + i; } // test mode 0, 2, 4 (standard, GPT-NeoX, GLM) From 8772658b117ff4921e33235362c5bac49789d8eb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 Feb 2024 14:14:49 +0200 Subject: [PATCH 2/4] ggml : add I32 <-> F32 conversion ggml-ci --- ggml.c | 73 ++++++++++++++++++++++-------------- llama.cpp | 8 ++-- tests/test-quantize-fns.cpp | 6 +-- tests/test-quantize-perf.cpp | 2 +- 4 files changed, 53 insertions(+), 36 deletions(-) diff --git a/ggml.c b/ggml.c index 6fc1fc1aa412d..4fd7a4141c3d6 100644 --- a/ggml.c +++ b/ggml.c @@ -355,6 +355,18 @@ void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n) { } } +static void ggml_i32_to_f32_row(const int32_t * x, float * y, int n) { + for (int i = 0; i < n; i++) { + y[i] = (float) x[i]; + } +} + +static void ggml_f32_to_i32_row(const float * x, int32_t * y, int n) { + for (int i = 0; i < n; i++) { + y[i] = (int32_t) x[i]; + } +} + // // timing // @@ -454,6 +466,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = 1, .type_size = sizeof(int32_t), .is_quantized = false, + .to_float = (ggml_to_float_t) ggml_i32_to_f32_row, + .from_float = (ggml_from_float_t) ggml_f32_to_i32_row, + .from_float_reference = (ggml_from_float_t) ggml_f32_to_i32_row, }, [GGML_TYPE_F32] = { .type_name = "f32", @@ -469,10 +484,10 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = 1, .type_size = sizeof(ggml_fp16_t), .is_quantized = false, - .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, + .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row, .from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row, - .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16, .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, @@ -481,8 +496,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK4_0, .type_size = sizeof(block_q4_0), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q4_0, - .from_float = quantize_row_q4_0, + .to_float = (ggml_to_float_t) dequantize_row_q4_0, + .from_float = (ggml_from_float_t) quantize_row_q4_0, .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference, .vec_dot = ggml_vec_dot_q4_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, @@ -497,8 +512,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK4_1, .type_size = sizeof(block_q4_1), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q4_1, - .from_float = quantize_row_q4_1, + .to_float = (ggml_to_float_t) dequantize_row_q4_1, + .from_float = (ggml_from_float_t) quantize_row_q4_1, .from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference, .vec_dot = ggml_vec_dot_q4_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, @@ -537,8 +552,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK5_0, .type_size = sizeof(block_q5_0), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q5_0, - .from_float = quantize_row_q5_0, + .to_float = (ggml_to_float_t) dequantize_row_q5_0, + .from_float = (ggml_from_float_t) quantize_row_q5_0, .from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference, .vec_dot = ggml_vec_dot_q5_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, @@ -549,8 +564,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK5_1, .type_size = sizeof(block_q5_1), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q5_1, - .from_float = quantize_row_q5_1, + .to_float = (ggml_to_float_t) dequantize_row_q5_1, + .from_float = (ggml_from_float_t) quantize_row_q5_1, .from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference, .vec_dot = ggml_vec_dot_q5_1_q8_1, .vec_dot_type = GGML_TYPE_Q8_1, @@ -561,8 +576,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK8_0, .type_size = sizeof(block_q8_0), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q8_0, - .from_float = quantize_row_q8_0, + .to_float = (ggml_to_float_t) dequantize_row_q8_0, + .from_float = (ggml_from_float_t) quantize_row_q8_0, .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference, .vec_dot = ggml_vec_dot_q8_0_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, @@ -577,7 +592,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK8_1, .type_size = sizeof(block_q8_1), .is_quantized = true, - .from_float = quantize_row_q8_1, + .from_float = (ggml_from_float_t) quantize_row_q8_1, .from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference, .vec_dot_type = GGML_TYPE_Q8_1, .nrows = 1, @@ -587,8 +602,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_K, .type_size = sizeof(block_q2_K), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q2_K, - .from_float = quantize_row_q2_K, + .to_float = (ggml_to_float_t) dequantize_row_q2_K, + .from_float = (ggml_from_float_t) quantize_row_q2_K, .from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference, .vec_dot = ggml_vec_dot_q2_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, @@ -599,8 +614,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_K, .type_size = sizeof(block_q3_K), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q3_K, - .from_float = quantize_row_q3_K, + .to_float = (ggml_to_float_t) dequantize_row_q3_K, + .from_float = (ggml_from_float_t) quantize_row_q3_K, .from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference, .vec_dot = ggml_vec_dot_q3_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, @@ -611,8 +626,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_K, .type_size = sizeof(block_q4_K), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q4_K, - .from_float = quantize_row_q4_K, + .to_float = (ggml_to_float_t) dequantize_row_q4_K, + .from_float = (ggml_from_float_t) quantize_row_q4_K, .from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference, .vec_dot = ggml_vec_dot_q4_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, @@ -623,8 +638,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_K, .type_size = sizeof(block_q5_K), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q5_K, - .from_float = quantize_row_q5_K, + .to_float = (ggml_to_float_t) dequantize_row_q5_K, + .from_float = (ggml_from_float_t) quantize_row_q5_K, .from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference, .vec_dot = ggml_vec_dot_q5_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, @@ -635,8 +650,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_K, .type_size = sizeof(block_q6_K), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q6_K, - .from_float = quantize_row_q6_K, + .to_float = (ggml_to_float_t) dequantize_row_q6_K, + .from_float = (ggml_from_float_t) quantize_row_q6_K, .from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference, .vec_dot = ggml_vec_dot_q6_K_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, @@ -671,9 +686,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK_K, .type_size = sizeof(block_iq3_xxs), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, - .from_float = quantize_row_iq3_xxs, - .from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference, + .to_float = (ggml_to_float_t) dequantize_row_iq3_xxs, + .from_float = (ggml_from_float_t) quantize_row_iq3_xxs, + .from_float_reference = (ggml_from_float_t) quantize_row_iq3_xxs_reference, .vec_dot = ggml_vec_dot_iq3_xxs_q8_K, .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, @@ -695,9 +710,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .blck_size = QK4_NL, .type_size = sizeof(block_iq4_nl), .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, - .from_float = quantize_row_iq4_nl, - .from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference, + .to_float = (ggml_to_float_t) dequantize_row_iq4_nl, + .from_float = (ggml_from_float_t) quantize_row_iq4_nl, + .from_float_reference = (ggml_from_float_t) quantize_row_iq4_nl_reference, .vec_dot = ggml_vec_dot_iq4_nl_q8_0, .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, diff --git a/llama.cpp b/llama.cpp index 788bf3fbc2bd2..174ecf899fc3b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5928,9 +5928,10 @@ struct llm_build_context { // get input vectors with right size const size_t stride1 = n_tokens * ggml_type_size(lctx.inp_tokens->type); - struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); + + struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0); struct ggml_tensor * inp_mean = ggml_view_2d(ctx0, lctx.inp_mean, n_tokens, n_tokens, stride1, 0); - struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0); + struct ggml_tensor * inp_cls = ggml_view_1d(ctx0, lctx.inp_cls, n_tokens, 0); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb); @@ -5938,8 +5939,9 @@ struct llm_build_context { // token types are hardcoded to zero ("Sentence A") struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); inpL = ggml_add(ctx0, inpL, type_row0); + if (model.arch == LLM_ARCH_BERT) { - inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); + inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, ggml_cast(ctx0, inp_pos, GGML_TYPE_I32)), inpL); } cb(inpL, "inp_embd", -1); diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index 5e92d5742a3cc..0b90e560c7fa9 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -143,10 +143,10 @@ int main(int argc, char * argv[]) { continue; } - printf("Testing %s\n", ggml_type_name((ggml_type) i)); - ggml_quantize_init(ei); + if (qfns.from_float && qfns.to_float && qfns.vec_dot) { + printf("Testing %s\n", ggml_type_name((ggml_type) i)); + ggml_quantize_init(ei); - if (qfns.from_float && qfns.to_float) { const float total_error = total_quantization_error(qfns, test_size, test_data.data()); const float max_quantization_error = type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp index 48d9fae3dc06e..ca4c156f3172c 100644 --- a/tests/test-quantize-perf.cpp +++ b/tests/test-quantize-perf.cpp @@ -275,7 +275,7 @@ int main(int argc, char * argv[]) { continue; } - if (qfns.from_float && qfns.to_float) { + if (qfns.from_float && qfns.to_float && qfns.vec_dot) { printf("%s\n", ggml_type_name(type)); ggml_quantize_init(type); From fff1e8a54a18d52bb742b656e13b67ec008167e8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 Feb 2024 16:15:37 +0200 Subject: [PATCH 3/4] batched.swift : fix build ggml-ci --- examples/batched.swift/Sources/main.swift | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index d75c503d58311..2e1671ca09fc8 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -79,7 +79,7 @@ batch.n_tokens = Int32(tokens.count) for (i, token) in tokens.enumerated() { batch.token[i] = token - batch.pos[i] = Int32(i) + batch.pos[i] = Float(i) batch.n_seq_id[i] = 1 // batch.seq_id[i][0] = 0 // TODO: is this the proper way to do this? @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) + llama_kv_cache_seq_cp(context, 0, Int32(i), 0, Float(batch.n_tokens)) } if n_parallel > 1 { @@ -125,8 +125,8 @@ while n_cur <= n_len { continue } - var n_vocab = llama_n_vocab(model) - var logits = llama_get_logits_ith(context, i_batch[i]) + let n_vocab = llama_n_vocab(model) + let logits = llama_get_logits_ith(context, i_batch[i]) var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab)) @@ -173,7 +173,7 @@ while n_cur <= n_len { // push this new token for next evaluation batch.token[Int(batch.n_tokens)] = new_token_id - batch.pos[Int(batch.n_tokens)] = n_cur + batch.pos[Int(batch.n_tokens)] = Float(n_cur) batch.n_seq_id[Int(batch.n_tokens)] = 1 if let seq_id = batch.seq_id[Int(batch.n_tokens)] { seq_id[0] = Int32(i) From 608f4498802d80d11e5d5f6ac05f4e23f49a747b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 23 Feb 2024 19:02:09 +0200 Subject: [PATCH 4/4] swift : fix build ggml-ci --- examples/batched.swift/Sources/main.swift | 6 +++--- examples/llama.swiftui/llama.cpp.swift/LibLlama.swift | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 2e1671ca09fc8..a761a1ba30291 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -79,7 +79,7 @@ batch.n_tokens = Int32(tokens.count) for (i, token) in tokens.enumerated() { batch.token[i] = token - batch.pos[i] = Float(i) + batch.pos[i] = llama_pos(i) batch.n_seq_id[i] = 1 // batch.seq_id[i][0] = 0 // TODO: is this the proper way to do this? @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_cache_seq_cp(context, 0, Int32(i), 0, Float(batch.n_tokens)) + llama_kv_cache_seq_cp(context, 0, Int32(i), 0, llama_pos(batch.n_tokens)) } if n_parallel > 1 { @@ -173,7 +173,7 @@ while n_cur <= n_len { // push this new token for next evaluation batch.token[Int(batch.n_tokens)] = new_token_id - batch.pos[Int(batch.n_tokens)] = Float(n_cur) + batch.pos[Int(batch.n_tokens)] = llama_pos(n_cur) batch.n_seq_id[Int(batch.n_tokens)] = 1 if let seq_id = batch.seq_id[Int(batch.n_tokens)] { seq_id[0] = Int32(i) diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 58fcf40c6fb69..e59d642dd8601 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -129,7 +129,7 @@ actor LlamaContext { for i1 in 0..