!16485 convert dtype of weight in net from float32 to float16 to avoid overflow of some operations

From: @zhouneng2
Reviewed-by: @linqingke,@wuxuejian
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-05-25 14:53:56 +08:00 committed by Gitee
commit 678e36006b
5 changed files with 33 additions and 12 deletions

View File

@ -18,6 +18,11 @@ import os
import time import time
import numpy as np import numpy as np
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16 from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset
from src.utils import metrics from src.utils import metrics
@ -26,9 +31,6 @@ from model_utils.config import config
from model_utils.moxing_adapter import moxing_wrapper from model_utils.moxing_adapter import moxing_wrapper
from model_utils.device_adapter import get_device_id, get_device_num from model_utils.device_adapter import get_device_id, get_device_num
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.serialization import load_checkpoint, load_param_into_net
set_seed(1) set_seed(1)
@ -49,6 +51,11 @@ def deeptext_eval_test(dataset_path='', ckpt_path=''):
print("\n========================================\n", flush=True) print("\n========================================\n", flush=True)
print("Processing, please wait a moment.", flush=True) print("Processing, please wait a moment.", flush=True)
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
if device_type == "Ascend":
net.to_float(mstype.float16)
max_num = 32 max_num = 32
pred_data = [] pred_data = []

View File

@ -18,6 +18,7 @@ import numpy as np
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -192,6 +193,7 @@ class Deeptext_VGG16(nn.Cell):
self.concat1 = P.Concat(axis=1) self.concat1 = P.Concat(axis=1)
self.roi_align_fuse = _conv(in_channels=1024, out_channels=512, kernel_size=1, padding=0, stride=1) self.roi_align_fuse = _conv(in_channels=1024, out_channels=512, kernel_size=1, padding=0, stride=1)
self.vgg16_feature_extractor = VGG16FeatureExtraction() self.vgg16_feature_extractor = VGG16FeatureExtraction()
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids): def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
_, _, _, f4, f5 = self.vgg16_feature_extractor(img_data) _, _, _, f4, f5 = self.vgg16_feature_extractor(img_data)
@ -260,6 +262,8 @@ class Deeptext_VGG16(nn.Cell):
bboxes_all = self.concat(bboxes_tuple) bboxes_all = self.concat(bboxes_tuple)
else: else:
bboxes_all = bboxes_tuple[0] bboxes_all = bboxes_tuple[0]
if self.device_type == "Ascend":
bboxes_all = self.cast(bboxes_all, mstype.float16)
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all)) rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
rois = self.cast(rois, mstype.float32) rois = self.cast(rois, mstype.float32)

View File

@ -40,7 +40,7 @@ class DenseNoTranpose(nn.Cell):
def construct(self, x): def construct(self, x):
x = self.cast(x, mstype.float16) x = self.cast(x, mstype.float16)
weight = self.cast(self.weight, mstype.float16) weight = self.cast(self.weight, mstype.float16)
output = self.bias_add(self.cast(self.matmul(x, weight), mstype.float32), self.bias) output = self.bias_add(self.matmul(x, weight), self.bias)
return output return output
@ -165,7 +165,7 @@ class Rcnn(nn.Cell):
weights = self.cast(weights, mstype.float32) weights = self.cast(weights, mstype.float32)
loss_cls = loss_cls * weights loss_cls = loss_cls * weights
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,)) loss_cls = self.sum_loss(loss_cls, (0,)) / (self.sum_loss(weights, (0,)) + 1e-5)
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value), bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
mstype.float32) mstype.float32)
@ -175,7 +175,7 @@ class Rcnn(nn.Cell):
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets) loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
loss_reg = self.sum_loss(loss_reg, (2,)) loss_reg = self.sum_loss(loss_reg, (2,))
loss_reg = loss_reg * bbox_weights loss_reg = loss_reg * bbox_weights
loss_reg = loss_reg / self.sum_loss(weights, (0,)) loss_reg = loss_reg / (self.sum_loss(weights, (0,)) + 1e-5)
loss_reg = self.sum_loss(loss_reg, (0, 1)) loss_reg = self.sum_loss(loss_reg, (0, 1))
loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg

View File

@ -17,7 +17,7 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import Tensor from mindspore import context, Tensor
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from .bbox_assign_sample import BboxAssignSample from .bbox_assign_sample import BboxAssignSample
@ -127,6 +127,8 @@ class RPN(nn.Cell):
cls_out_channels): cls_out_channels):
super(RPN, self).__init__() super(RPN, self).__init__()
cfg_rpn = config cfg_rpn = config
self.ms_type = mstype.float32
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
self.num_bboxes = cfg_rpn.num_bboxes self.num_bboxes = cfg_rpn.num_bboxes
self.slice_index = () self.slice_index = ()
self.feature_anchor_shape = () self.feature_anchor_shape = ()
@ -204,9 +206,12 @@ class RPN(nn.Cell):
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float32).to_tensor() weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float32).to_tensor()
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float32).to_tensor() bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float32).to_tensor()
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \ rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
weight_conv, bias_conv, weight_cls, \ weight_conv, bias_conv, weight_cls, \
bias_cls, weight_reg, bias_reg)) bias_cls, weight_reg, bias_reg)
if self.device_type == "Ascend":
rpn_reg_cls_block.to_float(mstype.float16)
rpn_layer.append(rpn_reg_cls_block)
rpn_layer[0].rpn_conv.weight = rpn_layer[0].rpn_conv.weight rpn_layer[0].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
rpn_layer[0].rpn_cls.weight = rpn_layer[0].rpn_cls.weight rpn_layer[0].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
@ -271,6 +276,7 @@ class RPN(nn.Cell):
mstype.bool_), mstype.bool_),
anchor_using_list, gt_valids_i) anchor_using_list, gt_valids_i)
bbox_target = self.cast(bbox_target, mstype.float32)
bbox_weight = self.cast(bbox_weight, mstype.float32) bbox_weight = self.cast(bbox_weight, mstype.float32)
label = self.cast(label, mstype.float32) label = self.cast(label, mstype.float32)
label_weight = self.cast(label_weight, mstype.float32) label_weight = self.cast(label_weight, mstype.float32)
@ -305,8 +311,8 @@ class RPN(nn.Cell):
label_ = F.stop_gradient(label_with_batchsize) label_ = F.stop_gradient(label_with_batchsize)
label_weight_ = F.stop_gradient(label_weight_with_batchsize) label_weight_ = F.stop_gradient(label_weight_with_batchsize)
cls_score_i = rpn_cls_score[0] cls_score_i = self.cast(rpn_cls_score[0], self.ms_type)
reg_score_i = rpn_bbox_pred[0] reg_score_i = self.cast(rpn_bbox_pred[0], self.ms_type)
loss_cls = self.loss_cls(cls_score_i, label_) loss_cls = self.loss_cls(cls_score_i, label_)
loss_cls_item = loss_cls * label_weight_ loss_cls_item = loss_cls * label_weight_

View File

@ -153,6 +153,10 @@ def run_train():
param_dict = load_checkpoint(load_path) param_dict = load_checkpoint(load_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
if device_type == "Ascend":
net.to_float(mstype.float16)
loss = LossNet() loss = LossNet()
lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32) lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32)