Skip to content

Commit d525163

Browse files
idostyleido
authored and
ido
committed
Support DistillT5
1 parent 10c6501 commit d525163

File tree

1 file changed

+41
-14
lines changed

1 file changed

+41
-14
lines changed

t5.hpp

+41-14
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,25 @@ struct T5Block : public GGMLBlock {
648648
}
649649
};
650650

651+
struct T5Projection : public UnaryBlock {
652+
public:
653+
T5Projection(int64_t model_dim, int64_t projection_dim) {
654+
blocks["0"] = std::shared_ptr<GGMLBlock>(new Linear(model_dim, projection_dim, false));
655+
blocks["3"] = std::shared_ptr<GGMLBlock>(new Linear(projection_dim, projection_dim, false));
656+
}
657+
658+
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
659+
// x: [N, n_token, model_dim]
660+
auto wi = std::dynamic_pointer_cast<Linear>(blocks["0"]);
661+
auto wo = std::dynamic_pointer_cast<Linear>(blocks["3"]);
662+
663+
x = wi->forward(ctx, x);
664+
x = ggml_relu_inplace(ctx, x);
665+
x = wo->forward(ctx, x);
666+
return x;
667+
}
668+
};
669+
651670
struct T5Stack : public GGMLBlock {
652671
int64_t num_layers;
653672

@@ -682,6 +701,7 @@ struct T5Stack : public GGMLBlock {
682701
auto final_layer_norm = std::dynamic_pointer_cast<T5LayerNorm>(blocks["final_layer_norm"]);
683702

684703
x = final_layer_norm->forward(ctx, x);
704+
685705
return x;
686706
}
687707
};
@@ -692,9 +712,11 @@ struct T5 : public GGMLBlock {
692712
int64_t model_dim,
693713
int64_t ff_dim,
694714
int64_t num_heads,
695-
int64_t vocab_size) {
715+
int64_t vocab_size,
716+
int64_t projection_dim) {
696717
blocks["encoder"] = std::shared_ptr<GGMLBlock>(new T5Stack(num_layers, model_dim, model_dim, ff_dim, num_heads));
697718
blocks["shared"] = std::shared_ptr<GGMLBlock>(new Embedding(vocab_size, model_dim));
719+
blocks["final_projection"] = std::shared_ptr<GGMLBlock>(new T5Projection(model_dim, projection_dim));
698720
}
699721

700722
struct ggml_tensor* forward(struct ggml_context* ctx,
@@ -709,6 +731,9 @@ struct T5 : public GGMLBlock {
709731

710732
auto x = shared->forward(ctx, input_ids);
711733
x = encoder->forward(ctx, x, past_bias, attention_mask, relative_position_bucket);
734+
735+
auto final_projection = std::dynamic_pointer_cast<T5Projection>(blocks["final_projection"]);
736+
x = final_projection->forward(ctx, x);
712737
return x;
713738
}
714739
};
@@ -720,12 +745,13 @@ struct T5Runner : public GGMLRunner {
720745
T5Runner(ggml_backend_t backend,
721746
std::map<std::string, enum ggml_type>& tensor_types,
722747
const std::string prefix,
723-
int64_t num_layers = 24,
724-
int64_t model_dim = 4096,
725-
int64_t ff_dim = 10240,
726-
int64_t num_heads = 64,
727-
int64_t vocab_size = 32128)
728-
: GGMLRunner(backend), model(num_layers, model_dim, ff_dim, num_heads, vocab_size) {
748+
int64_t num_layers = 12,
749+
int64_t model_dim = 768,
750+
int64_t ff_dim = 2048,
751+
int64_t num_heads = 12,
752+
int64_t vocab_size = 32128,
753+
int64_t projection_dim = 4096)
754+
: GGMLRunner(backend), model(num_layers, model_dim, ff_dim, num_heads, vocab_size, projection_dim) {
729755
model.init(params_ctx, tensor_types, prefix);
730756
}
731757

@@ -861,12 +887,13 @@ struct T5Embedder {
861887
T5Embedder(ggml_backend_t backend,
862888
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
863889
const std::string prefix = "",
864-
int64_t num_layers = 24,
865-
int64_t model_dim = 4096,
866-
int64_t ff_dim = 10240,
867-
int64_t num_heads = 64,
868-
int64_t vocab_size = 32128)
869-
: model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size) {
890+
int64_t num_layers = 12,
891+
int64_t model_dim = 768,
892+
int64_t ff_dim = 2048,
893+
int64_t num_heads = 12,
894+
int64_t vocab_size = 32128,
895+
int64_t projection_dim = 4096)
896+
: model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size, projection_dim) {
870897
}
871898

872899
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
@@ -983,4 +1010,4 @@ struct T5Embedder {
9831010
}
9841011
};
9851012

986-
#endif // __T5_HPP__
1013+
#endif // __T5_HPP__

0 commit comments

Comments
 (0)