Initial commit

pull/2/head
leejet 2023-08-13 15:38:16 +08:00
commit 3aca342e60
23 changed files with 5312 additions and 0 deletions

2
.gitignore vendored 100644
View File

@ -0,0 +1,2 @@
build*/
test/

3
.gitmodules vendored 100644
View File

@ -0,0 +1,3 @@
[submodule "ggml"]
path = ggml
url = https://github.com/leejet/ggml.git

17
CMakeLists.txt 100644
View File

@ -0,0 +1,17 @@
cmake_minimum_required(VERSION 3.12)
project(stable-diffusion)
set(SD_LIB stable-diffusion)
set(SD_TARGET sd)
add_subdirectory(ggml)
add_library(${SD_LIB} stable-diffusion.h stable-diffusion.cpp)
add_executable(${SD_TARGET} main.cpp stb_image_write.h)
target_link_libraries(${SD_LIB} PUBLIC ggml)
target_link_libraries(${SD_TARGET} ${SD_LIB})
target_compile_features(${SD_TARGET} PUBLIC cxx_std_11)
set(CMAKE_CXX_STANDARD 11)

21
LICENSE 100644
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 leejet
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

138
README.md 100644
View File

@ -0,0 +1,138 @@
<p align="center">
<img src="./assets/a%20lovely%20cat.png" width="256x">
</p>
# stable-diffusion.cpp
Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in pure C/C++
## Features
- Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp)
- 16-bit, 32-bit float support
- 4-bit, 5-bit and 8-bit integer quantization support
- Accelerated memory-efficient CPU inference
- AVX, AVX2 and AVX512 support for x86 architectures
- Original `txt2img` mode
- Negative prompt
- Sampling method
- `Euler A`
- Supported platforms
- Linux
- Mac OS
- Windows
### TODO
- [ ] Original `img2img` mode
- [ ] More sampling methods
- [ ] GPU support
- [ ] Make inference faster
- The current implementation of ggml_conv_2d is slow and has high memory usage
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
- [ ] [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer
- [ ] LoRA support
## Usage
### Get the Code
```
git clone --recursive https://github.com/leejet/stable-diffusion.cpp
cd stable-diffusion.cpp
```
### Convert weights
- download original weights(.ckpt or .safetensors). For example
- Stable Diffusion v1.4 from https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
- Stable Diffusion v1.5 from https://huggingface.co/runwayml/stable-diffusion-v1-5
```shell
curl -L -O https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt
# curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
```
- convert weights to ggml model format
```shell
cd models
pip install -r requirements.txt
python convert.py [path to weights] --out_type [output precision]
# For example, python convert.py sd-v1-4.ckpt --out_type f16
```
### Quantization
You can specify the output model format using the --out_type parameter
- `f16` for 16-bit floating-point
- `f32` for 32-bit floating-point
- `q8_0` for 8-bit integer quantization
- `q5_0` or `q5_1` for 5-bit integer quantization
- `q4_0` or `q4_1` for 4-bit integer quantization
### Build
```shell
mkdir build
cd build
cmake ..
cmake --build . --config Release
```
#### Using OpenBLAS
```
cmake .. -DGGML_OPENBLAS=ON
cmake --build . --config Release
```
### Run
```
usage: ./sd [arguments]
arguments:
-h, --help show this help message and exit
-t, --threads N number of threads to use during computation (default: -1).
If threads <= 0, then threads will be set to the number of CPU cores
-m, --model [MODEL] path to model
-o, --output OUTPUT path to write result image to (default: .\output.png)
-p, --prompt [PROMPT] the prompt to render
-n, --negative-prompt PROMPT the negative prompt (default: "")
--cfg-scale SCALE unconditional guidance scale: (default: 7.0)
-H, --height H image height, in pixel space (default: 512)
-W, --width W image width, in pixel space (default: 512)
--sample-method SAMPLE_METHOD sample method (default: "eular a")
--steps STEPS number of sample steps (default: 20)
-s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
-v, --verbose print extra info
```
For example
```
./sd -m ../models/sd-v1-4-ggml-model-f16.bin -p "a lovely cat"
```
Using formats of different precisions will yield results of varying quality.
| f32 | f16 |q8_0 |q5_0 |q5_1 |q4_0 |q4_1 |
| ---- |---- |---- |---- |---- |---- |---- |
| ![](./assets/f32.png) |![](./assets/f16.png) |![](./assets/q8_0.png) |![](./assets/q5_0.png) |![](./assets/q5_1.png) |![](./assets/q4_0.png) |![](./assets/q4_1.png) |
## Memory/Disk Requirements
| precision | f32 | f16 |q8_0 |q5_0 |q5_1 |q4_0 |q4_1 |
| ---- | ---- |---- |---- |---- |---- |---- |---- |
| **Disk** | 2.8G | 2.0G | 1.7G | 1.6G | 1.6G | 1.5G | 1.5G |
| **Memory**(txt2img - 512 x 512) | ~4.9G | ~4.1G | ~3.8G | ~3.7G | ~3.7G | ~3.6G | ~3.6G |
## References
- [ggml](https://github.com/ggerganov/ggml)
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui)
- [k-diffusion](https://github.com/crowsonkb/k-diffusion)

Binary file not shown.

After

Width:  |  Height:  |  Size: 679 KiB

BIN
assets/f16.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 679 KiB

BIN
assets/f32.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 679 KiB

BIN
assets/q4_0.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 649 KiB

BIN
assets/q4_1.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 665 KiB

BIN
assets/q5_0.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 666 KiB

BIN
assets/q5_1.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 658 KiB

BIN
assets/q8_0.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 658 KiB

1
ggml 160000

@ -0,0 +1 @@
Subproject commit 9519dcd3ce463d323926017f12f9bf69e2b5d459

213
main.cpp 100644
View File

@ -0,0 +1,213 @@
#include <stdio.h>
#include <fstream>
#include <iostream>
#include <random>
#include <string>
#include <thread>
#include "stable-diffusion.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
struct Option {
int n_threads = -1;
std::string model_path;
std::string output_path = "output.png";
std::string prompt;
std::string negative_prompt;
float cfg_scale = 7.0f;
int w = 512;
int h = 512;
SampleMethod sample_method = EULAR_A;
int sample_steps = 20;
int seed = 42;
bool verbose = false;
void print() {
printf("Option: \n");
printf(" n_threads: %d\n", n_threads);
printf(" model_path: %s\n", model_path.c_str());
printf(" output_path: %s\n", output_path.c_str());
printf(" prompt: %s\n", prompt.c_str());
printf(" negative_prompt: %s\n", negative_prompt.c_str());
printf(" cfg_scale: %.2f\n", cfg_scale);
printf(" width: %d\n", w);
printf(" height: %d\n", h);
printf(" sample_method: %s\n", "eular a");
printf(" sample_steps: %d\n", sample_steps);
printf(" seed: %d\n", seed);
}
};
void print_usage(int argc, const char* argv[]) {
printf("usage: %s [arguments]\n", argv[0]);
printf("\n");
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -t, --threads N number of threads to use during computation (default: -1).\n");
printf(" If threads <= 0, then threads will be set to the number of CPU cores\n");
printf(" -m, --model [MODEL] path to model\n");
printf(" -o, --output OUTPUT path to write result image to (default: .\\output.png)\n");
printf(" -p, --prompt [PROMPT] the prompt to render\n");
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
printf(" -H, --height H image height, in pixel space (default: 512)\n");
printf(" -W, --width W image width, in pixel space (default: 512)\n");
printf(" --sample-method SAMPLE_METHOD sample method (default: \"eular a\")\n");
printf(" --steps STEPS number of sample steps (default: 20)\n");
printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n");
printf(" -v, --verbose print extra info\n");
}
void parse_args(int argc, const char* argv[], Option* opt) {
bool invalid_arg = false;
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->n_threads = std::stoi(argv[i]);
} else if (arg == "-m" || arg == "--model") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->model_path = argv[i];
} else if (arg == "-o" || arg == "--output") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->output_path = argv[i];
} else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->prompt = argv[i];
} else if (arg == "-n" || arg == "--negative-prompt") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->negative_prompt = argv[i];
} else if (arg == "--cfg-scale") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->cfg_scale = std::stof(argv[i]);
} else if (arg == "-H" || arg == "--height") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->h = std::stoi(argv[i]);
} else if (arg == "-W" || arg == "--width") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->w = std::stoi(argv[i]);
} else if (arg == "--steps") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->sample_steps = std::stoi(argv[i]);
} else if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
invalid_arg = true;
break;
}
opt->seed = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv);
exit(0);
} else if (arg == "-v" || arg == "--verbose") {
opt->verbose = true;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
}
if (invalid_arg) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
print_usage(argc, argv);
exit(1);
}
}
if (opt->n_threads <= 0) {
opt->n_threads = std::thread::hardware_concurrency();
}
if (opt->prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
exit(1);
}
if (opt->model_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: model_path\n");
print_usage(argc, argv);
exit(1);
}
if (opt->output_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: output_path\n");
print_usage(argc, argv);
exit(1);
}
if (opt->w <= 0 || opt->w % 32 != 0) {
fprintf(stderr, "error: the width must be a multiple of 32\n");
exit(1);
}
if (opt->h <= 0 || opt->h % 32 != 0) {
fprintf(stderr, "error: the height must be a multiple of 32\n");
exit(1);
}
if (opt->sample_steps <= 0) {
fprintf(stderr, "error: the sample_steps must be greater than 0\n");
exit(1);
}
}
int main(int argc, const char* argv[]) {
Option opt;
parse_args(argc, argv, &opt);
if (opt.verbose) {
opt.print();
printf("%s", sd_get_system_info().c_str());
set_sd_log_level(SDLogLevel::DEBUG);
}
StableDiffusion sd(opt.n_threads);
if (!sd.load_from_file(opt.model_path)) {
return 1;
}
std::vector<uint8_t> img = sd.txt2img(opt.prompt,
opt.negative_prompt,
opt.cfg_scale,
opt.w,
opt.h,
opt.sample_method,
opt.sample_steps,
opt.seed);
stbi_write_png(opt.output_path.c_str(), opt.w, opt.h, 3, img.data(), 0);
printf("save result image to '%s'\n", opt.output_path.c_str());
return 0;
}

4
models/.gitignore vendored 100644
View File

@ -0,0 +1,4 @@
*.bin
*.ckpt
*.safetensor
*.log

26
models/README.md 100644
View File

@ -0,0 +1,26 @@
# Model Convert Script
## Requirements
- vocab.json, from https://huggingface.co/openai/clip-vit-large-patch14/raw/main/vocab.json
```shell
pip install -r requirements.txt
```
## Usage
```
usage: convert.py [-h] [--out_type {f32,f16,q4_0,q4_1,q5_0,q5_1,q8_0}] [--out_file OUT_FILE] model_path
Convert Stable Diffuison model to GGML compatible file format
positional arguments:
model_path model file path (*.pth, *.pt, *.ckpt, *.safetensors)
options:
-h, --help show this help message and exit
--out_type {f32,f16,q4_0,q4_1,q5_0,q5_1,q8_0}
output format (default: based on input)
--out_file OUT_FILE path to write to; default: based on input and current working directory
```

264
models/convert.py 100644
View File

@ -0,0 +1,264 @@
import struct
import json
import os
import numpy as np
import torch
import safetensors.torch
this_file_dir = os.path.dirname(__file__)
vocab_dir = this_file_dir
ggml_ftype_str_to_int = {
"f32": 0,
"f16": 1,
"q4_0": 2,
"q4_1": 3,
"q5_0": 8,
"q5_1": 9,
"q8_0": 7
}
ggml_ttype_str_to_int = {
"f32": 0,
"f16": 1,
"q4_0": 2,
"q4_1": 3,
"q5_0": 6,
"q5_1": 7,
"q8_0": 8
}
QK4_0 = 32
def quantize_q4_0(x):
assert x.shape[-1] % QK4_0 == 0
x = x.reshape(-1, QK4_0)
max = np.take_along_axis(x, np.argmax(np.abs(x), axis=-1)[:, np.newaxis], axis=-1)
d = max / -8
qs = ((x / d) + 8).round().clip(min=0, max=15).astype(np.int8)
half = QK4_0 // 2
qs = qs[:, :half] | (qs[:, half:] << 4)
d = d.astype(np.float16).view(np.int8)
y = np.concatenate((d, qs), axis=-1)
return y
QK4_1 = 32
def quantize_q4_1(x):
assert x.shape[-1] % QK4_1 == 0
x = x.reshape(-1, QK4_1)
min = np.min(x, axis=-1, keepdims=True)
max = np.max(x, axis=-1, keepdims=True)
d = (max - min) / ((1 << 4) - 1)
qs = ((x - min) / d).round().clip(min=0, max=15).astype(np.int8)
half = QK4_1 // 2
qs = qs[:, :half] | (qs[:, half:] << 4)
d = d.astype(np.float16).view(np.int8)
m = min.astype(np.float16).view(np.int8)
y = np.concatenate((d, m, qs), axis=-1)
return y
QK5_0 = 32
def quantize_q5_0(x):
assert x.shape[1] % QK5_0 == 0
x = x.reshape(-1, QK5_0)
max = np.take_along_axis(x, np.argmax(np.abs(x), axis=-1)[:, np.newaxis], axis=-1)
d = max / -16
xi = ((x / d) + 16).round().clip(min=0, max=31).astype(np.int8)
half = QK5_0 // 2
qs = (xi[:, :half] & 0x0F) | (xi[:, half:] << 4)
qh = np.zeros(qs.shape[:-1], dtype=np.int32)
for i in range(QK5_0):
qh |= ((xi[:, i] & 0x10) >> 4).astype(np.int32) << i
d = d.astype(np.float16).view(np.int8)
qh = qh[..., np.newaxis].view(np.int8)
y = np.concatenate((d, qh, qs), axis=-1)
return y
QK5_1 = 32
def quantize_q5_1(x):
assert x.shape[-1] % QK5_1 == 0
x = x.reshape(-1, QK5_1)
min = np.min(x, axis=-1, keepdims=True)
max = np.max(x, axis=-1, keepdims=True)
d = (max - min) / ((1 << 5) - 1)
xi = ((x - min) / d).round().clip(min=0, max=31).astype(np.int8)
half = QK5_1//2
qs = (xi[:, :half] & 0x0F) | (xi[:, half:] << 4)
qh = np.zeros(xi.shape[:-1], dtype=np.int32)
for i in range(QK5_1):
qh |= ((xi[:, i] & 0x10) >> 4).astype(np.int32) << i
d = d.astype(np.float16).view(np.int8)
m = min.astype(np.float16).view(np.int8)
qh = qh[..., np.newaxis].view(np.int8)
ndarray = np.concatenate((d, m, qh, qs), axis=-1)
return ndarray
QK8_0 = 32
def quantize_q8_0(x):
assert x.shape[-1] % QK8_0 == 0
x = x.reshape(-1, QK8_0)
amax = np.max(np.abs(x), axis=-1, keepdims=True)
d = amax / ((1 << 7) - 1)
qs = (x / d).round().clip(min=-128, max=127).astype(np.int8)
d = d.astype(np.float16).view(np.int8)
y = np.concatenate((d, qs), axis=-1)
return y
# copy from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py#L16
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def load_model_from_file(model_path):
print("loading model from {}".format(model_path))
if model_path.lower().endswith(".safetensors"):
pl_sd = safetensors.torch.load_file(model_path, device="cpu")
else:
pl_sd = torch.load(model_path, map_location="cpu")
state_dict = pl_sd["state_dict"] if "state_dict" in pl_sd else pl_sd
print("loading model from {} completed".format(model_path))
return state_dict
def get_alpha_comprod(linear_start=0.00085, linear_end=0.0120, timesteps=1000):
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=torch.float32) ** 2
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas.numpy(), axis=0)
return torch.tensor(alphas_cumprod)
unused_tensors = [
"betas",
"alphas_cumprod_prev",
"sqrt_alphas_cumprod",
"sqrt_one_minus_alphas_cumprod",
"log_one_minus_alphas_cumprod",
"sqrt_recip_alphas_cumprod",
"sqrt_recipm1_alphas_cumprod",
"posterior_variance",
"posterior_log_variance_clipped",
"posterior_mean_coef1",
"posterior_mean_coef2",
"cond_stage_model.transformer.text_model.embeddings.position_ids",
"model_ema.decay",
"model_ema.num_updates"
]
def convert(model_path, out_type = None, out_file=None):
# load model
with open(os.path.join(vocab_dir, "vocab.json"), encoding="utf-8") as f:
clip_vocab = json.load(f)
state_dict = load_model_from_file(model_path)
alphas_cumprod = state_dict.get("alphas_cumprod")
if alphas_cumprod != None:
# print((np.abs(get_alpha_comprod().numpy() - alphas_cumprod.numpy()) < 0.000001).all())
pass
else:
print("no alphas_cumprod in file, generate new one")
alphas_cumprod = get_alpha_comprod()
state_dict["alphas_cumprod"] = alphas_cumprod
# output option
if out_type == None:
weight = state_dict["cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight"].numpy()
if weight.dtype == np.float32:
out_type = "f32"
elif weight.dtype == np.float16:
out_type = "f16"
if out_file == None:
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
out_file = os.path.join(os.getcwd(), out_file)
print(f"Saving GGML compatible file to {out_file}")
# convert and save
with open(out_file, "wb") as file:
# magic: ggml in hex
file.write(struct.pack("i", 0x67676D6C))
# out type
file.write(struct.pack("i", ggml_ftype_str_to_int[out_type]))
# vocab
byte_encoder = bytes_to_unicode()
byte_decoder = {v: k for k, v in byte_encoder.items()}
file.write(struct.pack("i", len(clip_vocab)))
for key in clip_vocab:
text = bytearray([byte_decoder[c] for c in key])
file.write(struct.pack("i", len(text)))
file.write(text)
# weights
for name in state_dict.keys():
if not isinstance(state_dict[name], torch.Tensor):
continue
if name in unused_tensors:
continue
data = state_dict[name].numpy()
n_dims = len(data.shape)
shape = data.shape
old_type = data.dtype
ttype = "f32"
if n_dims == 4:
data = data.astype(np.float16)
ttype = "f16"
elif n_dims == 2 and name[-7:] == ".weight":
if out_type == "f32":
data = data.astype(np.float32)
elif out_type == "f16":
data = data.astype(np.float16)
elif out_type == "q4_0":
data = quantize_q4_0(data)
elif out_type == "q4_1":
data = quantize_q4_1(data)
elif out_type == "q5_0":
data = quantize_q5_0(data)
elif out_type == "q5_1":
data = quantize_q5_1(data)
elif out_type == "q8_0":
data = quantize_q8_0(data)
else:
raise Exception("invalid out_type {}".format(out_type))
ttype = out_type
else:
data = data.astype(np.float32)
ttype = "f32"
print("Processing tensor: {} with shape {}, {} -> {}".format(name, data.shape, old_type, ttype))
# header
name_bytes = name.encode("utf-8")
file.write(struct.pack("iii", n_dims, len(name_bytes), ggml_ttype_str_to_int[ttype]))
for i in range(n_dims):
file.write(struct.pack("i", shape[n_dims - 1 - i]))
file.write(name_bytes)
# data
data.tofile(file)
print("Convert done")
print(f"Saved GGML compatible file to {out_file}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Convert Stable Diffuison model to GGML compatible file format")
parser.add_argument("--out_type", choices=["f32", "f16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0"], help="output format (default: based on input)")
parser.add_argument("--out_file", help="path to write to; default: based on input and current working directory")
parser.add_argument("model_path", help="model file path (*.pth, *.pt, *.ckpt, *.safetensors)")
args = parser.parse_args()
convert(args.model_path, args.out_type, args.out_file)

View File

@ -0,0 +1,4 @@
numpy
torch
safetensors
pytorch_lightning

File diff suppressed because one or more lines are too long

2853
stable-diffusion.cpp 100644

File diff suppressed because it is too large Load Diff

41
stable-diffusion.h 100644
View File

@ -0,0 +1,41 @@
#ifndef __STABLE_DIFFUSION_H__
#define __STABLE_DIFFUSION_H__
#include <memory>
#include <vector>
enum class SDLogLevel {
DEBUG,
INFO,
WARN,
ERROR
};
enum SampleMethod {
EULAR_A,
};
class StableDiffusionGGML;
class StableDiffusion {
private:
std::shared_ptr<StableDiffusionGGML> sd;
public:
StableDiffusion(int n_threads = -1);
bool load_from_file(const std::string& file_path);
std::vector<uint8_t> txt2img(
const std::string& prompt,
const std::string& negative_prompt,
float cfg_scale,
int width,
int height,
SampleMethod sample_method,
int sample_steps,
int seed);
};
void set_sd_log_level(SDLogLevel level);
std::string sd_get_system_info();
#endif // __STABLE_DIFFUSION_H__

1724
stb_image_write.h 100644

File diff suppressed because it is too large Load Diff