Skip to content

Commit 3e959f0

Browse files
imatrix: fix oob writes if src1 is not contiguous (#13286)
1 parent 36667c8 commit 3e959f0

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

tools/imatrix/imatrix.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class IMatrixCollector {
4646
common_params m_params;
4747
std::mutex m_mutex;
4848
int m_last_call = 0;
49-
std::vector<float> m_src1_data;
49+
std::vector<char> m_src1_data;
5050
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
5151
};
5252

@@ -93,11 +93,13 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
9393
const bool is_host = ggml_backend_buffer_is_host(src1->buffer);
9494

9595
if (!is_host) {
96-
m_src1_data.resize(ggml_nelements(src1));
97-
ggml_backend_tensor_get(src1, m_src1_data.data(), 0, ggml_nbytes(src1));
96+
const size_t src1_nbytes = ggml_nbytes(src1);
97+
m_src1_data.resize(src1_nbytes);
98+
ggml_backend_tensor_get(src1, m_src1_data.data(), 0, src1_nbytes);
9899
}
99100

100-
const float * data = is_host ? (const float *) src1->data : m_src1_data.data();
101+
const char * data = is_host ? (const char *) src1->data : m_src1_data.data();
102+
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
101103

102104
// this has been adapted to the new format of storing merged experts in a single 3d tensor
103105
// ref: https://github.com/ggml-org/llama.cpp/pull/6387
@@ -144,7 +146,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
144146

145147
const int64_t i11 = idx % src1->ne[1];
146148
const int64_t i12 = row;
147-
const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
149+
const float * x = (const float *)(data + i11*src1->nb[1] + i12*src1->nb[2]);
148150

149151
for (int j = 0; j < (int)src1->ne[0]; ++j) {
150152
e.values[e_start + j] += x[j]*x[j];
@@ -180,7 +182,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
180182
++e.ncall;
181183
LOG_DBGV(2, "%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
182184
for (int row = 0; row < (int)src1->ne[1]; ++row) {
183-
const float * x = data + row * src1->ne[0];
185+
const float * x = (const float *) (data + row * src1->nb[1]);
184186
for (int j = 0; j < (int)src1->ne[0]; ++j) {
185187
e.values[j] += x[j]*x[j];
186188
e.counts[j]++;

0 commit comments

Comments
 (0)