Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
fix(image_encoder): enable batching encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
felix committed Jul 17, 2019
1 parent 316c9db commit cba5e19
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
2 changes: 2 additions & 0 deletions gnes/encoder/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from ..base import BaseImageEncoder
from ...helper import batching


class BasePytorchEncoder(BaseImageEncoder):
Expand Down Expand Up @@ -72,6 +73,7 @@ def forward(self, x):
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = self._model.to(self._device)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
import torch
self._model.eval()
Expand Down
2 changes: 2 additions & 0 deletions gnes/encoder/image/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from gnes.helper import batch_iterator
from ..base import BaseImageEncoder
from ...helper import batching
from PIL import Image


Expand Down Expand Up @@ -59,6 +60,7 @@ def post_init(self):
self.saver = tf.train.Saver()
self.saver.restore(self.sess, self.model_dir)

@batching
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
img = [(np.array(Image.fromarray(im).resize((self.inception_size_x,
Expand Down

0 comments on commit cba5e19

Please sign in to comment.