Skip to content

llama : switch to floating-point token positions #5679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/baby-llama/baby-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,9 @@ static struct ggml_tensor * forward_lora(
struct ggml_tensor * kc = kv_self.k;
struct ggml_tensor * vc = kv_self.v;

struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, N);
{
int * data = (int *) KQ_pos->data;
float * data = (float *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
Expand Down
10 changes: 5 additions & 5 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ batch.n_tokens = Int32(tokens.count)

for (i, token) in tokens.enumerated() {
batch.token[i] = token
batch.pos[i] = Int32(i)
batch.pos[i] = llama_pos(i)
batch.n_seq_id[i] = 1
// batch.seq_id[i][0] = 0
// TODO: is this the proper way to do this?
Expand All @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 {
}

for i in 1 ..< n_parallel {
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
llama_kv_cache_seq_cp(context, 0, Int32(i), 0, llama_pos(batch.n_tokens))
}

if n_parallel > 1 {
Expand All @@ -125,8 +125,8 @@ while n_cur <= n_len {
continue
}

var n_vocab = llama_n_vocab(model)
var logits = llama_get_logits_ith(context, i_batch[i])
let n_vocab = llama_n_vocab(model)
let logits = llama_get_logits_ith(context, i_batch[i])

var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))

Expand Down Expand Up @@ -173,7 +173,7 @@ while n_cur <= n_len {

// push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
batch.pos[Int(batch.n_tokens)] = llama_pos(n_cur)
batch.n_seq_id[Int(batch.n_tokens)] = 1
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
Expand Down
4 changes: 2 additions & 2 deletions examples/finetune/finetune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
};

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
ggml_set_input(KQ_pos);

// rope has so much parameters that we make a custom function for it
Expand Down Expand Up @@ -743,7 +743,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(

// set KQ_pos
{
int * data = (int *) KQ_pos->data;
float * data = (float *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
Expand Down
8 changes: 4 additions & 4 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ actor LlamaContext {

for i1 in 0..<tokens_list.count {
let i = Int(i1)
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
llama_batch_add(&batch, tokens_list[i], llama_pos(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true

Expand Down Expand Up @@ -183,7 +183,7 @@ actor LlamaContext {
// tokens_list.append(new_token_id)

llama_batch_clear(&batch)
llama_batch_add(&batch, new_token_id, n_cur, [0], true)
llama_batch_add(&batch, new_token_id, llama_pos(n_cur), [0], true)

n_decode += 1
n_cur += 1
Expand All @@ -210,7 +210,7 @@ actor LlamaContext {
let n_tokens = pp

for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
llama_batch_add(&batch, 0, llama_pos(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true

Expand All @@ -234,7 +234,7 @@ actor LlamaContext {
llama_batch_clear(&batch)

for j in 0..<pl {
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
llama_batch_add(&batch, 0, llama_pos(i), [Int32(j)], true)
}

if llama_decode(context, batch) != 0 {
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, (float) *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
Expand Down
2 changes: 1 addition & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ struct llama_server_context
}

const int n_embd = llama_n_embd(model);
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
llama_batch batch_img = { n_eval, nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, nullptr, (float) slot.n_past, 1, 0, };
if (llama_decode(ctx, batch_img))
{
LOG_TEE("%s : failed to eval image\n", __func__);
Expand Down
4 changes: 2 additions & 2 deletions examples/train-text-from-scratch/train-text-from-scratch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ static struct ggml_tensor * llama_build_train_graphs(
};

// KQ_pos - contains the positions
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, N);
ggml_set_input(KQ_pos);

// rope has so much parameters that we make a custom function for it
Expand Down Expand Up @@ -419,7 +419,7 @@ static struct ggml_tensor * llama_build_train_graphs(
ggml_gallocr_alloc_graph(alloc, gb);

if (!measure_only) {
int * data = (int *) KQ_pos->data;
float * data = (float *) KQ_pos->data;
for (int i = 0; i < N; ++i) {
data[i] = n_past + i;
}
Expand Down
30 changes: 15 additions & 15 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6040,7 +6040,7 @@ static __device__ void rope_yarn(
// rope == RoPE == rotary positional embedding
template<typename T, bool has_pos>
static __global__ void rope(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
const T * x, T * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
Expand All @@ -6053,7 +6053,7 @@ static __global__ void rope(
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;

const int p = has_pos ? pos[i2] : 0;
const float p = has_pos ? pos[i2] : 0.0f;
const float theta_base = p*powf(freq_base, -float(col)/ncols);

float cos_theta, sin_theta;
Expand All @@ -6068,7 +6068,7 @@ static __global__ void rope(

template<typename T, bool has_pos>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
const T * x, T * dst, int ncols, int n_dims, const float * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
Expand All @@ -6095,7 +6095,7 @@ static __global__ void rope_neox(

float cur_rot = inv_ndims * ic - ib;

const int p = has_pos ? pos[i2] : 0;
const float p = has_pos ? pos[i2] : 0.0f;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);

float cos_theta, sin_theta;
Expand All @@ -6109,7 +6109,7 @@ static __global__ void rope_neox(
}

static __global__ void rope_glm_f32(
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
const float * x, float * dst, int ncols, const float * pos, float freq_scale, int p_delta_rows, float freq_base,
int n_ctx
) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
Expand All @@ -6124,10 +6124,10 @@ static __global__ void rope_glm_f32(
const int i2 = row/p_delta_rows;

const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
// FIXME: this is likely wrong
const int p = pos != nullptr ? pos[i2] : 0;

const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
const float p = pos != nullptr ? pos[i2] : 0.0f;

const float theta = min(p, (float) n_ctx - 2)*freq_scale*col_theta_scale;
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);

Expand All @@ -6137,7 +6137,7 @@ static __global__ void rope_glm_f32(
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;

const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
const float block_theta = max(p - n_ctx - 2, 0.0f)*col_theta_scale;
const float sin_block_theta = sinf(block_theta);
const float cos_block_theta = cosf(block_theta);

Expand Down Expand Up @@ -7688,7 +7688,7 @@ static void clamp_f32_cuda(const float * x, float * dst, const float min, const

template<typename T>
static void rope_cuda(
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const T * x, T * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
Expand All @@ -7708,7 +7708,7 @@ static void rope_cuda(

template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const T * x, T * dst, int ncols, int n_dims, int nrows, const float * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
Expand All @@ -7733,7 +7733,7 @@ static void rope_neox_cuda(
}

static void rope_glm_f32_cuda(
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
const float * x, float * dst, int ncols, int nrows, const float * pos, float freq_scale, int p_delta_rows,
float freq_base, int n_ctx, cudaStream_t stream
) {
GGML_ASSERT(ncols % 4 == 0);
Expand Down Expand Up @@ -9035,11 +9035,11 @@ static void ggml_cuda_op_rope(
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));

const int32_t * pos = nullptr;
const float * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_dd;
pos = (const float *) src1_dd;
}

const bool is_neox = mode & 2;
Expand Down
8 changes: 7 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -2057,7 +2057,13 @@ static bool ggml_metal_graph_compute(
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];

float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
float freq_base;
float freq_scale;
float ext_factor;
float attn_factor;
float beta_fast;
float beta_slow;

memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
Expand Down
10 changes: 5 additions & 5 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ static void rope_yarn_corr_dims(

typedef void (rope_t)(
device const void * src0,
device const int32_t * src1,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
Expand Down Expand Up @@ -1709,7 +1709,7 @@ typedef void (rope_t)(
template<typename T>
kernel void kernel_rope(
device const void * src0,
device const int32_t * src1,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
Expand Down Expand Up @@ -1749,11 +1749,11 @@ kernel void kernel_rope(
float corr_dims[2];
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);

device const int32_t * pos = src1;
device const float * pos = src1;

const int64_t p = pos[i2];
const float p = pos[i2];

const float theta_0 = (float)p;
const float theta_0 = p;
const float inv_ndims = -1.f/n_dims;

if (!is_neox) {
Expand Down
Loading