Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IP-Adapter support for StableDiffusionXLControlNetInpaintPipeline #6941

Merged
merged 3 commits into from
Feb 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
Expand Down Expand Up @@ -195,6 +201,8 @@ def __init__(
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
feature_extractor: Optional[CLIPImageProcessor] = None,
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
):
super().__init__()

Expand All @@ -210,6 +218,8 @@ def __init__(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
Expand Down Expand Up @@ -497,6 +507,66 @@ def encode_prompt(

return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype

if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)

return image_embeds, uncond_image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
Expand Down Expand Up @@ -566,6 +636,8 @@ def check_inputs(
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
Expand Down Expand Up @@ -752,6 +824,11 @@ def check_inputs(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")

if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)

def prepare_control_image(
self,
image,
Expand Down Expand Up @@ -1100,6 +1177,8 @@ def __call__(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
Expand Down Expand Up @@ -1194,6 +1273,10 @@ def __call__(
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
Expand Down Expand Up @@ -1326,6 +1409,8 @@ def __call__(
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
Expand Down Expand Up @@ -1378,6 +1463,12 @@ def __call__(
clip_skip=self.clip_skip,
)

# 3.1 Encode ip_adapter_image
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
)

# 4. set timesteps
def denoising_value_valid(dnv):
return isinstance(denoising_end, float) and 0 < dnv < 1
Expand Down Expand Up @@ -1649,6 +1740,9 @@ def denoising_value_valid(dnv):
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])

if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds

if num_channels_unet == 9:
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

Expand Down
Loading