stable-diffusion.cpp/vae.hpp

740 lines
28 KiB
C++

#ifndef __VAE_HPP__
#define __VAE_HPP__
#include "common.hpp"
#include "ggml_extend.hpp"
/*================================================== AutoEncoderKL ===================================================*/
#define VAE_GRAPH_SIZE 10240
struct ResnetBlock {
// network hparams
int in_channels;
int out_channels;
// network params
struct ggml_tensor* norm1_w; // [in_channels, ]
struct ggml_tensor* norm1_b; // [in_channels, ]
struct ggml_tensor* conv1_w; // [out_channels, in_channels, 3, 3]
struct ggml_tensor* conv1_b; // [out_channels, ]
struct ggml_tensor* norm2_w; // [out_channels, ]
struct ggml_tensor* norm2_b; // [out_channels, ]
struct ggml_tensor* conv2_w; // [out_channels, out_channels, 3, 3]
struct ggml_tensor* conv2_b; // [out_channels, ]
// nin_shortcut, only if out_channels != in_channels
struct ggml_tensor* nin_shortcut_w; // [out_channels, in_channels, 1, 1]
struct ggml_tensor* nin_shortcut_b; // [out_channels, ]
size_t calculate_mem_size(ggml_type wtype) {
double mem_size = 0;
mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm1_w/b
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * in_channels * 3 * 3); // conv1_w
mem_size += 4 * ggml_row_size(GGML_TYPE_F32, out_channels); // conv1_b/norm2_w/norm2_b/conv2_b
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * out_channels * 3 * 3); // conv2_w
if (out_channels != in_channels) {
mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * in_channels * 1 * 1); // nin_shortcut_w
mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // nin_shortcut_b
}
return static_cast<size_t>(mem_size);
}
void init_params(struct ggml_context* ctx, ggml_type wtype) {
norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
conv1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, out_channels);
conv1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
conv2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels);
conv2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
if (out_channels != in_channels) {
nin_shortcut_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, out_channels);
nin_shortcut_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels);
}
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "norm1.weight"] = norm1_w;
tensors[prefix + "norm1.bias"] = norm1_b;
tensors[prefix + "conv1.weight"] = conv1_w;
tensors[prefix + "conv1.bias"] = conv1_b;
tensors[prefix + "norm2.weight"] = norm2_w;
tensors[prefix + "norm2.bias"] = norm2_b;
tensors[prefix + "conv2.weight"] = conv2_w;
tensors[prefix + "conv2.bias"] = conv2_b;
if (out_channels != in_channels) {
tensors[prefix + "nin_shortcut.weight"] = nin_shortcut_w;
tensors[prefix + "nin_shortcut.bias"] = nin_shortcut_b;
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
// z: [N, in_channels, h, w]
auto h = ggml_nn_group_norm(ctx, z, norm1_w, norm1_b);
h = ggml_silu_inplace(ctx, h);
h = ggml_nn_conv_2d(ctx, h, conv1_w, conv1_b, 1, 1, 1, 1); // [N, out_channels, h, w]
h = ggml_nn_group_norm(ctx, h, norm2_w, norm2_b);
h = ggml_silu_inplace(ctx, h);
// dropout, skip for inference
h = ggml_nn_conv_2d(ctx, h, conv2_w, conv2_b, 1, 1, 1, 1); // [N, out_channels, h, w]
// skip connection
if (out_channels != in_channels) {
z = ggml_nn_conv_2d(ctx, z, nin_shortcut_w, nin_shortcut_b); // [N, out_channels, h, w]
}
h = ggml_add(ctx, h, z);
return h; // [N, out_channels, h, w]
}
};
struct AttnBlock {
int in_channels; // mult * model_channels
// group norm
struct ggml_tensor* norm_w; // [in_channels,]
struct ggml_tensor* norm_b; // [in_channels,]
// q/k/v
struct ggml_tensor* q_w; // [in_channels, in_channels, 1, 1]
struct ggml_tensor* q_b; // [in_channels,]
struct ggml_tensor* k_w; // [in_channels, in_channels, 1, 1]
struct ggml_tensor* k_b; // [in_channels,]
struct ggml_tensor* v_w; // [in_channels, in_channels, 1, 1]
struct ggml_tensor* v_b; // [in_channels,]
// proj_out
struct ggml_tensor* proj_out_w; // [in_channels, in_channels, 1, 1]
struct ggml_tensor* proj_out_b; // [in_channels,]
size_t calculate_mem_size(ggml_type wtype) {
double mem_size = 0;
mem_size += 6 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm_w/norm_b/q_b/k_v/v_b/proj_out_b
mem_size += 4 * ggml_row_size(GGML_TYPE_F16, in_channels * in_channels * 1 * 1); // q_w/k_w/v_w/proj_out_w // object overhead
return static_cast<size_t>(mem_size);
}
void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) {
norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
q_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels);
q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
k_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels);
k_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
v_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels);
v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
proj_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels);
proj_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels);
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "norm.weight"] = norm_w;
tensors[prefix + "norm.bias"] = norm_b;
tensors[prefix + "q.weight"] = q_w;
tensors[prefix + "q.bias"] = q_b;
tensors[prefix + "k.weight"] = k_w;
tensors[prefix + "k.bias"] = k_b;
tensors[prefix + "v.weight"] = v_w;
tensors[prefix + "v.bias"] = v_b;
tensors[prefix + "proj_out.weight"] = proj_out_w;
tensors[prefix + "proj_out.bias"] = proj_out_b;
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto h_ = ggml_nn_group_norm(ctx, x, norm_w, norm_b);
const int64_t n = h_->ne[3];
const int64_t c = h_->ne[2];
const int64_t h = h_->ne[1];
const int64_t w = h_->ne[0];
auto q = ggml_nn_conv_2d(ctx, h_, q_w, q_b); // [N, in_channels, h, w]
auto k = ggml_nn_conv_2d(ctx, h_, k_w, k_b); // [N, in_channels, h, w]
auto v = ggml_nn_conv_2d(ctx, h_, v_w, v_b); // [N, in_channels, h, w]
q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
q = ggml_reshape_3d(ctx, q, c, h * w, n); // [N, h * w, in_channels]
k = ggml_cont(ctx, ggml_permute(ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
k = ggml_reshape_3d(ctx, k, c, h * w, n); // [N, h * w, in_channels]
auto w_ = ggml_mul_mat(ctx, k, q); // [N, h * w, h * w]
w_ = ggml_scale_inplace(ctx, w_, 1.0f / sqrt((float)in_channels));
w_ = ggml_soft_max_inplace(ctx, w_);
v = ggml_reshape_3d(ctx, v, h * w, c, n); // [N, in_channels, h * w]
h_ = ggml_mul_mat(ctx, v, w_); // [N, h * w, in_channels]
h_ = ggml_cont(ctx, ggml_permute(ctx, h_, 1, 0, 2, 3)); // [N, in_channels, h * w]
h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w]
// proj_out
h_ = ggml_nn_conv_2d(ctx, h_, proj_out_w, proj_out_b); // [N, in_channels, h, w]
h_ = ggml_add(ctx, h_, x);
return h_;
}
};
// ldm.modules.diffusionmodules.model.Encoder
struct Encoder {
int embed_dim = 4;
int ch = 128;
int z_channels = 4;
int in_channels = 3;
int num_res_blocks = 2;
int ch_mult[4] = {1, 2, 4, 4};
struct ggml_tensor* conv_in_w; // [ch, in_channels, 3, 3]
struct ggml_tensor* conv_in_b; // [ch, ]
ResnetBlock down_blocks[4][2];
DownSample down_samples[3];
struct
{
ResnetBlock block_1;
AttnBlock attn_1;
ResnetBlock block_2;
} mid;
// block_in = ch * ch_mult[len_mults - 1]
struct ggml_tensor* norm_out_w; // [block_in, ]
struct ggml_tensor* norm_out_b; // [block_in, ]
struct ggml_tensor* conv_out_w; // [embed_dim*2, block_in, 3, 3]
struct ggml_tensor* conv_out_b; // [embed_dim*2, ]
Encoder() {
int len_mults = sizeof(ch_mult) / sizeof(int);
int block_in = 1;
for (int i = 0; i < len_mults; i++) {
if (i == 0) {
block_in = ch;
} else {
block_in = ch * ch_mult[i - 1];
}
int block_out = ch * ch_mult[i];
for (int j = 0; j < num_res_blocks; j++) {
down_blocks[i][j].in_channels = block_in;
down_blocks[i][j].out_channels = block_out;
block_in = block_out;
}
if (i != len_mults - 1) {
down_samples[i].channels = block_in;
down_samples[i].out_channels = block_in;
down_samples[i].vae_downsample = true;
}
}
mid.block_1.in_channels = block_in;
mid.block_1.out_channels = block_in;
mid.attn_1.in_channels = block_in;
mid.block_2.in_channels = block_in;
mid.block_2.out_channels = block_in;
}
size_t get_num_tensors() {
int num_tensors = 6;
// mid
num_tensors += 10 * 3;
int len_mults = sizeof(ch_mult) / sizeof(int);
for (int i = len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
num_tensors += 10;
}
if (i != 0) {
num_tensors += 2;
}
}
return num_tensors;
}
size_t calculate_mem_size(ggml_type wtype) {
size_t mem_size = 0;
int len_mults = sizeof(ch_mult) / sizeof(int);
int block_in = ch * ch_mult[len_mults - 1];
mem_size += ggml_row_size(GGML_TYPE_F16, ch * in_channels * 3 * 3); // conv_in_w
mem_size += ggml_row_size(GGML_TYPE_F32, ch); // conv_in_b
mem_size += 2 * ggml_row_size(GGML_TYPE_F32, block_in); // norm_out_w/b
mem_size += ggml_row_size(GGML_TYPE_F16, z_channels * 2 * block_in * 3 * 3); // conv_out_w
mem_size += ggml_row_size(GGML_TYPE_F32, z_channels * 2); // conv_out_b
mem_size += mid.block_1.calculate_mem_size(wtype);
mem_size += mid.attn_1.calculate_mem_size(wtype);
mem_size += mid.block_2.calculate_mem_size(wtype);
for (int i = len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
mem_size += down_blocks[i][j].calculate_mem_size(wtype);
}
if (i != 0) {
mem_size += down_samples[i - 1].calculate_mem_size(wtype);
}
}
return mem_size;
}
void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) {
int len_mults = sizeof(ch_mult) / sizeof(int);
int block_in = ch * ch_mult[len_mults - 1];
conv_in_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, ch);
conv_in_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ch);
norm_out_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, block_in);
norm_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, block_in);
conv_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, block_in, z_channels * 2);
conv_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_channels * 2);
mid.block_1.init_params(ctx, wtype);
mid.attn_1.init_params(ctx, alloc, wtype);
mid.block_2.init_params(ctx, wtype);
for (int i = 0; i < len_mults; i++) {
for (int j = 0; j < num_res_blocks; j++) {
down_blocks[i][j].init_params(ctx, wtype);
}
if (i != len_mults - 1) {
down_samples[i].init_params(ctx, wtype);
}
}
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "norm_out.weight"] = norm_out_w;
tensors[prefix + "norm_out.bias"] = norm_out_b;
tensors[prefix + "conv_in.weight"] = conv_in_w;
tensors[prefix + "conv_in.bias"] = conv_in_b;
tensors[prefix + "conv_out.weight"] = conv_out_w;
tensors[prefix + "conv_out.bias"] = conv_out_b;
mid.block_1.map_by_name(tensors, prefix + "mid.block_1.");
mid.attn_1.map_by_name(tensors, prefix + "mid.attn_1.");
mid.block_2.map_by_name(tensors, prefix + "mid.block_2.");
int len_mults = sizeof(ch_mult) / sizeof(int);
for (int i = 0; i < len_mults; i++) {
for (int j = 0; j < num_res_blocks; j++) {
down_blocks[i][j].map_by_name(tensors, prefix + "down." + std::to_string(i) + ".block." + std::to_string(j) + ".");
}
if (i != len_mults - 1) {
down_samples[i].map_by_name(tensors, prefix + "down." + std::to_string(i) + ".downsample.");
}
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
// conv_in
auto h = ggml_nn_conv_2d(ctx, x, conv_in_w, conv_in_b, 1, 1, 1, 1); // [N, ch, h, w]
ggml_set_name(h, "b-start");
int len_mults = sizeof(ch_mult) / sizeof(int);
for (int i = 0; i < len_mults; i++) {
for (int j = 0; j < num_res_blocks; j++) {
h = down_blocks[i][j].forward(ctx, h);
}
if (i != len_mults - 1) {
h = down_samples[i].forward(ctx, h);
}
}
h = mid.block_1.forward(ctx, h);
h = mid.attn_1.forward(ctx, h);
h = mid.block_2.forward(ctx, h); // [N, block_in, h, w]
h = ggml_nn_group_norm(ctx, h, norm_out_w, norm_out_b);
h = ggml_silu_inplace(ctx, h);
// conv_out
h = ggml_nn_conv_2d(ctx, h, conv_out_w, conv_out_b, 1, 1, 1, 1); // [N, z_channels*2, h, w]
return h;
}
};
// ldm.modules.diffusionmodules.model.Decoder
struct Decoder {
int embed_dim = 4;
int ch = 128;
int z_channels = 4;
int out_ch = 3;
int num_res_blocks = 2;
int ch_mult[4] = {1, 2, 4, 4};
// block_in = ch * ch_mult[-1], 512
struct ggml_tensor* conv_in_w; // [block_in, z_channels, 3, 3]
struct ggml_tensor* conv_in_b; // [block_in, ]
struct
{
ResnetBlock block_1;
AttnBlock attn_1;
ResnetBlock block_2;
} mid;
ResnetBlock up_blocks[4][3];
UpSample up_samples[3];
struct ggml_tensor* norm_out_w; // [ch * ch_mult[0], ]
struct ggml_tensor* norm_out_b; // [ch * ch_mult[0], ]
struct ggml_tensor* conv_out_w; // [out_ch, ch * ch_mult[0], 3, 3]
struct ggml_tensor* conv_out_b; // [out_ch, ]
Decoder() {
int len_mults = sizeof(ch_mult) / sizeof(int);
int block_in = ch * ch_mult[len_mults - 1];
mid.block_1.in_channels = block_in;
mid.block_1.out_channels = block_in;
mid.attn_1.in_channels = block_in;
mid.block_2.in_channels = block_in;
mid.block_2.out_channels = block_in;
for (int i = len_mults - 1; i >= 0; i--) {
int mult = ch_mult[i];
int block_out = ch * mult;
for (int j = 0; j < num_res_blocks + 1; j++) {
up_blocks[i][j].in_channels = block_in;
up_blocks[i][j].out_channels = block_out;
block_in = block_out;
}
if (i != 0) {
up_samples[i - 1].channels = block_in;
up_samples[i - 1].out_channels = block_in;
}
}
}
size_t calculate_mem_size(ggml_type wtype) {
double mem_size = 0;
int len_mults = sizeof(ch_mult) / sizeof(int);
int block_in = ch * ch_mult[len_mults - 1];
mem_size += ggml_row_size(GGML_TYPE_F16, block_in * z_channels * 3 * 3); // conv_in_w
mem_size += ggml_row_size(GGML_TYPE_F32, block_in); // conv_in_b
mem_size += 2 * ggml_row_size(GGML_TYPE_F32, (ch * ch_mult[0])); // norm_out_w/b
mem_size += ggml_row_size(GGML_TYPE_F16, (ch * ch_mult[0]) * out_ch * 3 * 3); // conv_out_w
mem_size += ggml_row_size(GGML_TYPE_F32, out_ch); // conv_out_b
mem_size += mid.block_1.calculate_mem_size(wtype);
mem_size += mid.attn_1.calculate_mem_size(wtype);
mem_size += mid.block_2.calculate_mem_size(wtype);
for (int i = len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
mem_size += up_blocks[i][j].calculate_mem_size(wtype);
}
if (i != 0) {
mem_size += up_samples[i - 1].calculate_mem_size(wtype);
}
}
return static_cast<size_t>(mem_size);
}
size_t get_num_tensors() {
int num_tensors = 8;
// mid
num_tensors += 10 * 3;
int len_mults = sizeof(ch_mult) / sizeof(int);
for (int i = len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
num_tensors += 10;
}
if (i != 0) {
num_tensors += 2;
}
}
return num_tensors;
}
void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) {
int len_mults = sizeof(ch_mult) / sizeof(int);
int block_in = ch * ch_mult[len_mults - 1];
norm_out_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ch * ch_mult[0]);
norm_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ch * ch_mult[0]);
conv_in_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, z_channels, block_in);
conv_in_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, block_in);
conv_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, ch * ch_mult[0], out_ch);
conv_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_ch);
mid.block_1.init_params(ctx, wtype);
mid.attn_1.init_params(ctx, alloc, wtype);
mid.block_2.init_params(ctx, wtype);
for (int i = len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
up_blocks[i][j].init_params(ctx, wtype);
}
if (i != 0) {
up_samples[i - 1].init_params(ctx, wtype);
}
}
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "norm_out.weight"] = norm_out_w;
tensors[prefix + "norm_out.bias"] = norm_out_b;
tensors[prefix + "conv_in.weight"] = conv_in_w;
tensors[prefix + "conv_in.bias"] = conv_in_b;
tensors[prefix + "conv_out.weight"] = conv_out_w;
tensors[prefix + "conv_out.bias"] = conv_out_b;
mid.block_1.map_by_name(tensors, prefix + "mid.block_1.");
mid.attn_1.map_by_name(tensors, prefix + "mid.attn_1.");
mid.block_2.map_by_name(tensors, prefix + "mid.block_2.");
int len_mults = sizeof(ch_mult) / sizeof(int);
for (int i = len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
up_blocks[i][j].map_by_name(tensors, prefix + "up." + std::to_string(i) + ".block." + std::to_string(j) + ".");
}
if (i != 0) {
up_samples[i - 1].map_by_name(tensors, prefix + "up." + std::to_string(i) + ".upsample.");
}
}
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
// conv_in
auto h = ggml_nn_conv_2d(ctx, z, conv_in_w, conv_in_b, 1, 1, 1, 1); // [N, block_in, h, w]
h = mid.block_1.forward(ctx, h);
h = mid.attn_1.forward(ctx, h);
h = mid.block_2.forward(ctx, h); // [N, block_in, h, w]
int len_mults = sizeof(ch_mult) / sizeof(int);
for (int i = len_mults - 1; i >= 0; i--) {
for (int j = 0; j < num_res_blocks + 1; j++) {
h = up_blocks[i][j].forward(ctx, h);
}
if (i != 0) {
h = up_samples[i - 1].forward(ctx, h);
}
}
// group norm 32
h = ggml_nn_group_norm(ctx, h, norm_out_w, norm_out_b);
h = ggml_silu_inplace(ctx, h);
// conv_out
h = ggml_nn_conv_2d(ctx, h, conv_out_w, conv_out_b, 1, 1, 1, 1); // [N, out_ch, h, w]
return h;
}
};
// ldm.models.autoencoder.AutoencoderKL
struct AutoEncoderKL : public GGMLModule {
bool decode_only = true;
int embed_dim = 4;
struct {
int z_channels = 4;
int resolution = 256;
int in_channels = 3;
int out_ch = 3;
int ch = 128;
int ch_mult[4] = {1, 2, 4, 4};
int num_res_blocks = 2;
} dd_config;
struct ggml_tensor* quant_conv_w; // [2*embed_dim, 2*z_channels, 1, 1]
struct ggml_tensor* quant_conv_b; // [2*embed_dim, ]
struct ggml_tensor* post_quant_conv_w; // [z_channels, embed_dim, 1, 1]
struct ggml_tensor* post_quant_conv_b; // [z_channels, ]
Encoder encoder;
Decoder decoder;
AutoEncoderKL(bool decode_only = false)
: decode_only(decode_only) {
name = "vae";
assert(sizeof(dd_config.ch_mult) == sizeof(encoder.ch_mult));
assert(sizeof(dd_config.ch_mult) == sizeof(decoder.ch_mult));
encoder.embed_dim = embed_dim;
decoder.embed_dim = embed_dim;
encoder.ch = dd_config.ch;
decoder.ch = dd_config.ch;
encoder.z_channels = dd_config.z_channels;
decoder.z_channels = dd_config.z_channels;
encoder.in_channels = dd_config.in_channels;
decoder.out_ch = dd_config.out_ch;
encoder.num_res_blocks = dd_config.num_res_blocks;
int len_mults = sizeof(dd_config.ch_mult) / sizeof(int);
for (int i = 0; i < len_mults; i++) {
encoder.ch_mult[i] = dd_config.ch_mult[i];
decoder.ch_mult[i] = dd_config.ch_mult[i];
}
}
size_t calculate_mem_size() {
size_t mem_size = 0;
if (!decode_only) {
mem_size += ggml_row_size(GGML_TYPE_F16, 2 * embed_dim * 2 * dd_config.z_channels * 1 * 1); // quant_conv_w
mem_size += ggml_row_size(GGML_TYPE_F32, 2 * embed_dim); // quant_conv_b
mem_size += encoder.calculate_mem_size(wtype);
}
mem_size += ggml_row_size(GGML_TYPE_F16, dd_config.z_channels * embed_dim * 1 * 1); // post_quant_conv_w
mem_size += ggml_row_size(GGML_TYPE_F32, dd_config.z_channels); // post_quant_conv_b
mem_size += decoder.calculate_mem_size(wtype);
return mem_size;
}
size_t get_num_tensors() {
size_t num_tensors = decoder.get_num_tensors();
if (!decode_only) {
num_tensors += 2;
num_tensors += encoder.get_num_tensors();
}
return num_tensors;
}
void init_params() {
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
if (!decode_only) {
quant_conv_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 1, 1, 2 * dd_config.z_channels, 2 * embed_dim);
quant_conv_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, 2 * embed_dim);
encoder.init_params(params_ctx, alloc, wtype);
}
post_quant_conv_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 1, 1, embed_dim, dd_config.z_channels);
post_quant_conv_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, dd_config.z_channels);
decoder.init_params(params_ctx, alloc, wtype);
// alloc all tensors linked to this context
for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
if (t->data == NULL) {
ggml_allocr_alloc(alloc, t);
}
}
ggml_allocr_free(alloc);
}
void map_by_name(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
tensors[prefix + "quant_conv.weight"] = quant_conv_w;
tensors[prefix + "quant_conv.bias"] = quant_conv_b;
encoder.map_by_name(tensors, prefix + "encoder.");
tensors[prefix + "post_quant_conv.weight"] = post_quant_conv_w;
tensors[prefix + "post_quant_conv.bias"] = post_quant_conv_b;
decoder.map_by_name(tensors, prefix + "decoder.");
}
struct ggml_tensor* decode(struct ggml_context* ctx0, struct ggml_tensor* z) {
// z: [N, z_channels, h, w]
// post_quant_conv
auto h = ggml_nn_conv_2d(ctx0, z, post_quant_conv_w, post_quant_conv_b); // [N, z_channels, h, w]
ggml_set_name(h, "bench-start");
h = decoder.forward(ctx0, h);
ggml_set_name(h, "bench-end");
return h;
}
struct ggml_tensor* encode(struct ggml_context* ctx0, struct ggml_tensor* x) {
// x: [N, in_channels, h, w]
auto h = encoder.forward(ctx0, x); // [N, 2*z_channels, h/8, w/8]
// quant_conv
h = ggml_nn_conv_2d(ctx0, h, quant_conv_w, quant_conv_b); // [N, 2*embed_dim, h/8, w/8]
ggml_set_name(h, "b-end");
return h;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) {
// since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data
static size_t buf_size = ggml_tensor_overhead() * VAE_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params = {
/*.mem_size =*/buf_size,
/*.mem_buffer =*/buf.data(),
/*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph()
};
// LOG_DEBUG("mem_size %u ", params.mem_size);
struct ggml_context* ctx0 = ggml_init(params);
struct ggml_cgraph* gf = ggml_new_graph(ctx0);
struct ggml_tensor* z_ = NULL;
// it's performing a compute, check if backend isn't cpu
if (!ggml_backend_is_cpu(backend)) {
// pass input tensors to gpu memory
z_ = ggml_dup_tensor(ctx0, z);
ggml_allocr_alloc(compute_allocr, z_);
// pass data to device backend
if (!ggml_allocr_is_measure(compute_allocr)) {
ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z));
}
} else {
z_ = z;
}
struct ggml_tensor* out = decode_graph ? decode(ctx0, z_) : encode(ctx0, z_);
ggml_build_forward_expand(gf, out);
ggml_free(ctx0);
return gf;
}
void alloc_compute_buffer(struct ggml_tensor* x, bool decode) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, decode);
};
GGMLModule::alloc_compute_buffer(get_graph);
}
void compute(struct ggml_tensor* work_result, const int n_threads, struct ggml_tensor* z, bool decode_graph) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};
GGMLModule::compute(get_graph, n_threads, work_result);
}
};
#endif