Skip to content

Commit a79e4e8

Browse files
committed
llama : initial Mamba-2 support
1 parent 8062650 commit a79e4e8

File tree

7 files changed

+495
-87
lines changed

7 files changed

+495
-87
lines changed

convert_hf_to_gguf.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2788,6 +2788,73 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27882788
return [(new_name, data_torch)]
27892789

27902790

2791+
@Model.register("Mamba2ForCausalLM")
2792+
class Mamba2Model(Model):
2793+
model_arch = gguf.MODEL_ARCH.MAMBA2
2794+
2795+
def set_vocab(self):
2796+
vocab_size = self.hparams["vocab_size"]
2797+
# Round vocab size to next multiple of 16
2798+
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
2799+
# pad using ceiling division
2800+
# ref: https://stackoverflow.com/a/17511341/22827863
2801+
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
2802+
self.hparams["vocab_size"] = vocab_size
2803+
2804+
if (self.dir_model / "tokenizer.json").is_file():
2805+
self._set_vocab_gpt2()
2806+
elif (self.dir_model / "tokenizer.model").is_file():
2807+
self._set_vocab_sentencepiece()
2808+
elif (self.dir_model / "tokenizer.model.v3").is_file():
2809+
# mamba-codestral
2810+
raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}")
2811+
else:
2812+
# Use the GPT-NeoX tokenizer when no tokenizer files are present
2813+
self._set_vocab_builtin("gpt-neox", vocab_size)
2814+
2815+
def set_gguf_parameters(self):
2816+
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
2817+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
2818+
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
2819+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
2820+
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
2821+
n_group = self.find_hparam(["n_groups"], optional=True) or 1
2822+
2823+
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
2824+
2825+
# Fail early for models which don't have a block expansion factor of 2
2826+
# TODO: does this really matter?
2827+
assert d_inner == 2 * d_model
2828+
assert d_inner % head_dim == 0
2829+
2830+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
2831+
self.gguf_writer.add_embedding_length(d_model)
2832+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
2833+
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
2834+
self.gguf_writer.add_block_count(self.block_count)
2835+
self.gguf_writer.add_ssm_conv_kernel(d_conv)
2836+
self.gguf_writer.add_ssm_inner_size(d_inner)
2837+
self.gguf_writer.add_ssm_state_size(d_state)
2838+
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
2839+
self.gguf_writer.add_ssm_group_count(n_group)
2840+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
2841+
self.gguf_writer.add_file_type(self.ftype)
2842+
2843+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2844+
del bid # unused
2845+
2846+
if name.endswith(".dt_bias"):
2847+
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
2848+
2849+
new_name = self.map_tensor_name(name)
2850+
2851+
if name.endswith(".A_log"):
2852+
logger.debug("A_log --> A ==> " + new_name)
2853+
data_torch = -torch.exp(data_torch)
2854+
2855+
yield (new_name, data_torch)
2856+
2857+
27912858
@Model.register("CohereForCausalLM")
27922859
class CommandR2Model(Model):
27932860
model_arch = gguf.MODEL_ARCH.COMMAND_R

ggml/include/ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1787,7 +1787,8 @@ extern "C" {
17871787
struct ggml_tensor * dt,
17881788
struct ggml_tensor * A,
17891789
struct ggml_tensor * B,
1790-
struct ggml_tensor * C);
1790+
struct ggml_tensor * C,
1791+
struct ggml_tensor * D);
17911792

17921793
// partition into non-overlapping windows with padding if needed
17931794
// example:

ggml/src/ggml.c

Lines changed: 125 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan(
72707270
struct ggml_tensor * dt,
72717271
struct ggml_tensor * A,
72727272
struct ggml_tensor * B,
7273-
struct ggml_tensor * C) {
7273+
struct ggml_tensor * C,
7274+
struct ggml_tensor * D) {
72747275
GGML_ASSERT(ggml_is_contiguous(s));
7275-
GGML_ASSERT(ggml_is_contiguous(x));
72767276
GGML_ASSERT(ggml_is_contiguous(dt));
72777277
GGML_ASSERT(ggml_is_contiguous(A));
7278-
GGML_ASSERT(ggml_is_matrix(A));
7279-
GGML_ASSERT(ggml_is_3d(B));
7280-
GGML_ASSERT(ggml_is_3d(s));
7278+
GGML_ASSERT(x->nb[0] == ggml_type_size(x->type));
72817279
GGML_ASSERT(B->nb[0] == ggml_type_size(B->type));
72827280
GGML_ASSERT(C->nb[0] == ggml_type_size(C->type));
7283-
GGML_ASSERT(ggml_are_same_shape(x, dt));
7281+
GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]);
7282+
GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]);
7283+
GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]);
72847284
GGML_ASSERT(ggml_are_same_shape(B, C));
72857285

72867286
{
72877287
const int64_t d_state = s->ne[0];
7288-
const int64_t d_inner = s->ne[1];
7289-
const int64_t n_seq_tokens = x->ne[1];
7290-
const int64_t n_seqs = x->ne[2];
7291-
7292-
GGML_ASSERT(s->ne[2] == n_seqs);
7293-
GGML_ASSERT(x->ne[0] == d_inner);
7294-
GGML_ASSERT(A->ne[0] == d_state);
7295-
GGML_ASSERT(A->ne[1] == d_inner);
7288+
const int64_t head_dim = x->ne[0];
7289+
const int64_t n_head = x->ne[1];
7290+
const int64_t n_seq_tokens = x->ne[2];
7291+
const int64_t n_seqs = x->ne[3];
7292+
7293+
GGML_ASSERT(dt->ne[0] == n_head);
7294+
GGML_ASSERT(dt->ne[1] == n_seq_tokens);
7295+
GGML_ASSERT(dt->ne[2] == n_seqs);
7296+
GGML_ASSERT(ggml_is_3d(dt));
7297+
GGML_ASSERT(s->ne[1] == head_dim);
7298+
GGML_ASSERT(s->ne[2] == n_head);
7299+
GGML_ASSERT(s->ne[3] == n_seqs);
72967300
GGML_ASSERT(B->ne[0] == d_state);
7297-
GGML_ASSERT(B->ne[1] == n_seq_tokens);
7298-
GGML_ASSERT(B->ne[2] == n_seqs);
7301+
GGML_ASSERT(B->ne[2] == n_seq_tokens);
7302+
GGML_ASSERT(B->ne[3] == n_seqs);
7303+
GGML_ASSERT(D->ne[0] == n_head);
7304+
GGML_ASSERT(ggml_is_vector(D));
7305+
7306+
if (ggml_is_vector(A)) {
7307+
// Mamba-2
7308+
GGML_ASSERT(A->ne[0] == n_head);
7309+
} else {
7310+
// Mamba-1
7311+
GGML_ASSERT(A->ne[0] == d_state);
7312+
GGML_ASSERT(A->ne[1] == n_head);
7313+
GGML_ASSERT(ggml_is_matrix(A));
7314+
}
72997315
}
73007316

73017317
bool is_node = false;
@@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan(
73167332
result->src[3] = A;
73177333
result->src[4] = B;
73187334
result->src[5] = C;
7335+
result->src[6] = D;
73197336

73207337
return result;
73217338
}
@@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv(
1584015857
static void ggml_compute_forward_ssm_scan_f32(
1584115858
const struct ggml_compute_params * params,
1584215859
struct ggml_tensor * dst) {
15843-
const struct ggml_tensor * src0 = dst->src[0]; // s
15844-
const struct ggml_tensor * src1 = dst->src[1]; // x
15845-
const struct ggml_tensor * src2 = dst->src[2]; // dt
15846-
const struct ggml_tensor * src3 = dst->src[3]; // A
15847-
const struct ggml_tensor * src4 = dst->src[4]; // B
15848-
const struct ggml_tensor * src5 = dst->src[5]; // C
15860+
const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs}
15861+
const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
15862+
const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
15863+
const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head}
15864+
const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
15865+
const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
15866+
const struct ggml_tensor * src6 = dst->src[6]; // D {n_head}
1584915867

1585015868
const int ith = params->ith;
1585115869
const int nth = params->nth;
1585215870

15853-
const int64_t nc = src0->ne[0]; // d_state
15854-
const int64_t nr = src0->ne[1]; // d_inner
15855-
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
15856-
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
15871+
const int64_t nc = src0->ne[0]; // d_state
15872+
const int64_t nr = src0->ne[1]; // dim
15873+
const int64_t nh = src1->ne[1]; // n_head
15874+
const int64_t ng = src4->ne[1];
15875+
const int64_t nt = src1->ne[2]; // number of tokens per sequence
15876+
const int64_t ns = src0->ne[3]; // number of sequences in the batch
15877+
15878+
const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1);
1585715879

1585815880
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
1585915881
GGML_ASSERT(src0->nb[0] == sizeof(float));
@@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32(
1586215884
GGML_ASSERT(src3->nb[0] == sizeof(float));
1586315885
GGML_ASSERT(src4->nb[0] == sizeof(float));
1586415886
GGML_ASSERT(src5->nb[0] == sizeof(float));
15865-
// required for the dot product between s and C
15866-
GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
15867-
// required for per-sequence offsets for states
15868-
GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float));
15869-
// required to get correct offset for state destination (i.e. src1->nb[3])
15870-
GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float));
15871-
15872-
// rows per thread
15873-
const int dr = (nr + nth - 1)/nth;
15874-
15875-
// row range for this thread
15876-
const int ir0 = dr*ith;
15877-
const int ir1 = MIN(ir0 + dr, nr);
15878-
const int ir = ir1 - ir0;
15879-
15880-
for (int i3 = 0; i3 < n_s; ++i3) {
15881-
for (int i2 = 0; i2 < n_t; ++i2) {
15882-
const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s}
15883-
const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15884-
const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s}
15885-
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
15886-
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
15887-
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
15888-
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15889-
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
15890-
15891-
// use the output as the source for the next token-wise iterations
15887+
GGML_ASSERT(src6->nb[0] == sizeof(float));
15888+
// allows optimizing the modulo since n_group should be a power of 2
15889+
GGML_ASSERT((ng & -ng) == ng);
15890+
15891+
// heads per thread
15892+
const int dh = (nh + nth - 1)/nth;
15893+
15894+
// head range for this thread
15895+
const int ih0 = dh*ith;
15896+
const int ih1 = MIN(ih0 + dh, nh);
15897+
15898+
for (int i3 = 0; i3 < ns; ++i3) {
15899+
for (int i2 = 0; i2 < nt; ++i2) {
15900+
const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns}
15901+
const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
15902+
const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
15903+
const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh}
15904+
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
15905+
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
15906+
const float * D = (const float *) ((const char *) src6->data); // {nh}
15907+
float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
15908+
float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
15909+
15910+
// use the output as the source when it's not the first token-wise iteration
1589215911
if (i2 > 0) { s0 = s; }
1589315912

15894-
// d_inner
15895-
for (int i1 = 0; i1 < ir; ++i1) {
15896-
// ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78
15897-
float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1];
15898-
float x_dt = x[i1] * dt_soft_plus;
15899-
float sumf = 0.0f;
15900-
// d_state
15901-
for (int i0 = 0; i0 < nc; ++i0) {
15902-
int i = i0 + i1*nc;
15903-
// state = prev_state * dA + dB * x
15904-
float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt);
15905-
// y = rowwise_dotprod(state, C)
15906-
sumf += state * C[i0];
15907-
s[i] = state;
15913+
if (ggml_is_vector(src3)) {
15914+
// Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
15915+
15916+
// n_head
15917+
for (int h = ih0; h < ih1; ++h) {
15918+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
15919+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
15920+
const float dA = expf(dt_soft_plus * A[h]);
15921+
15922+
// TODO: SIMD implementation
15923+
// dim
15924+
for (int i1 = 0; i1 < nr; ++i1) {
15925+
const int i = i1 + h*nr;
15926+
const float x_dt = x[i] * dt_soft_plus;
15927+
float sumf = 0.0f;
15928+
// d_state
15929+
for (int i0 = 0; i0 < nc; ++i0) {
15930+
const int ii = i0 + i*nc;
15931+
const int ig = i0 + (h & (ng - 1))*nc;
15932+
// state = prev_state * dA + dB * x
15933+
const float state = (s0[ii] * dA) + (B[ig] * x_dt);
15934+
// y = rowwise_dotprod(state, C)
15935+
sumf += state * C[ig];
15936+
s[ii] = state;
15937+
}
15938+
y[i] = sumf + x[i] * D[h];
15939+
}
15940+
}
15941+
} else {
15942+
// Mamba-1 has an element-wise decay factor for the states
15943+
15944+
// n_head
15945+
for (int h = ih0; h < ih1; ++h) {
15946+
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
15947+
const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
15948+
15949+
// dim
15950+
for (int i1 = 0; i1 < nr; ++i1) {
15951+
const int i = i1 + h*nr;
15952+
const float x_dt = x[i] * dt_soft_plus;
15953+
float sumf = 0.0f;
15954+
// d_state
15955+
for (int i0 = 0; i0 < nc; ++i0) {
15956+
const int ii = i0 + i*nc;
15957+
const int ig = i0 + (h & (ng - 1))*nc;
15958+
// state = prev_state * dA + dB * x
15959+
const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
15960+
// y = rowwise_dotprod(state, C)
15961+
sumf += state * C[ig];
15962+
s[ii] = state;
15963+
}
15964+
y[i] = sumf + x[i] * D[h];
15965+
}
1590815966
}
15909-
y[i1] = sumf;
1591015967
}
1591115968
}
1591215969
}

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class SSM:
130130
INNER_SIZE = "{arch}.ssm.inner_size"
131131
STATE_SIZE = "{arch}.ssm.state_size"
132132
TIME_STEP_RANK = "{arch}.ssm.time_step_rank"
133+
GROUP_COUNT = "{arch}.ssm.group_count"
133134
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"
134135

135136
class Tokenizer:
@@ -208,6 +209,7 @@ class MODEL_ARCH(IntEnum):
208209
GEMMA2 = auto()
209210
STARCODER2 = auto()
210211
MAMBA = auto()
212+
MAMBA2 = auto()
211213
XVERSE = auto()
212214
COMMAND_R = auto()
213215
DBRX = auto()
@@ -269,6 +271,7 @@ class MODEL_TENSOR(IntEnum):
269271
SSM_DT = auto()
270272
SSM_A = auto()
271273
SSM_D = auto()
274+
SSM_NORM = auto()
272275
SSM_OUT = auto()
273276
ATTN_Q_A = auto()
274277
ATTN_Q_B = auto()
@@ -338,6 +341,7 @@ class MODEL_TENSOR(IntEnum):
338341
MODEL_ARCH.GEMMA2: "gemma2",
339342
MODEL_ARCH.STARCODER2: "starcoder2",
340343
MODEL_ARCH.MAMBA: "mamba",
344+
MODEL_ARCH.MAMBA2: "mamba2",
341345
MODEL_ARCH.XVERSE: "xverse",
342346
MODEL_ARCH.COMMAND_R: "command-r",
343347
MODEL_ARCH.DBRX: "dbrx",
@@ -399,6 +403,7 @@ class MODEL_TENSOR(IntEnum):
399403
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
400404
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
401405
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
406+
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
402407
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
403408
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
404409
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
@@ -869,6 +874,19 @@ class MODEL_TENSOR(IntEnum):
869874
MODEL_TENSOR.SSM_D,
870875
MODEL_TENSOR.SSM_OUT,
871876
],
877+
MODEL_ARCH.MAMBA2: [
878+
MODEL_TENSOR.TOKEN_EMBD,
879+
MODEL_TENSOR.OUTPUT_NORM,
880+
MODEL_TENSOR.OUTPUT,
881+
MODEL_TENSOR.ATTN_NORM,
882+
MODEL_TENSOR.SSM_IN,
883+
MODEL_TENSOR.SSM_CONV1D,
884+
MODEL_TENSOR.SSM_DT,
885+
MODEL_TENSOR.SSM_A,
886+
MODEL_TENSOR.SSM_D,
887+
MODEL_TENSOR.SSM_NORM,
888+
MODEL_TENSOR.SSM_OUT,
889+
],
872890
MODEL_ARCH.XVERSE: [
873891
MODEL_TENSOR.TOKEN_EMBD,
874892
MODEL_TENSOR.OUTPUT_NORM,
@@ -1373,6 +1391,7 @@ def get_type(val: Any) -> GGUFValueType:
13731391
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
13741392
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
13751393
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
1394+
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
13761395
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
13771396

13781397
# tokenization

gguf-py/gguf/gguf_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None:
730730
def add_ssm_time_step_rank(self, value: int) -> None:
731731
self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value)
732732

733+
def add_ssm_group_count(self, value: int) -> None:
734+
self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value)
735+
733736
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
734737
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)
735738

0 commit comments

Comments
 (0)