From 54c5f523926a06616f66137a46fb5dee6ab74660 Mon Sep 17 00:00:00 2001 From: Vincent Date: Sun, 11 Feb 2024 23:54:17 +0700 Subject: [PATCH 1/2] add ip-adapter support --- .../pipeline_controlnet_inpaint_sd_xl.py | 78 ++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index eacaf8477f24..e69301d9530f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -19,11 +19,17 @@ import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -195,6 +201,8 @@ def __init__( requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, + feature_extractor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModelWithProjection] = None, ): super().__init__() @@ -210,6 +218,8 @@ def __init__( unet=unet, controlnet=controlnet, scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) @@ -497,6 +507,60 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -1100,6 +1164,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -1378,6 +1443,12 @@ def __call__( clip_skip=self.clip_skip, ) + # 3.1 Encode ip_adapter_image + if ip_adapter_image is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, device, batch_size * num_images_per_prompt + ) + # 4. set timesteps def denoising_value_valid(dnv): return isinstance(denoising_end, float) and 0 < dnv < 1 @@ -1649,6 +1720,9 @@ def denoising_value_valid(dnv): down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + if ip_adapter_image is not None: + added_cond_kwargs["image_embeds"] = image_embeds + if num_channels_unet == 9: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) From c4a2ee74f17be5266bc98d1ddb5f78fae89fb9ac Mon Sep 17 00:00:00 2001 From: Vincent Date: Mon, 12 Feb 2024 14:18:01 +0700 Subject: [PATCH 2/2] support ip image embeds --- .../pipeline_controlnet_inpaint_sd_xl.py | 68 ++++++++++++------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index e69301d9530f..ac657986326c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -533,32 +533,38 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds - def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt): - if not isinstance(ip_adapter_image, list): - ip_adapter_image = [ip_adapter_image] - - if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): - raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." - ) + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] - image_embeds = [] - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers - ): - output_hidden_state = not isinstance(image_proj_layer, ImageProjection) - single_image_embeds, single_negative_image_embeds = self.encode_image( - single_ip_adapter_image, device, 1, output_hidden_state - ) - single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) - single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0) + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) - if self.do_classifier_free_guidance: - single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) - single_image_embeds = single_image_embeds.to(device) + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) - image_embeds.append(single_image_embeds) + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds return image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -630,6 +636,8 @@ def check_inputs( negative_prompt_2=None, prompt_embeds=None, negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, pooled_prompt_embeds=None, negative_pooled_prompt_embeds=None, controlnet_conditioning_scale=1.0, @@ -816,6 +824,11 @@ def check_inputs( if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + def prepare_control_image( self, image, @@ -1165,6 +1178,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", @@ -1259,6 +1273,10 @@ def __call__( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. @@ -1391,6 +1409,8 @@ def __call__( negative_prompt_2, prompt_embeds, negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, controlnet_conditioning_scale, @@ -1444,9 +1464,9 @@ def __call__( ) # 3.1 Encode ip_adapter_image - if ip_adapter_image is not None: + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, device, batch_size * num_images_per_prompt + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt ) # 4. set timesteps