fix: avoid double free and fix sdxl lora naming conversion

* Fixed a double free issue when running multiple backends on the CPU, eg: CLIP
and the primary backend, as this would result in the *_backend pointers both
pointing to the same thing resulting in a segfault when calling the
StableDiffusionGGML destructor.

* Improve logging to allow for a color switch on the command line interface.
Changed the base log_printf function to not bake the log level directly into
the log buffer as that information is already passed the logging function via
the level parameter and it's easier to add in there than strip it out.

* Added a fix for certain SDXL LoRAs that don't seem to follow the expected
naming convention, converts over the tensor name during the LoRA model
loading. Added some logging of useful LoRA loading information. Had to
increase the base size of the GGML graph as the existing size results in an
insufficient graph memory error when using SDXL LoRAs.

* small fixes

---------

Co-authored-by: leejet <leejet714@gmail.com>
master master-48bcce4
Grauho 2024-03-20 10:00:22 -04:00 committed by GitHub
parent a469688e30
commit 48bcce493f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 103 additions and 33 deletions

View File

@ -103,6 +103,7 @@ struct SDParams {
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
};
@ -469,6 +470,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(0);
} else if (arg == "-v" || arg == "--verbose") {
params.verbose = true;
} else if (arg == "--color") {
params.color = true;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
@ -572,18 +575,47 @@ std::string get_image_params(SDParams params, int64_t seed) {
return parameter_string;
}
/* Enables Printing the log level tag in color using ANSI escape codes */
void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
SDParams* params = (SDParams*)data;
if (!params->verbose && level <= SD_LOG_DEBUG) {
int tag_color;
const char* level_str;
FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout;
if (!log || (!params->verbose && level <= SD_LOG_DEBUG)) {
return;
}
if (level <= SD_LOG_INFO) {
fputs(log, stdout);
fflush(stdout);
} else {
fputs(log, stderr);
fflush(stderr);
switch (level) {
case SD_LOG_DEBUG:
tag_color = 37;
level_str = "DEBUG";
break;
case SD_LOG_INFO:
tag_color = 34;
level_str = "INFO";
break;
case SD_LOG_WARN:
tag_color = 35;
level_str = "WARN";
break;
case SD_LOG_ERROR:
tag_color = 31;
level_str = "ERROR";
break;
default: /* Potential future-proofing */
tag_color = 33;
level_str = "?????";
break;
}
if (params->color == true) {
fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str);
} else {
fprintf(out_stream, "[%-5s] ", level_str);
}
fputs(log, out_stream);
fflush(out_stream);
}
int main(int argc, const char* argv[]) {

View File

@ -759,8 +759,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
// virtual struct ggml_cgraph* get_ggml_cgraph() = 0;
// };
/*
#define MAX_PARAMS_TENSOR_NUM 10240
#define MAX_GRAPH_SIZE 10240
*/
/* SDXL with LoRA requires more space */
#define MAX_PARAMS_TENSOR_NUM 15360
#define MAX_GRAPH_SIZE 15360
struct GGMLModule {
protected:
@ -1308,4 +1313,4 @@ public:
}
};
#endif // __GGML_EXTEND__HPP__
#endif // __GGML_EXTEND__HPP__

View File

@ -75,7 +75,7 @@ struct LoraModel : public GGMLModule {
return true;
}
struct ggml_cgraph* build_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
struct ggml_cgraph* build_lora_graph(std::map<std::string, struct ggml_tensor*> model_tensors) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false);
std::set<std::string> applied_lora_tensors;
@ -90,7 +90,7 @@ struct LoraModel : public GGMLModule {
k_tensor = k_tensor.substr(0, k_pos);
replace_all_chars(k_tensor, '.', '_');
// LOG_DEBUG("k_tensor %s", k_tensor.c_str());
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL
if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { // fix for SDXL
k_tensor = "model_diffusion_model_output_blocks_2_1_conv";
}
std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight";
@ -155,21 +155,37 @@ struct LoraModel : public GGMLModule {
ggml_build_forward_expand(gf, final_weight);
}
size_t total_lora_tensors_count = 0;
size_t applied_lora_tensors_count = 0;
for (auto& kv : lora_tensors) {
total_lora_tensors_count++;
if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) {
LOG_WARN("unused lora tensor %s", kv.first.c_str());
} else {
applied_lora_tensors_count++;
}
}
/* Don't worry if this message shows up twice in the logs per LoRA,
* this function is called once to calculate the required buffer size
* and then again to actually generate a graph to be used */
if (applied_lora_tensors_count != total_lora_tensors_count) {
LOG_WARN("Only (%lu / %lu) LoRA tensors have been applied",
applied_lora_tensors_count, total_lora_tensors_count);
} else {
LOG_DEBUG("(%lu / %lu) LoRA tensors applied successfully",
applied_lora_tensors_count, total_lora_tensors_count);
}
return gf;
}
void apply(std::map<std::string, struct ggml_tensor*> model_tensors, int n_threads) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(model_tensors);
return build_lora_graph(model_tensors);
};
GGMLModule::compute(get_graph, n_threads, true);
}
};
#endif // __LORA_HPP__
#endif // __LORA_HPP__

View File

@ -204,6 +204,23 @@ std::string convert_vae_decoder_name(const std::string& name) {
return name;
}
/* If not a SDXL LoRA the unet" prefix will have already been replaced by this
* point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */
std::string convert_sdxl_lora_name(std::string tensor_name) {
const std::pair<std::string, std::string> sdxl_lora_name_lookup[] = {
{"unet", "model_diffusion_model"},
{"te2", "cond_stage_model_1_transformer"},
{"te1", "cond_stage_model_transformer"},
};
for (auto& pair_i : sdxl_lora_name_lookup) {
if (tensor_name.compare(0, pair_i.first.length(), pair_i.first) == 0) {
tensor_name = std::regex_replace(tensor_name, std::regex(pair_i.first), pair_i.second);
break;
}
}
return tensor_name;
}
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> suffix_conversion_underline = {
{
"attentions",
@ -415,8 +432,12 @@ std::string convert_tensor_name(const std::string& name) {
if (pos != std::string::npos) {
std::string name_without_network_parts = name.substr(5, pos - 5);
std::string network_part = name.substr(pos + 1);
// LOG_DEBUG("%s %s", name_without_network_parts.c_str(), network_part.c_str());
std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '_');
/* For dealing with the new SDXL LoRA tensor naming convention */
new_key = convert_sdxl_lora_name(new_key);
if (new_key.empty()) {
new_name = name;
} else {
@ -1641,4 +1662,4 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
}
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);
return success;
}
}

View File

@ -122,10 +122,16 @@ public:
}
~StableDiffusionGGML() {
if (clip_backend != backend) {
ggml_backend_free(clip_backend);
}
if (control_net_backend != backend) {
ggml_backend_free(control_net_backend);
}
if (vae_backend != backend) {
ggml_backend_free(vae_backend);
}
ggml_backend_free(backend);
ggml_backend_free(clip_backend);
ggml_backend_free(control_net_backend);
ggml_backend_free(vae_backend);
}
bool load_from_file(const std::string& model_path,
@ -521,9 +527,7 @@ public:
int64_t t1 = ggml_time_ms();
LOG_INFO("lora '%s' applied, taking %.2fs",
lora_name.c_str(),
(t1 - t0) * 1.0f / 1000);
LOG_INFO("lora '%s' applied, taking %.2fs", lora_name.c_str(), (t1 - t0) * 1.0f / 1000);
}
void apply_loras(const std::unordered_map<std::string, float>& lora_state) {
@ -546,6 +550,8 @@ public:
}
}
LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size());
for (auto& kv : lora_state_diff) {
apply_lora(kv.first, kv.second);
}
@ -2109,4 +2115,4 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx,
LOG_INFO("img2vid completed in %.2fs", (t3 - t0) * 1.0f / 1000);
return result_images;
}
}

View File

@ -201,4 +201,4 @@ SD_API uint8_t* preprocess_canny(uint8_t* img,
}
#endif
#endif // __STABLE_DIFFUSION_H__
#endif // __STABLE_DIFFUSION_H__

View File

@ -366,18 +366,8 @@ void log_printf(sd_log_level_t level, const char* file, int line, const char* fo
va_list args;
va_start(args, format);
const char* level_str = "DEBUG";
if (level == SD_LOG_INFO) {
level_str = "INFO ";
} else if (level == SD_LOG_WARN) {
level_str = "WARN ";
} else if (level == SD_LOG_ERROR) {
level_str = "ERROR";
}
static char log_buffer[LOG_BUFFER_SIZE + 1];
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "[%s] %s:%-4d - ", level_str, sd_basename(file).c_str(), line);
int written = snprintf(log_buffer, LOG_BUFFER_SIZE, "%s:%-4d - ", sd_basename(file).c_str(), line);
if (written >= 0 && written < LOG_BUFFER_SIZE) {
vsnprintf(log_buffer + written, LOG_BUFFER_SIZE - written, format, args);
@ -572,4 +562,4 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size) {
}
return result;
}
}