Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mobilenet itag config #276

Merged
merged 5 commits into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions configs/classification/itag/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
_base_ = '../imagenet/common/classification_base.py'

# oss_io_config = dict(ak_id='', # your oss ak id
# ak_secret='', # your oss ak secret
# hosts='', # your oss hosts
# buckets=[]) # your oss bucket name

class_list = ['label1', 'label2',
'label3'] # replace with your true lables of itag manifest file
num_classes = 3
# model settings
model = dict(
type='Classification',
backbone=dict(type='MobileNetV2'),
head=dict(
type='ClsHead',
with_avg_pool=True,
in_channels=1280,
loss_config=dict(
type='CrossEntropyLossWithLabelSmooth',
label_smooth=0,
),
num_classes=num_classes))

data_train_list = '/your/itag/train/file.manifest' # or oss://your_bucket/data/train.manifest
data_test_list = '/your/itag/test/file.manifest' # oss://your_bucket/data/test.manifest

image_size2 = 224
image_size1 = int((256 / 224) * image_size2)
data_source_type = 'ClsSourceItag'
dataset_type = 'ClsDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomResizedCrop', size=image_size2),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]
test_pipeline = [
dict(type='Resize', size=image_size1),
dict(type='CenterCrop', size=image_size2),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Collect', keys=['img', 'gt_labels'])
]

data = dict(
imgs_per_gpu=32, # total 256
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_source=dict(
type=data_source_type,
list_file=data_train_list,
class_list=class_list,
),
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_source=dict(
type=data_source_type,
list_file=data_test_list,
class_list=class_list),
pipeline=test_pipeline))

eval_config = dict(initial=False, interval=1, gpu_collect=True)
eval_pipelines = [
dict(
mode='test',
data=data['val'],
dist_eval=True,
evaluators=[dict(type='ClsEvaluator', topk=(1, ))],
)
]

# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)

# learning policy
lr_config = dict(policy='step', step=[30, 60, 90])
checkpoint_config = dict(interval=5)

# runtime settings
total_epochs = 100
checkpoint_sync_export = True
export = dict(export_neck=True)
120 changes: 39 additions & 81 deletions easycv/datasets/classification/data_sources/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image
from easycv.framework.errors import TypeError
from easycv.framework.errors import TypeError, ValueError
from easycv.utils.dist_utils import dist_zero_exec
from .utils import split_listfile_byrank

Expand All @@ -28,7 +28,6 @@ class ClsSourceImageList(object):
If split, data list will be split to each rank.
split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance
cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path.
max_try: int, max try numbers of reading image
"""

def __init__(self,
Expand All @@ -37,13 +36,9 @@ def __init__(self,
delimeter=' ',
split_huge_listfile_byrank=False,
split_label_balance=False,
cache_path='data/',
max_try=20):
cache_path='data/'):

ImageFile.LOAD_TRUNCATED_IMAGES = True

self.max_try = max_try

# DistributedMPSampler need this attr
self.has_labels = True

Expand Down Expand Up @@ -124,77 +119,39 @@ class ClsSourceItag(ClsSourceImageList):
list_file : str / list(str), str means a input image list file path,
this file contains records as `image_path label` in list_file
list(str) means multi image list, each one contains some records as `image_path label`
root: str / list(str), root path for image_path, each list_file will need a root,
if len(root) < len(list_file), we will use root[-1] to fill root list.
delimeter: str, delimeter of each line in the `list_file`
split_huge_listfile_byrank: Adapt to the situation that the memory cannot fully load a huge amount of data list.
If split, data list will be split to each rank.
split_label_balance: if `split_huge_listfile_byrank` is true, whether split with label balance
cache_path: if `split_huge_listfile_byrank` is true, cache list_file will be saved to cache_path.
max_try: int, max try numbers of reading image
"""

def __init__(self,
list_file,
root='',
delimeter=' ',
split_huge_listfile_byrank=False,
split_label_balance=False,
cache_path='data/',
max_try=20):

def __init__(self, list_file, root='', class_list=None):
assert root is None or len(
root) < 1, 'The "root" param is not used and will be removed soon!'
ImageFile.LOAD_TRUNCATED_IMAGES = True

self.max_try = max_try

# DistributedMPSampler need this attr
self.has_labels = True

if isinstance(list_file, str):
assert isinstance(root, str), 'list_file is str, root must be str'
list_file = [list_file]
root = [root]
self.class_list = class_list
if self.class_list is None:
logging.warning(
'It is recommended to specify the ``class_list`` parameter!')
self._auto_collect_labels = True
self.label_dict = {}
else:
assert isinstance(list_file, list), \
'list_file should be str or list(str)'
root = [root] if isinstance(root, str) else root
if not isinstance(root, list):
raise TypeError('root must be str or list(str), but get %s' %
type(root))

if len(root) < len(list_file):
logging.warning(
'len(root) < len(list_file), fill root with root last!')
root = root + [root[-1]] * (len(list_file) - len(root))

# TODO: support return list, donot save split file
# TODO: support loading list_file that have already been split
if split_huge_listfile_byrank:
with dist_zero_exec():
list_file = split_listfile_byrank(
list_file=list_file,
label_balance=split_label_balance,
save_path=cache_path)

self.fns = []
self.labels = []
label_dict = dict()
for l, r in zip(list_file, root):
fns, labels, label_dict = self.parse_list_file(l, label_dict)
self.fns += fns
self.labels += labels
self.label_dict = dict(
zip(self.class_list, range(len(self.class_list))))
self._auto_collect_labels = False
self.fns, self.labels, self.label_dict = self.parse_list_file(
list_file, self.label_dict, self._auto_collect_labels)

@staticmethod
def parse_list_file(list_file, label_dict):
with open(list_file, 'r', encoding='utf-8') as f:
data = f.readlines()
def parse_list_file(list_file, label_dict, auto_collect_labels=True):
with io.open(list_file, 'r') as f:
rows = f.read().splitlines()

fns = []
labels = []
for i in range(len(data)):
data_i = json.loads(data[i])
labels_id = []

for row_str in rows:
data_i = json.loads(row_str.strip())
img_path = data_i['data']['source']
label = []
label_id = []

priority = 2
for k in data_i.keys():
Expand All @@ -206,26 +163,27 @@ def parse_list_file(list_file, label_dict):

for k, v in data_i.items():
if 'label' in k:
label = []
label_id = []
result_list = v['results']
for j in range(len(result_list)):
anno_list = result_list[j]['data']
if 'labels' in anno_list:
if anno_list['labels']['单选'] not in label_dict:
label_dict[anno_list['labels']['单选']] = len(
label_dict)
label.append(label_dict[anno_list['labels']['单选']])
else:
if anno_list not in label_dict:
label_dict[anno_list] = len(label_dict)
label.append(label_dict[anno_list])
label = result_list[j]['data']
if 'labels' in label:
label = label['labels']['单选']
if label not in label_dict:
if auto_collect_labels:
label_dict[label] = len(label_dict)
else:
raise ValueError(
f'Not find label "{label}" in label dict: {label_dict}'
)
label_id.append(label_dict[label])
if 'verify' in k:
break
elif 'check' in k and priority == 1:
break

fns.append(img_path)
labels.append(
label[0]) if len(label) == 1 else labels.append(label)
labels_id.append(label_id[0]) if len(
label_id) == 1 else labels_id.append(label_id)

return fns, labels, label_dict
return fns, labels_id, label_dict
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from tests.ut_config import CLS_DATA_ITAG_OSS

from easycv.datasets.builder import build_datasource
from easycv.framework.errors import ValueError


class ClsSourceImageListTest(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))

def test_default(self):
from easycv.file import io
io.access_oss()

cfg = dict(type='ClsSourceItag', list_file=CLS_DATA_ITAG_OSS)
data_source = build_datasource(cfg)

index_list = list(range(5))
for idx in index_list:
results = data_source[idx]
img = results['img']
label = results['gt_labels']
self.assertEqual(img.mode, 'RGB')
self.assertIn(label, list(range(3)))
img.close()

self.assertEqual(len(data_source), 11)
self.assertDictEqual(data_source.label_dict, {
'ng': 0,
'ok': 1,
'中文': 2
})

def test_with_class_list(self):
from easycv.file import io
io.access_oss()

cfg = dict(
type='ClsSourceItag',
class_list=['中文', 'ng', 'ok'],
list_file=CLS_DATA_ITAG_OSS)
data_source = build_datasource(cfg)

index_list = list(range(5))
for idx in index_list:
results = data_source[idx]
img = results['img']
label = results['gt_labels']
self.assertEqual(img.mode, 'RGB')
self.assertIn(label, list(range(3)))
img.close()

self.assertEqual(len(data_source), 11)
self.assertDictEqual(data_source.label_dict, {
'中文': 0,
'ng': 1,
'ok': 2
})

def test_with_fault_class_list(self):
from easycv.file import io
io.access_oss()

with self.assertRaises(ValueError) as cm:
cfg = dict(
type='ClsSourceItag',
class_list=['error', 'ng', 'ok'],
list_file=CLS_DATA_ITAG_OSS)

data_source = build_datasource(cfg)
index_list = list(range(5))
for idx in index_list:
results = data_source[idx]
img = results['img']
label = results['gt_labels']
self.assertEqual(img.mode, 'RGB')
self.assertIn(label, list(range(3)))
img.close()

exception = cm.exception
self.assertEqual(
exception.message,
"Not find label \"中文\" in label dict: {'error': 0, 'ng': 1, 'ok': 2}"
)


if __name__ == '__main__':
unittest.main()
4 changes: 4 additions & 0 deletions tests/ut_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
TMP_DIR_OSS = os.path.join(BASE_OSS_PATH, 'tmp')
TMP_DIR_LOCAL = os.path.join(BASE_LOCAL_PATH, 'tmp')

CLS_DATA_ITAG_OSS = os.path.join(
BASE_OSS_PATH,
'local_backup/easycv_nfs/data/classification/cls_itagtest/cls_itagtest.manifest'
)
CLS_DATA_NPY_LOCAL = os.path.join(BASE_LOCAL_PATH, 'data/classification/npy/')
SMALL_IMAGENET_RAW_LOCAL = os.path.join(
BASE_LOCAL_PATH, 'data/classification/small_imagenet_raw')
Expand Down
4 changes: 2 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def main():
else:
cfg.oss_work_dir = None

if args.resume_from is not None:
if args.resume_from is not None and len(args.resume_from) > 0:
cfg.resume_from = args.resume_from
if args.load_from is not None:
if args.load_from is not None and len(args.load_from) > 0:
cfg.load_from = args.load_from

# dynamic adapt mmdet models
Expand Down