@@ -1293,7 +1293,7 @@ void ggml_metal_graph_compute(
1293
1293
[encoder setBytes: &pnb3 length: sizeof (pnb3) atIndex: 26 ];
1294
1294
[encoder setBytes: &offs length: sizeof (offs) atIndex: 27 ];
1295
1295
1296
- const int nth = MIN (1024 , ne0 );
1296
+ const int nth = MIN (( int ) ctx-> pipeline_add . maxTotalThreadsPerThreadgroup , ne00 );
1297
1297
1298
1298
[encoder dispatchThreadgroups: MTLSizeMake (ne11, ne12, ne13) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
1299
1299
} break ;
@@ -1793,8 +1793,9 @@ void ggml_metal_graph_compute(
1793
1793
[encoder setBytes: &r3 length: sizeof (r3) atIndex: 17 ];
1794
1794
[encoder setBytes: &idx length: sizeof (idx) atIndex: 18 ];
1795
1795
// TODO: how to make this an array? read Metal docs
1796
- for (int j = 0 ; j < n_as; ++j) {
1797
- struct ggml_tensor * src_cur = dst->src [2 + j];
1796
+ for (int j = 0 ; j < 8 ; ++j) {
1797
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1798
+ struct ggml_tensor * src_cur = dst->src [2 + (j % n_as)];
1798
1799
1799
1800
size_t offs_src_cur = 0 ;
1800
1801
id <MTLBuffer > id_src_cur = ggml_metal_get_buffer (ctx, src_cur, &offs_src_cur);
@@ -1917,8 +1918,9 @@ void ggml_metal_graph_compute(
1917
1918
[encoder setBytes: &r3 length: sizeof (r3) atIndex: 21 ];
1918
1919
[encoder setBytes: &idx length: sizeof (idx) atIndex: 22 ];
1919
1920
// TODO: how to make this an array? read Metal docs
1920
- for (int j = 0 ; j < n_as; ++j) {
1921
- struct ggml_tensor * src_cur = dst->src [2 + j];
1921
+ for (int j = 0 ; j < 8 ; ++j) {
1922
+ // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
1923
+ struct ggml_tensor * src_cur = dst->src [2 + (j % n_as)];
1922
1924
1923
1925
size_t offs_src_cur = 0 ;
1924
1926
id <MTLBuffer > id_src_cur = ggml_metal_get_buffer (ctx, src_cur, &offs_src_cur);
0 commit comments