-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconf.py
121 lines (104 loc) · 3.13 KB
/
conf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2020-present, Pietro Buzzega, Matteo Boschini, Angelo Porrello, Davide Abati, Simone Calderara.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
import numpy as np
from abc import abstractmethod
from argparse import Namespace
from torch import nn as nn
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from typing import Tuple
from torchvision import datasets
import numpy as np
def get_device() -> torch.device:
"""
Returns the GPU device if available else CPU.
"""
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def base_path() -> str:
"""
Returns the base bath where to log accuracies and tensorboard data.
"""
return './data/'
def set_random_seed(seed: int) -> None:
"""
Sets the seeds at a certain value.
:param seed: the value to be set
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
class ContinualDataset:
"""
Continual learning evaluation setting.
"""
NAME = None
SETTING = None
N_CLASSES_PER_TASK = None
N_TASKS = None
TRANSFORM = None
def __init__(self, args: Namespace) -> None:
"""
Initializes the train and test lists of dataloaders.
:param args: the arguments which contains the hyperparameters
"""
self.train_loader = None
self.test_loaders = []
self.i = 0
self.args = args
@abstractmethod
def get_data_loaders(self) -> Tuple[DataLoader, DataLoader]:
"""
Creates and returns the training and test loaders for the current task.
The current training loader and all test loaders are stored in self.
:return: the current training and test loaders
"""
pass
@abstractmethod
def not_aug_dataloader(self, batch_size: int) -> DataLoader:
"""
Returns the dataloader of the current task,
not applying data augmentation.
:param batch_size: the batch size of the loader
:return: the current training loader
"""
pass
@staticmethod
@abstractmethod
def get_backbone() -> nn.Module:
"""
Returns the backbone to be used for to the current dataset.
"""
pass
@staticmethod
@abstractmethod
def get_transform() -> transforms:
"""
Returns the transform to be used for to the current dataset.
"""
pass
@staticmethod
@abstractmethod
def get_loss() -> nn.functional:
"""
Returns the loss to be used for to the current dataset.
"""
pass
@staticmethod
@abstractmethod
def get_normalization_transform() -> transforms:
"""
Returns the transform used for normalizing the current dataset.
"""
pass
@staticmethod
@abstractmethod
def get_denormalization_transform() -> transforms:
"""
Returns the transform used for denormalizing the current dataset.
"""
pass