mirror of
https://github.com/mudler/LocalAI
synced 2026-04-21 13:27:21 +00:00
233 lines
8.2 KiB
Python
233 lines
8.2 KiB
Python
|
|
# Adapted from tinygrad examples/stable_diffusion.py (MIT license).
|
||
|
|
# Upstream: https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py
|
||
|
|
# Copyright (c) 2023- the tinygrad authors
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
#
|
||
|
|
# Local modifications: removed the MLPerf training branch (pulls
|
||
|
|
# examples/mlperf/initializers which we don't vendor) and the __main__
|
||
|
|
# argparse / fetch / profile blocks. Kept the core classes so the LocalAI
|
||
|
|
# tinygrad backend can instantiate and drive Stable Diffusion v1.x from a
|
||
|
|
# single checkpoint path.
|
||
|
|
from collections import namedtuple
|
||
|
|
from typing import Any, Dict
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
from tinygrad import Tensor, dtypes
|
||
|
|
from tinygrad.nn import Conv2d, GroupNorm
|
||
|
|
|
||
|
|
from . import clip as clip_mod
|
||
|
|
from . import unet as unet_mod
|
||
|
|
from .clip import Closed, Tokenizer
|
||
|
|
from .unet import UNetModel
|
||
|
|
|
||
|
|
|
||
|
|
class AttnBlock:
|
||
|
|
def __init__(self, in_channels):
|
||
|
|
self.norm = GroupNorm(32, in_channels)
|
||
|
|
self.q = Conv2d(in_channels, in_channels, 1)
|
||
|
|
self.k = Conv2d(in_channels, in_channels, 1)
|
||
|
|
self.v = Conv2d(in_channels, in_channels, 1)
|
||
|
|
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
||
|
|
|
||
|
|
def __call__(self, x):
|
||
|
|
h_ = self.norm(x)
|
||
|
|
q, k, v = self.q(h_), self.k(h_), self.v(h_)
|
||
|
|
b, c, h, w = q.shape
|
||
|
|
q, k, v = [t.reshape(b, c, h * w).transpose(1, 2) for t in (q, k, v)]
|
||
|
|
h_ = Tensor.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(b, c, h, w)
|
||
|
|
return x + self.proj_out(h_)
|
||
|
|
|
||
|
|
|
||
|
|
class ResnetBlock:
|
||
|
|
def __init__(self, in_channels, out_channels=None):
|
||
|
|
self.norm1 = GroupNorm(32, in_channels)
|
||
|
|
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
|
||
|
|
self.norm2 = GroupNorm(32, out_channels)
|
||
|
|
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
|
||
|
|
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else (lambda x: x)
|
||
|
|
|
||
|
|
def __call__(self, x):
|
||
|
|
h = self.conv1(self.norm1(x).swish())
|
||
|
|
h = self.conv2(self.norm2(h).swish())
|
||
|
|
return self.nin_shortcut(x) + h
|
||
|
|
|
||
|
|
|
||
|
|
class Mid:
|
||
|
|
def __init__(self, block_in):
|
||
|
|
self.block_1 = ResnetBlock(block_in, block_in)
|
||
|
|
self.attn_1 = AttnBlock(block_in)
|
||
|
|
self.block_2 = ResnetBlock(block_in, block_in)
|
||
|
|
|
||
|
|
def __call__(self, x):
|
||
|
|
return x.sequential([self.block_1, self.attn_1, self.block_2])
|
||
|
|
|
||
|
|
|
||
|
|
class Decoder:
|
||
|
|
def __init__(self):
|
||
|
|
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
|
||
|
|
self.conv_in = Conv2d(4, 512, 3, padding=1)
|
||
|
|
self.mid = Mid(512)
|
||
|
|
|
||
|
|
arr = []
|
||
|
|
for i, s in enumerate(sz):
|
||
|
|
arr.append({"block": [ResnetBlock(s[1], s[0]), ResnetBlock(s[0], s[0]), ResnetBlock(s[0], s[0])]})
|
||
|
|
if i != 0:
|
||
|
|
arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
|
||
|
|
self.up = arr
|
||
|
|
|
||
|
|
self.norm_out = GroupNorm(32, 128)
|
||
|
|
self.conv_out = Conv2d(128, 3, 3, padding=1)
|
||
|
|
|
||
|
|
def __call__(self, x):
|
||
|
|
x = self.conv_in(x)
|
||
|
|
x = self.mid(x)
|
||
|
|
for l in self.up[::-1]:
|
||
|
|
for b in l['block']:
|
||
|
|
x = b(x)
|
||
|
|
if 'upsample' in l:
|
||
|
|
bs, c, py, px = x.shape
|
||
|
|
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py * 2, px * 2)
|
||
|
|
x = l['upsample']['conv'](x)
|
||
|
|
x.realize()
|
||
|
|
return self.conv_out(self.norm_out(x).swish())
|
||
|
|
|
||
|
|
|
||
|
|
class Encoder:
|
||
|
|
def __init__(self):
|
||
|
|
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
|
||
|
|
self.conv_in = Conv2d(3, 128, 3, padding=1)
|
||
|
|
|
||
|
|
arr = []
|
||
|
|
for i, s in enumerate(sz):
|
||
|
|
arr.append({"block": [ResnetBlock(s[0], s[1]), ResnetBlock(s[1], s[1])]})
|
||
|
|
if i != 3:
|
||
|
|
arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0, 1, 0, 1))}
|
||
|
|
self.down = arr
|
||
|
|
|
||
|
|
self.mid = Mid(512)
|
||
|
|
self.norm_out = GroupNorm(32, 512)
|
||
|
|
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
||
|
|
|
||
|
|
def __call__(self, x):
|
||
|
|
x = self.conv_in(x)
|
||
|
|
for l in self.down:
|
||
|
|
for b in l['block']:
|
||
|
|
x = b(x)
|
||
|
|
if 'downsample' in l:
|
||
|
|
x = l['downsample']['conv'](x)
|
||
|
|
x = self.mid(x)
|
||
|
|
return self.conv_out(self.norm_out(x).swish())
|
||
|
|
|
||
|
|
|
||
|
|
class AutoencoderKL:
|
||
|
|
def __init__(self):
|
||
|
|
self.encoder = Encoder()
|
||
|
|
self.decoder = Decoder()
|
||
|
|
self.quant_conv = Conv2d(8, 8, 1)
|
||
|
|
self.post_quant_conv = Conv2d(4, 4, 1)
|
||
|
|
|
||
|
|
def __call__(self, x):
|
||
|
|
latent = self.encoder(x)
|
||
|
|
latent = self.quant_conv(latent)
|
||
|
|
latent = latent[:, 0:4]
|
||
|
|
latent = self.post_quant_conv(latent)
|
||
|
|
return self.decoder(latent)
|
||
|
|
|
||
|
|
|
||
|
|
def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000):
|
||
|
|
betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2
|
||
|
|
alphas = 1.0 - betas
|
||
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||
|
|
return Tensor(alphas_cumprod)
|
||
|
|
|
||
|
|
|
||
|
|
# SD1.x UNet hyperparameters (same as upstream `unet_params`).
|
||
|
|
UNET_PARAMS_SD1: Dict[str, Any] = {
|
||
|
|
"adm_in_ch": None,
|
||
|
|
"in_ch": 4,
|
||
|
|
"out_ch": 4,
|
||
|
|
"model_ch": 320,
|
||
|
|
"attention_resolutions": [4, 2, 1],
|
||
|
|
"num_res_blocks": 2,
|
||
|
|
"channel_mult": [1, 2, 4, 4],
|
||
|
|
"n_heads": 8,
|
||
|
|
"transformer_depth": [1, 1, 1, 1],
|
||
|
|
"ctx_dim": 768,
|
||
|
|
"use_linear": False,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
class StableDiffusion:
|
||
|
|
"""Stable Diffusion 1.x pipeline, adapted from tinygrad's reference example.
|
||
|
|
|
||
|
|
Drives the native CompVis `sd-v1-*.ckpt` checkpoint format (the only one
|
||
|
|
the vendored weight layout handles). For HuggingFace safetensors pipelines
|
||
|
|
the caller is expected to download / merge the `.ckpt` equivalent before
|
||
|
|
calling LoadModel.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.alphas_cumprod = get_alphas_cumprod()
|
||
|
|
self.first_stage_model = AutoencoderKL()
|
||
|
|
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(
|
||
|
|
transformer=namedtuple("Transformer", ["text_model"])(text_model=Closed.ClipTextTransformer())
|
||
|
|
)
|
||
|
|
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(
|
||
|
|
diffusion_model=UNetModel(**UNET_PARAMS_SD1)
|
||
|
|
)
|
||
|
|
|
||
|
|
# DDIM update step.
|
||
|
|
def _update(self, x, e_t, a_t, a_prev):
|
||
|
|
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||
|
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||
|
|
dir_xt = (1.0 - a_prev).sqrt() * e_t
|
||
|
|
return a_prev.sqrt() * pred_x0 + dir_xt
|
||
|
|
|
||
|
|
def _model_output(self, uncond, cond, latent, timestep, guidance):
|
||
|
|
latents = self.model.diffusion_model(latent.expand(2, *latent.shape[1:]), timestep, uncond.cat(cond, dim=0))
|
||
|
|
uncond_latent, cond_latent = latents[0:1], latents[1:2]
|
||
|
|
return uncond_latent + guidance * (cond_latent - uncond_latent)
|
||
|
|
|
||
|
|
def step(self, uncond, cond, latent, timestep, a_t, a_prev, guidance):
|
||
|
|
e_t = self._model_output(uncond, cond, latent, timestep, guidance)
|
||
|
|
return self._update(latent, e_t, a_t, a_prev).realize()
|
||
|
|
|
||
|
|
def decode(self, x):
|
||
|
|
x = self.first_stage_model.post_quant_conv(1 / 0.18215 * x)
|
||
|
|
x = self.first_stage_model.decoder(x)
|
||
|
|
x = (x + 1.0) / 2.0
|
||
|
|
x = x.reshape(3, 512, 512).permute(1, 2, 0).clip(0, 1) * 255
|
||
|
|
return x.cast(dtypes.uint8)
|
||
|
|
|
||
|
|
def encode_prompt(self, tokenizer, prompt: str):
|
||
|
|
ids = Tensor([tokenizer.encode(prompt)])
|
||
|
|
return self.cond_stage_model.transformer.text_model(ids).realize()
|
||
|
|
|
||
|
|
|
||
|
|
def run_sd15(model: StableDiffusion, prompt: str, negative_prompt: str, steps: int, guidance: float, seed: int):
|
||
|
|
"""Generate a single 512x512 image. Returns a (512,512,3) uint8 tensor."""
|
||
|
|
tokenizer = Tokenizer.ClipTokenizer()
|
||
|
|
|
||
|
|
context = model.encode_prompt(tokenizer, prompt)
|
||
|
|
uncond = model.encode_prompt(tokenizer, negative_prompt)
|
||
|
|
|
||
|
|
timesteps = list(range(1, 1000, 1000 // steps))
|
||
|
|
alphas = model.alphas_cumprod[Tensor(timesteps)]
|
||
|
|
alphas_prev = Tensor([1.0]).cat(alphas[:-1])
|
||
|
|
|
||
|
|
if seed is not None:
|
||
|
|
Tensor.manual_seed(seed)
|
||
|
|
latent = Tensor.randn(1, 4, 64, 64)
|
||
|
|
|
||
|
|
for index in range(len(timesteps) - 1, -1, -1):
|
||
|
|
timestep = timesteps[index]
|
||
|
|
tid = Tensor([index])
|
||
|
|
latent = model.step(
|
||
|
|
uncond, context, latent,
|
||
|
|
Tensor([timestep]),
|
||
|
|
alphas[tid], alphas_prev[tid],
|
||
|
|
Tensor([guidance]),
|
||
|
|
)
|
||
|
|
|
||
|
|
return model.decode(latent).realize()
|