From c6542939fcc26e6565f13302755ae9e93a9bef47 Mon Sep 17 00:00:00 2001 From: Ghassen Date: Wed, 5 Jun 2024 14:50:37 +0200 Subject: [PATCH] Change models and variables to float16 --- src/dot/gpen/face_model/face_gan.py | 4 +++- src/dot/gpen/face_model/model.py | 2 +- src/dot/gpen/retinaface/facemodels/retinaface.py | 2 +- src/dot/gpen/retinaface/retinaface_detection.py | 6 +++++- src/dot/simswap/fs_model.py | 2 ++ src/dot/simswap/models/base_model.py | 1 + src/dot/simswap/option.py | 9 +++++++-- src/dot/simswap/util/reverse2original.py | 2 +- src/dot/simswap/util/util.py | 3 ++- 9 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/dot/gpen/face_model/face_gan.py b/src/dot/gpen/face_model/face_gan.py index 23993ac..3a47391 100644 --- a/src/dot/gpen/face_model/face_gan.py +++ b/src/dot/gpen/face_model/face_gan.py @@ -52,12 +52,14 @@ def load_model(self, channel_multiplier=2, narrow=1, use_gpu=True): self.model.load_state_dict(pretrained_dict) self.model.eval() + self.model = self.model.half() + def process(self, img, use_gpu=True): img = cv2.resize(img, (self.resolution, self.resolution)) img_t = self.img2tensor(img, use_gpu) with torch.no_grad(): - out, __ = self.model(img_t) + out, __ = self.model(img_t.half()) out = self.tensor2img(out) diff --git a/src/dot/gpen/face_model/model.py b/src/dot/gpen/face_model/model.py index 4b6411d..1cee992 100644 --- a/src/dot/gpen/face_model/model.py +++ b/src/dot/gpen/face_model/model.py @@ -120,7 +120,7 @@ def __init__( def forward(self, input): out = F.conv2d( input, - self.weight * self.scale, + (self.weight * self.scale).half(), bias=self.bias, stride=self.stride, padding=self.padding, diff --git a/src/dot/gpen/retinaface/facemodels/retinaface.py b/src/dot/gpen/retinaface/facemodels/retinaface.py index 86a3b49..361e867 100644 --- a/src/dot/gpen/retinaface/facemodels/retinaface.py +++ b/src/dot/gpen/retinaface/facemodels/retinaface.py @@ -81,7 +81,7 @@ def __init__(self, cfg=None, phase="train"): import torchvision.models as models backbone = models.resnet50(pretrained=cfg["pretrain"]) - + # backbone = backbone.half() self.body = _utils.IntermediateLayerGetter(backbone, cfg["return_layers"]) in_channels_stage2 = cfg["in_channel"] in_channels_list = [ diff --git a/src/dot/gpen/retinaface/retinaface_detection.py b/src/dot/gpen/retinaface/retinaface_detection.py index a371003..da9a751 100644 --- a/src/dot/gpen/retinaface/retinaface_detection.py +++ b/src/dot/gpen/retinaface/retinaface_detection.py @@ -35,6 +35,8 @@ def __init__(self, base_dir, network="RetinaFace-R50", use_gpu=True): self.load_model(load_to_cpu=True) self.net = self.net.cpu() + self.net = self.net.half() + def check_keys(self, pretrained_state_dict): ckpt_keys = set(pretrained_state_dict.keys()) model_keys = set(self.net.state_dict().keys()) @@ -71,6 +73,8 @@ def load_model(self, load_to_cpu=False): self.net.load_state_dict(pretrained_dict, strict=False) self.net.eval() + self.net = self.net.half() + def detect( self, img_raw, @@ -96,7 +100,7 @@ def detect( img = img.cpu() scale = scale.cpu() - loc, conf, landms = self.net(img) # forward pass + loc, conf, landms = self.net(img.half()) # forward pass priorbox = PriorBox(self.cfg, image_size=(im_height, im_width)) priors = priorbox.forward() diff --git a/src/dot/simswap/fs_model.py b/src/dot/simswap/fs_model.py index 3a43654..7982aac 100644 --- a/src/dot/simswap/fs_model.py +++ b/src/dot/simswap/fs_model.py @@ -83,9 +83,11 @@ def initialize( self.netArc = netArc_checkpoint self.netArc = self.netArc.to(device) self.netArc.eval() + self.netArc = self.netArc.half() pretrained_path = "" self.load_network(self.netG, "G", opt_which_epoch, pretrained_path) + self.netG = self.netG.half() return def forward(self, img_id, img_att, latent_id, latent_att, for_G=False): diff --git a/src/dot/simswap/models/base_model.py b/src/dot/simswap/models/base_model.py index 2ad552d..739b787 100644 --- a/src/dot/simswap/models/base_model.py +++ b/src/dot/simswap/models/base_model.py @@ -102,6 +102,7 @@ def load_network(self, network, network_label, epoch_label, save_dir=""): print(sorted(not_initialized)) network.load_state_dict(model_dict) + network = network.half() def update_learning_rate(self): pass diff --git a/src/dot/simswap/option.py b/src/dot/simswap/option.py index 8a11969..8dcd5fa 100644 --- a/src/dot/simswap/option.py +++ b/src/dot/simswap/option.py @@ -94,6 +94,7 @@ def create_model( # type: ignore ) self.net.eval() + self.net = self.net.half() else: self.net = None @@ -141,7 +142,7 @@ def change_option(self, image: np.array, **kwargs) -> None: # create latent id img_id_downsample = F.interpolate(img_id, size=(112, 112)) - source_image = self.model.netArc(img_id_downsample) + source_image = self.model.netArc(img_id_downsample.half()) source_image = source_image.detach().to("cpu") source_image = source_image / np.linalg.norm( source_image, axis=1, keepdims=True @@ -186,7 +187,11 @@ def process_image(self, image: np.array, **kwargs) -> np.array: )[None, ...].cpu() swap_result = self.model( - None, frame_align_crop_tenor, self.source_image, None, True + None, + frame_align_crop_tenor.half(), + self.source_image.half(), + None, + True, )[0] swap_result_list.append(swap_result) frame_align_crop_tenor_list.append(frame_align_crop_tenor) diff --git a/src/dot/simswap/util/reverse2original.py b/src/dot/simswap/util/reverse2original.py index e77bc9b..7a9a15e 100644 --- a/src/dot/simswap/util/reverse2original.py +++ b/src/dot/simswap/util/reverse2original.py @@ -143,7 +143,7 @@ def reverse2wholeimage( if use_mask: source_img_norm = norm(source_img, use_gpu=use_gpu) source_img_512 = F.interpolate(source_img_norm, size=(512, 512)) - out = pasring_model(source_img_512)[0] + out = pasring_model(source_img_512.half())[0] parsing = out.squeeze(0).argmax(0) tgt_mask = encode_segmentation_rgb(parsing, device) diff --git a/src/dot/simswap/util/util.py b/src/dot/simswap/util/util.py index 67d03b5..5964b32 100644 --- a/src/dot/simswap/util/util.py +++ b/src/dot/simswap/util/util.py @@ -149,6 +149,7 @@ def load_parsing_model(path, use_mask, use_gpu): net.load_state_dict(torch.load(path, map_location=torch.device("cpu"))) net.eval() + net = net.half() return net else: return None @@ -180,7 +181,7 @@ def crop_align( # create latent id img_id_downsample = F.interpolate(img_id, size=(112, 112)) - id_vector = swap_model.netArc(img_id_downsample) + id_vector = swap_model.netArc(img_id_downsample.half()) id_vector = id_vector.detach().to("cpu") id_vector = id_vector / np.linalg.norm(id_vector, axis=1, keepdims=True)