feat: introduce GGMLBlock and implement SVD(Broken) (#159)

* introduce GGMLBlock and implement SVD(Broken)

* add sdxl vae warning
update_ggml master-b636886
leejet 2024-02-24 20:06:39 +08:00 committed by GitHub
parent 349439f239
commit b6368868d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 4137 additions and 3818 deletions

View File

@ -60,7 +60,8 @@ endif()
set(SD_LIB stable-diffusion)
add_library(${SD_LIB} stable-diffusion.h stable-diffusion.cpp model.h model.cpp util.h util.cpp upscaler.cpp
ggml_extend.hpp clip.hpp common.hpp unet.hpp tae.hpp esrgan.hpp lora.hpp denoiser.hpp rng.hpp rng_philox.hpp)
ggml_extend.hpp clip.hpp common.hpp unet.hpp tae.hpp esrgan.hpp lora.hpp denoiser.hpp rng.hpp rng_philox.hpp
control.hpp preprocessing.hpp)
if(BUILD_SHARED_LIBS)
message("Build shared library")

View File

@ -329,6 +329,7 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- [ComfyUI](https://github.com/comfyanonymous/ComfyUI)
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
- [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)
- [generative-models](https://github.com/Stability-AI/generative-models/)

1024
clip.hpp

File diff suppressed because it is too large Load Diff

View File

@ -3,541 +3,527 @@
#include "ggml_extend.hpp"
struct DownSample {
// hparams
class DownSampleBlock : public GGMLBlock {
protected:
int channels;
int out_channels;
bool vae_downsample;
// conv2d params
struct ggml_tensor* op_w; // [out_channels, channels, 3, 3]
struct ggml_tensor* op_b; // [out_channels,]
bool vae_downsample = false;
size_t calculate_mem_size(ggml_type wtype) {
size_t mem_size = 0;
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // op_w
mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // op_b
return mem_size;
}
void init_params(struct ggml_context* ctx, ggml_type wtype) {
op_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels);
op_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
public:
DownSampleBlock(int channels,
int out_channels,
bool vae_downsample = false)
: channels(channels),
out_channels(out_channels),
vae_downsample(vae_downsample) {
if (vae_downsample) {
tensors[prefix + "conv.weight"] = op_w;
tensors[prefix + "conv.bias"] = op_b;
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {0, 0}));
} else {
tensors[prefix + "op.weight"] = op_w;
tensors[prefix + "op.bias"] = op_b;
blocks["op"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {2, 2}, {1, 1}));
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
struct ggml_tensor* c = NULL;
if (vae_downsample) {
c = ggml_pad(ctx, x, 1, 1, 0, 0);
c = ggml_nn_conv_2d(ctx, c, op_w, op_b, 2, 2, 0, 0);
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_pad(ctx, x, 1, 1, 0, 0);
x = conv->forward(ctx, x);
} else {
c = ggml_nn_conv_2d(ctx, x, op_w, op_b, 2, 2, 1, 1);
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["op"]);
x = conv->forward(ctx, x);
}
return c; // [N, out_channels, h/2, w/2]
return x; // [N, out_channels, h/2, w/2]
}
};
struct UpSample {
// hparams
class UpSampleBlock : public GGMLBlock {
protected:
int channels;
int out_channels;
// conv2d params
struct ggml_tensor* conv_w; // [out_channels, channels, 3, 3]
struct ggml_tensor* conv_b; // [out_channels,]
size_t calculate_mem_size(ggml_type wtype) {
size_t mem_size = 0;
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // op_w
mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // op_b
return mem_size;
}
void init_params(struct ggml_context* ctx, ggml_type wtype) {
conv_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels);
conv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "conv.weight"] = conv_w;
tensors[prefix + "conv.bias"] = conv_b;
public:
UpSampleBlock(int channels,
int out_channels)
: channels(channels),
out_channels(out_channels) {
blocks["conv"] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
x = ggml_nn_conv_2d(ctx, x, conv_w, conv_b, 1, 1, 1, 1); // [N, out_channels, h*2, w*2]
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2]
x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2]
return x;
}
};
struct ResBlock {
class ResBlock : public GGMLBlock {
protected:
// network hparams
int channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4)
int emb_channels; // time_embed_dim
int out_channels; // mult * model_channels
int64_t channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4)
int64_t emb_channels; // time_embed_dim
int64_t out_channels; // mult * model_channels
std::pair<int, int> kernel_size;
int dims;
bool skip_t_emb;
bool exchange_temb_dims;
// network params
// in_layers
struct ggml_tensor* in_layer_0_w; // [channels, ]
struct ggml_tensor* in_layer_0_b; // [channels, ]
// in_layer_1 is nn.SILU()
struct ggml_tensor* in_layer_2_w; // [out_channels, channels, 3, 3]
struct ggml_tensor* in_layer_2_b; // [out_channels, ]
// emb_layers
// emb_layer_0 is nn.SILU()
struct ggml_tensor* emb_layer_1_w; // [out_channels, emb_channels]
struct ggml_tensor* emb_layer_1_b; // [out_channels, ]
// out_layers
struct ggml_tensor* out_layer_0_w; // [out_channels, ]
struct ggml_tensor* out_layer_0_b; // [out_channels, ]
// out_layer_1 is nn.SILU()
// out_layer_2 is nn.Dropout(), p = 0 for inference
struct ggml_tensor* out_layer_3_w; // [out_channels, out_channels, 3, 3]
struct ggml_tensor* out_layer_3_b; // [out_channels, ]
// skip connection, only if out_channels != channels
struct ggml_tensor* skip_w; // [out_channels, channels, 1, 1]
struct ggml_tensor* skip_b; // [out_channels, ]
size_t calculate_mem_size(ggml_type wtype) {
size_t mem_size = 0;
mem_size += 2 * ggml_row_size(GGML_TYPE_F32, channels); // in_layer_0_w/b
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // in_layer_2_w
mem_size += 5 * ggml_row_size(GGML_TYPE_F32, out_channels); // in_layer_2_b/emb_layer_1_b/out_layer_0_w/out_layer_0_b/out_layer_3_b
mem_size += ggml_row_size(wtype, out_channels * emb_channels); // emb_layer_1_w
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * out_channels * 3 * 3); // out_layer_3_w
if (out_channels != channels) {
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 1 * 1); // skip_w
mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // skip_b
}
return mem_size;
}
void init_params(struct ggml_context* ctx, ggml_type wtype) {
in_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels);
in_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels);
in_layer_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels);
in_layer_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
emb_layer_1_w = ggml_new_tensor_2d(ctx, wtype, emb_channels, out_channels);
emb_layer_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
out_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
out_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
out_layer_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels);
out_layer_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
if (out_channels != channels) {
skip_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels, out_channels);
skip_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
std::shared_ptr<GGMLBlock> conv_nd(int dims,
int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
std::pair<int, int> padding) {
GGML_ASSERT(dims == 2 || dims == 3);
if (dims == 3) {
return std::shared_ptr<GGMLBlock>(new Conv3dnx1x1(in_channels, out_channels, kernel_size.first, 1, padding.first));
} else {
return std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, out_channels, kernel_size, {1, 1}, padding));
}
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "in_layers.0.weight"] = in_layer_0_w;
tensors[prefix + "in_layers.0.bias"] = in_layer_0_b;
tensors[prefix + "in_layers.2.weight"] = in_layer_2_w;
tensors[prefix + "in_layers.2.bias"] = in_layer_2_b;
public:
ResBlock(int64_t channels,
int64_t emb_channels,
int64_t out_channels,
std::pair<int, int> kernel_size = {3, 3},
int dims = 2,
bool exchange_temb_dims = false,
bool skip_t_emb = false)
: channels(channels),
emb_channels(emb_channels),
out_channels(out_channels),
kernel_size(kernel_size),
dims(dims),
skip_t_emb(skip_t_emb),
exchange_temb_dims(exchange_temb_dims) {
std::pair<int, int> padding = {kernel_size.first / 2, kernel_size.second / 2};
blocks["in_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(channels));
// in_layer_1 is nn.SILU()
blocks["in_layers.2"] = conv_nd(dims, channels, out_channels, kernel_size, padding);
tensors[prefix + "emb_layers.1.weight"] = emb_layer_1_w;
tensors[prefix + "emb_layers.1.bias"] = emb_layer_1_b;
if (!skip_t_emb) {
// emb_layer_0 is nn.SILU()
blocks["emb_layers.1"] = std::shared_ptr<GGMLBlock>(new Linear(emb_channels, out_channels));
}
tensors[prefix + "out_layers.0.weight"] = out_layer_0_w;
tensors[prefix + "out_layers.0.bias"] = out_layer_0_b;
tensors[prefix + "out_layers.3.weight"] = out_layer_3_w;
tensors[prefix + "out_layers.3.bias"] = out_layer_3_b;
blocks["out_layers.0"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(out_channels));
// out_layer_1 is nn.SILU()
// out_layer_2 is nn.Dropout(), skip for inference
blocks["out_layers.3"] = conv_nd(dims, out_channels, out_channels, kernel_size, padding);
if (out_channels != channels) {
tensors[prefix + "skip_connection.weight"] = skip_w;
tensors[prefix + "skip_connection.bias"] = skip_b;
blocks["skip_connection"] = conv_nd(dims, channels, out_channels, {1, 1}, {0, 0});
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb) {
// x: [N, channels, h, w]
// emb: [N, emb_channels]
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb = NULL) {
// For dims==3, we reduce dimension from 5d to 4d by merging h and w, in order not to change ggml
// [N, c, t, h, w] => [N, c, t, h * w]
// x: [N, channels, h, w] if dims == 2 else [N, channels, t, h, w]
// emb: [N, emb_channels] if dims == 2 else [N, t, emb_channels]
auto in_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["in_layers.0"]);
auto in_layers_2 = std::dynamic_pointer_cast<UnaryBlock>(blocks["in_layers.2"]);
auto out_layers_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out_layers.0"]);
auto out_layers_3 = std::dynamic_pointer_cast<UnaryBlock>(blocks["out_layers.3"]);
if (emb == NULL) {
GGML_ASSERT(skip_t_emb);
}
// in_layers
auto h = ggml_nn_group_norm(ctx, x, in_layer_0_w, in_layer_0_b);
auto h = in_layers_0->forward(ctx, x);
h = ggml_silu_inplace(ctx, h);
h = ggml_nn_conv_2d(ctx, h, in_layer_2_w, in_layer_2_b, 1, 1, 1, 1); // [N, out_channels, h, w]
h = in_layers_2->forward(ctx, h); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
// emb_layers
auto emb_out = ggml_silu(ctx, emb);
emb_out = ggml_nn_linear(ctx, emb_out, emb_layer_1_w, emb_layer_1_b); // [N, out_channels]
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
if (!skip_t_emb) {
auto emb_layer_1 = std::dynamic_pointer_cast<Linear>(blocks["emb_layers.1"]);
auto emb_out = ggml_silu(ctx, emb);
emb_out = emb_layer_1->forward(ctx, emb_out); // [N, out_channels] if dims == 2 else [N, t, out_channels]
if (dims == 2) {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1]
} else {
emb_out = ggml_reshape_4d(ctx, emb_out, 1, emb_out->ne[0], emb_out->ne[1], emb_out->ne[2]); // [N, t, out_channels, 1]
if (exchange_temb_dims) {
// emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
emb_out = ggml_cont(ctx, ggml_permute(ctx, emb_out, 0, 2, 1, 3)); // [N, out_channels, t, 1]
}
}
h = ggml_add(ctx, h, emb_out); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
// out_layers
h = ggml_add(ctx, h, emb_out);
h = ggml_nn_group_norm(ctx, h, out_layer_0_w, out_layer_0_b);
h = out_layers_0->forward(ctx, h);
h = ggml_silu_inplace(ctx, h);
// dropout, skip for inference
h = ggml_nn_conv_2d(ctx, h, out_layer_3_w, out_layer_3_b, 1, 1, 1, 1); // [N, out_channels, h, w]
h = out_layers_3->forward(ctx, h);
// skip connection
if (out_channels != channels) {
x = ggml_nn_conv_2d(ctx, x, skip_w, skip_b); // [N, out_channels, h, w]
auto skip_connection = std::dynamic_pointer_cast<UnaryBlock>(blocks["skip_connection"]);
x = skip_connection->forward(ctx, x); // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
h = ggml_add(ctx, h, x);
return h; // [N, out_channels, h, w]
return h; // [N, out_channels, h, w] if dims == 2 else [N, out_channels, t, h, w]
}
};
struct SpatialTransformer {
int in_channels; // mult * model_channels
int n_head; // num_heads
int d_head; // in_channels // n_heads
int depth = 1; // 1
int context_dim = 768; // hidden_size, 1024 for VERSION_2_x
class GEGLU : public GGMLBlock {
protected:
int64_t dim_in;
int64_t dim_out;
// group norm
struct ggml_tensor* norm_w; // [in_channels,]
struct ggml_tensor* norm_b; // [in_channels,]
// proj_in
struct ggml_tensor* proj_in_w; // [in_channels, in_channels, 1, 1]
struct ggml_tensor* proj_in_b; // [in_channels,]
// transformer
struct Transformer {
// layer norm 1
struct ggml_tensor* norm1_w; // [in_channels, ]
struct ggml_tensor* norm1_b; // [in_channels, ]
// attn1
struct ggml_tensor* attn1_q_w; // [in_channels, in_channels]
struct ggml_tensor* attn1_k_w; // [in_channels, in_channels]
struct ggml_tensor* attn1_v_w; // [in_channels, in_channels]
struct ggml_tensor* attn1_out_w; // [in_channels, in_channels]
struct ggml_tensor* attn1_out_b; // [in_channels, ]
// layer norm 2
struct ggml_tensor* norm2_w; // [in_channels, ]
struct ggml_tensor* norm2_b; // [in_channels, ]
// attn2
struct ggml_tensor* attn2_q_w; // [in_channels, in_channels]
struct ggml_tensor* attn2_k_w; // [in_channels, context_dim]
struct ggml_tensor* attn2_v_w; // [in_channels, context_dim]
struct ggml_tensor* attn2_out_w; // [in_channels, in_channels]
struct ggml_tensor* attn2_out_b; // [in_channels, ]
// layer norm 3
struct ggml_tensor* norm3_w; // [in_channels, ]
struct ggml_tensor* norm3_b; // [in_channels, ]
// ff
struct ggml_tensor* ff_0_proj_w; // [in_channels * 4 * 2, in_channels]
struct ggml_tensor* ff_0_proj_b; // [in_channels * 4 * 2]
struct ggml_tensor* ff_2_w; // [in_channels, in_channels * 4]
struct ggml_tensor* ff_2_b; // [in_channels,]
};
std::vector<Transformer> transformers;
// proj_out
struct ggml_tensor* proj_out_w; // [in_channels, in_channels, 1, 1]
struct ggml_tensor* proj_out_b; // [in_channels,]
SpatialTransformer(int depth = 1)
: depth(depth) {
transformers.resize(depth);
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_out * 2);
}
int get_num_tensors() {
return depth * 20 + 7;
public:
GEGLU(int64_t dim_in, int64_t dim_out)
: dim_in(dim_in), dim_out(dim_out) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [ne3, ne2, ne1, dim_in]
// return: [ne3, ne2, ne1, dim_out]
struct ggml_tensor* w = params["proj.weight"];
struct ggml_tensor* b = params["proj.bias"];
auto x_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
auto x_b = ggml_view_1d(ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
auto gate_w = ggml_view_2d(ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
auto gate_b = ggml_view_1d(ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
auto x_in = x;
x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
gate = ggml_gelu_inplace(ctx, gate);
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, dim_out]
return x;
}
};
class FeedForward : public GGMLBlock {
public:
FeedForward(int64_t dim,
int64_t dim_out,
int64_t mult = 4) {
int64_t inner_dim = dim * mult;
blocks["net.0"] = std::shared_ptr<GGMLBlock>(new GEGLU(dim, inner_dim));
// net_1 is nn.Dropout(), skip for inference
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
}
size_t calculate_mem_size(ggml_type wtype) {
size_t mem_size = 0;
mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm_w/norm_b
mem_size += 2 * ggml_row_size(GGML_TYPE_F16, in_channels * in_channels * 1 * 1); // proj_in_w/proj_out_w
mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // proj_in_b/proj_out_b
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [ne3, ne2, ne1, dim]
// return: [ne3, ne2, ne1, dim_out]
// transformer
for (auto& transformer : transformers) {
mem_size += 6 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm1-3_w/b
mem_size += 6 * ggml_row_size(wtype, in_channels * in_channels); // attn1_q/k/v/out_w attn2_q/out_w
mem_size += 2 * ggml_row_size(wtype, in_channels * context_dim); // attn2_k/v_w
mem_size += ggml_row_size(wtype, in_channels * 4 * 2 * in_channels); // ff_0_proj_w
mem_size += ggml_row_size(GGML_TYPE_F32, in_channels * 4 * 2); // ff_0_proj_b
mem_size += ggml_row_size(wtype, in_channels * 4 * in_channels); // ff_2_w
mem_size += ggml_row_size(GGML_TYPE_F32, in_channels); // ff_2_b
}
return mem_size;
auto net_0 = std::dynamic_pointer_cast<GEGLU>(blocks["net.0"]);
auto net_2 = std::dynamic_pointer_cast<Linear>(blocks["net.2"]);
x = net_0->forward(ctx, x); // [ne3, ne2, ne1, inner_dim]
x = net_2->forward(ctx, x); // [ne3, ne2, ne1, dim_out]
return x;
}
};
void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) {
norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
proj_in_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels);
proj_in_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
class CrossAttention : public GGMLBlock {
protected:
int64_t query_dim;
int64_t context_dim;
int64_t n_head;
int64_t d_head;
proj_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels);
proj_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
public:
CrossAttention(int64_t query_dim,
int64_t context_dim,
int64_t n_head,
int64_t d_head)
: n_head(n_head),
d_head(d_head),
query_dim(query_dim),
context_dim(context_dim) {
int64_t inner_dim = d_head * n_head;
// transformer
for (auto& transformer : transformers) {
transformer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
transformer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
blocks["to_q"] = std::shared_ptr<GGMLBlock>(new Linear(query_dim, inner_dim, false));
blocks["to_k"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
blocks["to_v"] = std::shared_ptr<GGMLBlock>(new Linear(context_dim, inner_dim, false));
transformer.attn1_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels);
transformer.attn1_k_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels);
transformer.attn1_v_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels);
transformer.attn1_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels);
transformer.attn1_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
transformer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
transformer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
transformer.attn2_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels);
transformer.attn2_k_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels);
transformer.attn2_v_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels);
transformer.attn2_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels);
transformer.attn2_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
transformer.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
transformer.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
transformer.ff_0_proj_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels * 4 * 2);
transformer.ff_0_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels * 4 * 2);
transformer.ff_2_w = ggml_new_tensor_2d(ctx, wtype, in_channels * 4, in_channels);
transformer.ff_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
}
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "norm.weight"] = norm_w;
tensors[prefix + "norm.bias"] = norm_b;
tensors[prefix + "proj_in.weight"] = proj_in_w;
tensors[prefix + "proj_in.bias"] = proj_in_b;
// transformer
for (int i = 0; i < transformers.size(); i++) {
auto& transformer = transformers[i];
std::string transformer_prefix = prefix + "transformer_blocks." + std::to_string(i) + ".";
tensors[transformer_prefix + "attn1.to_q.weight"] = transformer.attn1_q_w;
tensors[transformer_prefix + "attn1.to_k.weight"] = transformer.attn1_k_w;
tensors[transformer_prefix + "attn1.to_v.weight"] = transformer.attn1_v_w;
tensors[transformer_prefix + "attn1.to_out.0.weight"] = transformer.attn1_out_w;
tensors[transformer_prefix + "attn1.to_out.0.bias"] = transformer.attn1_out_b;
tensors[transformer_prefix + "ff.net.0.proj.weight"] = transformer.ff_0_proj_w;
tensors[transformer_prefix + "ff.net.0.proj.bias"] = transformer.ff_0_proj_b;
tensors[transformer_prefix + "ff.net.2.weight"] = transformer.ff_2_w;
tensors[transformer_prefix + "ff.net.2.bias"] = transformer.ff_2_b;
tensors[transformer_prefix + "attn2.to_q.weight"] = transformer.attn2_q_w;
tensors[transformer_prefix + "attn2.to_k.weight"] = transformer.attn2_k_w;
tensors[transformer_prefix + "attn2.to_v.weight"] = transformer.attn2_v_w;
tensors[transformer_prefix + "attn2.to_out.0.weight"] = transformer.attn2_out_w;
tensors[transformer_prefix + "attn2.to_out.0.bias"] = transformer.attn2_out_b;
tensors[transformer_prefix + "norm1.weight"] = transformer.norm1_w;
tensors[transformer_prefix + "norm1.bias"] = transformer.norm1_b;
tensors[transformer_prefix + "norm2.weight"] = transformer.norm2_w;
tensors[transformer_prefix + "norm2.bias"] = transformer.norm2_b;
tensors[transformer_prefix + "norm3.weight"] = transformer.norm3_w;
tensors[transformer_prefix + "norm3.bias"] = transformer.norm3_b;
}
tensors[prefix + "proj_out.weight"] = proj_out_w;
tensors[prefix + "proj_out.bias"] = proj_out_b;
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, query_dim));
// to_out_1 is nn.Dropout(), skip for inference
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
// x: [N, in_channels, h, w]
// context: [N, max_position, hidden_size(aka context_dim)]
auto x_in = x;
x = ggml_nn_group_norm(ctx, x, norm_w, norm_b);
// proj_in
x = ggml_nn_conv_2d(ctx, x, proj_in_w, proj_in_b); // [N, in_channels, h, w]
// x: [N, n_token, query_dim]
// context: [N, n_context, context_dim]
// return: [N, n_token, query_dim]
auto to_q = std::dynamic_pointer_cast<Linear>(blocks["to_q"]);
auto to_k = std::dynamic_pointer_cast<Linear>(blocks["to_k"]);
auto to_v = std::dynamic_pointer_cast<Linear>(blocks["to_v"]);
auto to_out_0 = std::dynamic_pointer_cast<Linear>(blocks["to_out.0"]);
// transformer
const int64_t n = x->ne[3];
const int64_t c = x->ne[2];
const int64_t h = x->ne[1];
const int64_t w = x->ne[0];
const int64_t max_position = context->ne[1];
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, in_channels]
int64_t n = x->ne[2];
int64_t n_token = x->ne[1];
int64_t n_context = context->ne[1];
int64_t inner_dim = d_head * n_head;
for (auto& transformer : transformers) {
auto r = x;
// layer norm 1
x = ggml_reshape_2d(ctx, x, c, w * h * n);
x = ggml_nn_layer_norm(ctx, x, transformer.norm1_w, transformer.norm1_b);
auto q = to_q->forward(ctx, x); // [N, n_token, inner_dim]
q = ggml_reshape_4d(ctx, q, d_head, n_head, n_token, n); // [N, n_token, n_head, d_head]
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
q = ggml_reshape_3d(ctx, q, d_head, n_token, n_head * n); // [N * n_head, n_token, d_head]
// self-attention
{
x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels]
struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn1_q_w, x); // [N * h * w, in_channels]
#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL)
q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head));
#endif
q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head]
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head]
q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head]
auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim]
k = ggml_reshape_4d(ctx, k, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_context, d_head]
k = ggml_reshape_3d(ctx, k, d_head, n_context, n_head * n); // [N * n_head, n_context, d_head]
struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn1_k_w, x); // [N * h * w, in_channels]
k = ggml_reshape_4d(ctx, k, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head]
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, h * w, d_head]
k = ggml_reshape_3d(ctx, k, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head]
auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim]
v = ggml_reshape_4d(ctx, v, d_head, n_head, n_context, n); // [N, n_context, n_head, d_head]
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, n_context]
v = ggml_reshape_3d(ctx, v, n_context, d_head, n_head * n); // [N * n_head, d_head, n_context]
struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn1_v_w, x); // [N * h * w, in_channels]
v = ggml_reshape_4d(ctx, v, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head]
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, h * w]
v = ggml_reshape_3d(ctx, v, h * w, d_head, n_head * n); // [N * n_head, d_head, h * w]
auto kqv = ggml_nn_attention(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, n);
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head]
#else
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, h * w]
// kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
kq = ggml_soft_max_inplace(ctx, kq);
x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, n); // [N, n_token, inner_dim]
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head]
#endif
kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n);
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, h * w, n_head, d_head]
x = to_out_0->forward(ctx, x); // [N, n_token, query_dim]
return x;
}
};
// x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n));
x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n);
class BasicTransformerBlock : public GGMLBlock {
protected:
int64_t n_head;
int64_t d_head;
bool ff_in;
x = ggml_nn_linear(ctx, x, transformer.attn1_out_w, transformer.attn1_out_b);
public:
BasicTransformerBlock(int64_t dim,
int64_t n_head,
int64_t d_head,
int64_t context_dim,
bool ff_in = false)
: n_head(n_head), d_head(d_head), ff_in(ff_in) {
// disable_self_attn is always False
// disable_temporal_crossattention is always False
// switch_temporal_ca_to_sa is always False
// inner_dim is always None or equal to dim
// gated_ff is always True
blocks["attn1"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, dim, n_head, d_head));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new CrossAttention(dim, context_dim, n_head, d_head));
blocks["ff"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
x = ggml_reshape_4d(ctx, x, c, w, h, n);
}
if (ff_in) {
blocks["norm_in"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim));
blocks["ff_in"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim));
}
}
x = ggml_add(ctx, x, r);
r = x;
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
// x: [N, n_token, query_dim]
// context: [N, n_context, context_dim]
// return: [N, n_token, query_dim]
// layer norm 2
x = ggml_nn_layer_norm(ctx, x, transformer.norm2_w, transformer.norm2_b);
auto attn1 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn1"]);
auto attn2 = std::dynamic_pointer_cast<CrossAttention>(blocks["attn2"]);
auto ff = std::dynamic_pointer_cast<FeedForward>(blocks["ff"]);
auto norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm1"]);
auto norm2 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm2"]);
auto norm3 = std::dynamic_pointer_cast<LayerNorm>(blocks["norm3"]);
// cross-attention
{
x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels]
context = ggml_reshape_2d(ctx, context, context->ne[0], context->ne[1] * context->ne[2]); // [N * max_position, hidden_size]
struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn2_q_w, x); // [N * h * w, in_channels]
#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL)
q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head));
#endif
q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head]
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head]
q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head]
if (ff_in) {
auto norm_in = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_in"]);
auto ff_in = std::dynamic_pointer_cast<FeedForward>(blocks["ff_in"]);
struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn2_k_w, context); // [N * max_position, in_channels]
k = ggml_reshape_4d(ctx, k, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head]
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, max_position, d_head]
k = ggml_reshape_3d(ctx, k, d_head, max_position, n_head * n); // [N * n_head, max_position, d_head]
struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn2_v_w, context); // [N * max_position, in_channels]
v = ggml_reshape_4d(ctx, v, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head]
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, max_position]
v = ggml_reshape_3d(ctx, v, max_position, d_head, n_head * n); // [N * n_head, d_head, max_position]
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head]
#else
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position]
// kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
kq = ggml_soft_max_inplace(ctx, kq);
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head]
#endif
kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n);
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3));
// x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); // [N * h * w, in_channels]
x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); // [N * h * w, in_channels]
x = ggml_nn_linear(ctx, x, transformer.attn2_out_w, transformer.attn2_out_b);
x = ggml_reshape_4d(ctx, x, c, w, h, n);
}
x = ggml_add(ctx, x, r);
r = x;
// layer norm 3
x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels]
x = ggml_nn_layer_norm(ctx, x, transformer.norm3_w, transformer.norm3_b);
// ff
{
// GEGLU
auto x_w = ggml_view_2d(ctx,
transformer.ff_0_proj_w,
transformer.ff_0_proj_w->ne[0],
transformer.ff_0_proj_w->ne[1] / 2,
transformer.ff_0_proj_w->nb[1],
0); // [in_channels * 4, in_channels]
auto x_b = ggml_view_1d(ctx,
transformer.ff_0_proj_b,
transformer.ff_0_proj_b->ne[0] / 2,
0); // [in_channels * 4, in_channels]
auto gate_w = ggml_view_2d(ctx,
transformer.ff_0_proj_w,
transformer.ff_0_proj_w->ne[0],
transformer.ff_0_proj_w->ne[1] / 2,
transformer.ff_0_proj_w->nb[1],
transformer.ff_0_proj_w->nb[1] * transformer.ff_0_proj_w->ne[1] / 2); // [in_channels * 4, ]
auto gate_b = ggml_view_1d(ctx,
transformer.ff_0_proj_b,
transformer.ff_0_proj_b->ne[0] / 2,
transformer.ff_0_proj_b->nb[0] * transformer.ff_0_proj_b->ne[0] / 2); // [in_channels * 4, ]
x = ggml_reshape_2d(ctx, x, c, w * h * n);
auto x_in = x;
x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [N * h * w, in_channels * 4]
auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [N * h * w, in_channels * 4]
gate = ggml_gelu_inplace(ctx, gate);
x = ggml_mul(ctx, x, gate); // [N * h * w, in_channels * 4]
// fc
x = ggml_nn_linear(ctx, x, transformer.ff_2_w, transformer.ff_2_b); // [N * h * w, in_channels]
}
x = ggml_reshape_4d(ctx, x, c, w, h, n); // [N, h, w, in_channels]
// residual
x = ggml_add(ctx, x, r);
auto x_skip = x;
x = norm_in->forward(ctx, x);
x = ff_in->forward(ctx, x);
// self.is_res is always True
x = ggml_add(ctx, x, x_skip);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w]
auto r = x;
x = norm1->forward(ctx, x);
x = attn1->forward(ctx, x, x); // self-attention
x = ggml_add(ctx, x, r);
r = x;
x = norm2->forward(ctx, x);
x = attn2->forward(ctx, x, context); // cross-attention
x = ggml_add(ctx, x, r);
r = x;
x = norm3->forward(ctx, x);
x = ff->forward(ctx, x);
x = ggml_add(ctx, x, r);
return x;
}
};
class SpatialTransformer : public GGMLBlock {
protected:
int64_t in_channels; // mult * model_channels
int64_t n_head;
int64_t d_head;
int64_t depth = 1; // 1
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_2_x
public:
SpatialTransformer(int64_t in_channels,
int64_t n_head,
int64_t d_head,
int64_t depth,
int64_t context_dim)
: in_channels(in_channels),
n_head(n_head),
d_head(d_head),
depth(depth),
context_dim(context_dim) {
// We will convert unet transformer linear to conv2d 1x1 when loading the weights, so use_linear is always False
// disable_self_attn is always False
int64_t inner_dim = n_head * d_head; // in_channels
blocks["norm"] = std::shared_ptr<GGMLBlock>(new GroupNorm32(in_channels));
blocks["proj_in"] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, inner_dim, {1, 1}));
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim));
}
blocks["proj_out"] = std::shared_ptr<GGMLBlock>(new Conv2d(inner_dim, in_channels, {1, 1}));
}
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) {
// x: [N, in_channels, h, w]
// context: [N, max_position(aka n_token), hidden_size(aka context_dim)]
auto norm = std::dynamic_pointer_cast<GroupNorm32>(blocks["norm"]);
auto proj_in = std::dynamic_pointer_cast<Conv2d>(blocks["proj_in"]);
auto proj_out = std::dynamic_pointer_cast<Conv2d>(blocks["proj_out"]);
auto x_in = x;
int64_t n = x->ne[3];
int64_t h = x->ne[1];
int64_t w = x->ne[0];
int64_t inner_dim = n_head * d_head;
x = norm->forward(ctx, x);
x = proj_in->forward(ctx, x); // [N, inner_dim, h, w]
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
for (int i = 0; i < depth; i++) {
std::string name = "transformer_blocks." + std::to_string(i);
auto transformer_block = std::dynamic_pointer_cast<BasicTransformerBlock>(blocks[name]);
x = transformer_block->forward(ctx, x, context);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3)); // [N, inner_dim, h * w]
x = ggml_reshape_4d(ctx, x, w, h, inner_dim, n); // [N, inner_dim, h, w]
// proj_out
x = ggml_nn_conv_2d(ctx, x, proj_out_w, proj_out_b); // [N, in_channels, h, w]
x = proj_out->forward(ctx, x); // [N, in_channels, h, w]
x = ggml_add(ctx, x, x_in);
return x;
}
};
class AlphaBlender : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
}
float get_alpha() {
// image_only_indicator is always tensor([0.]) and since mix_factor.shape is [1,]
// so learned_with_images is same as learned
float alpha = ggml_backend_tensor_get_f32(params["mix_factor"]);
return sigmoid(alpha);
}
public:
AlphaBlender() {
// merge_strategy is always learned_with_images
// for inference, we don't need to set alpha
// since mix_factor.shape is [1,], we don't need rearrange using rearrange_pattern
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x_spatial,
struct ggml_tensor* x_temporal) {
// image_only_indicator is always tensor([0.])
float alpha = get_alpha();
auto x = ggml_add(ctx,
ggml_scale(ctx, x_spatial, alpha),
ggml_scale(ctx, x_temporal, 1.0f - alpha));
return x;
}
};
class VideoResBlock : public ResBlock {
public:
VideoResBlock(int channels,
int emb_channels,
int out_channels,
std::pair<int, int> kernel_size = {3, 3},
int64_t video_kernel_size = 3,
int dims = 2) // always 2
: ResBlock(channels, emb_channels, out_channels, kernel_size, dims) {
blocks["time_stack"] = std::shared_ptr<GGMLBlock>(new ResBlock(out_channels, emb_channels, out_channels, kernel_size, 3, true));
blocks["time_mixer"] = std::shared_ptr<GGMLBlock>(new AlphaBlender());
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* emb,
int num_video_frames) {
// x: [N, channels, h, w] aka [b*t, channels, h, w]
// emb: [N, emb_channels] aka [b*t, emb_channels]
// image_only_indicator is always tensor([0.])
auto time_stack = std::dynamic_pointer_cast<ResBlock>(blocks["time_stack"]);
auto time_mixer = std::dynamic_pointer_cast<AlphaBlender>(blocks["time_mixer"]);
x = ResBlock::forward(ctx, x, emb);
int64_t T = num_video_frames;
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
auto x_mix = x;
emb = ggml_reshape_4d(ctx, emb, emb->ne[0], T, B, emb->ne[3]); // (b t) ... -> b t ...
x = time_stack->forward(ctx, x, emb); // b t c (h w)
x = time_mixer->forward(ctx, x_mix, x); // b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x;
}
};
#endif // __COMMON_HPP__

File diff suppressed because it is too large Load Diff

View File

@ -12,279 +12,160 @@
*/
struct ResidualDenseBlock {
int num_features;
class ResidualDenseBlock : public GGMLBlock {
protected:
int num_feat;
int num_grow_ch;
ggml_tensor* conv1_w; // [num_grow_ch, num_features, 3, 3]
ggml_tensor* conv1_b; // [num_grow_ch]
ggml_tensor* conv2_w; // [num_grow_ch, num_features + num_grow_ch, 3, 3]
ggml_tensor* conv2_b; // [num_grow_ch]
ggml_tensor* conv3_w; // [num_grow_ch, num_features + 2 * num_grow_ch, 3, 3]
ggml_tensor* conv3_b; // [num_grow_ch]
ggml_tensor* conv4_w; // [num_grow_ch, num_features + 3 * num_grow_ch, 3, 3]
ggml_tensor* conv4_b; // [num_grow_ch]
ggml_tensor* conv5_w; // [num_features, num_features + 4 * num_grow_ch, 3, 3]
ggml_tensor* conv5_b; // [num_features]
ResidualDenseBlock() {}
ResidualDenseBlock(int num_feat, int n_grow_ch) {
num_features = num_feat;
num_grow_ch = n_grow_ch;
public:
ResidualDenseBlock(int num_feat = 64, int num_grow_ch = 32)
: num_feat(num_feat), num_grow_ch(num_grow_ch) {
blocks["conv1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
blocks["conv2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
blocks["conv3"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
blocks["conv4"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, {3, 3}, {1, 1}, {1, 1}));
blocks["conv5"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat + 4 * num_grow_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
}
size_t calculate_mem_size() {
size_t mem_size = num_features * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv1_w
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv1_b
mem_size += (num_features + num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv2_w
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv2_b
mem_size += (num_features + 2 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv3_w
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv3_w
mem_size += (num_features + 3 * num_grow_ch) * num_grow_ch * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv4_w
mem_size += num_grow_ch * ggml_type_size(GGML_TYPE_F32); // conv4_w
mem_size += (num_features + 4 * num_grow_ch) * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv5_w
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv5_w
return mem_size;
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx, x, 0.2f, true);
}
int get_num_tensors() {
int num_tensors = 10;
return num_tensors;
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [n, num_feat, h, w]
// return: [n, num_feat, h, w]
void init_params(ggml_context* ctx) {
conv1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features, num_grow_ch);
conv1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
conv2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + num_grow_ch, num_grow_ch);
conv2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
conv3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 2 * num_grow_ch, num_grow_ch);
conv3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
conv4_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 3 * num_grow_ch, num_grow_ch);
conv4_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_grow_ch);
conv5_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, num_features + 4 * num_grow_ch, num_features);
conv5_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_features);
}
auto conv1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv1"]);
auto conv2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv2"]);
auto conv3 = std::dynamic_pointer_cast<Conv2d>(blocks["conv3"]);
auto conv4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv4"]);
auto conv5 = std::dynamic_pointer_cast<Conv2d>(blocks["conv5"]);
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
tensors[prefix + "conv1.weight"] = conv1_w;
tensors[prefix + "conv1.bias"] = conv1_b;
auto x1 = lrelu(ctx, conv1->forward(ctx, x));
auto x_cat = ggml_concat(ctx, x, x1);
auto x2 = lrelu(ctx, conv2->forward(ctx, x_cat));
x_cat = ggml_concat(ctx, x_cat, x2);
auto x3 = lrelu(ctx, conv3->forward(ctx, x_cat));
x_cat = ggml_concat(ctx, x_cat, x3);
auto x4 = lrelu(ctx, conv4->forward(ctx, x_cat));
x_cat = ggml_concat(ctx, x_cat, x4);
auto x5 = conv5->forward(ctx, x_cat);
tensors[prefix + "conv2.weight"] = conv2_w;
tensors[prefix + "conv2.bias"] = conv2_b;
tensors[prefix + "conv3.weight"] = conv3_w;
tensors[prefix + "conv3.bias"] = conv3_b;
tensors[prefix + "conv4.weight"] = conv4_w;
tensors[prefix + "conv4.bias"] = conv4_b;
tensors[prefix + "conv5.weight"] = conv5_w;
tensors[prefix + "conv5.bias"] = conv5_b;
}
ggml_tensor* forward(ggml_context* ctx, float out_scale, ggml_tensor* x /* feat */) {
// x1 = self.lrelu(self.conv1(x))
ggml_tensor* x1 = ggml_nn_conv_2d(ctx, x, conv1_w, conv1_b, 1, 1, 1, 1);
x1 = ggml_leaky_relu(ctx, x1, 0.2f, true);
// x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
ggml_tensor* x_cat = ggml_concat(ctx, x, x1);
ggml_tensor* x2 = ggml_nn_conv_2d(ctx, x_cat, conv2_w, conv2_b, 1, 1, 1, 1);
x2 = ggml_leaky_relu(ctx, x2, 0.2f, true);
// x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x_cat = ggml_concat(ctx, x_cat, x2);
ggml_tensor* x3 = ggml_nn_conv_2d(ctx, x_cat, conv3_w, conv3_b, 1, 1, 1, 1);
x3 = ggml_leaky_relu(ctx, x3, 0.2f, true);
// x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x_cat = ggml_concat(ctx, x_cat, x3);
ggml_tensor* x4 = ggml_nn_conv_2d(ctx, x_cat, conv4_w, conv4_b, 1, 1, 1, 1);
x4 = ggml_leaky_relu(ctx, x4, 0.2f, true);
// self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
x_cat = ggml_concat(ctx, x_cat, x4);
ggml_tensor* x5 = ggml_nn_conv_2d(ctx, x_cat, conv5_w, conv5_b, 1, 1, 1, 1);
// return x5 * 0.2 + x
x5 = ggml_add(ctx, ggml_scale(ctx, x5, out_scale), x);
x5 = ggml_add(ctx, ggml_scale(ctx, x5, 0.2f), x);
return x5;
}
};
struct EsrganBlock {
ResidualDenseBlock rd_blocks[3];
int num_residual_blocks = 3;
EsrganBlock() {}
EsrganBlock(int num_feat, int num_grow_ch) {
for (int i = 0; i < num_residual_blocks; i++) {
rd_blocks[i] = ResidualDenseBlock(num_feat, num_grow_ch);
}
class RRDB : public GGMLBlock {
public:
RRDB(int num_feat, int num_grow_ch = 32) {
blocks["rdb1"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
blocks["rdb2"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
blocks["rdb3"] = std::shared_ptr<GGMLBlock>(new ResidualDenseBlock(num_feat, num_grow_ch));
}
int get_num_tensors() {
int num_tensors = 0;
for (int i = 0; i < num_residual_blocks; i++) {
num_tensors += rd_blocks[i].get_num_tensors();
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [n, num_feat, h, w]
// return: [n, num_feat, h, w]
auto rdb1 = std::dynamic_pointer_cast<ResidualDenseBlock>(blocks["rdb1"]);
auto rdb2 = std::dynamic_pointer_cast<ResidualDenseBlock>(blocks["rdb2"]);
auto rdb3 = std::dynamic_pointer_cast<ResidualDenseBlock>(blocks["rdb3"]);
auto out = rdb1->forward(ctx, x);
out = rdb2->forward(ctx, out);
out = rdb3->forward(ctx, out);
out = ggml_add(ctx, ggml_scale(ctx, out, 0.2f), x);
return out;
}
};
class RRDBNet : public GGMLBlock {
protected:
int scale = 4; // default RealESRGAN_x4plus_anime_6B
int num_block = 6; // default RealESRGAN_x4plus_anime_6B
int num_in_ch = 3;
int num_out_ch = 3;
int num_feat = 64; // default RealESRGAN_x4plus_anime_6B
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
public:
RRDBNet() {
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
for (int i = 0; i < num_block; i++) {
std::string name = "body." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch));
}
return num_tensors;
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
// upsample
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}));
}
size_t calculate_mem_size() {
size_t mem_size = 0;
for (int i = 0; i < num_residual_blocks; i++) {
mem_size += rd_blocks[i].calculate_mem_size();
}
return mem_size;
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx, x, 0.2f, true);
}
void init_params(ggml_context* ctx) {
for (int i = 0; i < num_residual_blocks; i++) {
rd_blocks[i].init_params(ctx);
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [n, num_in_ch, h, w]
// return: [n, num_out_ch, h*4, w*4]
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
auto conv_body = std::dynamic_pointer_cast<Conv2d>(blocks["conv_body"]);
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks["conv_hr"]);
auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]);
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
for (int i = 0; i < num_residual_blocks; i++) {
rd_blocks[i].map_by_name(tensors, prefix + "rdb" + std::to_string(i + 1) + ".");
}
}
auto feat = conv_first->forward(ctx, x);
auto body_feat = feat;
for (int i = 0; i < num_block; i++) {
std::string name = "body." + std::to_string(i);
auto block = std::dynamic_pointer_cast<RRDB>(blocks[name]);
ggml_tensor* forward(ggml_context* ctx, float out_scale, ggml_tensor* x) {
ggml_tensor* out = x;
for (int i = 0; i < num_residual_blocks; i++) {
// out = self.rdb...(x)
out = rd_blocks[i].forward(ctx, out_scale, out);
body_feat = block->forward(ctx, body_feat);
}
// return out * 0.2 + x
out = ggml_add(ctx, ggml_scale(ctx, out, out_scale), x);
body_feat = conv_body->forward(ctx, body_feat);
feat = ggml_add(ctx, feat, body_feat);
// upsample
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2)));
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2)));
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
return out;
}
};
struct ESRGAN : public GGMLModule {
int scale = 4; // default RealESRGAN_x4plus_anime_6B
int num_blocks = 6; // default RealESRGAN_x4plus_anime_6B
int in_channels = 3;
int out_channels = 3;
int num_features = 64; // default RealESRGAN_x4plus_anime_6B
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
RRDBNet rrdb_net;
int scale = 4;
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
ggml_tensor* conv_first_w; // [num_features, in_channels, 3, 3]
ggml_tensor* conv_first_b; // [num_features]
EsrganBlock body_blocks[6];
ggml_tensor* conv_body_w; // [num_features, num_features, 3, 3]
ggml_tensor* conv_body_b; // [num_features]
// upsample
ggml_tensor* conv_up1_w; // [num_features, num_features, 3, 3]
ggml_tensor* conv_up1_b; // [num_features]
ggml_tensor* conv_up2_w; // [num_features, num_features, 3, 3]
ggml_tensor* conv_up2_b; // [num_features]
ggml_tensor* conv_hr_w; // [num_features, num_features, 3, 3]
ggml_tensor* conv_hr_b; // [num_features]
ggml_tensor* conv_last_w; // [out_channels, num_features, 3, 3]
ggml_tensor* conv_last_b; // [out_channels]
bool decode_only = false;
ESRGAN() {
name = "esrgan";
for (int i = 0; i < num_blocks; i++) {
body_blocks[i] = EsrganBlock(num_features, num_grow_ch);
}
ESRGAN(ggml_backend_t backend,
ggml_type wtype)
: GGMLModule(backend, wtype) {
rrdb_net.init(params_ctx, wtype);
}
size_t calculate_mem_size() {
size_t mem_size = num_features * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_first_b
for (int i = 0; i < num_blocks; i++) {
mem_size += body_blocks[i].calculate_mem_size();
}
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_body_w
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_body_w
// upsample
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up1_w
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up1_b
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_up2_w
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_up2_b
mem_size += num_features * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_hr_w
mem_size += num_features * ggml_type_size(GGML_TYPE_F32); // conv_hr_b
mem_size += out_channels * num_features * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_last_w
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_last_b
return mem_size;
std::string get_desc() {
return "esrgan";
}
size_t get_num_tensors() {
size_t num_tensors = 12;
for (int i = 0; i < num_blocks; i++) {
num_tensors += body_blocks[i].get_num_tensors();
}
return num_tensors;
size_t get_params_mem_size() {
return rrdb_net.get_params_mem_size();
}
void init_params() {
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
conv_first_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, in_channels, num_features);
conv_first_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
conv_body_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
conv_body_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
conv_up1_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
conv_up1_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
conv_up2_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
conv_up2_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
conv_hr_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, num_features);
conv_hr_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, num_features);
conv_last_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, num_features, out_channels);
conv_last_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, out_channels);
for (int i = 0; i < num_blocks; i++) {
body_blocks[i].init_params(params_ctx);
}
// alloc all tensors linked to this context
for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
if (t->data == NULL) {
ggml_allocr_alloc(alloc, t);
}
}
ggml_allocr_free(alloc);
size_t get_params_num() {
return rrdb_net.get_params_num();
}
bool load_from_file(const std::string& file_path, ggml_backend_t backend) {
bool load_from_file(const std::string& file_path) {
LOG_INFO("loading esrgan from '%s'", file_path.c_str());
if (!alloc_params_buffer(backend)) {
return false;
}
alloc_params_buffer();
std::map<std::string, ggml_tensor*> esrgan_tensors;
// prepare memory for the weights
{
init_params();
map_by_name(esrgan_tensors);
}
rrdb_net.get_param_tensors(esrgan_tensors);
ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) {
@ -303,115 +184,22 @@ struct ESRGAN : public GGMLModule {
return success;
}
void map_by_name(std::map<std::string, ggml_tensor*>& tensors) {
tensors["conv_first.weight"] = conv_first_w;
tensors["conv_first.bias"] = conv_first_b;
for (int i = 0; i < num_blocks; i++) {
body_blocks[i].map_by_name(tensors, "body." + std::to_string(i) + ".");
}
tensors["conv_body.weight"] = conv_body_w;
tensors["conv_body.bias"] = conv_body_b;
tensors["conv_up1.weight"] = conv_up1_w;
tensors["conv_up1.bias"] = conv_up1_b;
tensors["conv_up2.weight"] = conv_up2_w;
tensors["conv_up2.bias"] = conv_up2_b;
tensors["conv_hr.weight"] = conv_hr_w;
tensors["conv_hr.bias"] = conv_hr_b;
tensors["conv_last.weight"] = conv_last_w;
tensors["conv_last.bias"] = conv_last_b;
}
ggml_tensor* forward(ggml_context* ctx0, float out_scale, ggml_tensor* x /* feat */) {
// feat = self.conv_first(feat)
auto h = ggml_nn_conv_2d(ctx0, x, conv_first_w, conv_first_b, 1, 1, 1, 1);
auto body_h = h;
// self.body(feat)
for (int i = 0; i < num_blocks; i++) {
body_h = body_blocks[i].forward(ctx0, out_scale, body_h);
}
// body_feat = self.conv_body(self.body(feat))
body_h = ggml_nn_conv_2d(ctx0, body_h, conv_body_w, conv_body_b, 1, 1, 1, 1);
// feat = feat + body_feat
h = ggml_add(ctx0, h, body_h);
// upsample
// feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
h = ggml_upscale(ctx0, h, 2);
h = ggml_nn_conv_2d(ctx0, h, conv_up1_w, conv_up1_b, 1, 1, 1, 1);
h = ggml_leaky_relu(ctx0, h, 0.2f, true);
// feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
h = ggml_upscale(ctx0, h, 2);
h = ggml_nn_conv_2d(ctx0, h, conv_up2_w, conv_up2_b, 1, 1, 1, 1);
h = ggml_leaky_relu(ctx0, h, 0.2f, true);
// out = self.conv_last(self.lrelu(self.conv_hr(feat)))
h = ggml_nn_conv_2d(ctx0, h, conv_hr_w, conv_hr_b, 1, 1, 1, 1);
h = ggml_leaky_relu(ctx0, h, 0.2f, true);
h = ggml_nn_conv_2d(ctx0, h, conv_last_w, conv_last_b, 1, 1, 1, 1);
return h;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* x) {
// since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params = {
/*.mem_size =*/buf_size,
/*.mem_buffer =*/buf.data(),
/*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
};
struct ggml_context* ctx0 = ggml_init(params);
struct ggml_cgraph* gf = ggml_new_graph(ctx0);
struct ggml_tensor* x_ = NULL;
float out_scale = 0.2f;
// it's performing a compute, check if backend isn't cpu
if (!ggml_backend_is_cpu(backend)) {
// pass input tensors to gpu memory
x_ = ggml_dup_tensor(ctx0, x);
ggml_allocr_alloc(compute_allocr, x_);
// pass data to device backend
if (!ggml_allocr_is_measure(compute_allocr)) {
ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x));
}
} else {
x_ = x;
}
struct ggml_tensor* out = forward(ctx0, out_scale, x);
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
x = to_backend(x);
struct ggml_tensor* out = rrdb_net.forward(compute_ctx, x);
ggml_build_forward_expand(gf, out);
ggml_free(ctx0);
return gf;
}
void alloc_compute_buffer(struct ggml_tensor* x) {
void compute(const int n_threads,
struct ggml_tensor* x,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x);
};
GGMLModule::alloc_compute_buffer(get_graph);
}
void compute(struct ggml_tensor* work_result, const int n_threads, struct ggml_tensor* x) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x);
};
GGMLModule::compute(get_graph, n_threads, work_result);
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
}
};

View File

@ -43,12 +43,14 @@ const char* schedule_str[] = {
const char* modes_str[] = {
"txt2img",
"img2img",
"img2vid",
"convert",
};
enum SDMode {
TXT2IMG,
IMG2IMG,
IMG2VID,
CONVERT,
MODE_COUNT
};
@ -71,12 +73,18 @@ struct SDParams {
std::string prompt;
std::string negative_prompt;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;
int video_frames = 6;
int motion_bucket_id = 127;
int fps = 6;
float augmentation_level = 0.f;
sample_method_t sample_method = EULER_A;
schedule_t schedule = DEFAULT;
int sample_steps = 20;
@ -108,6 +116,7 @@ void print_params(SDParams params) {
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" min_cfg: %.2f\n", params.min_cfg);
printf(" cfg_scale: %.2f\n", params.cfg_scale);
printf(" clip_skip: %d\n", params.clip_skip);
printf(" width: %d\n", params.width);
@ -190,7 +199,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
}
}
if (mode_found == -1) {
fprintf(stderr, "error: invalid mode %s, must be one of [txt2img, img2img]\n",
fprintf(stderr,
"error: invalid mode %s, must be one of [txt2img, img2img, img2vid, convert]\n",
mode_selected);
exit(1);
}
@ -420,7 +430,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.n_threads = get_num_physical_cores();
}
if (params.mode != CONVERT && params.prompt.length() == 0) {
if (params.mode != CONVERT && params.mode != IMG2VID && params.prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
exit(1);
@ -432,7 +442,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}
if (params.mode == IMG2IMG && params.input_path.length() == 0) {
if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) {
fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n");
print_usage(argc, argv);
exit(1);
@ -539,9 +549,14 @@ int main(int argc, const char* argv[]) {
}
}
if (params.mode == IMG2VID) {
fprintf(stderr, "SVD support is broken, do not use it!!!\n");
return 1;
}
bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL;
if (params.mode == IMG2IMG) {
if (params.mode == IMG2IMG || params.mode == IMG2VID) {
vae_decode_only = false;
int c = 0;
@ -625,19 +640,57 @@ int main(int argc, const char* argv[]) {
3,
input_image_buffer};
results = img2img(sd_ctx,
input_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count);
if (params.mode == IMG2VID) {
results = img2vid(sd_ctx,
input_image,
params.width,
params.height,
params.video_frames,
params.motion_bucket_id,
params.fps,
params.augmentation_level,
params.min_cfg,
params.cfg_scale,
params.sample_method,
params.sample_steps,
params.strength,
params.seed);
if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}
size_t last = params.output_path.find_last_of(".");
std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path;
for (int i = 0; i < params.video_frames; i++) {
if (results[i].data == NULL) {
continue;
}
std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png";
stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(params, params.seed + i).c_str());
printf("save result image to '%s'\n", final_image_path.c_str());
free(results[i].data);
results[i].data = NULL;
}
free(results);
free_sd_ctx(sd_ctx);
return 0;
} else {
results = img2img(sd_ctx,
input_image,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.strength,
params.seed,
params.batch_count);
}
}
if (results == NULL) {

2
ggml

@ -1 +1 @@
Subproject commit 2f3b12fbd6cf4cb41ad4c8fdfd65e937f5c92093
Subproject commit 9a5ce3002474b3ac1dc2441e5c6b95ccef02cc78

View File

@ -11,6 +11,7 @@
#include <iostream>
#include <iterator>
#include <map>
#include <memory>
#include <random>
#include <regex>
#include <set>
@ -65,9 +66,11 @@ __STATIC_INLINE__ void ggml_tensor_set_f32(struct ggml_tensor* tensor, float val
}
__STATIC_INLINE__ float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) {
// float value;
// ggml_backend_tensor_get(tensor, &value, i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0], sizeof(float));
// return value;
if (tensor->buffer != NULL) {
float value;
ggml_backend_tensor_get(tensor, &value, i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0], sizeof(float));
return value;
}
GGML_ASSERT(tensor->nb[0] == sizeof(float));
return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]);
}
@ -183,7 +186,7 @@ __STATIC_INLINE__ void copy_ggml_tensor(struct ggml_tensor* dst, struct ggml_ten
LOG_ERROR("ggml_init() failed");
return;
}
ggml_tensor* final = ggml_cpy_inplace(ctx, src, dst);
ggml_tensor* final = ggml_cpy(ctx, src, dst);
struct ggml_cgraph* graph = ggml_new_graph(ctx);
ggml_build_forward_expand(graph, final);
@ -191,6 +194,10 @@ __STATIC_INLINE__ void copy_ggml_tensor(struct ggml_tensor* dst, struct ggml_ten
ggml_free(ctx);
}
__STATIC_INLINE__ float sigmoid(float x) {
return 1 / (1.0f + expf(-x));
}
// SPECIAL OPERATIONS WITH TENSORS
__STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) {
@ -211,7 +218,8 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) {
}
__STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output) {
struct ggml_tensor* output,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
@ -219,8 +227,31 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
int value = *(image_data + iy * width * channels + ix * channels + k);
ggml_tensor_set_f32(output, value / 255.0f, ix, iy, k);
float value = *(image_data + iy * width * channels + ix * channels + k);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
}
__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data,
struct ggml_tensor* output,
bool scale = true) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float value = *(image_data + iy * width * channels + ix * channels + k);
if (scale) {
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
@ -407,7 +438,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
struct ggml_tensor* w,
struct ggml_tensor* b) {
x = ggml_mul_mat(ctx, w, x);
x = ggml_add(ctx, x, b);
if (b != NULL) {
x = ggml_add(ctx, x, b);
}
return x;
}
@ -428,19 +461,103 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1);
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
// b = ggml_repeat(ctx, b, x);
x = ggml_add(ctx, x, b);
}
return x;
}
// w: [OCIC, KD, 1 * 1]
// x: [N, IC, IH, IW]
// b: [OC,]
// result: [N, OC, OH, OW]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1_bak(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int s2 = 1,
int p2 = 1,
int d2 = 1) {
GGML_ASSERT(w->ne[0] == 1);
// timesteps = x.shape[0]
// x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
// x = conv3d(x)
// return rearrange(x, "b c t h w -> (b t) c h w")
int64_t T = x->ne[3];
int64_t B = x->ne[3] / T;
int64_t C = x->ne[2];
int64_t H = x->ne[1];
int64_t W = x->ne[0];
x = ggml_reshape_4d(ctx, x, W * H, C, T, B); // (b t) c h w -> b t c (h w)
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b t c (h w) -> b c t (h w)
x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [B, OC, T, OH * OW]
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
x = ggml_add(ctx, x, b);
}
x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); // b c t (h w) -> b t c (h w)
x = ggml_reshape_4d(ctx, x, W, H, C, T * B); // b t c (h w) -> (b t) c h w
return x; // [B*T, OC, OH, OW]
}
// w: [OCIC, KD, 1 * 1]
// x: [N, IC, ID, IH*IW]
// b: [OC,]
// result: [N, OC, OD, OH*OW]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_3d_nx1x1(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int s2 = 1,
int p2 = 1,
int d2 = 1) {
x = ggml_conv_2d(ctx, w, x, 1, s2, 0, p2, 1, d2); // [N, OC, T, OH * OW]
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
x = ggml_add(ctx, x, b);
}
return x; // [N, OC, T, OH * OW]
}
// q: [N * n_head, n_token, d_head]
// k: [N * n_head, n_k, d_head]
// v: [N * n_head, d_head, n_k]
// return: [N * n_head, n_token, d_head]
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx,
struct ggml_tensor* q,
struct ggml_tensor* k,
struct ggml_tensor* v,
bool mask = false) {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else
float d_head = (float)q->ne[0];
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
if (mask) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
}
kq = ggml_soft_max_inplace(ctx, kq);
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
#endif
return kqv;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
float eps = EPS) {
x = ggml_norm(ctx, x, eps);
x = ggml_mul(ctx, x, w);
x = ggml_add(ctx, x, b);
if (w != NULL) {
x = ggml_mul(ctx, x, w);
if (b != NULL) {
x = ggml_add(ctx, x, b);
}
}
return x;
}
@ -449,14 +566,17 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
struct ggml_tensor* w,
struct ggml_tensor* b,
int num_groups = 32) {
if (ggml_n_dims(x) >= 3) {
if (ggml_n_dims(x) >= 3 && w != NULL && b != NULL) {
w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], 1);
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
}
x = ggml_group_norm(ctx, x, num_groups);
x = ggml_mul(ctx, x, w);
x = ggml_add(ctx, x, b);
if (w != NULL && b != NULL) {
x = ggml_mul(ctx, x, w);
// b = ggml_repeat(ctx, b, x);
x = ggml_add(ctx, x, b);
}
return x;
}
@ -486,130 +606,249 @@ __STATIC_INLINE__ float ggml_backend_tensor_get_f32(ggml_tensor* tensor) {
return value;
}
__STATIC_INLINE__ std::vector<float> arange(float start, float end, float step = 1.f) {
std::vector<float> result;
for (float value = start; value < end; value += step) {
result.push_back(value);
}
return result;
}
// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
__STATIC_INLINE__ void set_timestep_embedding(struct ggml_tensor* timesteps, struct ggml_tensor* embedding, int dim, int max_period = 10000) {
__STATIC_INLINE__ std::vector<float> timestep_embedding(std::vector<float> timesteps,
int dim,
int max_period = 10000) {
// timesteps: [N,]
// embedding: [dim, N]
// embedding: [N, dim]
size_t N = timesteps.size();
int acutual_dim = dim;
if (dim % 2 != 0) {
acutual_dim = dim + 1;
}
std::vector<float> embedding(N * acutual_dim, 0.f);
int half = dim / 2;
std::vector<float> freqs(half);
for (int i = 0; i < half; ++i) {
freqs[i] = (float)std::exp(-std::log(max_period) * i / half);
}
for (int i = 0; i < timesteps->ne[0]; ++i) {
for (int i = 0; i < N; ++i) {
for (int j = 0; j < half; ++j) {
float arg = ggml_get_f32_1d(timesteps, i) * freqs[j];
ggml_tensor_set_f32(embedding, std::cos(arg), j, i);
ggml_tensor_set_f32(embedding, std::sin(arg), j + half, i);
float arg = timesteps[i] * freqs[j];
embedding[i * acutual_dim + j] = std::cos(arg);
embedding[i * acutual_dim + j + half] = std::sin(arg);
}
if (dim % 2 != 0) {
*(float*)((char*)embedding->data + i * embedding->nb[1] + dim * embedding->nb[0]) = 0;
}
}
}
__STATIC_INLINE__ struct ggml_tensor* new_timestep_embedding(struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* timesteps,
int dim,
int max_period = 10000) {
// timesteps: [N,]
// embedding: [dim, N]
int acutual_dim = dim;
if (dim % 2 != 0) {
acutual_dim = dim + 1;
}
struct ggml_tensor* embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, acutual_dim, timesteps->ne[0]);
if (allocr != NULL) {
ggml_allocr_alloc(allocr, embedding);
}
if (allocr != NULL && !ggml_allocr_is_measure(allocr)) {
set_timestep_embedding(timesteps, embedding, dim, max_period);
}
return embedding;
}
__STATIC_INLINE__ void set_timestep_embedding(std::vector<float> timesteps,
struct ggml_tensor* embedding,
int dim,
int max_period = 10000) {
std::vector<float> embedding_vec = timestep_embedding(timesteps, dim, max_period);
memcpy(((char*)embedding->data), ((char*)embedding_vec.data()), ggml_nbytes(embedding));
}
__STATIC_INLINE__ struct ggml_tensor* new_timestep_embedding(struct ggml_context* ctx,
struct ggml_allocr* allocr,
std::vector<float> timesteps,
int dim,
int max_period = 10000) {
// timesteps: [N,]
// embedding: [N, dim]
std::vector<float> embedding_vec = timestep_embedding(timesteps, dim, max_period);
int acutual_dim = dim;
if (dim % 2 != 0) {
acutual_dim = dim + 1;
}
struct ggml_tensor* embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, acutual_dim, timesteps.size());
if (allocr != NULL) {
ggml_allocr_alloc(allocr, embedding);
if (!ggml_allocr_is_measure(allocr)) {
ggml_backend_tensor_set(embedding, embedding_vec.data(), 0, ggml_nbytes(embedding));
}
} else {
memcpy(((char*)embedding->data), ((char*)embedding_vec.data()), ggml_nbytes(embedding));
}
return embedding;
}
// struct GGMLComputeGraph {
// virtual void init(struct ggml_context* ctx, ggml_type wtype) = 0;
// virtual std::string get_desc() = 0;
// virtual size_t get_params_mem_size() = 0;
// virtual size_t get_params_num() = 0;
// virtual struct ggml_cgraph* get_ggml_cgraph() = 0;
// };
#define MAX_PARAMS_TENSOR_NUM 10240
#define MAX_GRAPH_SIZE 10240
struct GGMLModule {
protected:
typedef std::function<struct ggml_cgraph*()> get_graph_cb_t;
std::string name = "ggml module";
struct ggml_context* params_ctx = NULL;
size_t params_buffer_size = 0;
size_t compute_buffer_size = 0;
ggml_backend_buffer_t params_buffer = NULL;
struct ggml_context* params_ctx = NULL;
ggml_backend_buffer_t params_buffer = NULL;
struct ggml_context* compute_ctx = NULL;
ggml_backend_buffer_t compute_buffer = NULL; // for compute
struct ggml_allocr* compute_allocr = NULL;
ggml_type wtype = GGML_TYPE_F32;
ggml_backend_t backend = NULL;
virtual size_t calculate_mem_size() = 0;
virtual size_t get_num_tensors() = 0;
bool alloc_params_buffer(ggml_backend_t backend_, ggml_type wtype_ = GGML_TYPE_F32) {
backend = backend_;
wtype = wtype_;
params_buffer_size = 4 * 1024 * 1024; // 10 MB, for padding
params_buffer_size += calculate_mem_size();
size_t num_tensors = get_num_tensors();
LOG_DEBUG("%s params backend buffer size = % 6.2f MB (%i tensors)",
name.c_str(), params_buffer_size / (1024.0 * 1024.0), num_tensors);
void alloc_params_ctx() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(num_tensors * ggml_tensor_overhead()) + 1 * 1024 * 1024;
params.mem_size = static_cast<size_t>(MAX_PARAMS_TENSOR_NUM * ggml_tensor_overhead());
params.mem_buffer = NULL;
params.no_alloc = true;
// LOG_DEBUG("mem_size %u ", params.mem_size);
params_ctx = ggml_init(params);
if (!params_ctx) {
LOG_ERROR("ggml_init() failed");
return false;
}
params_buffer = ggml_backend_alloc_buffer(backend, params_buffer_size);
return true;
GGML_ASSERT(params_ctx != NULL);
}
void free_params_buffer() {
void free_params_ctx() {
if (params_ctx != NULL) {
ggml_free(params_ctx);
params_ctx = NULL;
}
}
void alloc_compute_ctx() {
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(ggml_tensor_overhead() * MAX_GRAPH_SIZE + ggml_graph_overhead());
params.mem_buffer = NULL;
params.no_alloc = true;
compute_ctx = ggml_init(params);
GGML_ASSERT(compute_ctx != NULL);
}
void free_compute_ctx() {
if (compute_ctx != NULL) {
ggml_free(compute_ctx);
compute_ctx = NULL;
}
}
void alloc_compute_buffer(get_graph_cb_t get_graph) {
// alignment required by the backend
compute_allocr = ggml_allocr_new_measure_from_backend(backend);
reset_compute_ctx();
struct ggml_cgraph* gf = get_graph();
// compute the required memory
size_t compute_buffer_size = ggml_allocr_alloc_graph(compute_allocr, gf) + 1024 * 1024;
// recreate the allocator with the required memory
ggml_allocr_free(compute_allocr);
LOG_DEBUG("%s compute buffer size: %.2f MB", get_desc().c_str(), compute_buffer_size / 1024.0 / 1024.0);
compute_buffer = ggml_backend_alloc_buffer(backend, compute_buffer_size);
compute_allocr = ggml_allocr_new_from_buffer(compute_buffer);
}
public:
virtual size_t get_params_mem_size() = 0;
virtual size_t get_params_num() = 0;
virtual std::string get_desc() = 0;
GGMLModule(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32)
: backend(backend), wtype(wtype) {
alloc_params_ctx();
}
virtual ~GGMLModule() {
free_params_buffer();
free_compute_buffer();
free_params_ctx();
free_compute_ctx();
}
void reset_compute_ctx() {
free_compute_ctx();
alloc_compute_ctx();
}
void reset_compute_allocr(get_graph_cb_t get_graph) {
if (compute_allocr != NULL) {
ggml_allocr_reset(compute_allocr);
} else {
alloc_compute_buffer(get_graph);
}
}
bool alloc_params_buffer() {
size_t params_buffer_size = 10 * 1024 * 1024; // 10 MB, for padding
params_buffer_size += get_params_mem_size();
size_t num_tensors = get_params_num();
LOG_DEBUG("%s params backend buffer size = % 6.2f MB (%i tensors)",
get_desc().c_str(), params_buffer_size / (1024.0 * 1024.0), num_tensors);
params_buffer = ggml_backend_alloc_buffer(backend, params_buffer_size);
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
// alloc all tensors linked to params_ctx
for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
if (t->data == NULL) {
ggml_allocr_alloc(alloc, t);
}
}
ggml_allocr_free(alloc);
return true;
}
void free_params_buffer() {
if (params_buffer != NULL) {
ggml_backend_buffer_free(params_buffer);
params_buffer = NULL;
}
}
~GGMLModule() {
free_params_buffer();
}
void alloc_compute_buffer(get_graph_cb_t get_graph) {
if (compute_buffer_size == 0) {
// alignment required by the backend
compute_allocr = ggml_allocr_new_measure_from_backend(backend);
struct ggml_cgraph* gf = get_graph();
// compute the required memory
compute_buffer_size = ggml_allocr_alloc_graph(compute_allocr, gf) + 1024 * 1024;
// recreate the allocator with the required memory
void free_compute_buffer() {
if (compute_allocr != NULL) {
ggml_allocr_free(compute_allocr);
LOG_DEBUG("%s compute buffer size: %.2f MB", name.c_str(), compute_buffer_size / 1024.0 / 1024.0);
compute_allocr = NULL;
}
if (compute_buffer != NULL) {
ggml_backend_buffer_free(compute_buffer);
compute_buffer = NULL;
}
compute_buffer = ggml_backend_alloc_buffer(backend, compute_buffer_size);
compute_allocr = ggml_allocr_new_from_buffer(compute_buffer);
}
void compute(get_graph_cb_t get_graph, int n_threads, struct ggml_tensor* output = NULL) {
ggml_allocr_reset(compute_allocr);
struct ggml_tensor* to_backend(struct ggml_tensor* tensor) {
GGML_ASSERT(compute_ctx != NULL);
if (tensor == NULL) {
return NULL;
}
// it's performing a compute, check if backend isn't cpu
if (!ggml_backend_is_cpu(backend)) {
// pass input tensors to gpu memory
auto backend_tensor = ggml_dup_tensor(compute_ctx, tensor);
ggml_allocr_alloc(compute_allocr, backend_tensor);
// pass data to device backend
if (!ggml_allocr_is_measure(compute_allocr)) {
ggml_backend_tensor_set(backend_tensor, tensor->data, 0, ggml_nbytes(tensor));
}
return backend_tensor;
} else {
return tensor;
}
}
void compute(get_graph_cb_t get_graph,
int n_threads,
bool free_compute_buffer_immediately = true,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
reset_compute_allocr(get_graph);
reset_compute_ctx();
struct ggml_cgraph* gf = get_graph();
ggml_allocr_alloc_graph(compute_allocr, gf);
@ -631,15 +870,368 @@ struct GGMLModule {
#endif
if (output != NULL) {
ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], output->data, 0, ggml_nbytes(output));
auto result = gf->nodes[gf->n_nodes - 1];
if (*output == NULL && output_ctx != NULL) {
*output = ggml_dup_tensor(output_ctx, result);
}
if (*output != NULL) {
ggml_backend_tensor_get_and_sync(backend, result, (*output)->data, 0, ggml_nbytes(*output));
}
}
if (free_compute_buffer_immediately) {
free_compute_buffer();
}
}
};
class GGMLBlock {
private:
static char temp_buffer[1024 * 1024 * 10];
ggml_context* get_temp_ctx() {
struct ggml_init_params params;
params.mem_size = sizeof(temp_buffer);
params.mem_buffer = temp_buffer;
params.no_alloc = true;
ggml_context* temp_ctx = ggml_init(params);
GGML_ASSERT(temp_ctx != NULL);
return temp_ctx;
}
protected:
typedef std::unordered_map<std::string, struct ggml_tensor*> ParameterMap;
typedef std::unordered_map<std::string, std::shared_ptr<GGMLBlock>> GGMLBlockMap;
GGMLBlockMap blocks;
ParameterMap params;
void init_blocks(struct ggml_context* ctx, ggml_type wtype) {
for (auto& pair : blocks) {
auto& block = pair.second;
block->init(ctx, wtype);
}
}
void free_compute_buffer() {
ggml_allocr_free(compute_allocr);
ggml_backend_buffer_free(compute_buffer);
compute_allocr = NULL;
compute_buffer_size = 0;
virtual void init_params(struct ggml_context* ctx, ggml_type wtype) {}
public:
void init(struct ggml_context* ctx, ggml_type wtype) {
init_blocks(ctx, wtype);
init_params(ctx, wtype);
}
std::tuple<size_t, size_t> get_params_info(ggml_type wtype) {
ggml_context* temp_ctx = get_temp_ctx();
init(temp_ctx, wtype);
size_t num_tensors = get_params_num();
size_t mem_size = get_params_mem_size();
return {num_tensors, mem_size};
}
size_t get_params_num() {
size_t num_tensors = params.size();
for (auto& pair : blocks) {
auto& block = pair.second;
num_tensors += block->get_params_num();
}
return num_tensors;
};
size_t get_params_mem_size() {
size_t mem_size = 0;
for (auto& pair : blocks) {
auto& block = pair.second;
mem_size += block->get_params_mem_size();
}
for (auto& pair : params) {
mem_size += ggml_nbytes(pair.second);
}
return mem_size;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, std::string prefix = "") {
if (prefix.size() > 0) {
prefix = prefix + ".";
}
for (auto& pair : blocks) {
auto& block = pair.second;
block->get_param_tensors(tensors, prefix + pair.first);
}
for (auto& pair : params) {
struct ggml_tensor* param = pair.second;
tensors[prefix + pair.first] = pair.second;
}
}
};
class UnaryBlock : public GGMLBlock {
public:
virtual struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) = 0;
};
class Linear : public UnaryBlock {
protected:
int64_t in_features;
int64_t out_features;
bool bias;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
if (bias) {
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features);
}
}
public:
Linear(int64_t in_features,
int64_t out_features,
bool bias = true)
: in_features(in_features),
out_features(out_features),
bias(bias) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = NULL;
if (bias) {
b = params["bias"];
}
return ggml_nn_linear(ctx, x, w, b);
}
};
class Conv2d : public UnaryBlock {
protected:
int64_t in_channels;
int64_t out_channels;
std::pair<int, int> kernel_size;
std::pair<int, int> stride;
std::pair<int, int> padding;
std::pair<int, int> dilation;
bool bias;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kernel_size.second, kernel_size.first, in_channels, out_channels);
if (bias) {
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
}
}
public:
Conv2d(int64_t in_channels,
int64_t out_channels,
std::pair<int, int> kernel_size,
std::pair<int, int> stride = {1, 1},
std::pair<int, int> padding = {0, 0},
std::pair<int, int> dilation = {1, 1},
bool bias = true)
: in_channels(in_channels),
out_channels(out_channels),
kernel_size(kernel_size),
stride(stride),
padding(padding),
dilation(dilation),
bias(bias) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = NULL;
if (bias) {
b = params["bias"];
}
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
}
};
class Conv3dnx1x1 : public UnaryBlock {
protected:
int64_t in_channels;
int64_t out_channels;
int64_t kernel_size;
int64_t stride;
int64_t padding;
int64_t dilation;
bool bias;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, kernel_size, in_channels, out_channels); // 5d => 4d
if (bias) {
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
}
}
public:
Conv3dnx1x1(int64_t in_channels,
int64_t out_channels,
int64_t kernel_size,
int64_t stride = 1,
int64_t padding = 0,
int64_t dilation = 1,
bool bias = true)
: in_channels(in_channels),
out_channels(out_channels),
kernel_size(kernel_size),
stride(stride),
padding(padding),
dilation(dilation),
bias(bias) {}
// x: [N, IC, ID, IH*IW]
// result: [N, OC, OD, OH*OW]
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = NULL;
if (bias) {
b = params["bias"];
}
return ggml_nn_conv_3d_nx1x1(ctx, x, w, b, stride, padding, dilation);
}
};
class LayerNorm : public UnaryBlock {
protected:
int64_t normalized_shape;
float eps;
bool elementwise_affine;
bool bias;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
if (elementwise_affine) {
params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, normalized_shape);
if (bias) {
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, normalized_shape);
}
}
}
public:
LayerNorm(int64_t normalized_shape,
float eps = 1e-05f,
bool elementwise_affine = true,
bool bias = true)
: normalized_shape(normalized_shape),
eps(eps),
elementwise_affine(elementwise_affine),
bias(bias) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = NULL;
struct ggml_tensor* b = NULL;
if (elementwise_affine) {
w = params["weight"];
if (bias) {
b = params["bias"];
}
}
return ggml_nn_layer_norm(ctx, x, w, b, eps);
}
};
class GroupNorm : public GGMLBlock {
protected:
int64_t num_groups;
int64_t num_channels;
float eps;
bool affine;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
if (affine) {
params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_channels);
params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_channels);
}
}
public:
GroupNorm(int64_t num_groups,
int64_t num_channels,
float eps = 1e-05f,
bool affine = true)
: num_groups(num_groups),
num_channels(num_channels),
eps(eps),
affine(affine) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = NULL;
struct ggml_tensor* b = NULL;
if (affine) {
w = params["weight"];
b = params["bias"];
}
return ggml_nn_group_norm(ctx, x, w, b, num_groups);
}
};
class GroupNorm32 : public GroupNorm {
public:
GroupNorm32(int64_t num_channels)
: GroupNorm(32, num_channels, 1e-06f) {}
};
class MultiheadAttention : public GGMLBlock {
protected:
int64_t embed_dim;
int64_t n_head;
bool bias;
bool mask;
public:
MultiheadAttention(int64_t embed_dim,
int64_t n_head,
bool bias = true)
: embed_dim(embed_dim),
n_head(n_head),
bias(bias) {
blocks["q_proj"] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, bias));
blocks["k_proj"] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, bias));
blocks["v_proj"] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, bias));
blocks["out_proj"] = std::shared_ptr<GGMLBlock>(new Linear(embed_dim, embed_dim, bias));
}
// x: [N, n_token, embed_dim]
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, bool mask = false) {
auto q_proj = std::dynamic_pointer_cast<Linear>(blocks["q_proj"]);
auto k_proj = std::dynamic_pointer_cast<Linear>(blocks["k_proj"]);
auto v_proj = std::dynamic_pointer_cast<Linear>(blocks["v_proj"]);
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out_proj"]);
int64_t N = x->ne[2];
int64_t n_token = x->ne[1];
int64_t d_head = embed_dim / n_head;
struct ggml_tensor* q = q_proj->forward(ctx, x);
q = ggml_reshape_4d(ctx, q, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
q = ggml_reshape_3d(ctx, q, d_head, n_token, n_head * N); // [N * n_head, n_token, d_head]
struct ggml_tensor* k = k_proj->forward(ctx, x);
k = ggml_reshape_4d(ctx, k, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_head]
k = ggml_reshape_3d(ctx, k, d_head, n_token, n_head); // [N * n_head, n_token, d_head]
struct ggml_tensor* v = v_proj->forward(ctx, x);
v = ggml_reshape_4d(ctx, v, d_head, n_head, n_token, N); // [N, n_token, n_head, d_head]
v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, n_token]
v = ggml_reshape_3d(ctx, v, n_token, d_head, n_head * N); // [N * n_head, d_head, n_token]
struct ggml_tensor* kqv = ggml_nn_attention(ctx, q, k, v, mask); // [N * n_head, n_token, d_head]
kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N);
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
x = ggml_reshape_2d(ctx, kqv, d_head * n_head, n_token * N); // [N * n_token, d_head * n_head]
x = out_proj->forward(ctx, x);
return x;
}
};

View File

@ -12,32 +12,35 @@ struct LoraModel : public GGMLModule {
ModelLoader model_loader;
bool load_failed = false;
LoraModel(const std::string file_path = "")
: file_path(file_path) {
name = "lora";
LoraModel(ggml_backend_t backend,
ggml_type wtype,
const std::string file_path = "")
: file_path(file_path), GGMLModule(backend, wtype) {
if (!model_loader.init_from_file(file_path)) {
load_failed = true;
}
}
size_t get_num_tensors() {
std::string get_desc() {
return "lora";
}
size_t get_params_num() {
return LORA_GRAPH_SIZE;
}
size_t calculate_mem_size() {
return model_loader.cal_mem_size(NULL);
size_t get_params_mem_size() {
return model_loader.get_params_mem_size(NULL);
}
bool load_from_file(ggml_backend_t backend) {
if (!alloc_params_buffer(backend)) {
return false;
}
bool load_from_file() {
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
if (load_failed) {
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
return false;
}
alloc_params_buffer();
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
@ -61,20 +64,7 @@ struct LoraModel : public GGMLModule {
}
struct ggml_cgraph* build_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
// make a graph to compute all lora, expected lora and models tensors are in the same backend
// since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
static size_t buf_size = ggml_tensor_overhead() * LORA_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params = {
/*.mem_size =*/buf_size,
/*.mem_buffer =*/buf.data(),
/*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
};
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* ctx0 = ggml_init(params);
struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, LORA_GRAPH_SIZE, false);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);
std::set<std::string> applied_lora_tensors;
for (auto it : model_tensors) {
@ -125,27 +115,27 @@ struct LoraModel : public GGMLModule {
// flat lora tensors to multiply it
int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1];
lora_up = ggml_reshape_2d(ctx0, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows);
int64_t lora_down_rows = lora_down->ne[ggml_n_dims(lora_down) - 1];
lora_down = ggml_reshape_2d(ctx0, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows);
// ggml_mul_mat requires tensor b transposed
lora_down = ggml_cont(ctx0, ggml_transpose(ctx0, lora_down));
struct ggml_tensor* updown = ggml_mul_mat(ctx0, lora_up, lora_down);
updown = ggml_cont(ctx0, ggml_transpose(ctx0, updown));
updown = ggml_reshape(ctx0, updown, weight);
lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down));
struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down);
updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown));
updown = ggml_reshape(compute_ctx, updown, weight);
GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight));
updown = ggml_scale_inplace(ctx0, updown, scale_value);
updown = ggml_scale_inplace(compute_ctx, updown, scale_value);
ggml_tensor* final_weight;
// if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
// final_weight = ggml_new_tensor(ctx0, GGML_TYPE_F32, weight->n_dims, weight->ne);
// final_weight = ggml_cpy_inplace(ctx0, weight, final_weight);
// final_weight = ggml_add_inplace(ctx0, final_weight, updown);
// final_weight = ggml_cpy_inplace(ctx0, final_weight, weight);
// final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, weight->n_dims, weight->ne);
// final_weight = ggml_cpy_inplace(compute_ctx, weight, final_weight);
// final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
// final_weight = ggml_cpy_inplace(compute_ctx, final_weight, weight);
// } else {
// final_weight = ggml_add_inplace(ctx0, weight, updown);
// final_weight = ggml_add_inplace(compute_ctx, weight, updown);
// }
final_weight = ggml_add_inplace(ctx0, weight, updown); // apply directly
final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
ggml_build_forward_expand(gf, final_weight);
}
@ -158,20 +148,11 @@ struct LoraModel : public GGMLModule {
return gf;
}
void alloc_compute_buffer(std::map<std::string, struct ggml_tensor*> model_tensors) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(model_tensors);
};
GGMLModule::alloc_compute_buffer(get_graph);
}
void apply(std::map<std::string, struct ggml_tensor*> model_tensors, int n_threads) {
alloc_compute_buffer(model_tensors);
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(model_tensors);
};
GGMLModule::compute(get_graph, n_threads);
GGMLModule::compute(get_graph, n_threads, true);
}
};

View File

@ -108,6 +108,14 @@ std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
{"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"},
{"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"},
{"model.text_projection", "transformer.text_model.text_projection"},
{"model.visual.class_embedding", "transformer.visual_model.embeddings.class_embedding"},
{"model.visual.conv1.weight", "transformer.visual_model.embeddings.patch_embedding.weight"},
{"model.visual.ln_post.bias", "transformer.visual_model.post_layernorm.bias"},
{"model.visual.ln_post.weight", "transformer.visual_model.post_layernorm.weight"},
{"model.visual.ln_pre.bias", "transformer.visual_model.pre_layernorm.bias"},
{"model.visual.ln_pre.weight", "transformer.visual_model.pre_layernorm.weight"},
{"model.visual.positional_embedding", "transformer.visual_model.embeddings.position_embedding.weight"},
{"model.visual.proj", "transformer.visual_model.visual_projection"},
};
std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
@ -137,7 +145,10 @@ std::unordered_map<std::string, std::string> vae_decoder_name_map = {
std::string convert_open_clip_to_hf_clip(const std::string& name) {
std::string new_name = name;
std::string prefix;
if (starts_with(new_name, "conditioner.embedders.0.")) {
if (starts_with(new_name, "conditioner.embedders.0.open_clip.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("conditioner.embedders.0.open_clip."));
} else if (starts_with(new_name, "conditioner.embedders.0.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("conditioner.embedders.0."));
} else if (starts_with(new_name, "conditioner.embedders.1.")) {
@ -149,25 +160,35 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
} else {
return new_name;
}
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) {
new_name = open_clip_to_hf_clip_model[new_name];
}
if (new_name.find(open_clip_resblock_prefix) == 0) {
std::string remain = new_name.substr(open_clip_resblock_prefix.length());
std::string idx = remain.substr(0, remain.find("."));
std::string suffix = remain.substr(idx.length() + 1);
std::string open_clip_resblock_prefix = "model.transformer.resblocks.";
std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers.";
if (suffix == "attn.in_proj_weight" || suffix == "attn.in_proj_bias") {
new_name = hf_clip_resblock_prefix + idx + "." + suffix;
} else if (open_clip_to_hk_clip_resblock.find(suffix) != open_clip_to_hk_clip_resblock.end()) {
std::string new_suffix = open_clip_to_hk_clip_resblock[suffix];
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix;
auto replace_suffix = [&]() {
if (new_name.find(open_clip_resblock_prefix) == 0) {
std::string remain = new_name.substr(open_clip_resblock_prefix.length());
std::string idx = remain.substr(0, remain.find("."));
std::string suffix = remain.substr(idx.length() + 1);
if (suffix == "attn.in_proj_weight" || suffix == "attn.in_proj_bias") {
new_name = hf_clip_resblock_prefix + idx + "." + suffix;
} else if (open_clip_to_hk_clip_resblock.find(suffix) != open_clip_to_hk_clip_resblock.end()) {
std::string new_suffix = open_clip_to_hk_clip_resblock[suffix];
new_name = hf_clip_resblock_prefix + idx + "." + new_suffix;
}
}
}
};
replace_suffix();
open_clip_resblock_prefix = "model.visual.transformer.resblocks.";
hf_clip_resblock_prefix = "transformer.visual_model.encoder.layers.";
replace_suffix();
return prefix + new_name;
}
@ -437,7 +458,7 @@ void preprocess_tensor(TensorStorage tensor_storage,
tensor_storage.name = new_name;
if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
if (new_name.find("cond_stage_model") != std::string::npos &&
ends_with(new_name, "attn.in_proj_weight")) {
size_t prefix_size = new_name.find("attn.in_proj_weight");
std::string prefix = new_name.substr(0, prefix_size);
@ -449,7 +470,7 @@ void preprocess_tensor(TensorStorage tensor_storage,
processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end());
} else if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos &&
} else if (new_name.find("cond_stage_model") != std::string::npos &&
ends_with(new_name, "attn.in_proj_bias")) {
size_t prefix_size = new_name.find("attn.in_proj_bias");
std::string prefix = new_name.substr(0, prefix_size);
@ -778,17 +799,26 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
return false;
}
if (shape.size() > 4) {
if (shape.size() > SD_MAX_DIMS) {
LOG_ERROR("invalid tensor '%s'", name.c_str());
return false;
}
int n_dims = (int)shape.size();
int64_t ne[4] = {1, 1, 1, 1};
int n_dims = (int)shape.size();
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
for (int i = 0; i < n_dims; i++) {
ne[i] = shape[i].get<int64_t>();
}
if (n_dims == 5) {
if (ne[3] == 1 && ne[4] == 1) {
n_dims = 4;
} else {
LOG_ERROR("invalid tensor '%s'", name.c_str());
return false;
}
}
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne();
@ -803,6 +833,8 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
}
tensor_storages.push_back(tensor_storage);
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
}
return true;
@ -946,7 +978,7 @@ struct PickleTensorReader {
phase = READ_NAME;
}
} else if (phase == READ_DIMENS) {
if (tensor_storage.n_dims + 1 > 4) { // too many dimens
if (tensor_storage.n_dims + 1 > SD_MAX_DIMS) { // too many dimens
phase = READ_NAME;
tensor_storage.n_dims = 0;
}
@ -1181,7 +1213,6 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
}
SDVersion ModelLoader::get_sd_version() {
// return VERSION_1_x;
TensorStorage token_embedding_weight;
for (auto& tensor_storage : tensor_storages) {
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
@ -1190,6 +1221,10 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
return VERSION_XL;
}
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
return VERSION_SVD;
}
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
@ -1317,7 +1352,6 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.file_index != file_index) {
continue;
}
// LOG_DEBUG("%s", tensor_storage.name.c_str());
ggml_tensor* dst_tensor = NULL;
@ -1395,15 +1429,19 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
std::set<std::string> tensor_names_in_file;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
// LOG_DEBUG("%s", tensor_storage.to_string().c_str());
tensor_names_in_file.insert(name);
struct ggml_tensor* real;
if (tensors.find(name) != tensors.end()) {
real = tensors[name];
} else {
if (ignore_tensors.find(name) == ignore_tensors.end()) {
LOG_WARN("unknown tensor '%s' in model file", name.c_str());
for (auto& ignore_tensor : ignore_tensors) {
if (starts_with(name, ignore_tensor)) {
return true;
}
}
LOG_INFO("unknown tensor '%s' in model file", tensor_storage.to_string().c_str());
return true;
}
@ -1462,7 +1500,7 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
auto backend = ggml_backend_cpu_init();
size_t mem_size = 1 * 1024 * 1024; // for padding
mem_size += tensor_storages.size() * ggml_tensor_overhead();
mem_size += cal_mem_size(backend, type);
mem_size += get_params_mem_size(backend, type);
LOG_INFO("model tensors mem size: %.2fMB", mem_size / 1024.f / 1024.f);
ggml_context* ggml_ctx = ggml_init({mem_size, NULL, false});
@ -1512,7 +1550,7 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
return success;
}
int64_t ModelLoader::cal_mem_size(ggml_backend_t backend, ggml_type type) {
int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) {
size_t alignment = 128;
if (backend != NULL) {
alignment = ggml_backend_get_alignment(backend);

41
model.h
View File

@ -5,6 +5,7 @@
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
#include <vector>
@ -13,19 +14,22 @@
#include "json.hpp"
#include "zip.h"
#define SD_MAX_DIMS 5
enum SDVersion {
VERSION_1_x,
VERSION_2_x,
VERSION_XL,
VERSION_SVD,
VERSION_COUNT,
};
struct TensorStorage {
std::string name;
ggml_type type = GGML_TYPE_F32;
bool is_bf16 = false;
int64_t ne[4] = {1, 1, 1, 1};
int n_dims = 0;
ggml_type type = GGML_TYPE_F32;
bool is_bf16 = false;
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
int n_dims = 0;
size_t file_index = 0;
int index_in_zip = -1; // >= means stored in a zip file
@ -41,7 +45,11 @@ struct TensorStorage {
}
int64_t nelements() const {
return ne[0] * ne[1] * ne[2] * ne[3];
int64_t n = 1;
for (int i = 0; i < SD_MAX_DIMS; i++) {
n *= ne[i];
}
return n;
}
int64_t nbytes() const {
@ -69,6 +77,7 @@ struct TensorStorage {
std::vector<TensorStorage> chunk(size_t n) {
std::vector<TensorStorage> chunks;
size_t chunk_size = nbytes_to_read() / n;
// printf("%d/%d\n", chunk_size, nbytes_to_read());
reverse_ne();
for (int i = 0; i < n; i++) {
TensorStorage chunk_i = *this;
@ -82,7 +91,7 @@ struct TensorStorage {
}
void reverse_ne() {
int64_t new_ne[4] = {1, 1, 1, 1};
int64_t new_ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
for (int i = 0; i < n_dims; i++) {
new_ne[i] = ne[n_dims - 1 - i];
}
@ -90,6 +99,24 @@ struct TensorStorage {
ne[i] = new_ne[i];
}
}
std::string to_string() const {
std::stringstream ss;
const char* type_name = ggml_type_name(type);
if (is_bf16) {
type_name = "bf16";
}
ss << name << " | " << type_name << " | ";
ss << n_dims << " [";
for (int i = 0; i < SD_MAX_DIMS; i++) {
ss << ne[i];
if (i != SD_MAX_DIMS - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}
};
typedef std::function<bool(const TensorStorage&, ggml_tensor**)> on_new_tensor_cb_t;
@ -121,7 +148,7 @@ public:
ggml_backend_t backend,
std::set<std::string> ignore_tensors = {});
bool save_to_gguf_file(const std::string& file_path, ggml_type type);
int64_t cal_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;
};
#endif // __MODEL_H__

View File

@ -19,6 +19,7 @@ const char* model_version_to_str[] = {
"1.x",
"2.x",
"XL",
"SVD",
};
const char* sampling_methods_str[] = {
@ -32,6 +33,8 @@ const char* sampling_methods_str[] = {
"LCM",
};
char GGMLBlock::temp_buffer[1024 * 1024 * 10];
/*================================================== Helper Functions ================================================*/
void calculate_alphas_cumprod(float* alphas_cumprod,
@ -53,6 +56,9 @@ void calculate_alphas_cumprod(float* alphas_cumprod,
class StableDiffusionGGML {
public:
ggml_backend_t backend = NULL; // general backend
ggml_type model_data_type = GGML_TYPE_COUNT;
SDVersion version;
bool vae_decode_only = false;
bool free_params_immediately = false;
@ -61,9 +67,14 @@ public:
int n_threads = -1;
float scale_factor = 0.18215f;
FrozenCLIPEmbedderWithCustomWords cond_stage_model;
UNetModel diffusion_model;
AutoEncoderKL first_stage_model;
std::shared_ptr<FrozenCLIPEmbedderWithCustomWords> cond_stage_model;
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd
std::shared_ptr<UNetModel> diffusion_model;
std::shared_ptr<AutoEncoderKL> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
std::shared_ptr<ControlNet> control_net;
std::string taesd_path;
bool use_tiny_autoencoder = false;
bool vae_tiling = false;
@ -72,16 +83,8 @@ public:
std::string lora_model_dir;
// lora_name => multiplier
std::unordered_map<std::string, float> curr_lora_state;
std::map<std::string, LoraModel> loras;
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
ggml_backend_t backend = NULL; // general backend
ggml_type model_data_type = GGML_TYPE_COUNT;
TinyAutoEncoder tae_first_stage;
std::string taesd_path;
ControlNet control_net;
StableDiffusionGGML() = default;
@ -94,8 +97,6 @@ public:
vae_decode_only(vae_decode_only),
free_params_immediately(free_params_immediately),
lora_model_dir(lora_model_dir) {
first_stage_model.decode_only = vae_decode_only;
tae_first_stage.decode_only = vae_decode_only;
if (rng_type == STD_DEFAULT_RNG) {
rng = std::make_shared<STDDefaultRNG>();
} else if (rng_type == CUDA_RNG) {
@ -160,12 +161,6 @@ public:
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str());
return false;
}
if (version == VERSION_XL) {
scale_factor = 0.13025f;
}
cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version);
diffusion_model = UNetModel(version);
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
if (wtype == GGML_TYPE_COUNT) {
model_data_type = model_loader.get_sd_wtype();
@ -173,52 +168,73 @@ public:
model_data_type = wtype;
}
LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type));
LOG_DEBUG("loading vocab");
std::string merges_utf8_str = model_loader.load_merges();
if (merges_utf8_str.size() == 0) {
LOG_ERROR("get merges failed: '%s'", model_path.c_str());
return false;
}
cond_stage_model.tokenizer.load_from_merges(merges_utf8_str);
// create the ggml context for network params
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));
if (
!cond_stage_model.alloc_params_buffer(backend, model_data_type) ||
!diffusion_model.alloc_params_buffer(backend, model_data_type)) {
return false;
}
cond_stage_model.text_model.embd_dir = embeddings_path;
ggml_type vae_type = model_data_type;
if (version == VERSION_XL) {
vae_type = GGML_TYPE_F32; // avoid nan, not work...
scale_factor = 0.13025f;
if (vae_path.size() == 0 && taesd_path.size() == 0) {
LOG_WARN("!!!It looks like you are using SDXL model. "
"If you find that the generated images are completely black, "
"try specifying SDXL VAE FP16 Fix with the --vae parameter. "
"You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors");
}
}
if (!use_tiny_autoencoder && !first_stage_model.alloc_params_buffer(backend, vae_type)) {
return false;
}
if (version == VERSION_SVD) {
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend, model_data_type);
clip_vision->alloc_params_buffer();
clip_vision->get_param_tensors(tensors, "cond_stage_model.");
LOG_DEBUG("preparing memory for the weights");
// prepare memory for the weights
{
// cond_stage_model(FrozenCLIPEmbedder)
cond_stage_model.init_params();
cond_stage_model.map_by_name(tensors, "cond_stage_model.");
diffusion_model = std::make_shared<UNetModel>(backend, model_data_type, version);
diffusion_model->alloc_params_buffer();
diffusion_model->get_param_tensors(tensors, "model.diffusion_model");
// diffusion_model(UNetModel)
diffusion_model.init_params();
diffusion_model.map_by_name(tensors, "model.diffusion_model.");
first_stage_model = std::make_shared<AutoEncoderKL>(backend, model_data_type, vae_decode_only, true);
LOG_DEBUG("vae_decode_only %d", vae_decode_only);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(backend, model_data_type, version);
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors, "cond_stage_model.");
cond_stage_model->embd_dir = embeddings_path;
diffusion_model = std::make_shared<UNetModel>(backend, model_data_type, version);
diffusion_model->alloc_params_buffer();
diffusion_model->get_param_tensors(tensors, "model.diffusion_model");
ggml_type vae_type = model_data_type;
if (version == VERSION_XL) {
vae_type = GGML_TYPE_F32; // avoid nan, not work...
}
if (!use_tiny_autoencoder) {
// firest_stage_model(AutoEncoderKL)
first_stage_model.init_params();
first_stage_model = std::make_shared<AutoEncoderKL>(backend, vae_type, vae_decode_only);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_data_type, vae_decode_only);
}
first_stage_model.map_by_name(tensors, "first_stage_model.");
if (control_net_path.size() > 0) {
ggml_backend_t cn_backend = NULL;
if (control_net_cpu && !ggml_backend_is_cpu(backend)) {
LOG_DEBUG("ControlNet: Using CPU backend");
cn_backend = ggml_backend_cpu_init();
} else {
cn_backend = backend;
}
control_net = std::make_shared<ControlNet>(cn_backend, model_data_type, version);
}
LOG_DEBUG("loading vocab");
std::string merges_utf8_str = model_loader.load_merges();
if (merges_utf8_str.size() == 0) {
LOG_ERROR("get merges failed: '%s'", model_path.c_str());
return false;
}
cond_stage_model->tokenizer.load_from_merges(merges_utf8_str);
}
struct ggml_init_params params;
@ -227,10 +243,7 @@ public:
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* ctx = ggml_init(params); // for alphas_cumprod and is_using_v_parameterization check
if (!ctx) {
LOG_ERROR("ggml_init() failed");
return false;
}
GGML_ASSERT(ctx != NULL);
ggml_tensor* alphas_cumprod_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, TIMESTEPS);
calculate_alphas_cumprod((float*)alphas_cumprod_tensor->data);
@ -238,25 +251,19 @@ public:
LOG_DEBUG("loading weights");
int64_t t0 = ggml_time_ms();
std::map<std::string, struct ggml_tensor*> tensors_need_to_load;
std::set<std::string> ignore_tensors;
tensors_need_to_load["alphas_cumprod"] = alphas_cumprod_tensor;
for (auto& pair : tensors) {
const std::string& name = pair.first;
if (use_tiny_autoencoder && starts_with(name, "first_stage_model.")) {
ignore_tensors.insert(name);
continue;
}
if (vae_decode_only && (starts_with(name, "first_stage_model.encoder") || starts_with(name, "first_stage_model.quant"))) {
ignore_tensors.insert(name);
continue;
}
tensors_need_to_load.insert(pair);
tensors["alphas_cumprod"] = alphas_cumprod_tensor;
if (use_tiny_autoencoder) {
ignore_tensors.insert("first_stage_model.");
}
bool success = model_loader.load_tensors(tensors_need_to_load, backend, ignore_tensors);
if (vae_decode_only) {
ignore_tensors.insert("first_stage_model.encoder");
ignore_tensors.insert("first_stage_model.quant");
}
if (version == VERSION_SVD) {
ignore_tensors.insert("conditioner.embedders.3");
}
bool success = model_loader.load_tensors(tensors, backend, ignore_tensors);
if (!success) {
LOG_ERROR("load tensors from model loader failed");
ggml_free(ctx);
@ -265,15 +272,39 @@ public:
// LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0);
size_t total_params_size =
cond_stage_model.params_buffer_size +
diffusion_model.params_buffer_size +
first_stage_model.params_buffer_size;
LOG_INFO("total memory buffer size = %.2fMB (clip %.2fMB, unet %.2fMB, vae %.2fMB)",
total_params_size / 1024.0 / 1024.0,
cond_stage_model.params_buffer_size / 1024.0 / 1024.0,
diffusion_model.params_buffer_size / 1024.0 / 1024.0,
first_stage_model.params_buffer_size / 1024.0 / 1024.0);
if (version == VERSION_SVD) {
// diffusion_model->test();
// first_stage_model->test();
// return false;
} else {
size_t clip_params_mem_size = cond_stage_model->get_params_mem_size();
size_t unet_params_mem_size = diffusion_model->get_params_mem_size();
size_t vae_params_mem_size = 0;
if (!use_tiny_autoencoder) {
vae_params_mem_size = first_stage_model->get_params_mem_size();
} else {
if (!tae_first_stage->load_from_file(taesd_path)) {
return false;
}
vae_params_mem_size = tae_first_stage->get_params_mem_size();
}
size_t control_net_params_mem_size = 0;
if (control_net) {
if (!control_net->load_from_file(control_net_path)) {
return false;
}
control_net_params_mem_size = control_net->get_params_mem_size();
}
size_t total_params_size = clip_params_mem_size + clip_params_mem_size + clip_params_mem_size + control_net_params_mem_size;
LOG_INFO("total params memory size = %.2fMB (clip %.2fMB, unet %.2fMB, vae %.2fMB, controlnet %.2fMB)",
total_params_size / 1024.0 / 1024.0,
clip_params_mem_size / 1024.0 / 1024.0,
unet_params_mem_size / 1024.0 / 1024.0,
vae_params_mem_size / 1024.0 / 1024.0,
control_net_params_mem_size / 1024.0 / 1024.0);
}
int64_t t1 = ggml_time_ms();
LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000);
@ -283,6 +314,9 @@ public:
if (is_using_v_parameterization_for_sd2(ctx)) {
is_using_v_parameterization = true;
}
} else if (version == VERSION_SVD) {
// TODO: V_PREDICTION_EDM
is_using_v_parameterization = true;
}
if (is_using_v_parameterization) {
@ -319,23 +353,6 @@ public:
LOG_DEBUG("finished loaded file");
ggml_free(ctx);
if (control_net_path.size() > 0) {
ggml_backend_t cn_backend = NULL;
if (control_net_cpu && !ggml_backend_is_cpu(backend)) {
LOG_DEBUG("ControlNet: Using CPU backend");
cn_backend = ggml_backend_cpu_init();
} else {
cn_backend = backend;
}
if (!control_net.load_from_file(control_net_path, cn_backend, GGML_TYPE_F16 /* just f16 controlnet models */)) {
return false;
}
}
if (use_tiny_autoencoder) {
return tae_first_stage.load_from_file(taesd_path, backend);
}
return true;
}
@ -345,17 +362,11 @@ public:
struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1);
ggml_set_f32(c, 0.5);
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ]
struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels]
int64_t t0 = ggml_time_ms();
ggml_set_f32(timesteps, 999);
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
std::vector<struct ggml_tensor*> controls;
diffusion_model.alloc_compute_buffer(x_t, c, controls, t_emb);
diffusion_model.compute(out, n_threads, x_t, NULL, c, controls, 1.0f, t_emb);
diffusion_model.free_compute_buffer();
std::vector<float> timesteps = {999.f}; // [N, ]
int64_t t0 = ggml_time_ms();
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, -1, {}, 0.f, &out);
diffusion_model->free_compute_buffer();
double result = 0.f;
{
@ -387,15 +398,14 @@ public:
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
return;
}
LoraModel lora(file_path);
if (!lora.load_from_file(backend)) {
LoraModel lora(backend, model_data_type, file_path);
if (!lora.load_from_file()) {
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
return;
}
lora.multiplier = multiplier;
lora.apply(tensors, n_threads);
loras[lora_name] = lora;
lora.free_params_buffer();
int64_t t1 = ggml_time_ms();
@ -438,24 +448,17 @@ public:
int width,
int height,
bool force_zero_embeddings = false) {
cond_stage_model.set_clip_skip(clip_skip);
auto tokens_and_weights = cond_stage_model.tokenize(text, true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
int64_t t0 = ggml_time_ms();
struct ggml_tensor* pooled = NULL;
size_t total_hidden_size = cond_stage_model.text_model.hidden_size;
cond_stage_model->set_clip_skip(clip_skip);
auto tokens_and_weights = cond_stage_model->tokenize(text, true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
struct ggml_tensor* pooled = NULL;
cond_stage_model->compute(n_threads, tokens, false, &hidden_states, work_ctx);
if (version == VERSION_XL) {
total_hidden_size += cond_stage_model.text_model2.hidden_size;
pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, cond_stage_model.text_model2.projection_dim);
cond_stage_model->compute(n_threads, tokens, true, &pooled, work_ctx);
}
struct ggml_tensor* hidden_states = ggml_new_tensor_2d(work_ctx,
GGML_TYPE_F32,
total_hidden_size,
cond_stage_model.text_model.max_position_embeddings); // [N, n_token, hidden_size]
cond_stage_model.alloc_compute_buffer(work_ctx, (int)tokens.size());
cond_stage_model.compute(n_threads, tokens, hidden_states, pooled);
cond_stage_model.free_compute_buffer();
// if (pooled != NULL) {
// print_ggml_tensor(hidden_states);
// print_ggml_tensor(pooled);
@ -488,18 +491,17 @@ public:
ggml_tensor* vec = NULL;
if (version == VERSION_XL) {
int out_dim = 256;
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, diffusion_model.adm_in_channels);
vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, diffusion_model->unet.adm_in_channels);
// [0:1280]
size_t offset = 0;
memcpy(vec->data, pooled->data, ggml_nbytes(pooled));
offset += ggml_nbytes(pooled);
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 2);
// original_size_as_tuple
float orig_width = (float)width;
float orig_height = (float)height;
ggml_tensor_set_f32(timesteps, orig_height, 0);
ggml_tensor_set_f32(timesteps, orig_width, 1);
float orig_width = (float)width;
float orig_height = (float)height;
std::vector<float> timesteps = {orig_height, orig_width};
ggml_tensor* embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
offset += ggml_nbytes(embed_view);
set_timestep_embedding(timesteps, embed_view, out_dim);
@ -507,18 +509,16 @@ public:
// crop_coords_top_left
float crop_coord_top = 0.f;
float crop_coord_left = 0.f;
ggml_tensor_set_f32(timesteps, crop_coord_top, 0);
ggml_tensor_set_f32(timesteps, crop_coord_left, 1);
embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
timesteps = {crop_coord_top, crop_coord_left};
embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
offset += ggml_nbytes(embed_view);
set_timestep_embedding(timesteps, embed_view, out_dim);
// print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2));
// target_size_as_tuple
float target_width = (float)width;
float target_height = (float)height;
ggml_tensor_set_f32(timesteps, target_height, 0);
ggml_tensor_set_f32(timesteps, target_width, 1);
embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
timesteps = {target_height, target_width};
embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset);
offset += ggml_nbytes(embed_view);
set_timestep_embedding(timesteps, embed_view, out_dim);
// print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2));
@ -528,18 +528,103 @@ public:
return {result, vec};
}
std::tuple<ggml_tensor*, ggml_tensor*, ggml_tensor*> get_svd_condition(ggml_context* work_ctx,
sd_image_t init_image,
int width,
int height,
int fps = 6,
int motion_bucket_id = 127,
float augmentation_level = 0.f,
bool force_zero_embeddings = false) {
// c_crossattn
int64_t t0 = ggml_time_ms();
struct ggml_tensor* c_crossattn = NULL;
{
if (force_zero_embeddings) {
c_crossattn = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, clip_vision->vision_model.projection_dim);
ggml_set_f32(c_crossattn, 0.f);
} else {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image);
sd_image_f32_t resized_image = clip_preprocess(image, clip_vision->vision_model.image_size);
free(image.data);
image.data = NULL;
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image.data, pixel_values, false);
free(resized_image.data);
resized_image.data = NULL;
// print_ggml_tensor(pixel_values);
clip_vision->compute(n_threads, pixel_values, &c_crossattn, work_ctx);
// print_ggml_tensor(c_crossattn);
}
}
// c_concat
struct ggml_tensor* c_concat = NULL;
{
if (force_zero_embeddings) {
c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1);
ggml_set_f32(c_concat, 0.f);
} else {
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
if (width != init_image.width || height != init_image.height) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(init_image);
sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height);
free(image.data);
image.data = NULL;
sd_image_f32_to_tensor(resized_image.data, init_img, false);
free(resized_image.data);
resized_image.data = NULL;
} else {
sd_image_to_tensor(init_image.data, init_img);
}
if (augmentation_level > 0.f) {
struct ggml_tensor* noise = ggml_dup_tensor(work_ctx, init_img);
ggml_tensor_set_f32_randn(noise, rng);
// encode_pixels += torch.randn_like(pixels) * augmentation_level
ggml_tensor_scale(noise, augmentation_level);
ggml_tensor_add(init_img, noise);
}
print_ggml_tensor(init_img);
ggml_tensor* moments = encode_first_stage(work_ctx, init_img);
print_ggml_tensor(moments);
c_concat = get_first_stage_encoding(work_ctx, moments);
}
print_ggml_tensor(c_concat);
}
// y
struct ggml_tensor* y = NULL;
{
y = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, diffusion_model->unet.adm_in_channels);
int out_dim = 256;
int fps_id = fps - 1;
std::vector<float> timesteps = {(float)fps_id, (float)motion_bucket_id, augmentation_level};
set_timestep_embedding(timesteps, y, out_dim);
print_ggml_tensor(y);
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing svd condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {c_crossattn, c_concat, y};
}
ggml_tensor* sample(ggml_context* work_ctx,
ggml_tensor* x_t,
ggml_tensor* noise,
ggml_tensor* c,
ggml_tensor* c_concat,
ggml_tensor* c_vector,
ggml_tensor* uc,
ggml_tensor* uc_concat,
ggml_tensor* uc_vector,
ggml_tensor* control_hint,
float control_strength,
float min_cfg,
float cfg_scale,
sample_method_t method,
const std::vector<float>& sigmas,
float control_strength) {
const std::vector<float>& sigmas) {
size_t steps = sigmas.size() - 1;
// x_t = load_tensor_from_file(work_ctx, "./rand0.bin");
// print_ggml_tensor(x_t);
@ -547,16 +632,7 @@ public:
copy_ggml_tensor(x, x_t);
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t);
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ]
struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels]
struct ggml_tensor* guided_hint = NULL;
if (control_hint != NULL) {
guided_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, noised_input->ne[0], noised_input->ne[1], diffusion_model.model_channels, 1);
control_net.process_hint(guided_hint, n_threads, control_hint);
control_net.alloc_compute_buffer(noised_input, guided_hint, c, t_emb);
}
diffusion_model.alloc_compute_buffer(noised_input, c, control_net.controls, t_emb, c_vector);
bool has_unconditioned = cfg_scale != 1.0 && uc != NULL;
@ -598,27 +674,50 @@ public:
}
float t = denoiser->schedule->sigma_to_t(sigma);
ggml_set_f32(timesteps, t);
set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels);
std::vector<float> timesteps(x->ne[3], t); // [N, ]
copy_ggml_tensor(noised_input, input);
// noised_input = noised_input * c_in
ggml_tensor_scale(noised_input, c_in);
// cond
std::vector<struct ggml_tensor*> controls;
if (control_hint != NULL) {
control_net.compute(n_threads, noised_input, guided_hint, c, t_emb);
control_net->compute(n_threads, noised_input, control_hint, timesteps, c, c_vector);
controls = control_net->controls;
// print_ggml_tensor(controls[12]);
// GGML_ASSERT(0);
}
diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, control_strength, t_emb, c_vector);
// cond
diffusion_model->compute(n_threads,
noised_input,
timesteps,
c,
c_concat,
c_vector,
-1,
controls,
control_strength,
&out_cond);
float* negative_data = NULL;
if (has_unconditioned) {
// uncond
if (control_hint != NULL) {
control_net.compute(n_threads, noised_input, guided_hint, uc, t_emb);
control_net->compute(n_threads, noised_input, control_hint, timesteps, uc, uc_vector);
controls = control_net->controls;
}
diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, control_strength, t_emb, uc_vector);
diffusion_model->compute(n_threads,
noised_input,
timesteps,
uc,
uc_concat,
uc_vector,
-1,
controls,
control_strength,
&out_uncond);
negative_data = (float*)out_uncond->data;
}
float* vec_denoised = (float*)denoised->data;
@ -629,7 +728,13 @@ public:
float latent_result = positive_data[i];
if (has_unconditioned) {
// out_uncond + cfg_scale * (out_cond - out_uncond)
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
int64_t ne3 = out_cond->ne[3];
if (min_cfg != cfg_scale && ne3 != 1) {
int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2];
float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3);
} else {
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
}
}
// v = latent_result, eps = latent_result
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
@ -1027,8 +1132,11 @@ public:
LOG_ERROR("Attempting to sample with nonexisting sample method %i", method);
abort();
}
control_net.free_compute_buffer();
diffusion_model.free_compute_buffer();
if (control_net) {
control_net->free_control_ctx();
control_net->free_compute_buffer();
}
diffusion_model->free_compute_buffer();
return x;
}
@ -1067,10 +1175,11 @@ public:
ggml_tensor* compute_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];
ggml_tensor* result = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32,
decode ? (W * 8) : (W / 8), // width
decode ? (H * 8) : (H / 8), // height
decode ? 3 : (use_tiny_autoencoder ? 4 : 8)); // channels
ggml_tensor* result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
decode ? (W * 8) : (W / 8), // width
decode ? (H * 8) : (H / 8), // height
decode ? 3 : (use_tiny_autoencoder ? 4 : 8),
x->ne[3]); // channels
int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) {
if (decode) {
@ -1081,18 +1190,13 @@ public:
if (vae_tiling && decode) { // TODO: support tiling vae encode
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
if (init) {
first_stage_model.alloc_compute_buffer(in, decode);
} else {
first_stage_model.compute(out, n_threads, in, decode);
}
first_stage_model->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 32, 0.5f, on_tiling);
} else {
first_stage_model.alloc_compute_buffer(x, decode);
first_stage_model.compute(result, n_threads, x, decode);
first_stage_model->compute(n_threads, x, decode, &result);
}
first_stage_model.free_compute_buffer();
first_stage_model->free_compute_buffer();
if (decode) {
ggml_tensor_scale_output(result);
}
@ -1100,19 +1204,15 @@ public:
if (vae_tiling && decode) { // TODO: support tiling vae encode
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
if (init) {
tae_first_stage.alloc_compute_buffer(in, decode);
} else {
tae_first_stage.compute(out, n_threads, in, decode);
}
tae_first_stage->compute(n_threads, in, decode, &out);
};
sd_tiling(x, result, 8, 64, 0.5f, on_tiling);
} else {
tae_first_stage.alloc_compute_buffer(x, decode);
tae_first_stage.compute(result, n_threads, x, decode);
tae_first_stage->compute(n_threads, x, decode, &result);
}
tae_first_stage.free_compute_buffer();
tae_first_stage->free_compute_buffer();
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing vae [mode: %s] graph completed, taking %.2fs", decode ? "DECODE" : "ENCODE", (t1 - t0) * 1.0f / 1000);
if (decode) {
@ -1272,7 +1372,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->cond_stage_model.free_params_buffer();
sd_ctx->sd->cond_stage_model->free_params_buffer();
}
struct ggml_tensor* image_hint = NULL;
@ -1297,7 +1397,21 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, image_hint, cfg_scale, sample_method, sigmas, control_strength);
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
x_t,
NULL,
c,
NULL,
c_vector,
uc,
NULL,
uc_vector,
image_hint,
control_strength,
cfg_scale,
cfg_scale,
sample_method,
sigmas);
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms();
@ -1306,7 +1420,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
}
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model.free_params_buffer();
sd_ctx->sd->diffusion_model->free_params_buffer();
}
int64_t t3 = ggml_time_ms();
LOG_INFO("generating %" PRId64 " latent images completed, taking %.2fs", final_latents.size(), (t3 - t1) * 1.0f / 1000);
@ -1327,7 +1441,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int64_t t4 = ggml_time_ms();
LOG_INFO("decode_first_stage completed, taking %.2fs", (t4 - t3) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
sd_ctx->sd->first_stage_model.free_params_buffer();
sd_ctx->sd->first_stage_model->free_params_buffer();
}
sd_image_t* result_images = (sd_image_t*)calloc(batch_count, sizeof(sd_image_t));
if (result_images == NULL) {
@ -1442,7 +1556,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
int64_t t2 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->cond_stage_model.free_params_buffer();
sd_ctx->sd->cond_stage_model->free_params_buffer();
}
sd_ctx->sd->rng->manual_seed(seed);
@ -1450,19 +1564,32 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng);
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, init_latent, noise, c, c_vector, uc,
uc_vector, NULL, cfg_scale, sample_method, sigma_sched, 1.0f);
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
init_latent,
noise,
c,
NULL,
c_vector,
uc,
NULL,
uc_vector,
{},
0.f,
cfg_scale,
cfg_scale,
sample_method,
sigma_sched);
// struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t t3 = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (t3 - t2) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model.free_params_buffer();
sd_ctx->sd->diffusion_model->free_params_buffer();
}
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, x_0);
if (sd_ctx->sd->free_params_immediately && !sd_ctx->sd->use_tiny_autoencoder) {
sd_ctx->sd->first_stage_model.free_params_buffer();
sd_ctx->sd->first_stage_model->free_params_buffer();
}
if (img == NULL) {
ggml_free(work_ctx);
@ -1490,3 +1617,139 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
return result_images;
}
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
sd_image_t init_image,
int width,
int height,
int video_frames,
int motion_bucket_id,
int fps,
float augmentation_level,
float min_cfg,
float cfg_scale,
enum sample_method_t sample_method,
int sample_steps,
float strength,
int64_t seed) {
if (sd_ctx == NULL) {
return NULL;
}
LOG_INFO("img2vid %dx%d", width, height);
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024) * 1024; // 10 MB
params.mem_size += width * height * 3 * sizeof(float) * video_frames;
params.mem_buffer = NULL;
params.no_alloc = false;
// LOG_DEBUG("mem_size %u ", params.mem_size);
// draft context
struct ggml_context* work_ctx = ggml_init(params);
if (!work_ctx) {
LOG_ERROR("ggml_init() failed");
return NULL;
}
if (seed < 0) {
seed = (int)time(NULL);
}
sd_ctx->sd->rng->manual_seed(seed);
int64_t t0 = ggml_time_ms();
ggml_tensor* c_crossattn = NULL;
ggml_tensor* c_concat = NULL;
ggml_tensor* c_vector = NULL;
ggml_tensor* uc_crossattn = NULL;
ggml_tensor* uc_concat = NULL;
ggml_tensor* uc_vector = NULL;
std::tie(c_crossattn, c_concat, c_vector) = sd_ctx->sd->get_svd_condition(work_ctx,
init_image,
width,
height,
fps,
motion_bucket_id,
augmentation_level);
uc_crossattn = ggml_dup_tensor(work_ctx, c_crossattn);
ggml_set_f32(uc_crossattn, 0.f);
uc_concat = ggml_dup_tensor(work_ctx, c_concat);
ggml_set_f32(uc_concat, 0.f);
uc_vector = ggml_dup_tensor(work_ctx, c_vector);
int64_t t1 = ggml_time_ms();
LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->clip_vision->free_params_buffer();
}
sd_ctx->sd->rng->manual_seed(seed);
int C = 4;
int W = width / 8;
int H = height / 8;
struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, video_frames);
ggml_tensor_set_f32_randn(x_t, sd_ctx->sd->rng);
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
x_t,
NULL,
c_crossattn,
c_concat,
c_vector,
uc_crossattn,
uc_concat,
uc_vector,
{},
0.f,
min_cfg,
cfg_scale,
sample_method,
sigmas);
int64_t t2 = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->diffusion_model->free_params_buffer();
}
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, x_0);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->first_stage_model->free_params_buffer();
}
if (img == NULL) {
ggml_free(work_ctx);
return NULL;
}
sd_image_t* result_images = (sd_image_t*)calloc(video_frames, sizeof(sd_image_t));
if (result_images == NULL) {
ggml_free(work_ctx);
return NULL;
}
for (size_t i = 0; i < video_frames; i++) {
auto img_i = ggml_view_3d(work_ctx, img, img->ne[0], img->ne[1], img->ne[2], img->nb[1], img->nb[2], img->nb[3] * i);
result_images[i].width = width;
result_images[i].height = height;
result_images[i].channel = 3;
result_images[i].data = sd_tensor_to_image(img_i);
}
ggml_free(work_ctx);
int64_t t3 = ggml_time_ms();
LOG_INFO("img2vid completed in %.2fs", (t3 - t0) * 1.0f / 1000);
return result_images;
}

View File

@ -72,6 +72,7 @@ enum sd_type_t {
SD_TYPE_Q6_K = 14,
SD_TYPE_Q8_K = 15,
SD_TYPE_IQ2_XXS = 16,
SD_TYPE_IQ2_XS = 17,
SD_TYPE_I8,
SD_TYPE_I16,
SD_TYPE_I32,
@ -147,6 +148,21 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
int64_t seed,
int batch_count);
SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
sd_image_t init_image,
int width,
int height,
int video_frames,
int motion_bucket_id,
int fps,
float augmentation_level,
float min_cfg,
float cfg_scale,
enum sample_method_t sample_method,
int sample_steps,
float strength,
int64_t seed);
typedef struct upscaler_ctx_t upscaler_ctx_t;
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,

622
tae.hpp
View File

@ -8,88 +8,45 @@
/*
=================================== TinyAutoEncoder ===================================
References:
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_tiny.py
https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/vae.py
https://github.com/madebyollin/taesd/blob/main/taesd.py
*/
struct TAEBlock {
int in_channels;
int out_channels;
// conv
ggml_tensor* conv_0_w; // [in_channels, out_channels, 3, 3]
ggml_tensor* conv_0_b; // [in_channels]
ggml_tensor* conv_1_w; // [out_channels, out_channels, 3, 3]
ggml_tensor* conv_1_b; // [out_channels]
ggml_tensor* conv_2_w; // [out_channels, out_channels, 3, 3]
ggml_tensor* conv_2_b; // [out_channels]
class TAEBlock : public UnaryBlock {
protected:
int n_in;
int n_out;
// skip
ggml_tensor* conv_skip_w; // [in_channels, out_channels, 1, 1]
size_t calculate_mem_size() {
size_t mem_size = in_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_0_w
mem_size += in_channels * ggml_type_size(GGML_TYPE_F32); // conv_0_b
mem_size += out_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_1_b
mem_size += out_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_1_b
mem_size += out_channels * out_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_2_w
mem_size += out_channels * ggml_type_size(GGML_TYPE_F32); // conv_2_b
if (in_channels != out_channels) {
mem_size += in_channels * out_channels * ggml_type_size(GGML_TYPE_F16); // conv_skip_w
}
return mem_size;
}
int get_num_tensors() {
return 6 + (in_channels != out_channels ? 1 : 0);
}
void init_params(ggml_context* ctx) {
conv_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, in_channels);
conv_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels);
conv_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
conv_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels);
conv_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
if (in_channels != out_channels) {
conv_skip_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, out_channels, in_channels);
public:
TAEBlock(int n_in, int n_out)
: n_in(n_in), n_out(n_out) {
blocks["conv.0"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.2"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
blocks["conv.4"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_out, n_out, {3, 3}, {1, 1}, {1, 1}));
if (n_in != n_out) {
blocks["skip"] = std::shared_ptr<GGMLBlock>(new Conv2d(n_in, n_out, {1, 1}, {1, 1}, {1, 1}, {1, 1}, false));
}
}
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
tensors[prefix + "conv.0.weight"] = conv_0_w;
tensors[prefix + "conv.0.bias"] = conv_0_b;
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [n, n_in, h, w]
// return: [n, n_out, h, w]
tensors[prefix + "conv.2.weight"] = conv_1_w;
tensors[prefix + "conv.2.bias"] = conv_1_b;
auto conv_0 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.0"]);
auto conv_2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.2"]);
auto conv_4 = std::dynamic_pointer_cast<Conv2d>(blocks["conv.4"]);
tensors[prefix + "conv.4.weight"] = conv_2_w;
tensors[prefix + "conv.4.bias"] = conv_2_b;
auto h = conv_0->forward(ctx, x);
h = ggml_relu_inplace(ctx, h);
h = conv_2->forward(ctx, h);
h = ggml_relu_inplace(ctx, h);
h = conv_4->forward(ctx, h);
if (in_channels != out_channels) {
tensors[prefix + "skip.weight"] = conv_skip_w;
}
}
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) {
// conv(n_in, n_out)
ggml_tensor* h;
h = ggml_nn_conv_2d(ctx, x, conv_0_w, conv_0_b, 1, 1, 1, 1);
h = ggml_relu_inplace(ctx, h);
h = ggml_nn_conv_2d(ctx, h, conv_1_w, conv_1_b, 1, 1, 1, 1);
h = ggml_relu_inplace(ctx, h);
h = ggml_nn_conv_2d(ctx, h, conv_2_w, conv_2_b, 1, 1, 1, 1);
// skip connection
if (in_channels != out_channels) {
// skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
x = ggml_nn_conv_2d(ctx, x, conv_skip_w, NULL, 1, 1, 1, 1);
if (n_in != n_out) {
auto skip = std::dynamic_pointer_cast<Conv2d>(blocks["skip"]);
LOG_DEBUG("skip");
x = skip->forward(ctx, x);
}
h = ggml_add(ctx, h, x);
@ -98,412 +55,167 @@ struct TAEBlock {
}
};
struct TinyEncoder {
class TinyEncoder : public UnaryBlock {
int in_channels = 3;
int z_channels = 4;
int channels = 64;
int z_channels = 4;
int num_blocks = 3;
// input
ggml_tensor* conv_input_w; // [channels, in_channels, 3, 3]
ggml_tensor* conv_input_b; // [channels]
TAEBlock initial_block;
ggml_tensor* conv_1_w; // [channels, channels, 3, 3]
TAEBlock input_blocks[3];
// middle
ggml_tensor* conv_2_w; // [channels, channels, 3, 3]
TAEBlock middle_blocks[3];
// output
ggml_tensor* conv_3_w; // [channels, channels, 3, 3]
TAEBlock output_blocks[3];
// final
ggml_tensor* conv_final_w; // [z_channels, channels, 3, 3]
ggml_tensor* conv_final_b; // [z_channels]
public:
TinyEncoder() {
int index = 0;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1}));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) {
input_blocks[i].in_channels = channels;
input_blocks[i].out_channels = channels;
middle_blocks[i].in_channels = channels;
middle_blocks[i].out_channels = channels;
output_blocks[i].in_channels = channels;
output_blocks[i].out_channels = channels;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
initial_block.in_channels = channels;
initial_block.out_channels = channels;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {2, 2}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, z_channels, {3, 3}, {1, 1}, {1, 1}));
}
size_t calculate_mem_size() {
size_t mem_size = channels * in_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
mem_size += channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [n, in_channels, h, w]
// return: [n, z_channels, h/8, w/8]
mem_size += initial_block.calculate_mem_size();
for (int i = 0; i < num_blocks * 3 + 6; i++) {
auto block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(i)]);
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_2_w
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_3_w
for (int i = 0; i < num_blocks; i++) {
mem_size += input_blocks[i].calculate_mem_size();
mem_size += middle_blocks[i].calculate_mem_size();
mem_size += output_blocks[i].calculate_mem_size();
}
mem_size += z_channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
mem_size += z_channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
return mem_size;
}
int get_num_tensors() {
int num_tensors = 7;
for (int i = 0; i < num_blocks; i++) {
num_tensors += input_blocks[i].get_num_tensors();
num_tensors += middle_blocks[i].get_num_tensors();
num_tensors += output_blocks[i].get_num_tensors();
}
num_tensors += initial_block.get_num_tensors();
return num_tensors;
}
void init_params(ggml_context* ctx) {
conv_input_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, channels);
conv_input_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels);
initial_block.init_params(ctx);
conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
conv_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
conv_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
conv_final_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, z_channels);
conv_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_channels);
for (int i = 0; i < num_blocks; i++) {
input_blocks[i].init_params(ctx);
middle_blocks[i].init_params(ctx);
output_blocks[i].init_params(ctx);
}
}
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
tensors[prefix + "0.weight"] = conv_input_w;
tensors[prefix + "0.bias"] = conv_input_b;
initial_block.map_by_name(tensors, prefix + "1.");
tensors[prefix + "2.weight"] = conv_1_w;
for (int i = 0; i < num_blocks; i++) {
input_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 3) + ".");
x = block->forward(ctx, x);
}
tensors[prefix + "6.weight"] = conv_2_w;
for (int i = 0; i < num_blocks; i++) {
middle_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 7) + ".");
}
tensors[prefix + "10.weight"] = conv_3_w;
for (int i = 0; i < num_blocks; i++) {
output_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 11) + ".");
}
tensors[prefix + "14.weight"] = conv_final_w;
tensors[prefix + "14.bias"] = conv_final_b;
}
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) {
// conv(3, 64)
auto z = ggml_nn_conv_2d(ctx, x, conv_input_w, conv_input_b, 1, 1, 1, 1);
// Block(64, 64)
z = initial_block.forward(ctx, z);
// conv(64, 64, stride=2, bias=False)
z = ggml_nn_conv_2d(ctx, z, conv_1_w, NULL, 2, 2, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
z = input_blocks[i].forward(ctx, z);
}
// conv(64, 64, stride=2, bias=False)
z = ggml_nn_conv_2d(ctx, z, conv_2_w, NULL, 2, 2, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
z = middle_blocks[i].forward(ctx, z);
}
// conv(64, 64, stride=2, bias=False)
z = ggml_nn_conv_2d(ctx, z, conv_3_w, NULL, 2, 2, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
z = output_blocks[i].forward(ctx, z);
}
// conv(64, 4)
z = ggml_nn_conv_2d(ctx, z, conv_final_w, conv_final_b, 1, 1, 1, 1);
return z;
return x;
}
};
struct TinyDecoder {
int z_channels = 4;
int channels = 64;
int output_channels = 3;
int num_blocks = 3;
class TinyDecoder : public UnaryBlock {
int z_channels = 4;
int channels = 64;
int out_channels = 3;
int num_blocks = 3;
// input
ggml_tensor* conv_input_w; // [channels, z_channels, 3, 3]
ggml_tensor* conv_input_b; // [channels]
TAEBlock input_blocks[3];
ggml_tensor* conv_1_w; // [channels, channels, 3, 3]
public:
TinyDecoder(int index = 0) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1}));
index++; // nn.ReLU()
// middle
TAEBlock middle_blocks[3];
ggml_tensor* conv_2_w; // [channels, channels, 3, 3]
// output
TAEBlock output_blocks[3];
ggml_tensor* conv_3_w; // [channels, channels, 3, 3]
// final
TAEBlock final_block;
ggml_tensor* conv_final_w; // [output_channels, channels, 3, 3]
ggml_tensor* conv_final_b; // [output_channels]
TinyDecoder() {
for (int i = 0; i < num_blocks; i++) {
input_blocks[i].in_channels = channels;
input_blocks[i].out_channels = channels;
middle_blocks[i].in_channels = channels;
middle_blocks[i].out_channels = channels;
output_blocks[i].in_channels = channels;
output_blocks[i].out_channels = channels;
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
final_block.in_channels = channels;
final_block.out_channels = channels;
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
for (int i = 0; i < num_blocks; i++) {
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
}
index++; // nn.Upsample()
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, channels, {3, 3}, {1, 1}, {1, 1}, {1, 1}, false));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new TAEBlock(channels, channels));
blocks[std::to_string(index++)] = std::shared_ptr<GGMLBlock>(new Conv2d(channels, out_channels, {3, 3}, {1, 1}, {1, 1}));
}
size_t calculate_mem_size() {
size_t mem_size = channels * z_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
mem_size += channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
// z: [n, z_channels, h, w]
// return: [n, out_channels, h*8, w*8]
for (int i = 0; i < num_blocks; i++) {
mem_size += input_blocks[i].calculate_mem_size();
}
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w
for (int i = 0; i < num_blocks; i++) {
mem_size += middle_blocks[i].calculate_mem_size();
}
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_2_w
for (int i = 0; i < num_blocks; i++) {
mem_size += output_blocks[i].calculate_mem_size();
}
mem_size += channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_3_w
mem_size += final_block.calculate_mem_size();
mem_size += output_channels * channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_input_w
mem_size += output_channels * ggml_type_size(GGML_TYPE_F32); // conv_input_b
return mem_size;
}
int get_num_tensors() {
int num_tensors = 9;
for (int i = 0; i < num_blocks; i++) {
num_tensors += input_blocks[i].get_num_tensors();
num_tensors += middle_blocks[i].get_num_tensors();
num_tensors += output_blocks[i].get_num_tensors();
}
num_tensors += final_block.get_num_tensors();
return num_tensors;
}
void init_params(ggml_allocr* alloc, ggml_context* ctx) {
conv_input_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, z_channels, channels);
conv_input_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels);
conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
conv_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
conv_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, channels);
conv_final_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, output_channels);
conv_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, output_channels);
for (int i = 0; i < num_blocks; i++) {
input_blocks[i].init_params(ctx);
middle_blocks[i].init_params(ctx);
output_blocks[i].init_params(ctx);
}
final_block.init_params(ctx);
}
void map_by_name(std::map<std::string, ggml_tensor*>& tensors, std::string prefix) {
tensors[prefix + "0.weight"] = conv_input_w;
tensors[prefix + "0.bias"] = conv_input_b;
for (int i = 0; i < num_blocks; i++) {
input_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 2) + ".");
}
tensors[prefix + "6.weight"] = conv_1_w;
for (int i = 0; i < num_blocks; i++) {
middle_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 7) + ".");
}
tensors[prefix + "11.weight"] = conv_2_w;
for (int i = 0; i < num_blocks; i++) {
output_blocks[i].map_by_name(tensors, prefix + std::to_string(i + 12) + ".");
}
tensors[prefix + "16.weight"] = conv_3_w;
final_block.map_by_name(tensors, prefix + "17.");
tensors[prefix + "18.weight"] = conv_final_w;
tensors[prefix + "18.bias"] = conv_final_b;
}
ggml_tensor* forward(ggml_context* ctx, ggml_tensor* z) {
// torch.tanh(x / 3) * 3
auto h = ggml_scale(ctx, z, 1.0f / 3.0f);
h = ggml_tanh_inplace(ctx, h);
h = ggml_scale(ctx, h, 3.0f);
// conv(4, 64)
h = ggml_nn_conv_2d(ctx, h, conv_input_w, conv_input_b, 1, 1, 1, 1);
for (int i = 0; i < num_blocks * 3 + 10; i++) {
if (blocks.find(std::to_string(i)) == blocks.end()) {
if (i == 1) {
h = ggml_relu_inplace(ctx, h);
} else {
h = ggml_upscale(ctx, h, 2);
}
continue;
}
auto block = std::dynamic_pointer_cast<UnaryBlock>(blocks[std::to_string(i)]);
// nn.ReLU()
h = ggml_relu_inplace(ctx, h);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
h = input_blocks[i].forward(ctx, h);
h = block->forward(ctx, h);
}
// nn.Upsample(scale_factor=2)
h = ggml_upscale(ctx, h, 2);
// conv(64, 64, bias=False)
h = ggml_nn_conv_2d(ctx, h, conv_1_w, NULL, 1, 1, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
h = middle_blocks[i].forward(ctx, h);
}
// nn.Upsample(scale_factor=2)
h = ggml_upscale(ctx, h, 2);
// conv(64, 64, bias=False)
h = ggml_nn_conv_2d(ctx, h, conv_2_w, NULL, 1, 1, 1, 1);
// Block(64, 64), Block(64, 64), Block(64, 64)
for (int i = 0; i < num_blocks; i++) {
h = output_blocks[i].forward(ctx, h);
}
// nn.Upsample(scale_factor=2)
h = ggml_upscale(ctx, h, 2);
// conv(64, 64, bias=False)
h = ggml_nn_conv_2d(ctx, h, conv_3_w, NULL, 1, 1, 1, 1);
// Block(64, 64)
h = final_block.forward(ctx, h);
// conv(64, 3)
h = ggml_nn_conv_2d(ctx, h, conv_final_w, conv_final_b, 1, 1, 1, 1);
return h;
}
};
class TAESD : public GGMLBlock {
protected:
bool decode_only;
public:
TAESD(bool decode_only = true)
: decode_only(decode_only) {
blocks["decoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyDecoder());
if (!decode_only) {
blocks["encoder.layers"] = std::shared_ptr<GGMLBlock>(new TinyEncoder());
}
}
struct ggml_tensor* decode(struct ggml_context* ctx, struct ggml_tensor* z) {
auto decoder = std::dynamic_pointer_cast<TinyDecoder>(blocks["decoder.layers"]);
return decoder->forward(ctx, z);
}
struct ggml_tensor* encode(struct ggml_context* ctx, struct ggml_tensor* x) {
auto encoder = std::dynamic_pointer_cast<TinyEncoder>(blocks["encoder.layers"]);
return encoder->forward(ctx, x);
}
};
struct TinyAutoEncoder : public GGMLModule {
TinyEncoder encoder;
TinyDecoder decoder;
TAESD taesd;
bool decode_only = false;
TinyAutoEncoder(bool decoder_only_ = true)
: decode_only(decoder_only_) {
name = "tae";
TinyAutoEncoder(ggml_backend_t backend,
ggml_type wtype,
bool decoder_only = true)
: decode_only(decoder_only),
taesd(decode_only),
GGMLModule(backend, wtype) {
taesd.init(params_ctx, wtype);
}
size_t calculate_mem_size() {
size_t mem_size = decoder.calculate_mem_size();
if (!decode_only) {
mem_size += encoder.calculate_mem_size();
}
mem_size += 1024; // padding
return mem_size;
std::string get_desc() {
return "taesd";
}
size_t get_num_tensors() {
size_t num_tensors = decoder.get_num_tensors();
if (!decode_only) {
num_tensors += encoder.get_num_tensors();
}
return num_tensors;
size_t get_params_mem_size() {
return taesd.get_params_mem_size();
}
void init_params() {
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
decoder.init_params(alloc, params_ctx);
if (!decode_only) {
encoder.init_params(params_ctx);
}
// alloc all tensors linked to this context
for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
if (t->data == NULL) {
ggml_allocr_alloc(alloc, t);
}
}
ggml_allocr_free(alloc);
size_t get_params_num() {
return taesd.get_params_num();
}
void map_by_name(std::map<std::string, ggml_tensor*>& tensors) {
decoder.map_by_name(tensors, "decoder.layers.");
encoder.map_by_name(tensors, "encoder.layers.");
}
bool load_from_file(const std::string& file_path, ggml_backend_t backend) {
bool load_from_file(const std::string& file_path) {
LOG_INFO("loading taesd from '%s'", file_path.c_str());
if (!alloc_params_buffer(backend)) {
return false;
}
alloc_params_buffer();
std::map<std::string, ggml_tensor*> taesd_tensors;
// prepare memory for the weights
{
init_params();
map_by_name(taesd_tensors);
}
std::map<std::string, struct ggml_tensor*> tensors_need_to_load;
taesd.get_param_tensors(taesd_tensors);
std::set<std::string> ignore_tensors;
for (auto& pair : taesd_tensors) {
const std::string& name = pair.first;
if (decode_only && starts_with(name, "encoder")) {
ignore_tensors.insert(name);
continue;
}
tensors_need_to_load.insert(pair);
if (decode_only) {
ignore_tensors.insert("encoder.");
}
ModelLoader model_loader;
@ -512,7 +224,7 @@ struct TinyAutoEncoder : public GGMLModule {
return false;
}
bool success = model_loader.load_tensors(tensors_need_to_load, backend, ignore_tensors);
bool success = model_loader.load_tensors(taesd_tensors, backend, ignore_tensors);
if (!success) {
LOG_ERROR("load tae tensors from model loader failed");
@ -524,57 +236,23 @@ struct TinyAutoEncoder : public GGMLModule {
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
// since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params = {
/*.mem_size =*/buf_size,
/*.mem_buffer =*/buf.data(),
/*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
};
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* ctx0 = ggml_init(params);
struct ggml_cgraph* gf = ggml_new_graph(ctx0);
struct ggml_tensor* z_ = NULL;
// it's performing a compute, check if backend isn't cpu
if (!ggml_backend_is_cpu(backend)) {
// pass input tensors to gpu memory
z_ = ggml_dup_tensor(ctx0, z);
ggml_allocr_alloc(compute_allocr, z_);
// pass data to device backend
if (!ggml_allocr_is_measure(compute_allocr)) {
ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z));
}
} else {
z_ = z;
}
struct ggml_tensor* out = decode_graph ? decoder.forward(ctx0, z_) : encoder.forward(ctx0, z_);
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
z = to_backend(z);
struct ggml_tensor* out = decode_graph ? taesd.decode(compute_ctx, z) : taesd.encode(compute_ctx, z);
ggml_build_forward_expand(gf, out);
ggml_free(ctx0);
return gf;
}
void alloc_compute_buffer(struct ggml_tensor* x, bool decode) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, decode);
};
GGMLModule::alloc_compute_buffer(get_graph);
}
void compute(struct ggml_tensor* work_result, int n_threads, struct ggml_tensor* z, bool decode_graph) {
void compute(const int n_threads,
struct ggml_tensor* z,
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};
GGMLModule::compute(get_graph, n_threads, work_result);
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
}
};

968
unet.hpp

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,7 @@
struct UpscalerGGML {
ggml_backend_t backend = NULL; // general backend
ggml_type model_data_type = GGML_TYPE_F16;
ESRGAN esrgan_upscaler;
std::shared_ptr<ESRGAN> esrgan_upscaler;
std::string esrgan_path;
int n_threads;
@ -30,7 +30,8 @@ struct UpscalerGGML {
backend = ggml_backend_cpu_init();
}
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
if (!esrgan_upscaler.load_from_file(esrgan_path, backend)) {
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_data_type);
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
return false;
}
return true;
@ -39,8 +40,8 @@ struct UpscalerGGML {
sd_image_t upscale(sd_image_t input_image, uint32_t upscale_factor) {
// upscale_factor, unused for RealESRGAN_x4plus_anime_6B.pth
sd_image_t upscaled_image = {0, 0, 0, NULL};
int output_width = (int)input_image.width * esrgan_upscaler.scale;
int output_height = (int)input_image.height * esrgan_upscaler.scale;
int output_width = (int)input_image.width * esrgan_upscaler->scale;
int output_height = (int)input_image.height * esrgan_upscaler->scale;
LOG_INFO("upscaling from (%i x %i) to (%i x %i)",
input_image.width, input_image.height, output_width, output_height);
@ -62,15 +63,11 @@ struct UpscalerGGML {
ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, 3, 1);
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
if (init) {
esrgan_upscaler.alloc_compute_buffer(in);
} else {
esrgan_upscaler.compute(out, n_threads, in);
}
esrgan_upscaler->compute(n_threads, in, &out);
};
int64_t t0 = ggml_time_ms();
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler.scale, esrgan_upscaler.tile_size, 0.25f, on_tiling);
esrgan_upscaler.free_compute_buffer();
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, on_tiling);
esrgan_upscaler->free_compute_buffer();
ggml_tensor_clamp(upscaled, 0.f, 1.f);
uint8_t* upscaled_data = sd_tensor_to_image(upscaled);
ggml_free(upscale_ctx);

153
util.cpp
View File

@ -1,6 +1,7 @@
#include "util.h"
#include <stdarg.h>
#include <algorithm>
#include <cmath>
#include <codecvt>
#include <fstream>
#include <locale>
@ -203,6 +204,9 @@ std::string path_join(const std::string& p1, const std::string& p2) {
}
void pretty_progress(int step, int steps, float time) {
if (step == 0) {
return;
}
std::string progress = " |";
int max_progress = 50;
int32_t current = (int32_t)(step * 1.f * max_progress / steps);
@ -307,3 +311,152 @@ const char* sd_get_system_info() {
const char* sd_type_name(enum sd_type_t type) {
return ggml_type_name((ggml_type)type);
}
sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image) {
sd_image_f32_t converted_image;
converted_image.width = image.width;
converted_image.height = image.height;
converted_image.channel = image.channel;
// Allocate memory for float data
converted_image.data = (float*)malloc(image.width * image.height * image.channel * sizeof(float));
for (int i = 0; i < image.width * image.height * image.channel; i++) {
// Convert uint8_t to float
converted_image.data[i] = (float)image.data[i];
}
return converted_image;
}
// Function to perform double linear interpolation
float interpolate(float v1, float v2, float v3, float v4, float x_ratio, float y_ratio) {
return v1 * (1 - x_ratio) * (1 - y_ratio) + v2 * x_ratio * (1 - y_ratio) + v3 * (1 - x_ratio) * y_ratio + v4 * x_ratio * y_ratio;
}
sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height) {
sd_image_f32_t resized_image;
resized_image.width = target_width;
resized_image.height = target_height;
resized_image.channel = image.channel;
// Allocate memory for resized float data
resized_image.data = (float*)malloc(target_width * target_height * image.channel * sizeof(float));
for (int y = 0; y < target_height; y++) {
for (int x = 0; x < target_width; x++) {
float original_x = (float)x * image.width / target_width;
float original_y = (float)y * image.height / target_height;
int x1 = (int)original_x;
int y1 = (int)original_y;
int x2 = x1 + 1;
int y2 = y1 + 1;
for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
float v2 = *(image.data + y1 * image.width * image.channel + x2 * image.channel + k);
float v3 = *(image.data + y2 * image.width * image.channel + x1 * image.channel + k);
float v4 = *(image.data + y2 * image.width * image.channel + x2 * image.channel + k);
float x_ratio = original_x - x1;
float y_ratio = original_y - y1;
float value = interpolate(v1, v2, v3, v4, x_ratio, y_ratio);
*(resized_image.data + y * target_width * image.channel + x * image.channel + k) = value;
}
}
}
return resized_image;
}
void normalize_sd_image_f32_t(sd_image_f32_t image, float means[3], float stds[3]) {
for (int y = 0; y < image.height; y++) {
for (int x = 0; x < image.width; x++) {
for (int k = 0; k < image.channel; k++) {
int index = (y * image.width + x) * image.channel + k;
image.data[index] = (image.data[index] - means[k]) / stds[k];
}
}
}
}
// Constants for means and std
float means[3] = {0.48145466, 0.4578275, 0.40821073};
float stds[3] = {0.26862954, 0.26130258, 0.27577711};
// Function to clip and preprocess sd_image_f32_t
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
float scale = (float)size / fmin(image.width, image.height);
// Interpolation
int new_width = (int)(scale * image.width);
int new_height = (int)(scale * image.height);
float* resized_data = (float*)malloc(new_width * new_height * image.channel * sizeof(float));
for (int y = 0; y < new_height; y++) {
for (int x = 0; x < new_width; x++) {
float original_x = (float)x * image.width / new_width;
float original_y = (float)y * image.height / new_height;
int x1 = (int)original_x;
int y1 = (int)original_y;
int x2 = x1 + 1;
int y2 = y1 + 1;
for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
float v2 = *(image.data + y1 * image.width * image.channel + x2 * image.channel + k);
float v3 = *(image.data + y2 * image.width * image.channel + x1 * image.channel + k);
float v4 = *(image.data + y2 * image.width * image.channel + x2 * image.channel + k);
float x_ratio = original_x - x1;
float y_ratio = original_y - y1;
float value = interpolate(v1, v2, v3, v4, x_ratio, y_ratio);
*(resized_data + y * new_width * image.channel + x * image.channel + k) = value;
}
}
}
// Clip and preprocess
int h = (new_height - size) / 2;
int w = (new_width - size) / 2;
sd_image_f32_t result;
result.width = size;
result.height = size;
result.channel = image.channel;
result.data = (float*)malloc(size * size * image.channel * sizeof(float));
for (int k = 0; k < image.channel; k++) {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
*(result.data + i * size * image.channel + j * image.channel + k) =
fmin(fmax(*(resized_data + (i + h) * new_width * image.channel + (j + w) * image.channel + k), 0.0f), 255.0f) / 255.0f;
}
}
}
// Free allocated memory
free(resized_data);
// Normalize
for (int k = 0; k < image.channel; k++) {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
// *(result.data + i * size * image.channel + j * image.channel + k) = 0.5f;
int offset = i * size * image.channel + j * image.channel + k;
float value = *(result.data + offset);
value = (value - means[k]) / stds[k];
// value = 0.5f;
*(result.data + offset) = value;
}
}
}
return result;
}

15
util.h
View File

@ -23,6 +23,21 @@ std::u32string unicode_value_to_utf32(int unicode_value);
std::string sd_basename(const std::string& path);
typedef struct {
uint32_t width;
uint32_t height;
uint32_t channel;
float* data;
} sd_image_f32_t;
void normalize_sd_image_f32_t(sd_image_f32_t image, float means[3], float stds[3]);
sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image);
sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int target_height);
sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size);
std::string path_join(const std::string& p1, const std::string& p2);
void pretty_progress(int step, int steps, float time);

1081
vae.hpp

File diff suppressed because it is too large Load Diff