Skip to content

Commit eb6545a

Browse files
committed
achieve the accuracy 80.xx
1 parent 53c4ff8 commit eb6545a

File tree

6 files changed

+19
-15
lines changed

6 files changed

+19
-15
lines changed

dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(self, root, opt):
141141
label_key = 'label-%09d'.encode() % index
142142
label = txn.get(label_key).decode('utf-8')
143143

144-
if len(label) >= self.opt.batch_max_length or len(label) == 0:
144+
if len(label) > self.opt.batch_max_length or len(label) == 0:
145145
# print(f'The length of the label is longer than max_length: length \
146146
# {len(label)}, {label} in dataset {self.root}')
147147
continue

model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __init__(self, opt):
9090
elif opt.Prediction == 'Bert_pred':
9191
pass
9292
elif opt.Prediction == 'SRN':
93-
self.Prediction = SRN_Decoder(n_position=opt.position_dim, n_class=opt.alphabet_size)
93+
self.Prediction = SRN_Decoder(n_position=opt.position_dim, N_max_character=opt.batch_max_character + 1, n_class=opt.alphabet_size)
9494
else:
9595
raise Exception('Prediction is neither CTC or Attn')
9696

@@ -103,7 +103,7 @@ def forward(self, input, text, is_train=True):
103103
""" Feature extraction stage """
104104
visual_feature = self.FeatureExtraction(input)
105105
# if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn':
106-
if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn' or self.stages['Feat'] == 'ResNet':
106+
if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn':
107107
b, c, h, w = visual_feature.shape
108108
visual_feature = visual_feature.permute(0, 1, 3, 2)
109109
visual_feature = visual_feature.contiguous().view(b, c, -1)

modules/feature_extraction.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# coding:utf-8
2+
# 2020-05-11
13
import torch.nn as nn
24
import torch.nn.functional as F
35

test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ def test(opt):
234234
parser.add_argument('--eval_data', default='/home/deepblue/deepbluetwo/chenjun/1_OCR/data/data_lmdb_release/evaluation', help='path to evaluation dataset')
235235
parser.add_argument('--benchmark_all_eval', default=True, help='evaluate 10 benchmark evaluation datasets')
236236
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
237-
parser.add_argument('--batch_size', type=int, default=192, help='input batch size')
238-
parser.add_argument('--saved_model', default='./saved_models/None-ResNet-SRN-SRN-Seed666/iter_210000.pth', help="path to saved_model to evaluation")
237+
parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
238+
parser.add_argument('--saved_model', default='./saved_models/None-ResNet-SRN-SRN-Seed666/iter_65000.pth', help="path to saved_model to evaluation")
239239
""" Data processing """
240240
parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')
241241
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')
@@ -258,6 +258,7 @@ def test(opt):
258258
parser.add_argument('--position_dim', type=int, default=26, help='the length sequence out from cnn encoder,resnet:65;resnetfpn:256')
259259

260260
parser.add_argument('--SRN_PAD', type=int, default=36, help='the pad character for srn')
261+
parser.add_argument('--batch_max_character', type=int, default=25, help='the max sequence length')
261262
opt = parser.parse_args()
262263

263264
""" vocab / character number configuration """

train.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def train(opt):
126126
print("Optimizer:")
127127
print(optimizer)
128128

129-
lrScheduler = lr_scheduler.MultiStepLR(optimizer, [2, 5, 10], gamma=0.1) # 减小学习速率
129+
lrScheduler = lr_scheduler.MultiStepLR(optimizer, [2, 4, 5], gamma=0.1) # 减小学习速率
130130

131131
""" final options """
132132
# print(opt)
@@ -266,8 +266,9 @@ def train(opt):
266266
if i == opt.num_iter:
267267
print('end the training')
268268
sys.exit()
269-
270-
if i > 0 and i % step_per_epoch == 0: # 调整学习速率
269+
270+
if i > 0 and i % int(step_per_epoch) == 0: # 调整学习速率
271+
print('down the learn rate 1/10')
271272
lrScheduler.step()
272273

273274
i += 1
@@ -283,13 +284,13 @@ def train(opt):
283284
parser.add_argument('--train_data', default='/home/deepblue/deepbluetwo/chenjun/1_OCR/data/data_lmdb_release/training', help='path to training dataset')
284285
parser.add_argument('--valid_data', default='/home/deepblue/deepbluetwo/chenjun/1_OCR/data/data_lmdb_release/validation', help='path to validation dataset')
285286
parser.add_argument('--manualSeed', type=int, default=666, help='for random seed setting')
286-
parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)
287-
parser.add_argument('--batch_size', type=int, default=16, help='input batch size')
288-
parser.add_argument('--num_iter', type=int, default=150000, help='number of iterations to train for')
289-
parser.add_argument('--valInterval', type=int, default=50, help='Interval between each validation')
287+
parser.add_argument('--workers', type=int, help='number of data loading workers', default=6)
288+
parser.add_argument('--batch_size', type=int, default=256, help='input batch size')
289+
parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for')
290+
parser.add_argument('--valInterval', type=int, default=5000, help='Interval between each validation')
290291
parser.add_argument('--saveInterval', type=int, default=5000, help='Interval between each save')
291292
parser.add_argument('--disInterval', type=int, default=5, help='Interval betweet each show')
292-
# parser.add_argument('--continue_model', default = '', help="path to model to continue training")
293+
parser.add_argument('--continue_model', default = '', help="path to model to continue training")
293294
# parser.add_argument('--continue_model', default='./saved_models/None-ResNet-SRN-SRN-Seed666/iter_150000.pth', help="path to model to continue training")
294295
parser.add_argument('--adam', default=True, help='Whether to use adam (default is Adadelta)')
295296
parser.add_argument('--ranger', default=False, help='use RAdam + Lookahead for optimizer')
@@ -320,7 +321,7 @@ def train(opt):
320321
parser.add_argument('--batch_max_character', type=int, default=25, help='the max character of one image')
321322
parser.add_argument('--alphabet_size', type=int, default=None, help='the categry of the string')
322323

323-
parser.add_argument('--select_data', type=str, default='ICDAR2019-ICDAR2019',
324+
parser.add_argument('--select_data', type=str, default='MJ-ST',
324325
help='select training data MJ-ST | MJ-ST-ICDAR2019 | baidu')
325326
parser.add_argument('--batch_ratio', type=str, default='1.0-1.0',
326327
help='assign ratio for each selected data in the batch')

utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def encode(self, text, batch_max_length=25):
175175
"""
176176
length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence.
177177
# additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token.
178-
batch_text = torch.cuda.LongTensor(len(text), batch_max_length).fill_(self.PAD)
178+
batch_text = torch.cuda.LongTensor(len(text), batch_max_length + 1).fill_(self.PAD)
179179
# mask_text = torch.cuda.LongTensor(len(text), batch_max_length).fill_(0)
180180
for i, t in enumerate(text):
181181
t = list(t + self.character[-2])

0 commit comments

Comments
 (0)