Skip to content

Commit b799c92

Browse files
authored
Update flux.hpp _inplace
1 parent 8914625 commit b799c92

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

flux.hpp

+11-11
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ namespace Flux {
204204
// return: [ModulationOut, ModulationOut]
205205
auto lin = std::dynamic_pointer_cast<Linear>(blocks["lin"]);
206206

207-
auto out = ggml_silu(ctx, vec);
207+
auto out = ggml_silu_inplace(ctx, vec);
208208
out = lin->forward(ctx, out); // [N, multiplier*dim]
209209

210210
auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
@@ -235,8 +235,8 @@ namespace Flux {
235235
// shift: [N, C]
236236
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
237237
shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C]
238-
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
239-
x = ggml_add(ctx, x, shift);
238+
x = ggml_add_inplace(ctx, x, ggml_mul(ctx, x, scale));
239+
x = ggml_add_inplace(ctx, x, shift);
240240
return x;
241241
}
242242

@@ -346,22 +346,22 @@ namespace Flux {
346346
img_attn_out = ggml_cont(ctx, ggml_permute(ctx, img_attn_out, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
347347

348348
// calculate the img bloks
349-
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
349+
img = ggml_add_inplace(ctx, img, ggml_mul(ctx, img_attn->post_attention(ctx, img_attn_out), img_mod1.gate));
350350

351351
auto img_mlp_out = img_mlp_0->forward(ctx, Flux::modulate(ctx, img_norm2->forward(ctx, img), img_mod2.shift, img_mod2.scale));
352352
img_mlp_out = ggml_gelu_inplace(ctx, img_mlp_out);
353353
img_mlp_out = img_mlp_2->forward(ctx, img_mlp_out);
354354

355-
img = ggml_add(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate));
355+
img = ggml_add_inplace(ctx, img, ggml_mul(ctx, img_mlp_out, img_mod2.gate));
356356

357357
// calculate the txt bloks
358-
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
358+
txt = ggml_add_inplace(ctx, txt, ggml_mul(ctx, txt_attn->post_attention(ctx, txt_attn_out), txt_mod1.gate));
359359

360360
auto txt_mlp_out = txt_mlp_0->forward(ctx, Flux::modulate(ctx, txt_norm2->forward(ctx, txt), txt_mod2.shift, txt_mod2.scale));
361361
txt_mlp_out = ggml_gelu_inplace(ctx, txt_mlp_out);
362362
txt_mlp_out = txt_mlp_2->forward(ctx, txt_mlp_out);
363363

364-
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate));
364+
txt = ggml_add_inplace(ctx, txt, ggml_mul(ctx, txt_mlp_out, txt_mod2.gate));
365365

366366
return {img, txt};
367367
}
@@ -448,7 +448,7 @@ namespace Flux {
448448
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
449449
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
450450

451-
output = ggml_add(ctx, x, ggml_mul(ctx, output, mod.gate));
451+
output = ggml_add_inplace(ctx, x, ggml_mul(ctx, output, mod.gate));
452452
return output;
453453
}
454454
};
@@ -473,7 +473,7 @@ namespace Flux {
473473
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
474474
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
475475

476-
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
476+
auto m = adaLN_modulation_1->forward(ctx, ggml_silu_inplace(ctx, c)); // [N, 2 * hidden_size]
477477
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
478478
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
479479

@@ -741,10 +741,10 @@ namespace Flux {
741741
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
742742
// bf16 and fp16 result is different
743743
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
744-
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
744+
vec = ggml_add_inplace(ctx, vec, guidance_in->forward(ctx, g_in));
745745
}
746746

747-
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
747+
vec = ggml_add_inplace(ctx, vec, vector_in->forward(ctx, y));
748748
txt = txt_in->forward(ctx, txt);
749749

750750
for (int i = 0; i < params.depth; i++) {

0 commit comments

Comments
 (0)