feat: add TencentARC PhotoMaker support (#179)

* first efforts at implementing photomaker; lots more to do

* added PhotoMakerIDEncoder model in SD

* fixed soem bugs; now photomaker model weights can be loaded into their tensor buffers

* added input id image loading

* added preprocessing inpit id images

* finished get_num_tensors

* fixed a bug in remove_duplicates

* add a get_learned_condition_with_trigger function to do photomaker stuff

* add a convert_token_to_id function for photomaker to extract trigger word's token id

* making progress; need to implement tokenizer decoder

* making more progress; finishing vision model forward

* debugging vision_model outputs

* corrected clip vision model output

* continue making progress in id fusion process

* finished stacked id embedding; to be tested

* remove garbage file

* debuging graph compute

* more progress; now alloc buffer failed

* fixed wtype issue; input images can only be 1 because issue with transformer when batch size > 1 (to be investigated)

* added delayed subject conditioning; now photomaker runs and generates images

* fixed stat_merge_step

* added photomaker lora model (to be tested)

* reworked pmid lora

* finished applying pmid lora; to be tested

* finalized pmid lora

* add a few print tensor; tweak in sample again

* small tweak; still not getting ID faces

* fixed a bug in FuseBlock forward; also remove diag_mask op in for vision transformer; getting better results

* disable pmid lora apply for now; 1 input image seems working; > 1 not working

* turn pmid lora apply back on

* fixed a decode bug

* fixed a bug in ggml's conv_2d, and now > 1 input images working

* add style_ratio as a cli param; reworked encode with trigger for attention weights

* merge commit fixing lora free param buffer error

* change default style ratio to 10%

* added an option to offload vae decoder to CPU for mem-limited gpus

* removing image normalization step seems making ID fidelity much higher

* revert default style ratio back ro 20%

* added an option for normalizing input ID images; cleaned up debugging code

* more clean up

* fixed bugs; now failed with cuda error; likely out-of-mem on GPU

* free pmid model params when required

* photomaker working properly now after merging and adapting to GGMLBlock API

* remove tensor renaming;  fixing names in the photomaker model file

* updated README.md to include instructions and notes for running PhotoMaker

* a bit clean up

* remove -DGGML_CUDA_FORCE_MMQ; more clean up and README update

* add input image requirement in README

* bring back freeing pmid lora params buffer; simply pooled output of CLIPvision

* remove MultiheadAttention2; customized MultiheadAttention

* added a WIN32 get_files_from_dir; turn off Photomakder if receiving no input images

* update docs

* fix ci error

* make stable-diffusion.h a pure c header file

This reverts commit 27887b630d.

* fix ci error

* format code

* reuse get_learned_condition

* reuse pad_tokens

* reuse CLIPVisionModel

* reuse LoraModel

* add --clip-on-cpu

* fix lora name conversion for SDXL

---------

Co-authored-by: bssrdf <bssrdf@gmail.com>
Co-authored-by: leejet <leejet714@gmail.com>
pull/204/head master-a469688
bssrdf 2024-03-12 11:15:17 -04:00 committed by GitHub
parent 61980171a1
commit a469688e30
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 3935 additions and 186 deletions

View File

@ -14,6 +14,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors).
- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
- 16-bit, 32-bit float support
- 4-bit, 5-bit and 8-bit integer quantization support
- Accelerated memory-efficient CPU inference
@ -151,7 +152,7 @@ cmake --build . --config Release
### Run
```
usage: ./build/bin/sd [arguments]
usage: ./bin/sd [arguments]
arguments:
-h, --help show this help message and exit
@ -163,6 +164,9 @@ arguments:
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model
--embd-dir [EMBEDDING_PATH] path to embeddings.
--stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.
--input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.
--normalize-input normalize PHOTOMAKER input id images
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
--type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)
@ -175,6 +179,7 @@ arguments:
-n, --negative-prompt PROMPT the negative prompt (default: "")
--cfg-scale SCALE unconditional guidance scale: (default: 7.0)
--strength STRENGTH strength for noising/unnoising (default: 0.75)
--style-ratio STYLE-RATIO strength for keeping input identity (default: 20%)
--control-strength STRENGTH strength to apply Control Net (default: 0.9)
1.0 corresponds to full destruction of information in init image
-H, --height H image height, in pixel space (default: 512)
@ -299,6 +304,39 @@ You can use ESRGAN to upscale the generated images. At the moment, only the [Rea
sd -m ../models/v1-5-pruned-emaonly.safetensors -p "a lovely cat" --upscale-model ../models/RealESRGAN_x4plus_anime_6B.pth
```
#### Using PhotoMaker to personalize image generation
You can use [PhotoMaker](https://github.com/TencentARC/PhotoMaker) to personalize generated images with your own ID.
**NOTE**, currently PhotoMaker **ONLY** works with **SDXL** (any SDXL model files will work).
Download PhotoMaker model file (in safetensor format) [here](https://huggingface.co/bssrdf/PhotoMaker). The official release of the model file (in .bin format) does not work with ```stablediffusion.cpp```.
- Specify the PhotoMaker model path using the `--stacked-id-embd-dir PATH` parameter.
- Specify the input images path using the `--input-id-images-dir PATH` parameter.
- input images **must** have the same width and height for preprocessing (to be improved)
In prompt, make sure you have a class word followed by the trigger word ```"img"``` (hard-coded for now). The class word could be one of ```"man, woman, girl, boy"```. If input ID images contain asian faces, add ```Asian``` before the class
word.
Another PhotoMaker specific parameter:
- ```--style-ratio (0-100)%```: default is 20 and 10-20 typically gets good results. Lower ratio means more faithfully following input ID (not necessarily better quality).
Other parameters recommended for running Photomaker:
- ```--cfg-scale 5.0```
- ```-H 1024```
- ```-W 1024```
If on low memory GPUs (<= 8GB), recommend running with ```--vae-on-cpu``` option to get artifact free images.
Example:
```bash
bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png
```
### Docker
#### Building using Docker
@ -345,3 +383,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)
- [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model)
- [generative-models](https://github.com/Stability-AI/generative-models/)
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker)

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 311 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 53 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 88 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

398
clip.hpp
View File

@ -75,9 +75,13 @@ class CLIPTokenizer {
private:
SDVersion version = VERSION_1_x;
std::map<int, std::u32string> byte_encoder;
std::map<std::u32string, int> byte_decoder;
std::map<std::u32string, int> encoder;
std::map<int, std::u32string> decoder;
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
std::regex pat;
int encoder_len;
int bpe_len;
static std::string strip(const std::string& str) {
std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f");
@ -118,7 +122,11 @@ public:
void load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
// printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size());
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
for (auto& pair : byte_unicode_pairs) {
byte_decoder[pair.second] = pair.first;
}
// for (auto & pair: byte_unicode_pairs) {
// std::cout << pair.first << ": " << pair.second << std::endl;
// }
@ -138,6 +146,8 @@ public:
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
// printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(),
// utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
}
std::vector<std::u32string> vocab;
for (const auto& pair : byte_unicode_pairs) {
@ -154,15 +164,36 @@ public:
LOG_DEBUG("vocab size: %llu", vocab.size());
int i = 0;
for (const auto& token : vocab) {
encoder[token] = i++;
encoder[token] = i;
decoder[i] = token;
i++;
}
encoder_len = i;
auto it = encoder.find(utf8_to_utf32("img</w>"));
if (it != encoder.end()) {
LOG_DEBUG(" trigger word img already in vocab");
} else {
LOG_DEBUG(" trigger word img not in vocab yet");
}
int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
bpe_len = rank;
};
void add_token(const std::string& text) {
std::u32string token = utf8_to_utf32(text);
auto it = encoder.find(token);
if (it != encoder.end()) {
encoder[token] = encoder_len;
decoder[encoder_len] = token;
encoder_len++;
}
}
std::u32string bpe(const std::u32string& token) {
std::vector<std::u32string> word;
@ -243,6 +274,7 @@ public:
size_t max_length = 0,
bool padding = false) {
std::vector<int32_t> tokens = encode(text, on_new_token_cb);
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
if (max_length > 0) {
if (tokens.size() > max_length - 1) {
@ -259,9 +291,34 @@ public:
}
}
}
return tokens;
}
std::string decode(const std::vector<int>& tokens) {
std::string text = "";
for (int t : tokens) {
if (t == 49406 || t == 49407)
continue;
std::u32string ts = decoder[t];
// printf("%d, %s \n", t, utf32_to_utf8(ts).c_str());
std::string s = utf32_to_utf8(ts);
if (s.length() >= 4 && ends_with(s, "</w>")) {
text += " " + s.replace(s.length() - 4, s.length() - 1, "");
} else {
text += " " + s;
}
}
// std::vector<unsigned char> bytes;
// for (auto c : text){
// bytes.push_back(byte_decoder[c]);
// }
// std::string s((char *)bytes.data());
// std::string s = "";
return trim(text);
}
std::vector<int> encode(std::string text, on_new_token_cb_t on_new_token_cb) {
std::string original_text = text;
std::vector<int32_t> bpe_tokens;
@ -308,7 +365,8 @@ public:
ss << "\"" << token << "\", ";
}
ss << "]";
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
return bpe_tokens;
}
};
@ -469,7 +527,8 @@ public:
: d_model(d_model),
n_head(n_head),
intermediate_size(intermediate_size) {
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head));
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new MultiheadAttention(d_model, n_head, true));
blocks["layer_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
blocks["layer_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(d_model));
@ -508,7 +567,7 @@ public:
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, int clip_skip = -1, bool mask = true) {
// x: [N, n_token, d_model]
int layer_idx = n_layer - 1;
LOG_DEBUG("clip_skip %d", clip_skip);
// LOG_DEBUG("clip_skip %d", clip_skip);
if (clip_skip > 0) {
layer_idx = n_layer - clip_skip;
}
@ -520,7 +579,7 @@ public:
}
std::string name = "layers." + std::to_string(i);
auto layer = std::dynamic_pointer_cast<CLIPLayer>(blocks[name]);
x = layer->forward(ctx, x); // [N, n_token, d_model]
x = layer->forward(ctx, x, mask); // [N, n_token, d_model]
// LOG_DEBUG("layer %d", i);
}
return x;
@ -703,7 +762,7 @@ public:
auto final_layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["final_layer_norm"]);
auto x = embeddings->forward(ctx, input_ids, tkn_embeddings); // [N, n_token, hidden_size]
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
x = encoder->forward(ctx, x, return_pooled ? -1 : clip_skip, true);
if (return_pooled || with_final_ln) {
x = final_layer_norm->forward(ctx, x);
}
@ -720,11 +779,6 @@ public:
};
class CLIPVisionModel : public GGMLBlock {
protected:
void init_params(struct ggml_context* ctx, ggml_type wtype) {
params["visual_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size);
}
public:
// network hparams
int32_t num_channels = 3;
@ -735,16 +789,14 @@ public:
int32_t intermediate_size = 4096;
int32_t n_head = 16;
int32_t n_layer = 24;
int32_t projection_dim = 768;
public:
CLIPVisionModel(CLIPVersion version = OPEN_CLIP_VIT_H_14) {
CLIPVisionModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14) {
if (version == OPEN_CLIP_VIT_H_14) {
hidden_size = 1280;
intermediate_size = 5120;
n_head = 16;
n_layer = 32;
projection_dim = 1024;
} else if (version == OPEN_CLIP_VIT_BIGG_14) {
hidden_size = 1664;
intermediate_size = 8192;
@ -758,9 +810,8 @@ public:
blocks["post_layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values, bool return_pooled = true) {
// pixel_values: [N, num_channels, image_size, image_size]
// return: // [N, projection_dim]
auto embeddings = std::dynamic_pointer_cast<CLIPVisionEmbeddings>(blocks["embeddings"]);
auto pre_layernorm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_layernorm"]);
auto encoder = std::dynamic_pointer_cast<CLIPEncoder>(blocks["encoder"]);
@ -768,26 +819,60 @@ public:
auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim]
x = pre_layernorm->forward(ctx, x);
x = encoder->forward(ctx, x, -1, true);
x = encoder->forward(ctx, x, -1, false);
x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size]
GGML_ASSERT(x->ne[2] == 1);
int64_t max_token_idx = 0;
ggml_tensor* pooled = ggml_view_1d(ctx, x, x->ne[0], x->nb[1] * max_token_idx); // assert N == 1
auto visual_projection = params["visual_projection"];
pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, visual_projection)), pooled);
return pooled; // [N, projection_dim]
GGML_ASSERT(x->ne[3] == 1);
if (return_pooled) {
ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0));
return pooled; // [N, hidden_size]
} else {
return x; // [N, n_token, hidden_size]
}
}
};
class CLIPProjection : public UnaryBlock {
protected:
int64_t in_features;
int64_t out_features;
bool transpose_weight;
void init_params(struct ggml_context* ctx, ggml_type wtype) {
if (transpose_weight) {
LOG_ERROR("transpose_weight");
params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features);
} else {
params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features);
}
}
public:
CLIPProjection(int64_t in_features,
int64_t out_features,
bool transpose_weight = false)
: in_features(in_features),
out_features(out_features),
transpose_weight(transpose_weight) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
if (transpose_weight) {
w = ggml_cont(ctx, ggml_transpose(ctx, w));
}
return ggml_nn_linear(ctx, x, w, NULL);
}
};
class CLIPVisionModelProjection : public GGMLBlock {
public:
int32_t hidden_size = 1024;
int32_t projection_dim = 1024;
int32_t projection_dim = 768;
int32_t image_size = 224;
public:
CLIPVisionModelProjection(CLIPVersion version = OPEN_CLIP_VIT_H_14) {
CLIPVisionModelProjection(CLIPVersion version = OPENAI_CLIP_VIT_L_14,
bool transpose_proj_w = false) {
if (version == OPEN_CLIP_VIT_H_14) {
hidden_size = 1280;
projection_dim = 1024;
@ -795,17 +880,17 @@ public:
hidden_size = 1664;
}
blocks["visual_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version));
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, projection_dim, false));
blocks["vision_model"] = std::shared_ptr<GGMLBlock>(new CLIPVisionModel(version));
blocks["visual_projection"] = std::shared_ptr<GGMLBlock>(new CLIPProjection(hidden_size, projection_dim, transpose_proj_w));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* pixel_values) {
// pixel_values: [N, num_channels, image_size, image_size]
// return: [N, num_positions, projection_dim]
auto visual_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["visual_model"]);
auto visual_projection = std::dynamic_pointer_cast<Linear>(blocks["visual_projection"]);
// return: [N, projection_dim]
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
auto x = visual_model->forward(ctx, pixel_values); // [N, embed_dim]
auto x = vision_model->forward(ctx, pixel_values); // [N, hidden_size]
x = visual_projection->forward(ctx, x); // [N, projection_dim]
return x; // [N, projection_dim]
@ -1029,6 +1114,205 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
return tokenize(text, text_model.n_token, padding);
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
tokenize_with_trigger_token(std::string text,
int num_input_imgs,
int32_t image_token,
bool padding = false) {
return tokenize_with_trigger_token(text, num_input_imgs, image_token,
text_model.n_token, padding);
}
std::vector<int> convert_token_to_id(std::string text) {
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
}
return false;
};
std::vector<int> curr_tokens = tokenizer.encode(text, on_new_token_cb);
return curr_tokens;
}
std::string decode(const std::vector<int>& tokens) {
return tokenizer.decode(tokens);
}
void pad_tokens(std::vector<int>& tokens,
std::vector<float>& weights,
size_t max_length = 0,
bool padding = false) {
if (max_length > 0 && padding) {
size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2));
if (n == 0) {
n = 1;
}
size_t length = max_length * n;
LOG_DEBUG("token length: %llu", length);
std::vector<int> new_tokens;
std::vector<float> new_weights;
new_tokens.push_back(BOS_TOKEN_ID);
new_weights.push_back(1.0);
int token_idx = 0;
for (int i = 1; i < length; i++) {
if (token_idx >= tokens.size()) {
break;
}
if (i % max_length == 0) {
new_tokens.push_back(BOS_TOKEN_ID);
new_weights.push_back(1.0);
} else if (i % max_length == max_length - 1) {
new_tokens.push_back(EOS_TOKEN_ID);
new_weights.push_back(1.0);
} else {
new_tokens.push_back(tokens[token_idx]);
new_weights.push_back(weights[token_idx]);
token_idx++;
}
}
new_tokens.push_back(EOS_TOKEN_ID);
new_weights.push_back(1.0);
tokens = new_tokens;
weights = new_weights;
if (padding) {
int pad_token_id = PAD_TOKEN_ID;
if (version == VERSION_2_x) {
pad_token_id = 0;
}
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
weights.insert(weights.end(), length - weights.size(), 1.0);
}
}
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<bool>>
tokenize_with_trigger_token(std::string text,
int num_input_imgs,
int32_t image_token,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);
{
std::stringstream ss;
ss << "[";
for (const auto& item : parsed_attention) {
ss << "['" << item.first << "', " << item.second << "], ";
}
ss << "]";
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
size_t word_end = str.find(",");
std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end);
embd_name = trim(embd_name);
std::string embd_path = get_full_path(embd_dir, embd_name + ".pt");
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".ckpt");
}
if (embd_path.size() == 0) {
embd_path = get_full_path(embd_dir, embd_name + ".safetensors");
}
if (embd_path.size() > 0) {
if (load_embedding(embd_name, embd_path, bpe_tokens)) {
if (word_end != std::string::npos) {
str = str.substr(word_end);
} else {
str = "";
}
return true;
}
}
return false;
};
std::vector<int> tokens;
std::vector<float> weights;
std::vector<bool> class_token_mask;
int32_t class_idx = -1, tokens_acc = 0;
for (const auto& item : parsed_attention) {
std::vector<int> class_token_index;
std::vector<int> clean_input_ids;
const std::string& curr_text = item.first;
float curr_weight = item.second;
// printf(" %s: %f \n", curr_text.c_str(), curr_weight);
std::vector<int> curr_tokens = tokenizer.encode(curr_text, on_new_token_cb);
int32_t clean_index = 0;
for (uint32_t i = 0; i < curr_tokens.size(); i++) {
int token_id = curr_tokens[i];
if (token_id == image_token)
class_token_index.push_back(clean_index - 1);
else {
clean_input_ids.push_back(token_id);
clean_index++;
}
}
// GGML_ASSERT(class_token_index.size() == 1); // PhotoMaker currently does not support multiple
// trigger words in a single prompt.
if (class_token_index.size() == 1) {
// Expand the class word token and corresponding mask
int class_token = clean_input_ids[class_token_index[0]];
class_idx = tokens_acc + class_token_index[0];
std::vector<int> clean_input_ids_tmp;
for (uint32_t i = 0; i < class_token_index[0]; i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
for (uint32_t i = 0; i < num_input_imgs; i++)
clean_input_ids_tmp.push_back(class_token);
for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++)
clean_input_ids_tmp.push_back(clean_input_ids[i]);
clean_input_ids.clear();
clean_input_ids = clean_input_ids_tmp;
}
tokens_acc += clean_index;
tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end());
weights.insert(weights.end(), clean_input_ids.size(), curr_weight);
}
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
weights.insert(weights.begin(), 1.0);
pad_tokens(tokens, weights, max_length, padding);
for (uint32_t i = 0; i < tokens.size(); i++) {
if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs)
class_token_mask.push_back(true);
else
class_token_mask.push_back(false);
}
// printf("[");
// for (int i = 0; i < tokens.size(); i++) {
// printf("%d, ", class_token_mask[i] ? 1 : 0);
// }
// printf("]\n");
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
// }
// std::cout << std::endl;
return std::make_tuple(tokens, weights, class_token_mask);
}
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
@ -1078,49 +1362,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
if (max_length > 0 && padding) {
size_t n = std::ceil(tokens.size() * 1.0 / (max_length - 2));
if (n == 0) {
n = 1;
}
size_t length = max_length * n;
LOG_DEBUG("token length: %llu", length);
std::vector<int> new_tokens;
std::vector<float> new_weights;
new_tokens.push_back(BOS_TOKEN_ID);
new_weights.push_back(1.0);
int token_idx = 0;
for (int i = 1; i < length; i++) {
if (token_idx >= tokens.size()) {
break;
}
if (i % max_length == 0) {
new_tokens.push_back(BOS_TOKEN_ID);
new_weights.push_back(1.0);
} else if (i % max_length == max_length - 1) {
new_tokens.push_back(EOS_TOKEN_ID);
new_weights.push_back(1.0);
} else {
new_tokens.push_back(tokens[token_idx]);
new_weights.push_back(weights[token_idx]);
token_idx++;
}
}
new_tokens.push_back(EOS_TOKEN_ID);
new_weights.push_back(1.0);
tokens = new_tokens;
weights = new_weights;
if (padding) {
int pad_token_id = PAD_TOKEN_ID;
if (version == VERSION_2_x) {
pad_token_id = 0;
}
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
weights.insert(weights.end(), length - weights.size(), 1.0);
}
}
pad_tokens(tokens, weights, max_length, padding);
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
@ -1132,10 +1374,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
};
struct FrozenCLIPVisionEmbedder : public GGMLModule {
CLIPVisionModel vision_model;
CLIPVisionModelProjection vision_model;
FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype)
: GGMLModule(backend, wtype) {
: vision_model(OPEN_CLIP_VIT_H_14, true), GGMLModule(backend, wtype) {
vision_model.init(params_ctx, wtype);
}
@ -1152,7 +1394,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLModule {
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
vision_model.get_param_tensors(tensors, prefix + "transformer.visual_model");
vision_model.get_param_tensors(tensors, prefix + "transformer");
}
struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) {

View File

@ -10,6 +10,7 @@
#include "stable-diffusion.h"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
@ -65,6 +66,8 @@ struct SDParams {
std::string esrgan_path;
std::string controlnet_path;
std::string embeddings_path;
std::string stacked_id_embeddings_path;
std::string input_id_images_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string lora_model_dir;
std::string output_path = "output.png";
@ -73,12 +76,13 @@ struct SDParams {
std::string prompt;
std::string negative_prompt;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;
float min_cfg = 1.0f;
float cfg_scale = 7.0f;
float style_ratio = 20.f;
int clip_skip = -1; // <= 0 represents unspecified
int width = 512;
int height = 512;
int batch_count = 1;
int video_frames = 6;
int motion_bucket_id = 127;
@ -95,6 +99,9 @@ struct SDParams {
bool verbose = false;
bool vae_tiling = false;
bool control_net_cpu = false;
bool normalize_input = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool canny_preprocess = false;
int upscale_repeats = 1;
};
@ -110,10 +117,16 @@ void print_params(SDParams params) {
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
printf(" controlnet_path: %s\n", params.controlnet_path.c_str());
printf(" embeddings_path: %s\n", params.embeddings_path.c_str());
printf(" stacked_id_embeddings_path: %s\n", params.stacked_id_embeddings_path.c_str());
printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str());
printf(" style ratio: %.2f\n", params.style_ratio);
printf(" normzalize input image : %s\n", params.normalize_input ? "true" : "false");
printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str());
printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
@ -146,6 +159,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n");
printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.\n");
printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.\n");
printf(" --normalize-input normalize PHOTOMAKER input id images\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n");
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n");
@ -158,6 +174,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n");
printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n");
printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n");
printf(" 1.0 corresponds to full destruction of information in init image\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n");
@ -244,6 +261,18 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.embeddings_path = argv[i];
} else if (arg == "--stacked-id-embd-dir") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.stacked_id_embeddings_path = argv[i];
} else if (arg == "--input-id-images-dir") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.input_id_images_path = argv[i];
} else if (arg == "--type") {
if (++i >= argc) {
invalid_arg = true;
@ -327,6 +356,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.strength = std::stof(argv[i]);
} else if (arg == "--style-ratio") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.style_ratio = std::stof(argv[i]);
} else if (arg == "--control-strength") {
if (++i >= argc) {
invalid_arg = true;
@ -361,6 +396,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.vae_tiling = true;
} else if (arg == "--control-net-cpu") {
params.control_net_cpu = true;
} else if (arg == "--normalize-input") {
params.normalize_input = true;
} else if (arg == "--clip-on-cpu") {
params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs
} else if (arg == "--vae-on-cpu") {
params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs
} else if (arg == "--canny") {
params.canny_preprocess = true;
} else if (arg == "-b" || arg == "--batch-count") {
@ -613,6 +654,7 @@ int main(int argc, const char* argv[]) {
params.controlnet_path.c_str(),
params.lora_model_dir.c_str(),
params.embeddings_path.c_str(),
params.stacked_id_embeddings_path.c_str(),
vae_decode_only,
params.vae_tiling,
true,
@ -620,7 +662,9 @@ int main(int argc, const char* argv[]) {
params.wtype,
params.rng_type,
params.schedule,
params.control_net_cpu);
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu);
if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
@ -664,7 +708,10 @@ int main(int argc, const char* argv[]) {
params.seed,
params.batch_count,
control_image,
params.control_strength);
params.control_strength,
params.style_ratio,
params.normalize_input,
params.input_id_images_path.c_str());
} else {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,

View File

@ -80,8 +80,27 @@ __STATIC_INLINE__ ggml_fp16_t ggml_tensor_get_f16(const ggml_tensor* tensor, int
return *(ggml_fp16_t*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]);
}
__STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only = false) {
printf("shape(%zu, %zu, %zu, %zu)\n", tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
static struct ggml_tensor* get_tensor_from_graph(struct ggml_cgraph* gf, const char* name) {
struct ggml_tensor* res = NULL;
for (int i = 0; i < gf->n_nodes; i++) {
// printf("%d, %s \n", i, gf->nodes[i]->name);
if (strcmp(ggml_get_name(gf->nodes[i]), name) == 0) {
res = gf->nodes[i];
break;
}
}
for (int i = 0; i < gf->n_leafs; i++) {
// printf("%d, %s \n", i, gf->leafs[i]->name);
if (strcmp(ggml_get_name(gf->leafs[i]), name) == 0) {
res = gf->leafs[i];
break;
}
}
return res;
}
__STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_only = false, const char* mark = "") {
printf("%s (%s): shape(%zu, %zu, %zu, %zu)\n", mark, ggml_type_name(tensor->type), tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
fflush(stdout);
if (shape_only) {
return;
@ -217,6 +236,23 @@ __STATIC_INLINE__ uint8_t* sd_tensor_to_image(struct ggml_tensor* input) {
return image_data;
}
__STATIC_INLINE__ uint8_t* sd_tensor_to_mul_image(struct ggml_tensor* input, int idx) {
int64_t width = input->ne[0];
int64_t height = input->ne[1];
int64_t channels = input->ne[2];
GGML_ASSERT(channels == 3 && input->type == GGML_TYPE_F32);
uint8_t* image_data = (uint8_t*)malloc(width * height * channels);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float value = ggml_tensor_get_f32(input, ix, iy, k, idx);
*(image_data + iy * width * channels + ix * channels + k) = (uint8_t)(value * 255.0f);
}
}
}
return image_data;
}
__STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
bool scale = true) {
@ -237,6 +273,28 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data,
}
}
__STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data,
struct ggml_tensor* output,
int idx,
float* mean = NULL,
float* std = NULL) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
int value = *(image_data + iy * width * channels + ix * channels + k);
float pixel_val = value / 255.0f;
if (mean != NULL && std != NULL)
pixel_val = (pixel_val - mean[k]) / std[k];
ggml_tensor_set_f32(output, pixel_val, ix, iy, k, idx);
}
}
}
}
__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data,
struct ggml_tensor* output,
bool scale = true) {
@ -247,7 +305,7 @@ __STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data,
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float value = *(image_data + iy * width * channels + ix * channels + k);
int value = *(image_data + iy * width * channels + ix * channels + k);
if (scale) {
value /= 255.f;
}
@ -771,7 +829,10 @@ protected:
// compute the required memory
size_t compute_buffer_size = ggml_gallocr_get_buffer_size(compute_allocr, 0);
LOG_DEBUG("%s compute buffer size: %.2f MB", get_desc().c_str(), compute_buffer_size / 1024.0 / 1024.0);
LOG_DEBUG("%s compute buffer size: %.2f MB(%s)",
get_desc().c_str(),
compute_buffer_size / 1024.0 / 1024.0,
ggml_backend_is_cpu(backend) ? "RAM" : "VRAM");
return true;
}
@ -816,8 +877,11 @@ public:
return false;
}
size_t params_buffer_size = ggml_backend_buffer_get_size(params_buffer);
LOG_DEBUG("%s params backend buffer size = % 6.2f MB (%i tensors)",
get_desc().c_str(), params_buffer_size / (1024.0 * 1024.0), num_tensors);
LOG_DEBUG("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)",
get_desc().c_str(),
params_buffer_size / (1024.0 * 1024.0),
ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
num_tensors);
return true;
}
@ -865,11 +929,8 @@ public:
alloc_compute_buffer(get_graph);
reset_compute_ctx();
struct ggml_cgraph* gf = get_graph();
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
cpy_data_to_backend_tensor();
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, n_threads);
}
@ -879,13 +940,11 @@ public:
ggml_backend_metal_set_n_cb(backend, n_threads);
}
#endif
ggml_backend_graph_compute(backend, gf);
#ifdef GGML_PERF
ggml_graph_print(gf);
#endif
if (output != NULL) {
auto result = gf->nodes[gf->n_nodes - 1];
if (*output == NULL && output_ctx != NULL) {
@ -977,13 +1036,11 @@ public:
}
for (auto& pair : blocks) {
auto& block = pair.second;
block->get_param_tensors(tensors, prefix + pair.first);
}
for (auto& pair : params) {
struct ggml_tensor* param = pair.second;
struct ggml_tensor* param = pair.second;
tensors[prefix + pair.first] = pair.second;
}
}
@ -1243,11 +1300,10 @@ public:
struct ggml_tensor* kqv = ggml_nn_attention(ctx, q, k, v, mask); // [N * n_head, n_token, d_head]
kqv = ggml_reshape_4d(ctx, kqv, d_head, n_token, n_head, N);
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, n_token, n_head, d_head]
x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N, n_token, d_head * n_head]
x = ggml_reshape_3d(ctx, kqv, d_head * n_head, n_token, N); // [N * n_token, d_head * n_head]
x = out_proj->forward(ctx, x);
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;
}
};

View File

@ -14,9 +14,10 @@ struct LoraModel : public GGMLModule {
LoraModel(ggml_backend_t backend,
ggml_type wtype,
const std::string file_path = "")
const std::string& file_path = "",
const std::string& prefix = "")
: file_path(file_path), GGMLModule(backend, wtype) {
if (!model_loader.init_from_file(file_path)) {
if (!model_loader.init_from_file(file_path, prefix)) {
load_failed = true;
}
}
@ -33,8 +34,7 @@ struct LoraModel : public GGMLModule {
return model_loader.get_params_mem_size(NULL);
}
bool load_from_file() {
bool load_from_file(bool filter_tensor = false) {
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
if (load_failed) {
@ -46,6 +46,11 @@ struct LoraModel : public GGMLModule {
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
if (filter_tensor && !contains(name, "lora")) {
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
return true;
}
if (dry_run) {
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
tensor_storage.type,
@ -66,7 +71,6 @@ struct LoraModel : public GGMLModule {
dry_run = false;
model_loader.load_tensors(on_new_tensor_cb, backend);
LOG_DEBUG("finished loaded lora");
return true;
}
@ -85,6 +89,10 @@ struct LoraModel : public GGMLModule {
}
k_tensor = k_tensor.substr(0, k_pos);
replace_all_chars(k_tensor, '.', '_');
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
}
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight";
std::string alpha_name = "lora." + k_tensor + ".alpha";

View File

@ -108,14 +108,14 @@ std::unordered_map<std::string, std::string> open_clip_to_hf_clip_model = {
{"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"},
{"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"},
{"model.text_projection", "transformer.text_model.text_projection"},
{"model.visual.class_embedding", "transformer.visual_model.embeddings.class_embedding"},
{"model.visual.conv1.weight", "transformer.visual_model.embeddings.patch_embedding.weight"},
{"model.visual.ln_post.bias", "transformer.visual_model.post_layernorm.bias"},
{"model.visual.ln_post.weight", "transformer.visual_model.post_layernorm.weight"},
{"model.visual.ln_pre.bias", "transformer.visual_model.pre_layernorm.bias"},
{"model.visual.ln_pre.weight", "transformer.visual_model.pre_layernorm.weight"},
{"model.visual.positional_embedding", "transformer.visual_model.embeddings.position_embedding.weight"},
{"model.visual.proj", "transformer.visual_model.visual_projection"},
{"model.visual.class_embedding", "transformer.vision_model.embeddings.class_embedding"},
{"model.visual.conv1.weight", "transformer.vision_model.embeddings.patch_embedding.weight"},
{"model.visual.ln_post.bias", "transformer.vision_model.post_layernorm.bias"},
{"model.visual.ln_post.weight", "transformer.vision_model.post_layernorm.weight"},
{"model.visual.ln_pre.bias", "transformer.vision_model.pre_layernorm.bias"},
{"model.visual.ln_pre.weight", "transformer.vision_model.pre_layernorm.weight"},
{"model.visual.positional_embedding", "transformer.vision_model.embeddings.position_embedding.weight"},
{"model.visual.proj", "transformer.visual_projection.weight"},
};
std::unordered_map<std::string, std::string> open_clip_to_hk_clip_resblock = {
@ -157,6 +157,10 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
} else if (starts_with(new_name, "cond_stage_model.")) {
prefix = "cond_stage_model.";
new_name = new_name.substr(strlen("cond_stage_model."));
} else if (ends_with(new_name, "vision_model.visual_projection.weight")) {
prefix = new_name.substr(0, new_name.size() - strlen("vision_model.visual_projection.weight"));
new_name = prefix + "visual_projection.weight";
return new_name;
} else {
return new_name;
}
@ -186,7 +190,7 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) {
replace_suffix();
open_clip_resblock_prefix = "model.visual.transformer.resblocks.";
hf_clip_resblock_prefix = "transformer.visual_model.encoder.layers.";
hf_clip_resblock_prefix = "transformer.vision_model.encoder.layers.";
replace_suffix();
@ -248,7 +252,7 @@ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> su
},
};
std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) {
std::string convert_diffusers_name_to_compvis(std::string key, char seq) {
std::vector<std::string> m;
auto match = [](std::vector<std::string>& match_list, const std::regex& regex, const std::string& key) {
@ -282,6 +286,11 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq)
return inner_key;
};
// convert attn to out
if (ends_with(key, "to_out")) {
key += format("%c0", seq);
}
// unet
if (match(m, std::regex(format("unet%cconv_in(.*)", seq)), key)) {
return format("model%cdiffusion_model%cinput_blocks%c0%c0", seq, seq, seq, seq) + m[0];
@ -391,8 +400,8 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq)
}
std::string convert_tensor_name(const std::string& name) {
std::string new_name;
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.")) {
std::string new_name = name;
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
new_name = convert_open_clip_to_hf_clip(name);
} else if (starts_with(name, "first_stage_model.decoder")) {
new_name = convert_vae_decoder_name(name);
@ -416,6 +425,26 @@ std::string convert_tensor_name(const std::string& name) {
} else {
new_name = name;
}
} else if (contains(name, "lora_up") || contains(name, "lora_down") || contains(name, "lora.up") || contains(name, "lora.down")) {
size_t pos = new_name.find(".processor");
if (pos != std::string::npos) {
new_name.replace(pos, strlen(".processor"), "");
}
pos = new_name.find_last_of('_');
if (pos != std::string::npos) {
std::string name_without_network_parts = new_name.substr(0, pos);
std::string network_part = new_name.substr(pos + 1);
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.');
replace_all_chars(new_key, '.', '_');
if (starts_with(network_part, "lora.")) {
network_part = "lora_" + network_part.substr(5);
}
if (new_key.size() > 0) {
new_name = "lora." + new_key + "." + network_part;
}
// LOG_DEBUG("new name: %s", new_name.c_str());
}
} else if (starts_with(name, "unet") || starts_with(name, "vae") || starts_with(name, "te")) { // for diffuser
size_t pos = name.find_last_of('.');
if (pos != std::string::npos) {
@ -830,7 +859,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
}
TensorStorage tensor_storage(prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
tensor_storage.reverse_ne();
size_t tensor_data_size = end - begin;
@ -1169,7 +1197,9 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
if (reader.phase == PickleTensorReader::READ_DIMENS) {
reader.tensor_storage.reverse_ne();
reader.tensor_storage.file_index = file_index;
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
// if(strcmp(prefix.c_str(), "scarlett") == 0)
// printf(" got tensor %s \n ", reader.tensor_storage.name.c_str());
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
tensor_storages.push_back(reader.tensor_storage);
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
// reset
@ -1272,7 +1302,8 @@ std::string ModelLoader::load_merges() {
return merges_utf8_str;
}
void remove_duplicates(std::vector<TensorStorage>& vec) {
std::vector<TensorStorage> remove_duplicates(const std::vector<TensorStorage>& vec) {
std::vector<TensorStorage> res;
std::unordered_map<std::string, size_t> name_to_index_map;
for (size_t i = 0; i < vec.size(); ++i) {
@ -1280,13 +1311,16 @@ void remove_duplicates(std::vector<TensorStorage>& vec) {
auto it = name_to_index_map.find(current_name);
if (it != name_to_index_map.end()) {
vec[it->second] = vec[i];
res[it->second] = vec[i];
} else {
name_to_index_map[current_name] = i;
res.push_back(vec[i]);
}
}
vec.resize(name_to_index_map.size());
// vec.resize(name_to_index_map.size());
return res;
}
bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend) {
@ -1300,7 +1334,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
preprocess_tensor(tensor_storage, processed_tensor_storages);
}
remove_duplicates(processed_tensor_storages);
std::vector<TensorStorage> dedup = remove_duplicates(processed_tensor_storages);
processed_tensor_storages = dedup;
bool success = true;
for (size_t file_index = 0; file_index < file_paths_.size(); file_index++) {
std::string file_path = file_paths_[file_index];
@ -1362,7 +1398,6 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
if (tensor_storage.file_index != file_index) {
continue;
}
ggml_tensor* dst_tensor = NULL;
success = on_new_tensor_cb(tensor_storage, &dst_tensor);

View File

@ -7,6 +7,7 @@
#include <set>
#include <sstream>
#include <string>
#include <tuple>
#include <vector>
#include "ggml/ggml-backend.h"

305
pmid.hpp 100644
View File

@ -0,0 +1,305 @@
#ifndef __PMI_HPP__
#define __PMI_HPP__
#include "ggml_extend.hpp"
#include "clip.hpp"
#include "lora.hpp"
struct FuseBlock : public GGMLBlock {
// network hparams
int in_dim;
int out_dim;
int hidden_dim;
bool use_residue;
public:
FuseBlock(int i_d, int o_d, int h_d, bool use_residue = true)
: in_dim(i_d), out_dim(o_d), hidden_dim(h_d), use_residue(use_residue) {
blocks["fc1"] = std::shared_ptr<GGMLBlock>(new Linear(in_dim, hidden_dim, true));
blocks["fc2"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_dim, out_dim, true));
blocks["layernorm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(in_dim));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [N, channels, h, w]
auto fc1 = std::dynamic_pointer_cast<Linear>(blocks["fc1"]);
auto fc2 = std::dynamic_pointer_cast<Linear>(blocks["fc2"]);
auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layernorm"]);
struct ggml_tensor* r = x;
// x = ggml_nn_layer_norm(ctx, x, ln_w, ln_b);
x = layer_norm->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc1_w, x), fc1_b);
x = fc1->forward(ctx, x);
x = ggml_gelu_inplace(ctx, x);
x = fc2->forward(ctx, x);
// x = ggml_add(ctx, ggml_mul_mat(ctx, fc2_w, x), fc2_b);
if (use_residue)
x = ggml_add(ctx, x, r);
return x;
}
};
struct FuseModule : public GGMLBlock {
// network hparams
int embed_dim;
public:
FuseModule(int imb_d)
: embed_dim(imb_d) {
blocks["mlp1"] = std::shared_ptr<GGMLBlock>(new FuseBlock(imb_d * 2, imb_d, imb_d, false));
blocks["mlp2"] = std::shared_ptr<GGMLBlock>(new FuseBlock(imb_d, imb_d, imb_d, true));
blocks["layer_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(embed_dim));
}
struct ggml_tensor* fuse_fn(struct ggml_context* ctx,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* id_embeds) {
auto mlp1 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp1"]);
auto mlp2 = std::dynamic_pointer_cast<FuseBlock>(blocks["mlp2"]);
auto layer_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["layer_norm"]);
auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3));
auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3));
// concat is along dim 2
auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0);
stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3));
// stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds);
// stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds);
// stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds);
// stacked_id_embeds = ggml_nn_layer_norm(ctx, stacked_id_embeds, ln_w, ln_b);
stacked_id_embeds = mlp1->forward(ctx, stacked_id_embeds);
stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds);
stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds);
stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds);
return stacked_id_embeds;
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* id_embeds,
struct ggml_tensor* class_tokens_mask,
struct ggml_tensor* class_tokens_mask_pos,
struct ggml_tensor* left,
struct ggml_tensor* right) {
// x: [N, channels, h, w]
struct ggml_tensor* valid_id_embeds = id_embeds;
// # slice out the image token embeddings
// print_ggml_tensor(class_tokens_mask_pos, false);
ggml_set_name(class_tokens_mask_pos, "class_tokens_mask_pos");
ggml_set_name(prompt_embeds, "prompt_embeds");
// print_ggml_tensor(valid_id_embeds, true, "valid_id_embeds");
// print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos");
struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos);
ggml_set_name(image_token_embeds, "image_token_embeds");
struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds);
stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
if (left && right) {
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds);
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right);
} else if (left) {
stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds);
} else if (right) {
stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right);
}
stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3));
class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask));
class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds);
prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask);
struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds);
ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds");
return updated_prompt_embeds;
}
};
struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection {
PhotoMakerIDEncoderBlock()
: CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14) {
blocks["visual_projection_2"] = std::shared_ptr<GGMLBlock>(new Linear(1024, 1280, false));
blocks["fuse_module"] = std::shared_ptr<GGMLBlock>(new FuseModule(2048));
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds,
struct ggml_tensor* class_tokens_mask,
struct ggml_tensor* class_tokens_mask_pos,
struct ggml_tensor* left,
struct ggml_tensor* right) {
// x: [N, channels, h, w]
auto vision_model = std::dynamic_pointer_cast<CLIPVisionModel>(blocks["vision_model"]);
auto visual_projection = std::dynamic_pointer_cast<CLIPProjection>(blocks["visual_projection"]);
auto visual_projection_2 = std::dynamic_pointer_cast<Linear>(blocks["visual_projection_2"]);
auto fuse_module = std::dynamic_pointer_cast<FuseModule>(blocks["fuse_module"]);
struct ggml_tensor* shared_id_embeds = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size]
struct ggml_tensor* id_embeds = visual_projection->forward(ctx, shared_id_embeds); // [N, proj_dim(768)]
struct ggml_tensor* id_embeds_2 = visual_projection_2->forward(ctx, shared_id_embeds); // [N, 1280]
id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3));
id_embeds_2 = ggml_cont(ctx, ggml_permute(ctx, id_embeds_2, 2, 0, 1, 3));
id_embeds = ggml_concat(ctx, id_embeds, id_embeds_2); // [batch_size, seq_length, 1, 2048] check whether concat at dim 2 is right
id_embeds = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 1, 2, 0, 3));
struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx,
prompt_embeds,
id_embeds,
class_tokens_mask,
class_tokens_mask_pos,
left, right);
return updated_prompt_embeds;
}
};
struct PhotoMakerIDEncoder : public GGMLModule {
public:
SDVersion version = VERSION_XL;
PhotoMakerIDEncoderBlock id_encoder;
float style_strength;
std::vector<float> ctm;
std::vector<ggml_fp16_t> ctmf16;
std::vector<int> ctmpos;
std::vector<ggml_fp16_t> zeros_left_16;
std::vector<float> zeros_left;
std::vector<ggml_fp16_t> zeros_right_16;
std::vector<float> zeros_right;
public:
PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_XL, float sty = 20.f)
: GGMLModule(backend, wtype),
version(version),
style_strength(sty) {
id_encoder.init(params_ctx, wtype);
}
std::string get_desc() {
return "pmid";
}
size_t get_params_mem_size() {
size_t params_mem_size = id_encoder.get_params_mem_size();
return params_mem_size;
}
size_t get_params_num() {
size_t params_num = id_encoder.get_params_num();
return params_num;
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) {
id_encoder.get_param_tensors(tensors, prefix);
}
struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr,
struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds,
std::vector<bool>& class_tokens_mask) {
ctm.clear();
ctmf16.clear();
ctmpos.clear();
zeros_left.clear();
zeros_left_16.clear();
zeros_right.clear();
zeros_right_16.clear();
ggml_context* ctx0 = compute_ctx;
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
int64_t hidden_size = prompt_embeds->ne[0];
int64_t seq_length = prompt_embeds->ne[1];
ggml_type type = GGML_TYPE_F32;
struct ggml_tensor* class_tokens_mask_d = ggml_new_tensor_1d(ctx0, type, class_tokens_mask.size());
struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values);
struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds);
struct ggml_tensor* left = NULL;
struct ggml_tensor* right = NULL;
for (int i = 0; i < class_tokens_mask.size(); i++) {
if (class_tokens_mask[i]) {
ctm.push_back(0.f); // here use 0.f instead of 1.f to make a scale mask
ctmf16.push_back(ggml_fp32_to_fp16(0.f)); // here use 0.f instead of 1.f to make a scale mask
ctmpos.push_back(i);
} else {
ctm.push_back(1.f); // here use 1.f instead of 0.f to make a scale mask
ctmf16.push_back(ggml_fp32_to_fp16(1.f)); // here use 0.f instead of 1.f to make a scale mask
}
}
if (ctmpos[0] > 0) {
left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]);
}
if (ctmpos[ctmpos.size() - 1] < seq_length - 1) {
right = ggml_new_tensor_3d(ctx0, type,
hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1);
}
struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size());
{
if (type == GGML_TYPE_F16)
set_backend_tensor_data(class_tokens_mask_d, ctmf16.data());
else
set_backend_tensor_data(class_tokens_mask_d, ctm.data());
set_backend_tensor_data(class_tokens_mask_pos, ctmpos.data());
if (left) {
if (type == GGML_TYPE_F16) {
for (int i = 0; i < ggml_nelements(left); ++i)
zeros_left_16.push_back(ggml_fp32_to_fp16(0.f));
set_backend_tensor_data(left, zeros_left_16.data());
} else {
for (int i = 0; i < ggml_nelements(left); ++i)
zeros_left.push_back(0.f);
set_backend_tensor_data(left, zeros_left.data());
}
}
if (right) {
if (type == GGML_TYPE_F16) {
for (int i = 0; i < ggml_nelements(right); ++i)
zeros_right_16.push_back(ggml_fp32_to_fp16(0.f));
set_backend_tensor_data(right, zeros_right_16.data());
} else {
for (int i = 0; i < ggml_nelements(right); ++i)
zeros_right.push_back(0.f);
set_backend_tensor_data(right, zeros_right.data());
}
}
}
struct ggml_tensor* updated_prompt_embeds = id_encoder.forward(ctx0,
id_pixel_values_d,
prompt_embeds_d,
class_tokens_mask_d,
class_tokens_mask_pos,
left, right);
ggml_build_forward_expand(gf, updated_prompt_embeds);
return gf;
}
void compute(const int n_threads,
struct ggml_tensor* id_pixel_values,
struct ggml_tensor* prompt_embeds,
std::vector<bool>& class_tokens_mask,
struct ggml_tensor** updated_prompt_embeds,
ggml_context* output_ctx) {
auto get_graph = [&]() -> struct ggml_cgraph* {
// return build_graph(compute_allocr, id_pixel_values, prompt_embeds, class_tokens_mask);
return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask);
};
// GGMLModule::compute(get_graph, n_threads, updated_prompt_embeds);
GGMLModule::compute(get_graph, n_threads, true, updated_prompt_embeds, output_ctx);
}
};
#endif // __PMI_HPP__

View File

@ -11,10 +11,19 @@
#include "denoiser.hpp"
#include "esrgan.hpp"
#include "lora.hpp"
#include "pmid.hpp"
#include "tae.hpp"
#include "unet.hpp"
#include "vae.hpp"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"
// #define STB_IMAGE_WRITE_IMPLEMENTATION
// #define STB_IMAGE_WRITE_STATIC
// #include "stb_image_write.h"
const char* model_version_to_str[] = {
"1.x",
"2.x",
@ -56,8 +65,11 @@ void calculate_alphas_cumprod(float* alphas_cumprod,
class StableDiffusionGGML {
public:
ggml_backend_t backend = NULL; // general backend
ggml_type model_data_type = GGML_TYPE_COUNT;
ggml_backend_t backend = NULL; // general backend
ggml_backend_t clip_backend = NULL;
ggml_backend_t control_net_backend = NULL;
ggml_backend_t vae_backend = NULL;
ggml_type model_data_type = GGML_TYPE_COUNT;
SDVersion version;
bool vae_decode_only = false;
@ -73,10 +85,13 @@ public:
std::shared_ptr<AutoEncoderKL> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
std::shared_ptr<ControlNet> control_net;
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
std::shared_ptr<LoraModel> pmid_lora;
std::string taesd_path;
bool use_tiny_autoencoder = false;
bool vae_tiling = false;
bool stacked_id = false;
std::map<std::string, struct ggml_tensor*> tensors;
@ -86,6 +101,8 @@ public:
std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();
std::string trigger_word = "img"; // should be user settable
StableDiffusionGGML() = default;
StableDiffusionGGML(int n_threads,
@ -106,17 +123,23 @@ public:
~StableDiffusionGGML() {
ggml_backend_free(backend);
ggml_backend_free(clip_backend);
ggml_backend_free(control_net_backend);
ggml_backend_free(vae_backend);
}
bool load_from_file(const std::string& model_path,
const std::string& vae_path,
const std::string control_net_path,
const std::string embeddings_path,
const std::string id_embeddings_path,
const std::string& taesd_path,
bool vae_tiling_,
ggml_type wtype,
schedule_t schedule,
bool control_net_cpu) {
bool clip_on_cpu,
bool control_net_cpu,
bool vae_on_cpu) {
use_tiny_autoencoder = taesd_path.size() > 0;
#ifdef SD_USE_CUBLAS
LOG_DEBUG("Using CUDA backend");
@ -161,6 +184,7 @@ public:
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str());
return false;
}
LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]);
if (wtype == GGML_TYPE_COUNT) {
model_data_type = model_loader.get_sd_wtype();
@ -195,7 +219,12 @@ public:
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(backend, model_data_type, version);
clip_backend = backend;
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
}
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_data_type, version);
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors, "cond_stage_model.");
@ -211,24 +240,59 @@ public:
}
if (!use_tiny_autoencoder) {
first_stage_model = std::make_shared<AutoEncoderKL>(backend, vae_type, vae_decode_only);
if (vae_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("VAE Autoencoder: Using CPU backend");
vae_backend = ggml_backend_cpu_init();
} else {
vae_backend = backend;
}
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend, vae_type, vae_decode_only);
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_data_type, vae_decode_only);
}
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
if (control_net_path.size() > 0) {
ggml_backend_t cn_backend = NULL;
ggml_backend_t controlnet_backend = NULL;
if (control_net_cpu && !ggml_backend_is_cpu(backend)) {
LOG_DEBUG("ControlNet: Using CPU backend");
cn_backend = ggml_backend_cpu_init();
controlnet_backend = ggml_backend_cpu_init();
} else {
cn_backend = backend;
controlnet_backend = backend;
}
control_net = std::make_shared<ControlNet>(cn_backend, model_data_type, version);
control_net = std::make_shared<ControlNet>(controlnet_backend, model_data_type, version);
}
pmid_model = std::make_shared<PhotoMakerIDEncoder>(clip_backend, model_data_type, version);
if (id_embeddings_path.size() > 0) {
pmid_lora = std::make_shared<LoraModel>(backend, model_data_type, id_embeddings_path, "");
if (!pmid_lora->load_from_file(true)) {
LOG_WARN("load photomaker lora tensors from %s failed", id_embeddings_path.c_str());
return false;
}
LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", id_embeddings_path.c_str());
if (!model_loader.init_from_file(id_embeddings_path, "pmid.")) {
LOG_WARN("loading stacked ID embedding from '%s' failed", id_embeddings_path.c_str());
} else {
stacked_id = true;
}
}
if (stacked_id) {
if (!pmid_model->alloc_params_buffer()) {
LOG_ERROR(" pmid model params buffer allocation failed");
return false;
}
// LOG_INFO("pmid param memory buffer size = %.2fMB ",
// pmid_model->params_buffer_size / 1024.0 / 1024.0);
pmid_model->get_param_tensors(tensors, "pmid");
}
// if(stacked_id){
// pmid_model.init_params(GGML_TYPE_F32);
// pmid_model.map_by_name(tensors, "pmid.");
// }
LOG_DEBUG("loading vocab");
std::string merges_utf8_str = model_loader.load_merges();
if (merges_utf8_str.size() == 0) {
@ -250,6 +314,7 @@ public:
// load weights
LOG_DEBUG("loading weights");
int64_t t0 = ggml_time_ms();
std::set<std::string> ignore_tensors;
@ -257,6 +322,10 @@ public:
if (use_tiny_autoencoder) {
ignore_tensors.insert("first_stage_model.");
}
if (stacked_id) {
ignore_tensors.insert("lora.");
}
if (vae_decode_only) {
ignore_tensors.insert("first_stage_model.encoder");
ignore_tensors.insert("first_stage_model.quant");
@ -296,14 +365,54 @@ public:
}
control_net_params_mem_size = control_net->get_params_mem_size();
}
size_t pmid_params_mem_size = 0;
if (stacked_id) {
pmid_params_mem_size = pmid_model->get_params_mem_size();
}
size_t total_params_size = clip_params_mem_size + clip_params_mem_size + clip_params_mem_size + control_net_params_mem_size;
LOG_INFO("total params memory size = %.2fMB (clip %.2fMB, unet %.2fMB, vae %.2fMB, controlnet %.2fMB)",
total_params_size / 1024.0 / 1024.0,
clip_params_mem_size / 1024.0 / 1024.0,
unet_params_mem_size / 1024.0 / 1024.0,
vae_params_mem_size / 1024.0 / 1024.0,
control_net_params_mem_size / 1024.0 / 1024.0);
size_t total_params_ram_size = 0;
size_t total_params_vram_size = 0;
if (ggml_backend_is_cpu(clip_backend)) {
total_params_ram_size += clip_params_mem_size + pmid_params_mem_size;
} else {
total_params_vram_size += clip_params_mem_size + pmid_params_mem_size;
}
if (ggml_backend_is_cpu(backend)) {
total_params_ram_size += unet_params_mem_size;
} else {
total_params_vram_size += unet_params_mem_size;
}
if (ggml_backend_is_cpu(vae_backend)) {
total_params_ram_size += vae_params_mem_size;
} else {
total_params_vram_size += vae_params_mem_size;
}
if (ggml_backend_is_cpu(control_net_backend)) {
total_params_ram_size += control_net_params_mem_size;
} else {
total_params_vram_size += control_net_params_mem_size;
}
size_t total_params_size = total_params_ram_size + total_params_vram_size;
LOG_INFO(
"total params memory size = %.2fMB (VRAM %.2fMB, RAM %.2fMB): "
"clip %.2fMB(%s), unet %.2fMB(%s), vae %.2fMB(%s), controlnet %.2fMB(%s), pmid %.2fMB(%s)",
total_params_size / 1024.0 / 1024.0,
total_params_vram_size / 1024.0 / 1024.0,
total_params_ram_size / 1024.0 / 1024.0,
clip_params_mem_size / 1024.0 / 1024.0,
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM",
unet_params_mem_size / 1024.0 / 1024.0,
ggml_backend_is_cpu(backend) ? "RAM" : "VRAM",
vae_params_mem_size / 1024.0 / 1024.0,
ggml_backend_is_cpu(vae_backend) ? "RAM" : "VRAM",
control_net_params_mem_size / 1024.0 / 1024.0,
ggml_backend_is_cpu(control_net_backend) ? "RAM" : "VRAM",
pmid_params_mem_size / 1024.0 / 1024.0,
ggml_backend_is_cpu(clip_backend) ? "RAM" : "VRAM");
}
int64_t t1 = ggml_time_ms();
@ -444,16 +553,80 @@ public:
curr_lora_state = lora_state;
}
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
auto image_tokens = cond_stage_model->convert_token_to_id(trigger_word);
GGML_ASSERT(image_tokens.size() == 1);
auto tokens_and_weights = cond_stage_model->tokenize(prompt, false);
std::vector<int>& tokens = tokens_and_weights.first;
auto it = std::find(tokens.begin(), tokens.end(), image_tokens[0]);
GGML_ASSERT(it != tokens.end()); // prompt must have trigger word
tokens.erase(it);
return cond_stage_model->decode(tokens);
}
std::tuple<ggml_tensor*, ggml_tensor*, std::vector<bool>>
get_learned_condition_with_trigger(ggml_context* work_ctx,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
bool force_zero_embeddings = false) {
auto image_tokens = cond_stage_model->convert_token_to_id(trigger_word);
// if(image_tokens.size() == 1){
// printf(" image token id is: %d \n", image_tokens[0]);
// }
GGML_ASSERT(image_tokens.size() == 1);
auto tokens_and_weights = cond_stage_model->tokenize_with_trigger_token(text,
num_input_imgs,
image_tokens[0],
true);
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
std::vector<float>& weights = std::get<1>(tokens_and_weights);
std::vector<bool>& clsm = std::get<2>(tokens_and_weights);
// printf("tokens: \n");
// for(int i = 0; i < tokens.size(); ++i)
// printf("%d ", tokens[i]);
// printf("\n");
// printf("clsm: \n");
// for(int i = 0; i < clsm.size(); ++i)
// printf("%d ", clsm[i]?1:0);
// printf("\n");
auto cond = get_learned_condition_common(work_ctx, tokens, weights, clip_skip, width, height, force_zero_embeddings);
return std::make_tuple(cond.first, cond.second, clsm);
}
ggml_tensor* id_encoder(ggml_context* work_ctx,
ggml_tensor* init_img,
ggml_tensor* prompts_embeds,
std::vector<bool>& class_tokens_mask) {
ggml_tensor* res = NULL;
pmid_model->compute(n_threads, init_img, prompts_embeds, class_tokens_mask, &res, work_ctx);
return res;
}
std::pair<ggml_tensor*, ggml_tensor*> get_learned_condition(ggml_context* work_ctx,
const std::string& text,
int clip_skip,
int width,
int height,
bool force_zero_embeddings = false) {
auto tokens_and_weights = cond_stage_model->tokenize(text, true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
return get_learned_condition_common(work_ctx, tokens, weights, clip_skip, width, height, force_zero_embeddings);
}
std::pair<ggml_tensor*, ggml_tensor*> get_learned_condition_common(ggml_context* work_ctx,
std::vector<int>& tokens,
std::vector<float>& weights,
int clip_skip,
int width,
int height,
bool force_zero_embeddings = false) {
cond_stage_model->set_clip_skip(clip_skip);
auto tokens_and_weights = cond_stage_model->tokenize(text, true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, hidden_size]
@ -466,7 +639,7 @@ public:
std::vector<int> chunk_tokens(tokens.begin() + chunk_idx * chunk_len,
tokens.begin() + (chunk_idx + 1) * chunk_len);
std::vector<float> chunk_weights(weights.begin() + chunk_idx * chunk_len,
weights.begin() + (chunk_idx + 1) * chunk_len);
weights.begin() + (chunk_idx + 1) * chunk_len);
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
struct ggml_tensor* input_ids2 = NULL;
@ -664,7 +837,10 @@ public:
float min_cfg,
float cfg_scale,
sample_method_t method,
const std::vector<float>& sigmas) {
const std::vector<float>& sigmas,
int start_merge_step,
ggml_tensor* c_id,
ggml_tensor* c_vec_id) {
size_t steps = sigmas.size() - 1;
// x_t = load_tensor_from_file(work_ctx, "./rand0.bin");
// print_ggml_tensor(x_t);
@ -730,17 +906,30 @@ public:
// GGML_ASSERT(0);
}
// cond
diffusion_model->compute(n_threads,
noised_input,
timesteps,
c,
c_concat,
c_vector,
-1,
controls,
control_strength,
&out_cond);
if (start_merge_step == -1 || step <= start_merge_step) {
// cond
diffusion_model->compute(n_threads,
noised_input,
timesteps,
c,
c_concat,
c_vector,
-1,
controls,
control_strength,
&out_cond);
} else {
diffusion_model->compute(n_threads,
noised_input,
timesteps,
c_id,
c_concat,
c_vec_id,
-1,
controls,
control_strength,
&out_cond);
}
float* negative_data = NULL;
if (has_unconditioned) {
@ -1283,6 +1472,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
const char* control_net_path_c_str,
const char* lora_model_dir_c_str,
const char* embed_dir_c_str,
const char* id_embed_dir_c_str,
bool vae_decode_only,
bool vae_tiling,
bool free_params_immediately,
@ -1290,7 +1480,9 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
enum sd_type_t wtype,
enum rng_type_t rng_type,
enum schedule_t s,
bool keep_control_net_cpu) {
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu) {
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (sd_ctx == NULL) {
return NULL;
@ -1300,6 +1492,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
std::string taesd_path(taesd_path_c_str);
std::string control_net_path(control_net_path_c_str);
std::string embd_path(embed_dir_c_str);
std::string id_embd_path(id_embed_dir_c_str);
std::string lora_model_dir(lora_model_dir_c_str);
sd_ctx->sd = new StableDiffusionGGML(n_threads,
@ -1315,11 +1508,14 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
vae_path,
control_net_path,
embd_path,
id_embd_path,
taesd_path,
vae_tiling,
(ggml_type)wtype,
s,
keep_control_net_cpu)) {
keep_clip_on_cpu,
keep_control_net_cpu,
keep_vae_on_cpu)) {
delete sd_ctx->sd;
sd_ctx->sd = NULL;
free(sd_ctx);
@ -1348,7 +1544,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int64_t seed,
int batch_count,
const sd_image_t* control_cond,
float control_strength) {
float control_strength,
float style_ratio,
bool normalize_input,
const char* input_id_images_path_c_str) {
LOG_DEBUG("txt2img %dx%d", width, height);
if (sd_ctx == NULL) {
return NULL;
@ -1356,6 +1555,35 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
// LOG_DEBUG("%s %s %f %d %d %d", prompt_c_str, negative_prompt_c_str, cfg_scale, sample_steps, seed, batch_count);
std::string prompt(prompt_c_str);
std::string negative_prompt(negative_prompt_c_str);
std::string input_id_images_path(input_id_images_path_c_str);
// preprocess input id images
std::vector<sd_image_t*> input_id_images;
if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) {
std::vector<std::string> img_files = get_files_from_dir(input_id_images_path);
for (std::string img_file : img_files) {
int c = 0;
int width, height;
uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3);
if (input_image_buffer == NULL) {
LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str());
continue;
} else {
LOG_INFO("PhotoMaker loaded image from '%s'", img_file.c_str());
}
sd_image_t* input_image = NULL;
input_image = new sd_image_t{(uint32_t)width,
(uint32_t)height,
3,
input_image_buffer};
input_image = preprocess_id_image(input_image);
if (input_image == NULL) {
LOG_ERROR("preprocess input id image from '%s' failed", img_file.c_str());
continue;
}
input_id_images.push_back(input_image);
}
}
// extract and remove lora
auto result_pair = extract_and_remove_lora(prompt);
@ -1372,8 +1600,22 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
sd_ctx->sd->apply_loras(lora_f2m);
int64_t t1 = ggml_time_ms();
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->stacked_id) {
t0 = ggml_time_ms();
sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads);
t1 = ggml_time_ms();
LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_lora->free_params_buffer();
}
}
struct ggml_init_params params;
params.mem_size = static_cast<size_t>(10 * 1024 * 1024); // 10 MB
if (sd_ctx->sd->stacked_id) {
params.mem_size += static_cast<size_t>(10 * 1024 * 1024); // 10 MB
}
params.mem_size += width * height * 3 * sizeof(float);
params.mem_size *= batch_count;
params.mem_buffer = NULL;
@ -1394,10 +1636,67 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
seed = rand();
}
t0 = ggml_time_ms();
auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height);
ggml_tensor* c = cond_pair.first;
ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ]
std::string prompt_text_only;
ggml_tensor* init_img = NULL;
ggml_tensor* prompts_embeds = NULL;
ggml_tensor* pooled_prompts_embeds = NULL;
// ggml_tensor* class_tokens_mask = NULL;
std::vector<bool> class_tokens_mask;
if (sd_ctx->sd->stacked_id) {
if (input_id_images.size() > 0) {
sd_ctx->sd->pmid_model->style_strength = style_ratio;
int32_t w = input_id_images[0]->width;
int32_t h = input_id_images[0]->height;
int32_t channels = input_id_images[0]->channel;
int32_t num_input_images = (int32_t)input_id_images.size();
init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, w, h, channels, num_input_images);
// TODO: move these to somewhere else and be user settable
float mean[] = {0.48145466f, 0.4578275f, 0.40821073f};
float std[] = {0.26862954f, 0.26130258f, 0.27577711f};
for (int i = 0; i < num_input_images; i++) {
sd_image_t* init_image = input_id_images[i];
if (normalize_input)
sd_mul_images_to_tensor(init_image->data, init_img, i, mean, std);
else
sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL);
}
t0 = ggml_time_ms();
auto cond_tup = sd_ctx->sd->get_learned_condition_with_trigger(work_ctx, prompt,
clip_skip, width, height, num_input_images);
prompts_embeds = std::get<0>(cond_tup);
pooled_prompts_embeds = std::get<1>(cond_tup); // [adm_in_channels, ]
class_tokens_mask = std::get<2>(cond_tup); //
prompts_embeds = sd_ctx->sd->id_encoder(work_ctx, init_img, prompts_embeds, class_tokens_mask);
t1 = ggml_time_ms();
LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0);
if (sd_ctx->sd->free_params_immediately) {
sd_ctx->sd->pmid_model->free_params_buffer();
}
// Encode input prompt without the trigger word for delayed conditioning
prompt_text_only = sd_ctx->sd->remove_trigger_from_prompt(work_ctx, prompt);
// printf("%s || %s \n", prompt.c_str(), prompt_text_only.c_str());
prompt = prompt_text_only; //
if (sample_steps < 50) {
LOG_INFO("sampling steps increases from %d to 50 for PHOTOMAKER", sample_steps);
sample_steps = 50;
}
} else {
LOG_WARN("Provided PhotoMaker model file, but NO input ID images");
LOG_WARN("Turn off PhotoMaker");
sd_ctx->sd->stacked_id = false;
}
}
for (sd_image_t* img : input_id_images) {
free(img->data);
}
input_id_images.clear();
t0 = ggml_time_ms();
auto cond_pair = sd_ctx->sd->get_learned_condition(work_ctx, prompt, clip_skip, width, height);
ggml_tensor* c = cond_pair.first;
ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ]
struct ggml_tensor* uc = NULL;
struct ggml_tensor* uc_vector = NULL;
if (cfg_scale != 1.0) {
@ -1438,6 +1737,14 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
std::vector<float> sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps);
int start_merge_step = -1;
if (sd_ctx->sd->stacked_id) {
start_merge_step = int(sd_ctx->sd->pmid_model->style_strength / 100.f * sample_steps);
if (start_merge_step > 30)
start_merge_step = 30;
LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step);
}
struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx,
x_t,
NULL,
@ -1452,7 +1759,10 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx,
cfg_scale,
cfg_scale,
sample_method,
sigmas);
sigmas,
start_merge_step,
prompts_embeds,
pooled_prompts_embeds);
// struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms();
@ -1619,7 +1929,10 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx,
cfg_scale,
cfg_scale,
sample_method,
sigma_sched);
sigma_sched,
-1,
NULL,
NULL);
// struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin");
// print_ggml_tensor(x_0);
int64_t t3 = ggml_time_ms();
@ -1755,7 +2068,10 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
min_cfg,
cfg_scale,
sample_method,
sigmas);
sigmas,
-1,
NULL,
NULL);
int64_t t2 = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (t2 - t1) * 1.0f / 1000);

View File

@ -65,12 +65,12 @@ enum sd_type_t {
SD_TYPE_Q8_0 = 8,
SD_TYPE_Q8_1 = 9,
// k-quantizations
SD_TYPE_Q2_K = 10,
SD_TYPE_Q3_K = 11,
SD_TYPE_Q4_K = 12,
SD_TYPE_Q5_K = 13,
SD_TYPE_Q6_K = 14,
SD_TYPE_Q8_K = 15,
SD_TYPE_Q2_K = 10,
SD_TYPE_Q3_K = 11,
SD_TYPE_Q4_K = 12,
SD_TYPE_Q5_K = 13,
SD_TYPE_Q6_K = 14,
SD_TYPE_Q8_K = 15,
SD_TYPE_IQ2_XXS = 16,
SD_TYPE_IQ2_XS = 17,
SD_TYPE_IQ3_XXS = 18,
@ -95,7 +95,7 @@ enum sd_log_level_t {
};
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data);
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
@ -117,6 +117,7 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
const char* control_net_path_c_str,
const char* lora_model_dir,
const char* embed_dir_c_str,
const char* stacked_id_embed_dir_c_str,
bool vae_decode_only,
bool vae_tiling,
bool free_params_immediately,
@ -124,7 +125,9 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
enum sd_type_t wtype,
enum rng_type_t rng_type,
enum schedule_t s,
bool keep_control_net_cpu);
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);
@ -140,7 +143,10 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx,
int64_t seed,
int batch_count,
const sd_image_t* control_cond,
float control_strength);
float control_strength,
float style_strength,
bool normalize_input,
const char* input_id_images_path);
SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx,
sd_image_t init_image,

2585
thirdparty/stb_image_resize.h vendored 100644

File diff suppressed because it is too large Load Diff

113
util.cpp
View File

@ -25,6 +25,9 @@
#include "ggml/ggml.h"
#include "stable-diffusion.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#include "stb_image_resize.h"
bool ends_with(const std::string& str, const std::string& ending) {
if (str.length() >= ending.length()) {
return (str.compare(str.length() - ending.length(), ending.length(), ending) == 0);
@ -40,6 +43,13 @@ bool starts_with(const std::string& str, const std::string& start) {
return false;
}
bool contains(const std::string& str, const std::string& substr) {
if (str.find(substr) != std::string::npos) {
return true;
}
return false;
}
void replace_all_chars(std::string& str, char target, char replacement) {
for (size_t i = 0; i < str.length(); ++i) {
if (str[i] == target) {
@ -88,6 +98,43 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
}
}
std::vector<std::string> get_files_from_dir(const std::string& dir) {
std::vector<std::string> files;
WIN32_FIND_DATA findFileData;
HANDLE hFind;
char currentDirectory[MAX_PATH];
GetCurrentDirectory(MAX_PATH, currentDirectory);
char directoryPath[MAX_PATH]; // this is absolute path
sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str());
// Find the first file in the directory
hFind = FindFirstFile(directoryPath, &findFileData);
// Check if the directory was found
if (hFind == INVALID_HANDLE_VALUE) {
printf("Unable to find directory.\n");
return files;
}
// Loop through all files in the directory
do {
// Check if the found file is a regular file (not a directory)
if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) {
files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName));
}
} while (FindNextFile(hFind, &findFileData) != 0);
// Close the handle
FindClose(hFind);
sort(files.begin(), files.end());
return files;
}
#else // Unix
#include <dirent.h>
#include <sys/stat.h>
@ -102,6 +149,7 @@ bool is_directory(const std::string& path) {
return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode));
}
// TODO: add windows version
std::string get_full_path(const std::string& dir, const std::string& filename) {
DIR* dp = opendir(dir.c_str());
@ -121,6 +169,27 @@ std::string get_full_path(const std::string& dir, const std::string& filename) {
return "";
}
std::vector<std::string> get_files_from_dir(const std::string& dir) {
std::vector<std::string> files;
DIR* dp = opendir(dir.c_str());
if (dp != nullptr) {
struct dirent* entry;
while ((entry = readdir(dp)) != nullptr) {
std::string fname = dir + "/" + entry->d_name;
if (!is_directory(fname))
files.push_back(fname);
}
closedir(dp);
}
sort(files.begin(), files.end());
return files;
}
#endif
// get_num_physical_cores is copy from
@ -161,8 +230,8 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}
static sd_progress_cb_t sd_progress_cb = NULL;
void* sd_progress_cb_data = NULL;
static sd_progress_cb_t sd_progress_cb = NULL;
void* sd_progress_cb_data = NULL;
std::u32string utf8_to_utf32(const std::string& utf8_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
@ -207,9 +276,42 @@ std::string path_join(const std::string& p1, const std::string& p2) {
return p1 + "/" + p2;
}
sd_image_t* preprocess_id_image(sd_image_t* img) {
int shortest_edge = 224;
int size = shortest_edge;
sd_image_t* resized = NULL;
uint32_t w = img->width;
uint32_t h = img->height;
uint32_t c = img->channel;
// 1. do resize using stb_resize functions
unsigned char* buf = (unsigned char*)malloc(sizeof(unsigned char) * 3 * size * size);
if (!stbir_resize_uint8(img->data, w, h, 0,
buf, size, size, 0,
c)) {
fprintf(stderr, "%s: resize operation failed \n ", __func__);
return resized;
}
// 2. do center crop (likely unnecessary due to step 1)
// 3. do rescale
// 4. do normalize
// 3 and 4 will need to be done in float format.
resized = new sd_image_t{(uint32_t)shortest_edge,
(uint32_t)shortest_edge,
3,
buf};
return resized;
}
void pretty_progress(int step, int steps, float time) {
if (sd_progress_cb) {
sd_progress_cb(step,steps,time, sd_progress_cb_data);
sd_progress_cb(step, steps, time, sd_progress_cb_data);
return;
}
if (step == 0) {
@ -255,9 +357,8 @@ std::string trim(const std::string& s) {
return rtrim(ltrim(s));
}
static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL;
static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL;
#define LOG_BUFFER_SIZE 1024

8
util.h
View File

@ -3,11 +3,13 @@
#include <cstdint>
#include <string>
#include <vector>
#include "stable-diffusion.h"
bool ends_with(const std::string& str, const std::string& ending);
bool starts_with(const std::string& str, const std::string& start);
bool contains(const std::string& str, const std::string& substr);
std::string format(const char* fmt, ...);
@ -17,10 +19,16 @@ bool file_exists(const std::string& filename);
bool is_directory(const std::string& path);
std::string get_full_path(const std::string& dir, const std::string& filename);
std::vector<std::string> get_files_from_dir(const std::string& dir);
std::u32string utf8_to_utf32(const std::string& utf8_str);
std::string utf32_to_utf8(const std::u32string& utf32_str);
std::u32string unicode_value_to_utf32(int unicode_value);
sd_image_t* preprocess_id_image(sd_image_t* img);
// std::string sd_basename(const std::string& path);
typedef struct {
uint32_t width;
uint32_t height;