Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add K-LMS scheduler from k-diffusion #185

Merged
merged 11 commits into from
Aug 16, 2022
8 changes: 7 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .utils import is_inflect_available, is_scipy_available, is_transformers_available, is_unidecode_available


__version__ = "0.1.3"
Expand All @@ -27,11 +27,17 @@
SchedulerMixin,
ScoreSdeVeScheduler,
)


if is_scipy_available():
from .schedulers import LMSDiscreteScheduler

from .training_utils import EMAModel


if is_transformers_available():
from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline


else:
from .utils.dummy_transformers_objects import *
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler


class StableDiffusionPipeline(DiffusionPipeline):
Expand All @@ -18,7 +18,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
):
super().__init__()
scheduler = scheduler.set_format("pt")
Expand Down Expand Up @@ -105,9 +105,16 @@ def __call__(
if accepts_eta:
extra_step_kwargs["eta"] = eta

for t in tqdm(self.scheduler.timesteps):
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]

for i, t in tqdm(enumerate(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
Expand All @@ -118,7 +125,10 @@ def __call__(
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we maybe pass the sigma here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok we need it anyway - good to leave as is for me!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The i is needed inside the step for other calculations too, and we can't reverse the sigma into i, so we'll need it like this until a more full refactor.

else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_scipy_available
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_karras_ve import KarrasVeScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin


if is_scipy_available():
from .scheduling_lms_discrete import LMSDiscreteScheduler
134 changes: 134 additions & 0 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

import numpy as np
import torch

from scipy import integrate

from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin


class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
tensor_format="pt",
):
"""
Linear Multistep Scheduler for discrete beta schedules.
Based on the original k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L181
"""

if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)

self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.derivatives = []

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)

def get_lms_coefficient(self, order, t, current_order):
"""
Compute a linear multistep coefficient
"""

def lms_derivative(tau):
prod = 1.0
for k in range(order):
if current_order == k:
continue
prod *= (tau - self.sigmas[t - k]) / (self.sigmas[t - current_order] - self.sigmas[t - k])
return prod

integrated_coeff = integrate.quad(lms_derivative, self.sigmas[t], self.sigmas[t + 1], epsrel=1e-4)[0]

return integrated_coeff

def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)

low_idx = np.floor(self.timesteps).astype(int)
high_idx = np.ceil(self.timesteps).astype(int)
frac = np.mod(self.timesteps, 1.0)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
self.sigmas = np.concatenate([sigmas, [0.0]])

self.derivatives = []

self.set_format(tensor_format=self.tensor_format)

def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
):
sigma = self.sigmas[timestep]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output

# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
self.derivatives.append(derivative)
if len(self.derivatives) > order:
self.derivatives.pop(0)

# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]

# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
)

return {"prev_sample": prev_sample}

def add_noise(self, original_samples, noise, timesteps):
alpha_prod = self.alphas_cumprod[timesteps]
alpha_prod = self.match_shape(alpha_prod, original_samples)

noisy_samples = (alpha_prod**0.5) * original_samples + ((1 - alpha_prod) ** 0.5) * noise
return noisy_samples

def __len__(self):
return self.config.num_train_timesteps
19 changes: 19 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@
_modelcards_available = False


_scipy_available = importlib.util.find_spec("scipy") is not None
try:
_scipy_version = importlib_metadata.version("scipy")
logger.debug(f"Successfully imported transformers version {_scipy_version}")
except importlib_metadata.PackageNotFoundError:
_scipy_available = False


def is_transformers_available():
return _transformers_available

Expand All @@ -85,6 +93,10 @@ def is_modelcards_available():
return _modelcards_available


def is_scipy_available():
return _scipy_available


class RepositoryNotFoundError(HTTPError):
"""
Raised when trying to access a hf.co URL with an invalid repository name, or with a private repo name the user does
Expand Down Expand Up @@ -118,11 +130,18 @@ class RevisionNotFoundError(HTTPError):
"""


SCIPY_IMPORT_ERROR = """
{0} requires the scipy library but it was not found in your environment. You can install it with pip: `pip install
scipy`
"""


BACKENDS_MAPPING = OrderedDict(
[
("transformers", (is_transformers_available, TRANSFORMERS_IMPORT_ERROR)),
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("inflect", (is_inflect_available, INFLECT_IMPORT_ERROR)),
("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)),
]
)

Expand Down
24 changes: 24 additions & 0 deletions src/diffusers/utils/dummy_scipy_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
# flake8: noqa
from ..utils import DummyObject, requires_backends


class LmsDiscreteScheduler(metaclass=DummyObject):
_backends = ["scipy"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])


class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["scipy"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])


class StableDiffusionPipeline(metaclass=DummyObject):
_backends = ["scipy"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["scipy"])
24 changes: 22 additions & 2 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
KarrasVeScheduler,
LDMPipeline,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
Expand Down Expand Up @@ -841,7 +842,7 @@ def test_ldm_text2img_fast(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion(self):
# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")
Expand All @@ -862,7 +863,7 @@ def test_stable_diffusion(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU")
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers")

Expand Down Expand Up @@ -977,3 +978,22 @@ def test_karras_ve_pipeline(self):
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1-diffusers"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
pipe.scheduler = scheduler

prompt = "a photograph of an astronaut riding a horse"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe([prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy")[
"sample"
]

image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.9077, 0.9254, 0.9181, 0.9227, 0.9213, 0.9367, 0.9399, 0.9406, 0.9024])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2