@@ -204,7 +204,7 @@ namespace Flux {
204
204
// return: [ModulationOut, ModulationOut]
205
205
auto lin = std::dynamic_pointer_cast<Linear>(blocks[" lin" ]);
206
206
207
- auto out = ggml_silu (ctx, vec);
207
+ auto out = ggml_silu_inplace (ctx, vec);
208
208
out = lin->forward (ctx, out); // [N, multiplier*dim]
209
209
210
210
auto m = ggml_reshape_3d (ctx, out, vec->ne [0 ], multiplier, vec->ne [1 ]); // [N, multiplier, dim]
@@ -235,8 +235,8 @@ namespace Flux {
235
235
// shift: [N, C]
236
236
scale = ggml_reshape_3d (ctx, scale, scale->ne [0 ], 1 , scale->ne [1 ]); // [N, 1, C]
237
237
shift = ggml_reshape_3d (ctx, shift, shift->ne [0 ], 1 , shift->ne [1 ]); // [N, 1, C]
238
- x = ggml_add (ctx, x, ggml_mul (ctx, x, scale));
239
- x = ggml_add (ctx, x, shift);
238
+ x = ggml_add_inplace (ctx, x, ggml_mul (ctx, x, scale));
239
+ x = ggml_add_inplace (ctx, x, shift);
240
240
return x;
241
241
}
242
242
@@ -346,22 +346,22 @@ namespace Flux {
346
346
img_attn_out = ggml_cont (ctx, ggml_permute (ctx, img_attn_out, 0 , 2 , 1 , 3 )); // [N, n_img_token, hidden_size]
347
347
348
348
// calculate the img bloks
349
- img = ggml_add (ctx, img, ggml_mul (ctx, img_attn->post_attention (ctx, img_attn_out), img_mod1.gate ));
349
+ img = ggml_add_inplace (ctx, img, ggml_mul (ctx, img_attn->post_attention (ctx, img_attn_out), img_mod1.gate ));
350
350
351
351
auto img_mlp_out = img_mlp_0->forward (ctx, Flux::modulate (ctx, img_norm2->forward (ctx, img), img_mod2.shift , img_mod2.scale ));
352
352
img_mlp_out = ggml_gelu_inplace (ctx, img_mlp_out);
353
353
img_mlp_out = img_mlp_2->forward (ctx, img_mlp_out);
354
354
355
- img = ggml_add (ctx, img, ggml_mul (ctx, img_mlp_out, img_mod2.gate ));
355
+ img = ggml_add_inplace (ctx, img, ggml_mul (ctx, img_mlp_out, img_mod2.gate ));
356
356
357
357
// calculate the txt bloks
358
- txt = ggml_add (ctx, txt, ggml_mul (ctx, txt_attn->post_attention (ctx, txt_attn_out), txt_mod1.gate ));
358
+ txt = ggml_add_inplace (ctx, txt, ggml_mul (ctx, txt_attn->post_attention (ctx, txt_attn_out), txt_mod1.gate ));
359
359
360
360
auto txt_mlp_out = txt_mlp_0->forward (ctx, Flux::modulate (ctx, txt_norm2->forward (ctx, txt), txt_mod2.shift , txt_mod2.scale ));
361
361
txt_mlp_out = ggml_gelu_inplace (ctx, txt_mlp_out);
362
362
txt_mlp_out = txt_mlp_2->forward (ctx, txt_mlp_out);
363
363
364
- txt = ggml_add (ctx, txt, ggml_mul (ctx, txt_mlp_out, txt_mod2.gate ));
364
+ txt = ggml_add_inplace (ctx, txt, ggml_mul (ctx, txt_mlp_out, txt_mod2.gate ));
365
365
366
366
return {img, txt};
367
367
}
@@ -448,7 +448,7 @@ namespace Flux {
448
448
auto attn_mlp = ggml_concat (ctx, attn, ggml_gelu_inplace (ctx, mlp), 0 ); // [N, n_token, hidden_size + mlp_hidden_dim]
449
449
auto output = linear2->forward (ctx, attn_mlp); // [N, n_token, hidden_size]
450
450
451
- output = ggml_add (ctx, x, ggml_mul (ctx, output, mod.gate ));
451
+ output = ggml_add_inplace (ctx, x, ggml_mul (ctx, output, mod.gate ));
452
452
return output;
453
453
}
454
454
};
@@ -473,7 +473,7 @@ namespace Flux {
473
473
auto linear = std::dynamic_pointer_cast<Linear>(blocks[" linear" ]);
474
474
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks[" adaLN_modulation.1" ]);
475
475
476
- auto m = adaLN_modulation_1->forward (ctx, ggml_silu (ctx, c)); // [N, 2 * hidden_size]
476
+ auto m = adaLN_modulation_1->forward (ctx, ggml_silu_inplace (ctx, c)); // [N, 2 * hidden_size]
477
477
m = ggml_reshape_3d (ctx, m, c->ne [0 ], 2 , c->ne [1 ]); // [N, 2, hidden_size]
478
478
m = ggml_cont (ctx, ggml_permute (ctx, m, 0 , 2 , 1 , 3 )); // [2, N, hidden_size]
479
479
@@ -741,10 +741,10 @@ namespace Flux {
741
741
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" guidance_in" ]);
742
742
// bf16 and fp16 result is different
743
743
auto g_in = ggml_nn_timestep_embedding (ctx, guidance, 256 , 10000 , 1000 .f );
744
- vec = ggml_add (ctx, vec, guidance_in->forward (ctx, g_in));
744
+ vec = ggml_add_inplace (ctx, vec, guidance_in->forward (ctx, g_in));
745
745
}
746
746
747
- vec = ggml_add (ctx, vec, vector_in->forward (ctx, y));
747
+ vec = ggml_add_inplace (ctx, vec, vector_in->forward (ctx, y));
748
748
txt = txt_in->forward (ctx, txt);
749
749
750
750
for (int i = 0 ; i < params.depth ; i++) {
0 commit comments