diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 90dfe7a7fcc00..4f84e56b3d3aa 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -72,7 +72,7 @@ uint32_t llama_hparams::n_embd_v_s() const { bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { - return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); + return n_swa_pattern == 0 || (il % n_swa_pattern < (n_swa_pattern - 1)); } GGML_ABORT("fatal error"); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index f865cbaea0240..5222eedcfb099 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -104,7 +104,18 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; uint32_t n_swa = 0; // the size of the sliding window (0 - no SWA) - uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention + uint32_t n_swa_pattern = 1; // this value n means that every nth layer is dense (i.e. non-SWA) + // by default n == 1, all layers are dense + // note that if n_swa_pattern == 0, all layers are SWA + // example: n_swa_pattern = 3 + // il == 0: swa + // il == 1: swa + // il == 2: dense + // il == 3: swa + // il == 4: swa + // il == 5: dense + // il == 6: swa + // etc ... // for State Space Models uint32_t ssm_d_conv = 0;