feat: add progress callback (#170)

pull/188/head master-7be65fa
fszontagh 2024-03-02 10:28:41 +01:00 committed by GitHub
parent d164236b2a
commit 7be65faa7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 24 additions and 4 deletions

View File

@ -891,6 +891,10 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
LOG_ERROR("embedding '%s' failed", embd_name.c_str());
return false;
}
if (std::find(readed_embeddings.begin(), readed_embeddings.end(), embd_name) != readed_embeddings.end()) {
LOG_DEBUG("embedding already read in: %s", embd_name.c_str());
return true;
}
struct ggml_init_params params;
params.mem_size = 10 * 1024 * 1024; // max for custom embeddings 10 MB
params.mem_buffer = NULL;

View File

@ -33,6 +33,7 @@ struct LoraModel : public GGMLModule {
return model_loader.get_params_mem_size(NULL);
}
bool load_from_file() {
LOG_INFO("loading LoRA from '%s'", file_path.c_str());
@ -55,6 +56,7 @@ struct LoraModel : public GGMLModule {
auto real = lora_tensors[name];
*dst_tensor = real;
}
return true;
};
@ -64,6 +66,7 @@ struct LoraModel : public GGMLModule {
dry_run = false;
model_loader.load_tensors(on_new_tensor_cb, backend);
LOG_DEBUG("finished loaded lora");
return true;
}

View File

@ -92,8 +92,10 @@ enum sd_log_level_t {
};
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
typedef void (*sd_progress_cb_t)(int step,int steps,float time, void* data);
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
SD_API int32_t get_num_physical_cores();
SD_API const char* sd_get_system_info();

View File

@ -161,6 +161,9 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}
static sd_progress_cb_t sd_progress_cb = NULL;
void* sd_progress_cb_data = NULL;
std::u32string utf8_to_utf32(const std::string& utf8_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.from_bytes(utf8_str);
@ -205,6 +208,10 @@ std::string path_join(const std::string& p1, const std::string& p2) {
}
void pretty_progress(int step, int steps, float time) {
if (sd_progress_cb) {
sd_progress_cb(step,steps,time, sd_progress_cb_data);
return;
}
if (step == 0) {
return;
}
@ -248,8 +255,9 @@ std::string trim(const std::string& s) {
return rtrim(ltrim(s));
}
static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL;
static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL;
#define LOG_BUFFER_SIZE 1024
@ -286,7 +294,10 @@ void sd_set_log_callback(sd_log_cb_t cb, void* data) {
sd_log_cb = cb;
sd_log_cb_data = data;
}
void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
sd_progress_cb = cb;
sd_progress_cb_data = data;
}
const char* sd_get_system_info() {
static char buffer[1024];
std::stringstream ss;

View File

@ -6,7 +6,7 @@
/*================================================== AutoEncoderKL ===================================================*/
#define VAE_GRAPH_SIZE 10240
#define VAE_GRAPH_SIZE 20480
class ResnetBlock : public UnaryBlock {
protected: