From cba5e1905e272047481c41033474792f00b6da7a Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 17 Jul 2019 10:37:31 +0800 Subject: [PATCH] fix(image_encoder): enable batching encoding --- gnes/encoder/image/base.py | 2 ++ gnes/encoder/image/inception.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/gnes/encoder/image/base.py b/gnes/encoder/image/base.py index 1ef32725..cd3f8c58 100644 --- a/gnes/encoder/image/base.py +++ b/gnes/encoder/image/base.py @@ -19,6 +19,7 @@ import numpy as np from ..base import BaseImageEncoder +from ...helper import batching class BasePytorchEncoder(BaseImageEncoder): @@ -72,6 +73,7 @@ def forward(self, x): self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self._model = self._model.to(self._device) + @batching def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray: import torch self._model.eval() diff --git a/gnes/encoder/image/inception.py b/gnes/encoder/image/inception.py index bfc29681..27712987 100644 --- a/gnes/encoder/image/inception.py +++ b/gnes/encoder/image/inception.py @@ -17,6 +17,7 @@ import numpy as np from gnes.helper import batch_iterator from ..base import BaseImageEncoder +from ...helper import batching from PIL import Image @@ -59,6 +60,7 @@ def post_init(self): self.saver = tf.train.Saver() self.saver.restore(self.sess, self.model_dir) + @batching def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray: ret = [] img = [(np.array(Image.fromarray(im).resize((self.inception_size_x,