Skip to content

Commit d7679c9

Browse files
committed
I don't know what I'm doing, but it's working better now
1 parent 26d44ab commit d7679c9

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

conditioner.hpp

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,20 +1087,28 @@ struct FluxCLIPEmbedder : public Conditioner {
10871087
int64_t t0 = ggml_time_ms();
10881088
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
10891089
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token*2, 4096]
1090-
struct ggml_tensor* chunk_hidden_states_l = NULL; // [n_token, hidden_size_l]
1091-
struct ggml_tensor* chunk_hidden_states_t5 = NULL; // [n_token, hidden_size_t5]
10921090
struct ggml_tensor* pooled = NULL; // [768,]
10931091
std::vector<float> hidden_states_vec;
10941092

1095-
size_t chunk_len = 77;
1096-
size_t chunk_count = clip_l_tokens.size() / chunk_len;
1093+
size_t chunk_len_l = 77;
1094+
size_t chunk_count_l = clip_l_tokens.size() / chunk_len_l;
1095+
1096+
size_t chunk_len_t5 = 256;
1097+
size_t chunk_count_t5 = t5_tokens.size() / chunk_len_t5;
1098+
1099+
// TODO: I believe chunk_count_l is actually bigger than chunk_count_t5
1100+
// So this ignores some tokens for clip
1101+
size_t chunk_count = chunk_count_t5;
1102+
10971103
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
1104+
struct ggml_tensor* chunk_hidden_states_l = NULL; // [n_token, hidden_size_l]
1105+
struct ggml_tensor* chunk_hidden_states_t5 = NULL; // [n_token, hidden_size_t5]
10981106
// clip_l
1099-
{
1100-
std::vector<int> chunk_tokens(clip_l_tokens.begin() + chunk_idx * chunk_len,
1101-
clip_l_tokens.begin() + (chunk_idx + 1) * chunk_len);
1102-
std::vector<float> chunk_weights(clip_l_weights.begin() + chunk_idx * chunk_len,
1103-
clip_l_weights.begin() + (chunk_idx + 1) * chunk_len);
1107+
if(chunk_idx < chunk_count_l) {
1108+
std::vector<int> chunk_tokens(clip_l_tokens.begin() + chunk_idx * chunk_len_l,
1109+
clip_l_tokens.begin() + (chunk_idx + 1) * chunk_len_l);
1110+
std::vector<float> chunk_weights(clip_l_weights.begin() + chunk_idx * chunk_len_l,
1111+
clip_l_weights.begin() + (chunk_idx + 1) * chunk_len_l);
11041112

11051113
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
11061114
size_t max_token_idx = 0;
@@ -1129,7 +1137,6 @@ struct FluxCLIPEmbedder : public Conditioner {
11291137
ggml_tensor_scale(tensor, (original_mean / new_mean));
11301138
}
11311139
if (chunk_idx == 0) {
1132-
size_t chunk_len_l = 77;
11331140
std::vector<int> chunk_tokens(clip_l_tokens.begin(),
11341141
clip_l_tokens.begin() + chunk_len_l);
11351142
std::vector<float> chunk_weights(clip_l_weights.begin(),
@@ -1157,11 +1164,11 @@ struct FluxCLIPEmbedder : public Conditioner {
11571164
}
11581165

11591166
// t5
1160-
{
1161-
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
1162-
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
1163-
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
1164-
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
1167+
if(chunk_idx < chunk_count_t5) {
1168+
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len_t5,
1169+
t5_tokens.begin() + (chunk_idx + 1) * chunk_len_t5);
1170+
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len_t5,
1171+
t5_weights.begin() + (chunk_idx + 1) * chunk_len_t5);
11651172

11661173
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
11671174

@@ -1205,8 +1212,12 @@ struct FluxCLIPEmbedder : public Conditioner {
12051212
}
12061213
}
12071214
}
1208-
1209-
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states_l_pad, chunk_hidden_states_t5, 1); // [n_token*2, 4096]
1215+
1216+
if(chunk_hidden_states_t5 == NULL){
1217+
chunk_hidden_states = chunk_hidden_states_l_pad;
1218+
} else {
1219+
chunk_hidden_states = ggml_tensor_concat(work_ctx, chunk_hidden_states_l_pad, chunk_hidden_states_t5, 1); // [n_token*2, 4096]
1220+
}
12101221

12111222

12121223
int64_t t1 = ggml_time_ms();

0 commit comments

Comments
 (0)