Skip to content

Commit

Permalink
Raise ModuleNotFoundError when unavailable (#290)
Browse files Browse the repository at this point in the history
* Raise ModuleNotFoundError when unavailable

* Raise ModuleNotFoundError when unavailable

* Follow pep8

* Raise ModuleNotFoundError when unavailable

* Remove unnecessary line in docstring

* Use importlib

* Remove unnecessary assignment

* Use importlib

* Rename sklearn.utils.shuffle to sk_shuffle for consistency

* Import importlib
  • Loading branch information
akihironitta authored Oct 22, 2020
1 parent c21bfe2 commit eee7684
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
19 changes: 13 additions & 6 deletions pl_bolts/datasets/imagenet_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gzip
import hashlib
import importlib
import os
import shutil
import tarfile
Expand All @@ -11,9 +12,10 @@
import torch
from torch._six import PY3

try:
from sklearn.utils import shuffle
except ModuleNotFoundError:
_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None
if _SKLEARN_AVAILABLE:
from sklearn.utils import shuffle as sk_shuffle
else:
warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')

Expand Down Expand Up @@ -72,8 +74,13 @@ def __init__(
super(ImageNet, self).__init__(self.split_folder, **kwargs)
self.root = root

if not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `shuffle` function from `scikit-learn` which is not installed yet.'
)

# shuffle images first
self.imgs = shuffle(self.imgs, random_state=1234)
self.imgs = sk_shuffle(self.imgs, random_state=1234)

# partition train set into [train, val]
if split == 'train':
Expand All @@ -98,7 +105,7 @@ def __init__(
# limit the number of classes
if num_classes != -1:
# choose the classes at random (but deterministic)
ok_classes = shuffle(list(range(num_classes)), random_state=1234)
ok_classes = sk_shuffle(list(range(num_classes)), random_state=1234)
ok_classes = ok_classes[:num_classes]
ok_classes = set(ok_classes)

Expand All @@ -110,7 +117,7 @@ def __init__(
self.imgs = clean_imgs

# shuffle again for final exit
self.imgs = shuffle(self.imgs, random_state=1234)
self.imgs = sk_shuffle(self.imgs, random_state=1234)

# list of class_nbs for each image
idcs = [idx for _, idx in self.imgs]
Expand Down
15 changes: 11 additions & 4 deletions pl_bolts/datasets/ssl_amdim_datasets.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from abc import ABC
import importlib
from typing import Callable, Optional
from warnings import warn

import numpy as np

try:
from sklearn.utils import shuffle
except ModuleNotFoundError:
_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None
if _SKLEARN_AVAILABLE:
from sklearn.utils import shuffle as sk_shuffle
else:
warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')

Expand Down Expand Up @@ -81,9 +83,14 @@ def select_nb_imgs_per_class(cls, examples, labels, nb_imgs_in_val):

@classmethod
def deterministic_shuffle(cls, x, y):
if not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `shuffle` function from `scikit-learn` which is not installed yet.'
)

n = len(x)
idxs = list(range(0, n))
idxs = shuffle(idxs, random_state=1234)
idxs = sk_shuffle(idxs, random_state=1234)

x = x[idxs]

Expand Down
12 changes: 9 additions & 3 deletions pl_bolts/utils/semi_supervised.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import importlib
import math
from warnings import warn

import numpy as np
import torch

try:
_SKLEARN_AVAILABLE = importlib.util.find_spec("sklearn") is not None
if _SKLEARN_AVAILABLE:
from sklearn.utils import shuffle as sk_shuffle
except ModuleNotFoundError:
else:
warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover
' install it with `pip install sklearn`.')

Expand Down Expand Up @@ -36,11 +38,15 @@ def balance_classes(X: np.ndarray, Y: list, batch_size: int):
Perfect balance
Args:
X: input features
Y: mixed labels (ints)
batch_size: the ultimate batch size
"""
if not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `shuffle` function from `scikit-learn` which is not installed yet.'
)

nb_classes = len(set(Y))

nb_batches = math.ceil(len(Y) / batch_size)
Expand Down

0 comments on commit eee7684

Please sign in to comment.