forked from mindspore-Ecosystem/mindspore
fix ctpn BNTrainingUpdateGrad inconsistent data type
This commit is contained in:
parent
c84e09cf45
commit
2c999fcae5
|
@ -114,13 +114,13 @@ pretraining_dataset_file: ""
|
|||
finetune_dataset_file: ""
|
||||
|
||||
# pretrain lr
|
||||
pre_base_lr: 0.0009
|
||||
pre_base_lr: 0.009
|
||||
pre_warmup_step: 30000
|
||||
pre_warmup_ratio: 1/3
|
||||
pre_total_epoch: 100
|
||||
|
||||
# finetune lr
|
||||
fine_base_lr: 0.0005
|
||||
fine_base_lr: 0.005
|
||||
fine_warmup_step: 300
|
||||
fine_warmup_ratio: 1/3
|
||||
fine_total_epoch: 50
|
||||
|
|
|
@ -92,8 +92,8 @@ class CTPN(nn.Cell):
|
|||
self.num_step = config.num_step
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction()
|
||||
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same')
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction().to_float(mstype.float16)
|
||||
self.conv = nn.Conv2d(512, 512, kernel_size=3, padding=0, pad_mode='same').to_float(mstype.float16)
|
||||
self.rnn = BiLSTM(self.config, batch_size=self.batch_size).to_float(mstype.float16)
|
||||
self.reshape = P.Reshape()
|
||||
self.transpose = P.Transpose()
|
||||
|
|
Loading…
Reference in New Issue