11
11
#include < random>
12
12
#include < unordered_map>
13
13
14
+ static int llama_sample_dist (llama_token_data_array * cur_p, std::mt19937 & rng, std::vector<float > & probs) {
15
+ probs.resize (cur_p->size );
16
+ for (size_t i = 0 ; i < cur_p->size ; ++i) {
17
+ probs[i] = cur_p->data [i].p ;
18
+ }
19
+
20
+ std::discrete_distribution<size_t > dist (probs.begin (), probs.end ());
21
+
22
+ return dist (rng);
23
+ }
24
+
14
25
static void llama_log_softmax (float * array, size_t size) {
15
26
float max_l = *std::max_element (array, array + size);
16
27
float sum = 0 .f ;
@@ -456,22 +467,16 @@ struct llama_sampler_context_dist {
456
467
const uint32_t seed;
457
468
458
469
std::mt19937 rng;
470
+
471
+ std::vector<float > probs; // work array
459
472
};
460
473
461
474
static struct llama_sampler_i llama_sampler_dist_i = {
462
475
/* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " dist" ; },
463
476
/* .accept = */ nullptr ,
464
477
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
465
478
auto * ctx = (llama_sampler_context_dist *) smpl->ctx ;
466
- std::vector<float > probs;
467
- probs.reserve (cur_p->size );
468
- for (size_t i = 0 ; i < cur_p->size ; ++i) {
469
- probs.push_back (cur_p->data [i].p );
470
- }
471
-
472
- std::discrete_distribution<size_t > dist (probs.begin (), probs.end ());
473
-
474
- cur_p->selected = dist (ctx->rng );
479
+ cur_p->selected = llama_sample_dist (cur_p, ctx->rng , ctx->probs );
475
480
},
476
481
/* .reset = */ nullptr ,
477
482
/* .clone = */ [](const struct llama_sampler * smpl) {
@@ -489,6 +494,7 @@ struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) {
489
494
/* .ctx = */ new llama_sampler_context_dist {
490
495
/* .seed = */ seed,
491
496
/* .rng = */ std::mt19937 (seed),
497
+ /* .probs = */ {},
492
498
},
493
499
};
494
500
}
@@ -761,35 +767,23 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta,
761
767
struct llama_sampler_context_mirostat {
762
768
const struct llama_vocab * vocab;
763
769
770
+ const uint32_t seed;
771
+
764
772
const float tau;
765
773
const float eta;
766
774
767
775
const int32_t m;
768
776
769
777
float mu;
770
778
771
- std::vector<llama_token_data> cur;
779
+ std::mt19937 rng;
780
+
781
+ std::vector<float > probs;
772
782
};
773
783
774
784
static struct llama_sampler_i llama_sampler_mirostat_i = {
775
785
/* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " mirostat" ; },
776
- /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
777
- auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx ;
778
-
779
- int32_t idx = -1 ;
780
- for (size_t i = 0 ; i < ctx->cur .size (); ++i) {
781
- if (ctx->cur [i].id == token) {
782
- idx = i;
783
- break ;
784
- }
785
- }
786
-
787
- float observed_surprise = -log2f (ctx->cur [idx].p );
788
- float e = observed_surprise - ctx->tau ;
789
-
790
- // Update mu using the learning rate and error
791
- ctx->mu = ctx->mu - ctx->eta * e;
792
- },
786
+ /* .accept = */ nullptr ,
793
787
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
794
788
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx ;
795
789
@@ -812,70 +806,66 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
812
806
float k = powf ((epsilon_hat * powf (2 , ctx->mu )) / (1 - powf (ctx->vocab ->n_vocab , -epsilon_hat)), 1 / s_hat);
813
807
814
808
llama_sampler_top_k_impl (cur_p, std::max (int (k), 1 ));
809
+ llama_sampler_softmax_impl (cur_p);
815
810
816
- // remember the order to be able to compute the distance later when accepting the token
817
- ctx->cur .resize (cur_p->size );
818
- for (size_t i = 0 ; i < cur_p->size ; ++i) {
819
- ctx->cur [i] = cur_p->data [i];
820
- }
811
+ const int idx = llama_sample_dist (cur_p, ctx->rng , ctx->probs );
812
+
813
+ cur_p->selected = idx;
814
+
815
+ float observed_surprise = -log2f (cur_p->data [idx].p );
816
+ float e = observed_surprise - ctx->tau ;
817
+
818
+ // Update mu using the learning rate and error
819
+ ctx->mu = ctx->mu - ctx->eta * e;
821
820
},
822
821
/* .reset = */ [](struct llama_sampler * smpl) {
823
822
auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx ;
824
823
ctx->mu = 2 .0f *ctx->tau ;
824
+ ctx->rng = std::mt19937 (ctx->seed );
825
825
},
826
826
/* .clone = */ [](const struct llama_sampler * smpl) {
827
827
const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx ;
828
- return llama_sampler_init_mirostat_impl (*ctx->vocab , ctx->tau , ctx->eta , ctx->m );
828
+ return llama_sampler_init_mirostat_impl (*ctx->vocab , ctx->seed , ctx-> tau , ctx->eta , ctx->m );
829
829
},
830
830
/* .free = */ [](struct llama_sampler * smpl) {
831
831
delete (llama_sampler_context_mirostat *) smpl->ctx ;
832
832
},
833
833
};
834
834
835
- struct llama_sampler * llama_sampler_init_mirostat_impl (const struct llama_vocab & vocab, float tau, float eta, int32_t m) {
835
+ struct llama_sampler * llama_sampler_init_mirostat_impl (const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) {
836
836
return new llama_sampler {
837
837
/* .iface = */ &llama_sampler_mirostat_i,
838
838
/* .ctx = */ new llama_sampler_context_mirostat {
839
839
/* .vocab = */ &vocab,
840
+ /* .seed = */ seed,
840
841
/* .tau = */ tau,
841
842
/* .eta = */ eta,
842
843
/* .m = */ m,
843
844
/* .mu = */ 2 .0f *tau,
844
- /* .cur = */ {},
845
+ /* .rng = */ std::mt19937 (seed),
846
+ /* .probs = */ {},
845
847
},
846
848
};
847
849
}
848
850
849
851
// mirostat v2
850
852
851
853
struct llama_sampler_context_mirostat_v2 {
854
+ const uint32_t seed;
855
+
852
856
const float tau;
853
857
const float eta;
854
858
855
859
float mu;
856
860
857
- std::vector<llama_token_data> cur;
861
+ std::mt19937 rng;
862
+
863
+ std::vector<float > probs;
858
864
};
859
865
860
866
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
861
867
/* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " mirostat-v2" ; },
862
- /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
863
- auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
864
-
865
- int32_t idx = -1 ;
866
- for (size_t i = 0 ; i < ctx->cur .size (); ++i) {
867
- if (ctx->cur [i].id == token) {
868
- idx = i;
869
- break ;
870
- }
871
- }
872
-
873
- float observed_surprise = -log2f (ctx->cur [idx].p );
874
- float e = observed_surprise - ctx->tau ;
875
-
876
- // Update mu using the learning rate and error
877
- ctx->mu = ctx->mu - ctx->eta * e;
878
- },
868
+ /* .accept = */ nullptr ,
879
869
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
880
870
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
881
871
@@ -893,33 +883,40 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
893
883
// Normalize the probabilities of the remaining words
894
884
llama_sampler_softmax_impl (cur_p);
895
885
896
- // remember the order to be able to compute the distance later when accepting the token
897
- ctx->cur .resize (cur_p->size );
898
- for (size_t i = 0 ; i < cur_p->size ; ++i) {
899
- ctx->cur [i] = cur_p->data [i];
900
- }
886
+ const int idx = llama_sample_dist (cur_p, ctx->rng , ctx->probs );
887
+
888
+ cur_p->selected = idx;
889
+
890
+ float observed_surprise = -log2f (cur_p->data [idx].p );
891
+ float e = observed_surprise - ctx->tau ;
892
+
893
+ // Update mu using the learning rate and error
894
+ ctx->mu = ctx->mu - ctx->eta * e;
901
895
},
902
896
/* .reset = */ [](struct llama_sampler * smpl) {
903
897
auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
904
898
ctx->mu = 2 .0f *ctx->tau ;
899
+ ctx->rng = std::mt19937 (ctx->seed );
905
900
},
906
901
/* .clone = */ [](const struct llama_sampler * smpl) {
907
902
const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx ;
908
- return llama_sampler_init_mirostat_v2_impl (ctx->tau , ctx->eta );
903
+ return llama_sampler_init_mirostat_v2_impl (ctx->seed , ctx-> tau , ctx->eta );
909
904
},
910
905
/* .free = */ [](struct llama_sampler * smpl) {
911
906
delete (llama_sampler_context_mirostat_v2 *) smpl->ctx ;
912
907
},
913
908
};
914
909
915
- struct llama_sampler * llama_sampler_init_mirostat_v2_impl (float tau, float eta) {
910
+ struct llama_sampler * llama_sampler_init_mirostat_v2_impl (uint32_t seed, float tau, float eta) {
916
911
return new llama_sampler {
917
912
/* .iface = */ &llama_sampler_mirostat_v2_i,
918
913
/* .ctx = */ new llama_sampler_context_mirostat_v2 {
919
- /* .tau = */ tau,
920
- /* .eta = */ eta,
921
- /* .mu = */ 2 .0f *tau,
922
- /* .cur = */ {},
914
+ /* .seed = */ seed,
915
+ /* .tau = */ tau,
916
+ /* .eta = */ eta,
917
+ /* .mu = */ 2 .0f *tau,
918
+ /* .rng = */ std::mt19937 (seed),
919
+ /* .probs = */ {},
923
920
},
924
921
};
925
922
}
@@ -1154,9 +1151,15 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl(
1154
1151
1155
1152
static struct llama_sampler_i llama_sampler_chain_i = {
1156
1153
/* .name = */ [](const struct llama_sampler * /* smpl*/ ) { return " chain" ; },
1157
- /* .accept = */ [](struct llama_sampler * smpl, llama_token /* token*/ ) {
1154
+ /* .accept = */ [](struct llama_sampler * smpl, llama_token token) {
1158
1155
auto * chain = (llama_sampler_chain *) smpl->ctx ;
1159
1156
1157
+ time_meas tm (chain->t_sample_us , chain->params .no_timing );
1158
+
1159
+ for (auto * smpl : chain->samplers ) {
1160
+ llama_sampler_accept_impl (*smpl, token);
1161
+ }
1162
+
1160
1163
chain->n_sample ++;
1161
1164
},
1162
1165
/* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) {
0 commit comments