sync: update ggml (#180)

pull/183/head master-730585d
leejet 2024-02-25 21:11:01 +08:00 committed by GitHub
parent 193fb620b1
commit 730585d515
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 197 additions and 182 deletions

View File

@ -956,64 +956,32 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
return hidden_states;
}
struct ggml_cgraph* build_graph(struct ggml_allocr* allocr, std::vector<int> tokens, bool return_pooled = false) {
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
struct ggml_tensor* input_ids2 = NULL,
size_t max_token_idx = 0,
bool return_pooled = false) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
struct ggml_tensor* input_ids = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, tokens.size());
ggml_allocr_alloc(allocr, input_ids);
if (!ggml_allocr_is_measure(allocr)) {
ggml_backend_tensor_set(input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids));
}
struct ggml_tensor* input_ids2 = NULL;
size_t max_token_idx = 0;
if (version == VERSION_XL) {
input_ids2 = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, tokens.size());
ggml_allocr_alloc(allocr, input_ids2);
auto it = std::find(tokens.begin(), tokens.end(), EOS_TOKEN_ID);
if (it != tokens.end()) {
std::fill(std::next(it), tokens.end(), 0);
}
max_token_idx = std::min<size_t>(std::distance(tokens.begin(), it), tokens.size() - 1);
// for (int i = 0; i < tokens.size(); i++) {
// printf("%d ", tokens[i]);
// }
// printf("\n");
if (!ggml_allocr_is_measure(allocr)) {
ggml_backend_tensor_set(input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2));
}
input_ids2 = to_backend(input_ids2);
if (!return_pooled) {
input_ids = to_backend(input_ids);
}
struct ggml_tensor* embeddings = NULL;
if (num_custom_embeddings > 0 && version != VERSION_XL) {
embeddings = ggml_new_tensor_2d(compute_ctx,
wtype,
text_model.hidden_size,
text_model.vocab_size + num_custom_embeddings /* custom placeholder */);
ggml_allocr_alloc(allocr, embeddings);
if (!ggml_allocr_is_measure(allocr)) {
// really bad, there is memory inflexibility (this is for host<->device memory conflicts)
auto token_embed_weight = text_model.get_token_embed_weight();
void* freeze_data = malloc(ggml_nbytes(token_embed_weight));
ggml_backend_tensor_get_and_sync(backend,
token_embed_weight,
freeze_data,
0,
ggml_nbytes(token_embed_weight));
ggml_backend_tensor_set(embeddings, freeze_data, 0, ggml_nbytes(token_embed_weight));
free(freeze_data);
// concatenate custom embeddings
ggml_backend_tensor_set(embeddings,
(const void*)token_embed_custom.data(),
ggml_nbytes(token_embed_weight),
num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype));
}
auto custom_embeddings = ggml_new_tensor_3d(compute_ctx,
wtype,
text_model.hidden_size,
1,
num_custom_embeddings);
set_backend_tensor_data(custom_embeddings, token_embed_custom.data());
auto token_embed_weight = text_model.get_token_embed_weight();
token_embed_weight = ggml_reshape_3d(compute_ctx, token_embed_weight, token_embed_weight->ne[0], 1, token_embed_weight->ne[1]);
// concatenate custom embeddings
embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings);
embeddings = ggml_reshape_2d(compute_ctx, embeddings, embeddings->ne[0], embeddings->ne[2]);
}
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, input_ids2, embeddings, max_token_idx, return_pooled);
@ -1024,12 +992,14 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule {
}
void compute(const int n_threads,
std::vector<int> tokens,
struct ggml_tensor* input_ids,
struct ggml_tensor* input_ids2,
size_t max_token_idx,
bool return_pooled,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(compute_allocr, tokens, return_pooled);
return build_graph(input_ids, input_ids2, max_token_idx, return_pooled);
};
GGMLModule::compute(get_graph, n_threads, true, output, output_ctx);
}
@ -1143,8 +1113,7 @@ struct FrozenCLIPVisionEmbedder : public GGMLModule {
vision_model.get_param_tensors(tensors, prefix + "transformer.visual_model");
}
struct ggml_cgraph* build_graph(struct ggml_allocr* allocr,
struct ggml_tensor* pixel_values) {
struct ggml_cgraph* build_graph(struct ggml_tensor* pixel_values) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
pixel_values = to_backend(pixel_values);
@ -1156,19 +1125,12 @@ struct FrozenCLIPVisionEmbedder : public GGMLModule {
return gf;
}
void alloc_compute_buffer(ggml_context* work_ctx, ggml_tensor* pixel_values) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(compute_allocr, pixel_values);
};
GGMLModule::alloc_compute_buffer(get_graph);
}
void compute(const int n_threads,
ggml_tensor* pixel_values,
ggml_tensor** output,
ggml_context* output_ctx) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(compute_allocr, pixel_values);
return build_graph(pixel_values);
};
GGMLModule::compute(get_graph, n_threads, true, output, output_ctx);
}

View File

@ -166,7 +166,6 @@ public:
struct ggml_tensor* resblock_forward(std::string name,
struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* x,
struct ggml_tensor* emb) {
auto block = std::dynamic_pointer_cast<ResBlock>(blocks[name]);
@ -175,7 +174,6 @@ public:
struct ggml_tensor* attention_layer_forward(std::string name,
struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* x,
struct ggml_tensor* context) {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
@ -201,11 +199,10 @@ public:
}
std::vector<struct ggml_tensor*> forward(struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* x,
struct ggml_tensor* hint,
struct ggml_tensor* guided_hint,
std::vector<float> timesteps,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y = NULL) {
// x: [N, in_channels, h, w] or [N, in_channels/2, h, w]
@ -231,7 +228,7 @@ public:
auto middle_block_out = std::dynamic_pointer_cast<Conv2d>(blocks["middle_block_out.0"]);
auto t_emb = new_timestep_embedding(ctx, allocr, timesteps, model_channels); // [N, model_channels]
auto t_emb = ggml_nn_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb);
@ -272,10 +269,10 @@ public:
for (int j = 0; j < num_res_blocks; j++) {
input_block_idx += 1;
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0";
h = resblock_forward(name, ctx, allocr, h, emb); // [N, mult*model_channels, h, w]
h = resblock_forward(name, ctx, h, emb); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, allocr, h, context); // [N, mult*model_channels, h, w]
h = attention_layer_forward(name, ctx, h, context); // [N, mult*model_channels, h, w]
}
auto zero_conv = std::dynamic_pointer_cast<Conv2d>(blocks["zero_convs." + std::to_string(input_block_idx) + ".0"]);
@ -299,9 +296,9 @@ public:
// [N, 4*model_channels, h/8, w/8]
// middle_block
h = resblock_forward("middle_block.0", ctx, allocr, h, emb); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, allocr, h, context); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, allocr, h, emb); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.0", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, h, context); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb); // [N, 4*model_channels, h/8, w/8]
// out
outs.push_back(middle_block_out->forward(ctx, h));
@ -386,18 +383,22 @@ struct ControlNet : public GGMLModule {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* hint,
std::vector<float> timesteps,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y = NULL) {
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, CONTROL_NET_GRAPH_SIZE, false);
x = to_backend(x);
hint = to_backend(hint);
context = to_backend(context);
y = to_backend(y);
x = to_backend(x);
if (guided_hint_cached) {
hint = NULL;
} else {
hint = to_backend(hint);
}
context = to_backend(context);
y = to_backend(y);
timesteps = to_backend(timesteps);
auto outs = control_net.forward(compute_ctx,
compute_allocr,
x,
hint,
guided_hint_cached ? guided_hint : NULL,
@ -420,7 +421,7 @@ struct ControlNet : public GGMLModule {
void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* hint,
std::vector<float> timesteps,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* y,
struct ggml_tensor** output = NULL,
@ -434,7 +435,6 @@ struct ControlNet : public GGMLModule {
};
GGMLModule::compute(get_graph, n_threads, false, output, output_ctx);
guided_hint_cached = true;
}

2
ggml

@ -1 +1 @@
Subproject commit 9a5ce3002474b3ac1dc2441e5c6b95ccef02cc78
Subproject commit 4212b7570a48e09b16939878314d83e919370a9a

View File

@ -606,6 +606,20 @@ __STATIC_INLINE__ float ggml_backend_tensor_get_f32(ggml_tensor* tensor) {
return value;
}
__STATIC_INLINE__ struct ggml_tensor* vector_to_ggml_tensor(struct ggml_context* ctx,
const std::vector<float>& vec) {
struct ggml_tensor* t = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, vec.size());
memcpy(t->data, (const void*)vec.data(), ggml_nbytes(t));
return t;
}
__STATIC_INLINE__ struct ggml_tensor* vector_to_ggml_tensor_i32(struct ggml_context* ctx,
const std::vector<int>& vec) {
struct ggml_tensor* t = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, vec.size());
memcpy(t->data, (const void*)vec.data(), ggml_nbytes(t));
return t;
}
__STATIC_INLINE__ std::vector<float> arange(float start, float end, float step = 1.f) {
std::vector<float> result;
@ -652,7 +666,6 @@ __STATIC_INLINE__ void set_timestep_embedding(std::vector<float> timesteps,
}
__STATIC_INLINE__ struct ggml_tensor* new_timestep_embedding(struct ggml_context* ctx,
struct ggml_allocr* allocr,
std::vector<float> timesteps,
int dim,
int max_period = 10000) {
@ -664,17 +677,22 @@ __STATIC_INLINE__ struct ggml_tensor* new_timestep_embedding(struct ggml_context
acutual_dim = dim + 1;
}
struct ggml_tensor* embedding = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, acutual_dim, timesteps.size());
if (allocr != NULL) {
ggml_allocr_alloc(allocr, embedding);
if (!ggml_allocr_is_measure(allocr)) {
ggml_backend_tensor_set(embedding, embedding_vec.data(), 0, ggml_nbytes(embedding));
}
} else {
if (embedding->data != NULL) {
memcpy(((char*)embedding->data), ((char*)embedding_vec.data()), ggml_nbytes(embedding));
} else {
ggml_backend_tensor_set(embedding, embedding_vec.data(), 0, ggml_nbytes(embedding));
}
return embedding;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_timestep_embedding(
struct ggml_context* ctx,
struct ggml_tensor* timesteps,
int dim,
int max_period = 10000) {
return ggml_timestep_embedding(ctx, timesteps, dim, max_period);
}
// struct GGMLComputeGraph {
// virtual void init(struct ggml_context* ctx, ggml_type wtype) = 0;
// virtual std::string get_desc() = 0;
@ -693,9 +711,10 @@ protected:
struct ggml_context* params_ctx = NULL;
ggml_backend_buffer_t params_buffer = NULL;
struct ggml_context* compute_ctx = NULL;
ggml_backend_buffer_t compute_buffer = NULL; // for compute
struct ggml_allocr* compute_allocr = NULL;
struct ggml_context* compute_ctx = NULL;
struct ggml_gallocr* compute_allocr = NULL;
std::map<struct ggml_tensor*, const void*> backend_tensor_data_map;
ggml_type wtype = GGML_TYPE_F32;
ggml_backend_t backend = NULL;
@ -734,23 +753,37 @@ protected:
}
}
void alloc_compute_buffer(get_graph_cb_t get_graph) {
// alignment required by the backend
compute_allocr = ggml_allocr_new_measure_from_backend(backend);
bool alloc_compute_buffer(get_graph_cb_t get_graph) {
if (compute_allocr != NULL) {
return true;
}
reset_compute_ctx();
struct ggml_cgraph* gf = get_graph();
backend_tensor_data_map.clear();
compute_allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
if (!ggml_gallocr_reserve(compute_allocr, gf)) {
// failed to allocate the compute buffer
LOG_ERROR("%s: failed to allocate the compute buffer\n", get_desc().c_str());
free_compute_buffer();
return false;
}
// compute the required memory
size_t compute_buffer_size = ggml_allocr_alloc_graph(compute_allocr, gf) + 1024 * 1024;
// recreate the allocator with the required memory
ggml_allocr_free(compute_allocr);
size_t compute_buffer_size = ggml_gallocr_get_buffer_size(compute_allocr, 0);
LOG_DEBUG("%s compute buffer size: %.2f MB", get_desc().c_str(), compute_buffer_size / 1024.0 / 1024.0);
return true;
}
compute_buffer = ggml_backend_alloc_buffer(backend, compute_buffer_size);
compute_allocr = ggml_allocr_new_from_buffer(compute_buffer);
void cpy_data_to_backend_tensor() {
for (auto& kv : backend_tensor_data_map) {
auto tensor = kv.first;
auto data = kv.second;
ggml_backend_tensor_set(tensor, data, 0, ggml_nbytes(tensor));
}
backend_tensor_data_map.clear();
}
public:
@ -775,31 +808,16 @@ public:
alloc_compute_ctx();
}
void reset_compute_allocr(get_graph_cb_t get_graph) {
if (compute_allocr != NULL) {
ggml_allocr_reset(compute_allocr);
} else {
alloc_compute_buffer(get_graph);
}
}
bool alloc_params_buffer() {
size_t params_buffer_size = 10 * 1024 * 1024; // 10 MB, for padding
params_buffer_size += get_params_mem_size();
size_t num_tensors = get_params_num();
params_buffer = ggml_backend_alloc_ctx_tensors(params_ctx, backend);
if (params_buffer == NULL) {
LOG_ERROR("%s alloc params backend buffer failed", get_desc().c_str());
return false;
}
size_t params_buffer_size = ggml_backend_buffer_get_size(params_buffer);
LOG_DEBUG("%s params backend buffer size = % 6.2f MB (%i tensors)",
get_desc().c_str(), params_buffer_size / (1024.0 * 1024.0), num_tensors);
params_buffer = ggml_backend_alloc_buffer(backend, params_buffer_size);
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
// alloc all tensors linked to params_ctx
for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) {
if (t->data == NULL) {
ggml_allocr_alloc(alloc, t);
}
}
ggml_allocr_free(alloc);
return true;
}
@ -812,13 +830,14 @@ public:
void free_compute_buffer() {
if (compute_allocr != NULL) {
ggml_allocr_free(compute_allocr);
ggml_gallocr_free(compute_allocr);
compute_allocr = NULL;
}
if (compute_buffer != NULL) {
ggml_backend_buffer_free(compute_buffer);
compute_buffer = NULL;
}
}
// do copy after alloc graph
void set_backend_tensor_data(struct ggml_tensor* tensor, const void* data) {
backend_tensor_data_map[tensor] = data;
}
struct ggml_tensor* to_backend(struct ggml_tensor* tensor) {
@ -827,15 +846,11 @@ public:
return NULL;
}
// it's performing a compute, check if backend isn't cpu
if (!ggml_backend_is_cpu(backend)) {
if (!ggml_backend_is_cpu(backend) && tensor->backend == GGML_BACKEND_CPU) {
// pass input tensors to gpu memory
auto backend_tensor = ggml_dup_tensor(compute_ctx, tensor);
ggml_allocr_alloc(compute_allocr, backend_tensor);
// pass data to device backend
if (!ggml_allocr_is_measure(compute_allocr)) {
ggml_backend_tensor_set(backend_tensor, tensor->data, 0, ggml_nbytes(tensor));
}
set_backend_tensor_data(backend_tensor, tensor->data);
return backend_tensor;
} else {
return tensor;
@ -847,11 +862,13 @@ public:
bool free_compute_buffer_immediately = true,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
reset_compute_allocr(get_graph);
alloc_compute_buffer(get_graph);
reset_compute_ctx();
struct ggml_cgraph* gf = get_graph();
ggml_allocr_alloc_graph(compute_allocr, gf);
GGML_ASSERT(ggml_gallocr_alloc_graph(compute_allocr, gf));
cpy_data_to_backend_tensor();
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, n_threads);

View File

@ -40,26 +40,31 @@ struct LoraModel : public GGMLModule {
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
return false;
}
alloc_params_buffer();
ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer);
bool dry_run = true;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
struct ggml_tensor* real = ggml_new_tensor(params_ctx, tensor_storage.type, tensor_storage.n_dims, tensor_storage.ne);
ggml_allocr_alloc(alloc, real);
*dst_tensor = real;
lora_tensors[name] = real;
if (dry_run) {
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
tensor_storage.type,
tensor_storage.n_dims,
tensor_storage.ne);
lora_tensors[name] = real;
} else {
auto real = lora_tensors[name];
*dst_tensor = real;
}
return true;
};
model_loader.load_tensors(on_new_tensor_cb, backend);
alloc_params_buffer();
dry_run = false;
model_loader.load_tensors(on_new_tensor_cb, backend);
LOG_DEBUG("finished loaded lora");
ggml_allocr_free(alloc);
return true;
}

View File

@ -498,7 +498,13 @@ void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
}
}
void convert_tensor(void* src, ggml_type src_type, void* dst, ggml_type dst_type, int n) {
void convert_tensor(void* src,
ggml_type src_type,
void* dst,
ggml_type dst_type,
int nrows,
int n_per_row) {
int n = nrows * n_per_row;
if (src_type == dst_type) {
size_t nbytes = n * ggml_type_size(src_type) / ggml_blck_size(src_type);
memcpy(((char*)dst), ((char*)src), nbytes);
@ -507,7 +513,9 @@ void convert_tensor(void* src, ggml_type src_type, void* dst, ggml_type dst_type
ggml_fp32_to_fp16_row((float*)src, (ggml_fp16_t*)dst, n);
} else {
int64_t hist[16];
ggml_quantize_chunk(dst_type, (float*)src, dst, 0, n, hist);
std::vector<float> imatrix(n_per_row, 1.0f); // dummy importance matrix
const float* im = imatrix.data();
ggml_quantize_chunk(dst_type, (float*)src, dst, 0, nrows, n_per_row, hist, im);
}
} else if (dst_type == GGML_TYPE_F32) {
if (src_type == GGML_TYPE_F16) {
@ -536,7 +544,9 @@ void convert_tensor(void* src, ggml_type src_type, void* dst, ggml_type dst_type
ggml_fp32_to_fp16_row((float*)src_data_f32, (ggml_fp16_t*)dst, n);
} else {
int64_t hist[16];
ggml_quantize_chunk(dst_type, (float*)src_data_f32, dst, 0, n, hist);
std::vector<float> imatrix(n_per_row, 1.0f); // dummy importance matrix
const float* im = imatrix.data();
ggml_quantize_chunk(dst_type, (float*)src_data_f32, dst, 0, nrows, n_per_row, hist, im);
}
}
}
@ -1387,7 +1397,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
}
convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
dst_tensor->type, (int)tensor_storage.nelements());
dst_tensor->type, (int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
}
} else {
read_buffer.resize(tensor_storage.nbytes());
@ -1406,7 +1416,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
convert_buffer.resize(ggml_nbytes(dst_tensor));
convert_tensor((void*)read_buffer.data(), tensor_storage.type,
(void*)convert_buffer.data(), dst_tensor->type,
(int)tensor_storage.nelements());
(int)tensor_storage.nelements() / (int)tensor_storage.ne[0], (int)tensor_storage.ne[0]);
ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor));
}
}

View File

@ -363,9 +363,10 @@ public:
struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1);
ggml_set_f32(c, 0.5);
std::vector<float> timesteps = {999.f}; // [N, ]
int64_t t0 = ggml_time_ms();
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1);
ggml_set_f32(timesteps, 999);
int64_t t0 = ggml_time_ms();
struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t);
diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, -1, {}, 0.f, &out);
diffusion_model->free_compute_buffer();
@ -456,9 +457,29 @@ public:
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, hidden_size]
struct ggml_tensor* pooled = NULL;
cond_stage_model->compute(n_threads, tokens, false, &hidden_states, work_ctx);
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
struct ggml_tensor* input_ids2 = NULL;
size_t max_token_idx = 0;
if (version == VERSION_XL) {
cond_stage_model->compute(n_threads, tokens, true, &pooled, work_ctx);
auto it = std::find(tokens.begin(), tokens.end(), EOS_TOKEN_ID);
if (it != tokens.end()) {
std::fill(std::next(it), tokens.end(), 0);
}
max_token_idx = std::min<size_t>(std::distance(tokens.begin(), it), tokens.size() - 1);
input_ids2 = vector_to_ggml_tensor_i32(work_ctx, tokens);
// for (int i = 0; i < tokens.size(); i++) {
// printf("%d ", tokens[i]);
// }
// printf("\n");
}
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, false, &hidden_states, work_ctx);
if (version == VERSION_XL) {
cond_stage_model->compute(n_threads, input_ids, input_ids2, max_token_idx, true, &pooled, work_ctx);
}
// if (pooled != NULL) {
// print_ggml_tensor(hidden_states);
@ -675,7 +696,8 @@ public:
}
float t = denoiser->schedule->sigma_to_t(sigma);
std::vector<float> timesteps(x->ne[3], t); // [N, ]
std::vector<float> timesteps_vec(x->ne[3], t); // [N, ]
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
copy_ggml_tensor(noised_input, input);
// noised_input = noised_input * c_in

View File

@ -73,6 +73,9 @@ enum sd_type_t {
SD_TYPE_Q8_K = 15,
SD_TYPE_IQ2_XXS = 16,
SD_TYPE_IQ2_XS = 17,
SD_TYPE_IQ3_XXS = 18,
SD_TYPE_IQ1_S = 19,
SD_TYPE_IQ4_NL = 20,
SD_TYPE_I8,
SD_TYPE_I16,
SD_TYPE_I32,

View File

@ -61,7 +61,6 @@ public:
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* x,
struct ggml_tensor* context,
int timesteps) {
@ -112,9 +111,9 @@ public:
x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, inner_dim]
x = ggml_reshape_3d(ctx, x, inner_dim, w * h, n); // [N, h * w, inner_dim]
std::vector<float> num_frames = arange(0, timesteps);
auto num_frames = ggml_arange(ctx, 0, timesteps, 1);
// since b is 1, no need to do repeat
auto t_emb = new_timestep_embedding(ctx, allocr, num_frames, in_channels, max_time_embed_period); // [N, in_channels]
auto t_emb = ggml_nn_timestep_embedding(ctx, num_frames, in_channels, max_time_embed_period); // [N, in_channels]
auto emb = time_pos_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb);
@ -342,7 +341,6 @@ public:
struct ggml_tensor* resblock_forward(std::string name,
struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* x,
struct ggml_tensor* emb,
int num_video_frames) {
@ -359,14 +357,13 @@ public:
struct ggml_tensor* attention_layer_forward(std::string name,
struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* x,
struct ggml_tensor* context,
int timesteps) {
if (version == VERSION_SVD) {
auto block = std::dynamic_pointer_cast<SpatialVideoTransformer>(blocks[name]);
return block->forward(ctx, allocr, x, context, timesteps);
return block->forward(ctx, x, context, timesteps);
} else {
auto block = std::dynamic_pointer_cast<SpatialTransformer>(blocks[name]);
@ -375,9 +372,8 @@ public:
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_allocr* allocr,
struct ggml_tensor* x,
std::vector<float> timesteps,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat = NULL,
struct ggml_tensor* y = NULL,
@ -386,7 +382,6 @@ public:
float control_strength = 0.f) {
// x: [N, in_channels, h, w] or [N, in_channels/2, h, w]
// timesteps: [N,]
// t_emb: [N, model_channels] timestep_embedding(timesteps, model_channels)
// context: [N, max_position, hidden_size] or [1, max_position, hidden_size]. for example, [N, 77, 768]
// c_concat: [N, in_channels, h, w] or [1, in_channels, h, w]
// y: [N, adm_in_channels] or [1, adm_in_channels]
@ -417,7 +412,7 @@ public:
auto out_0 = std::dynamic_pointer_cast<GroupNorm32>(blocks["out.0"]);
auto out_2 = std::dynamic_pointer_cast<Conv2d>(blocks["out.2"]);
auto t_emb = new_timestep_embedding(ctx, allocr, timesteps, model_channels); // [N, model_channels]
auto t_emb = ggml_nn_timestep_embedding(ctx, timesteps, model_channels); // [N, model_channels]
auto emb = time_embed_0->forward(ctx, t_emb);
emb = ggml_silu_inplace(ctx, emb);
@ -452,10 +447,10 @@ public:
for (int j = 0; j < num_res_blocks; j++) {
input_block_idx += 1;
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".0";
h = resblock_forward(name, ctx, allocr, h, emb, num_video_frames); // [N, mult*model_channels, h, w]
h = resblock_forward(name, ctx, h, emb, num_video_frames); // [N, mult*model_channels, h, w]
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "input_blocks." + std::to_string(input_block_idx) + ".1";
h = attention_layer_forward(name, ctx, allocr, h, context, num_video_frames); // [N, mult*model_channels, h, w]
h = attention_layer_forward(name, ctx, h, context, num_video_frames); // [N, mult*model_channels, h, w]
}
hs.push_back(h);
}
@ -473,9 +468,9 @@ public:
// [N, 4*model_channels, h/8, w/8]
// middle_block
h = resblock_forward("middle_block.0", ctx, allocr, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, allocr, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, allocr, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
if (controls.size() > 0) {
auto cs = ggml_scale_inplace(ctx, controls[controls.size() - 1], control_strength);
@ -500,13 +495,13 @@ public:
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".0";
h = resblock_forward(name, ctx, allocr, h, emb, num_video_frames);
h = resblock_forward(name, ctx, h, emb, num_video_frames);
int up_sample_idx = 1;
if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) {
std::string name = "output_blocks." + std::to_string(output_block_idx) + ".1";
h = attention_layer_forward(name, ctx, allocr, h, context, num_video_frames);
h = attention_layer_forward(name, ctx, h, context, num_video_frames);
up_sample_idx++;
}
@ -561,7 +556,7 @@ struct UNetModel : public GGMLModule {
}
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
std::vector<float> timesteps,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat = NULL,
struct ggml_tensor* y = NULL,
@ -574,16 +569,16 @@ struct UNetModel : public GGMLModule {
num_video_frames = x->ne[3];
}
x = to_backend(x);
context = to_backend(context);
y = to_backend(y);
x = to_backend(x);
context = to_backend(context);
y = to_backend(y);
timesteps = to_backend(timesteps);
for (int i = 0; i < controls.size(); i++) {
controls[i] = to_backend(controls[i]);
}
struct ggml_tensor* out = unet.forward(compute_ctx,
compute_allocr,
x,
timesteps,
context,
@ -600,7 +595,7 @@ struct UNetModel : public GGMLModule {
void compute(int n_threads,
struct ggml_tensor* x,
std::vector<float> timesteps,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
@ -638,7 +633,8 @@ struct UNetModel : public GGMLModule {
int num_video_frames = 3;
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 8, num_video_frames);
std::vector<float> timesteps(num_video_frames, 999.f);
std::vector<float> timesteps_vec(num_video_frames, 999.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
ggml_set_f32(x, 0.5f);
// print_ggml_tensor(x);