From 7eb7844b2163c52928d72b0cd8ef085d7af8c863 Mon Sep 17 00:00:00 2001 From: zhangyubo0722 Date: Mon, 18 Sep 2023 11:58:06 +0000 Subject: [PATCH 1/2] add svtr large model --- configs/rec/rec_svtrnet_large.yml | 144 ++++++++++++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 configs/rec/rec_svtrnet_large.yml diff --git a/configs/rec/rec_svtrnet_large.yml b/configs/rec/rec_svtrnet_large.yml new file mode 100644 index 0000000000..45d7521087 --- /dev/null +++ b/configs/rec/rec_svtrnet_large.yml @@ -0,0 +1,144 @@ +Global: + debug: false + use_gpu: true + epoch_num: 200 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/svtr_large/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: true + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: false + infer_img: doc/imgs_words/ch/word_1.jpg + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: &max_text_length 25 + infer_mode: false + use_space_char: true + distributed: true + save_res_path: ./output/rec/predicts_svtr_large.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 1.0e-08 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed char_node_embed pos_node_embed char_pos_embed vis_pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.00025 # 8gpus 64bs + warmup_epoch: 5 + + +Architecture: + model_type: rec + algorithm: SVTR_LCNet + Transform: null + Backbone: + name: SVTRNet + img_size: + - 48 + - 320 + out_char_num: 40 + out_channels: 512 + patch_merging: Conv + embed_dim: [192, 256, 512] + depth: [6, 6, 9] + num_heads: [6, 8, 16] + mixer: ['Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Conv','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global'] + local_mixer: [[5, 5], [5, 5], [5, 5]] + last_stage: False + prenorm: True + Head: + name: MultiHead + use_pool: true + use_pos: true + head_list: + - CTCHead: + Neck: + name: svtr + dims: 256 + depth: 2 + hidden_dims: 256 + kernel_size: [1, 3] + use_guide: True + Head: + fc_decay: 0.00001 + - NRTRHead: + nrtr_dim: 512 + max_text_length: *max_text_length + +Loss: + name: MultiLoss + loss_config_list: + - CTCLoss: + - NRTRLoss: + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + ignore_space: true + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/ + ext_op_transform_idx: 1 + label_file_list: + - ./train_data/train_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - RecAug: + - MultiLabelEncode: + gtc_encode: NRTRLabelEncode + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_gtc + - length + - valid_ratio + loader: + shuffle: true + batch_size_per_card: 64 + drop_last: true + num_workers: 8 +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data + label_file_list: + - ./train_data/val_list.txt + transforms: + - DecodeImage: + img_mode: BGR + channel_first: false + - MultiLabelEncode: + gtc_encode: NRTRLabelEncode + - RecResizeImg: + image_shape: [3, 48, 320] + - KeepKeys: + keep_keys: + - image + - label_ctc + - label_gtc + - length + - valid_ratio + loader: + shuffle: false + drop_last: false + batch_size_per_card: 128 + num_workers: 4 From 2d8a6ae93ab8609796f98c5880c455eb54ed6bac Mon Sep 17 00:00:00 2001 From: zhangyubo0722 Date: Mon, 25 Sep 2023 11:52:22 +0000 Subject: [PATCH 2/2] [WIP]add svtr large model --- .../ch_PP-OCRv4_rec_svtr_large.yml} | 4 +- ppocr/data/imaug/rec_img_aug.py | 2 +- ppocr/modeling/backbones/__init__.py | 3 +- ppocr/modeling/backbones/rec_vit.py | 258 ++++++++++++++++++ ppocr/modeling/heads/rec_multi_head.py | 30 +- ppocr/utils/utility.py | 34 ++- tools/infer_rec.py | 6 +- 7 files changed, 316 insertions(+), 21 deletions(-) rename configs/rec/{rec_svtrnet_large.yml => PP-OCRv4/ch_PP-OCRv4_rec_svtr_large.yml} (98%) create mode 100644 ppocr/modeling/backbones/rec_vit.py mode change 100755 => 100644 ppocr/utils/utility.py diff --git a/configs/rec/rec_svtrnet_large.yml b/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_svtr_large.yml similarity index 98% rename from configs/rec/rec_svtrnet_large.yml rename to configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_svtr_large.yml index 45d7521087..525d1c0bb5 100644 --- a/configs/rec/rec_svtrnet_large.yml +++ b/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_svtr_large.yml @@ -15,7 +15,7 @@ Global: use_visualdl: false infer_img: doc/imgs_words/ch/word_1.jpg character_dict_path: ppocr/utils/ppocr_keys_v1.txt - max_text_length: &max_text_length 25 + max_text_length: &max_text_length 40 infer_mode: false use_space_char: true distributed: true @@ -128,7 +128,7 @@ Eval: channel_first: false - MultiLabelEncode: gtc_encode: NRTRLabelEncode - - RecResizeImg: + - SVTRRecResizeImg: image_shape: [3, 48, 320] - KeepKeys: keep_keys: diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 264579c038..0bf15114d5 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -47,7 +47,7 @@ def __call__(self, data): if h >= 20 and w >= 20: img = tia_distort(img, random.randint(3, 6)) img = tia_stretch(img, random.randint(3, 6)) - img = tia_perspective(img) + img = tia_perspective(img) # bda data['image'] = img diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 60b9daf98a..10839b82b7 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -24,6 +24,7 @@ def build_backbone(config, model_type): from .det_pp_lcnet import PPLCNet from .rec_lcnetv3 import PPLCNetV3 from .rec_hgnet import PPHGNet_small + from .rec_vit import ViT support_dict = [ "MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST", "PPLCNet", "PPLCNetV3", "PPHGNet_small" @@ -55,7 +56,7 @@ def build_backbone(config, model_type): 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'ResNet31', 'ResNet45', 'ResNet_ASTER', 'MicroNet', 'EfficientNetb3_PREN', 'SVTRNet', 'ViTSTR', 'ResNet32', 'ResNetRFL', - 'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ' + 'DenseNet', 'ShallowCNN', 'PPLCNetV3', 'PPHGNet_small', 'ViTParseQ', 'ViT' ] elif model_type == 'e2e': from .e2e_resnet_vd_pg import ResNet diff --git a/ppocr/modeling/backbones/rec_vit.py b/ppocr/modeling/backbones/rec_vit.py new file mode 100644 index 0000000000..b7a55539da --- /dev/null +++ b/ppocr/modeling/backbones/rec_vit.py @@ -0,0 +1,258 @@ +# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import ParamAttr +from paddle.nn.initializer import KaimingNormal +import numpy as np +import paddle +import paddle.nn as nn +from paddle.nn.initializer import TruncatedNormal, Constant, Normal + +trunc_normal_ = TruncatedNormal(std=.02) +normal_ = Normal +zeros_ = Constant(value=0.) +ones_ = Constant(value=1.) + + +def drop_path(x, drop_prob=0., training=False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... + """ + if drop_prob == 0. or not training: + return x + keep_prob = paddle.to_tensor(1 - drop_prob) + shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype) + random_tensor = paddle.floor(random_tensor) # binarize + output = x.divide(keep_prob) * random_tensor + return output + + +class DropPath(nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, input): + return input + + +class Mlp(nn.Layer): + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Layer): + def __init__(self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + super().__init__() + self.num_heads = num_heads + self.dim = dim + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + + def forward(self, x): + + qkv = paddle.reshape(self.qkv(x), (0, -1, 3, self.num_heads, self.dim // + self.num_heads)).transpose((2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + + attn = (q.matmul(k.transpose((0, 1, 3, 2)))) + attn = nn.functional.softmax(attn, axis=-1) + attn = self.attn_drop(attn) + + x = (attn.matmul(v)).transpose((0, 2, 1, 3)).reshape((0, -1, self.dim)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Layer): + def __init__(self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer='nn.LayerNorm', + epsilon=1e-6, + prenorm=True): + super().__init__() + if isinstance(norm_layer, str): + self.norm1 = eval(norm_layer)(dim, epsilon=epsilon) + else: + self.norm1 = norm_layer(dim) + self.mixer = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity() + if isinstance(norm_layer, str): + self.norm2 = eval(norm_layer)(dim, epsilon=epsilon) + else: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.prenorm = prenorm + + def forward(self, x): + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class ViT(nn.Layer): + def __init__( + self, + img_size=[32, 128], + patch_size=[4,4], + in_channels=3, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + qkv_bias=False, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer='nn.LayerNorm', + epsilon=1e-6, + act='nn.GELU', + prenorm=False, + **kwargs): + super().__init__() + self.embed_dim = embed_dim + self.out_channels = embed_dim + self.prenorm = prenorm + self.patch_embed = nn.Conv2D(in_channels, embed_dim, patch_size, patch_size, padding=(0, 0)) + self.pos_embed = self.create_parameter( + shape=[1, 257, embed_dim], default_initializer=zeros_) + self.add_parameter("pos_embed", self.pos_embed) + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = np.linspace(0, drop_path_rate, depth) + self.blocks1 = nn.LayerList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=eval(act), + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + epsilon=epsilon, + prenorm=prenorm) for i in range(depth) + ]) + if not prenorm: + self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon) + + self.avg_pool = nn.AdaptiveAvgPool2D([1, 25]) + self.last_conv = nn.Conv2D( + in_channels=embed_dim, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias_attr=False) + self.hardswish = nn.Hardswish() + self.dropout = nn.Dropout(p=0.1, mode="downscale_in_infer") + + trunc_normal_(self.pos_embed) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + zeros_(m.bias) + ones_(m.weight) + + def forward(self, x): + x = self.patch_embed(x).flatten(2).transpose((0, 2, 1)) + x = x + self.pos_embed[:, 1:, :] #[:, :paddle.shape(x)[1], :] + x = self.pos_drop(x) + for blk in self.blocks1: + x = blk(x) + if not self.prenorm: + x = self.norm(x) + + x = self.avg_pool(x.transpose([0, 2, 1]).reshape( + [0, self.embed_dim, -1, 25])) + x = self.last_conv(x) + x = self.hardswish(x) + x = self.dropout(x) + return x diff --git a/ppocr/modeling/heads/rec_multi_head.py b/ppocr/modeling/heads/rec_multi_head.py index 0b4fa939ee..ae557e7aa0 100644 --- a/ppocr/modeling/heads/rec_multi_head.py +++ b/ppocr/modeling/heads/rec_multi_head.py @@ -22,7 +22,7 @@ import paddle.nn as nn import paddle.nn.functional as F -from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR +from ppocr.modeling.necks.rnn import Im2Seq, EncoderWithRNN, EncoderWithFC, SequenceEncoder, EncoderWithSVTR, trunc_normal_, zeros_ from .rec_ctc_head import CTCHead from .rec_sar_head import SARHead from .rec_nrtr_head import Transformer @@ -41,12 +41,28 @@ def forward(self, x): else: return self.fc(x.transpose([0, 2, 1])) +class AddPos(nn.Layer): + def __init__(self, dim, w): + super().__init__() + self.dec_pos_embed = self.create_parameter( + shape=[1, w, dim], default_initializer=zeros_) + self.add_parameter("dec_pos_embed", self.dec_pos_embed) + trunc_normal_(self.dec_pos_embed) + + def forward(self,x): + x = x + self.dec_pos_embed[:, :paddle.shape(x)[1], :] + return x + class MultiHead(nn.Layer): def __init__(self, in_channels, out_channels_list, **kwargs): super().__init__() self.head_list = kwargs.pop('head_list') - + self.use_pool = kwargs.get('use_pool', False) + self.use_pos = kwargs.get('use_pos', False) + self.in_channels = in_channels + if self.use_pool: + self.pool = nn.AvgPool2D(kernel_size=[3, 2], stride=[3, 2], padding=0) self.gtc_head = 'sar' assert len(self.head_list) >= 2 for idx, head_name in enumerate(self.head_list): @@ -61,8 +77,13 @@ def __init__(self, in_channels, out_channels_list, **kwargs): max_text_length = gtc_args.get('max_text_length', 25) nrtr_dim = gtc_args.get('nrtr_dim', 256) num_decoder_layers = gtc_args.get('num_decoder_layers', 4) - self.before_gtc = nn.Sequential( + if self.use_pos: + self.before_gtc = nn.Sequential( + nn.Flatten(2), FCTranspose(in_channels, nrtr_dim), AddPos(nrtr_dim, 80)) + else: + self.before_gtc = nn.Sequential( nn.Flatten(2), FCTranspose(in_channels, nrtr_dim)) + self.gtc_head = Transformer( d_model=nrtr_dim, nhead=nrtr_dim // 32, @@ -88,7 +109,8 @@ def __init__(self, in_channels, out_channels_list, **kwargs): '{} is not supported in MultiHead yet'.format(name)) def forward(self, x, targets=None): - + if self.use_pool: + x = self.pool(x.reshape([0, 3, -1, self.in_channels]).transpose([0, 3, 1, 2])) ctc_encoder = self.ctc_encoder(x) ctc_out = self.ctc_head(ctc_encoder, targets) head_out = dict() diff --git a/ppocr/utils/utility.py b/ppocr/utils/utility.py old mode 100755 new mode 100644 index 47461d7d5e..688e55698c --- a/ppocr/utils/utility.py +++ b/ppocr/utils/utility.py @@ -57,23 +57,35 @@ def _check_image_file(path): return any([path.lower().endswith(e) for e in img_end]) -def get_image_file_list(img_file): +def get_image_file_list(img_file, infer_list=None): imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - if os.path.isfile(img_file) and _check_image_file(img_file): - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - file_path = os.path.join(img_file, single_file) - if os.path.isfile(file_path) and _check_image_file(file_path): - imgs_lists.append(file_path) + if infer_list and not os.path.exists(infer_list): + raise Exception("not found infer list {}".format(infer_list)) + if infer_list: + with open(infer_list, "r") as f: + lines = f.readlines() + for line in lines: + image_path = line.strip().split("\t")[0] + image_path = os.path.join(img_file, image_path) + imgs_lists.append(image_path) + else: + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'pdf'} + if os.path.isfile(img_file) and _check_image_file(img_file): + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + file_path = os.path.join(img_file, single_file) + if os.path.isfile(file_path) and _check_image_file(file_path): + imgs_lists.append(file_path) if len(imgs_lists) == 0: raise Exception("not found any img file in {}".format(img_file)) imgs_lists = sorted(imgs_lists) return imgs_lists + def binarize_img(img): if len(img.shape) == 3 and img.shape[2] == 3: gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # conversion to grayscale image diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 80986ccdeb..8a7d599356 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -118,9 +118,11 @@ def main(): os.makedirs(os.path.dirname(save_res_path)) model.eval() - + + infer_imgs = config['Global']['infer_img'] + infer_list = config['Global'].get('infer_list', None) with open(save_res_path, "w") as fout: - for file in get_image_file_list(config['Global']['infer_img']): + for file in get_image_file_list(infer_imgs, infer_list=infer_list): logger.info("infer_img: {}".format(file)) with open(file, 'rb') as f: img = f.read()