forked from mindspore-Ecosystem/mindspore
convert dtype of weight in net from float32 to float16 to avoid overflow of some operations
This commit is contained in:
parent
356a49b98d
commit
5a5bb1a883
|
@ -18,6 +18,11 @@ import os
|
|||
import time
|
||||
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.Deeptext.deeptext_vgg16 import Deeptext_VGG16
|
||||
from src.dataset import data_to_mindrecord_byte_image, create_deeptext_dataset
|
||||
from src.utils import metrics
|
||||
|
@ -26,9 +31,6 @@ from model_utils.config import config
|
|||
from model_utils.moxing_adapter import moxing_wrapper
|
||||
from model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
@ -49,6 +51,11 @@ def deeptext_eval_test(dataset_path='', ckpt_path=''):
|
|||
|
||||
print("\n========================================\n", flush=True)
|
||||
print("Processing, please wait a moment.", flush=True)
|
||||
|
||||
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
if device_type == "Ascend":
|
||||
net.to_float(mstype.float16)
|
||||
|
||||
max_num = 32
|
||||
|
||||
pred_data = []
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import functional as F
|
||||
|
@ -192,6 +193,7 @@ class Deeptext_VGG16(nn.Cell):
|
|||
self.concat1 = P.Concat(axis=1)
|
||||
self.roi_align_fuse = _conv(in_channels=1024, out_channels=512, kernel_size=1, padding=0, stride=1)
|
||||
self.vgg16_feature_extractor = VGG16FeatureExtraction()
|
||||
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
|
||||
def construct(self, img_data, img_metas, gt_bboxes, gt_labels, gt_valids):
|
||||
_, _, _, f4, f5 = self.vgg16_feature_extractor(img_data)
|
||||
|
@ -260,6 +262,8 @@ class Deeptext_VGG16(nn.Cell):
|
|||
bboxes_all = self.concat(bboxes_tuple)
|
||||
else:
|
||||
bboxes_all = bboxes_tuple[0]
|
||||
if self.device_type == "Ascend":
|
||||
bboxes_all = self.cast(bboxes_all, mstype.float16)
|
||||
rois = self.concat_1((self.roi_align_index_test_tensor, bboxes_all))
|
||||
|
||||
rois = self.cast(rois, mstype.float32)
|
||||
|
|
|
@ -40,7 +40,7 @@ class DenseNoTranpose(nn.Cell):
|
|||
def construct(self, x):
|
||||
x = self.cast(x, mstype.float16)
|
||||
weight = self.cast(self.weight, mstype.float16)
|
||||
output = self.bias_add(self.cast(self.matmul(x, weight), mstype.float32), self.bias)
|
||||
output = self.bias_add(self.matmul(x, weight), self.bias)
|
||||
return output
|
||||
|
||||
|
||||
|
@ -165,7 +165,7 @@ class Rcnn(nn.Cell):
|
|||
|
||||
weights = self.cast(weights, mstype.float32)
|
||||
loss_cls = loss_cls * weights
|
||||
loss_cls = self.sum_loss(loss_cls, (0,)) / self.sum_loss(weights, (0,))
|
||||
loss_cls = self.sum_loss(loss_cls, (0,)) / (self.sum_loss(weights, (0,)) + 1e-5)
|
||||
|
||||
bbox_weights = self.cast(self.onehot(bbox_weights, self.num_classes, self.on_value, self.off_value),
|
||||
mstype.float32)
|
||||
|
@ -175,7 +175,7 @@ class Rcnn(nn.Cell):
|
|||
loss_reg = self.loss_bbox(pos_bbox_pred, bbox_targets)
|
||||
loss_reg = self.sum_loss(loss_reg, (2,))
|
||||
loss_reg = loss_reg * bbox_weights
|
||||
loss_reg = loss_reg / self.sum_loss(weights, (0,))
|
||||
loss_reg = loss_reg / (self.sum_loss(weights, (0,)) + 1e-5)
|
||||
loss_reg = self.sum_loss(loss_reg, (0, 1))
|
||||
|
||||
loss = self.rcnn_loss_cls_weight * loss_cls + self.rcnn_loss_reg_weight * loss_reg
|
||||
|
|
|
@ -17,7 +17,7 @@ import numpy as np
|
|||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import initializer
|
||||
from .bbox_assign_sample import BboxAssignSample
|
||||
|
@ -127,6 +127,8 @@ class RPN(nn.Cell):
|
|||
cls_out_channels):
|
||||
super(RPN, self).__init__()
|
||||
cfg_rpn = config
|
||||
self.ms_type = mstype.float32
|
||||
self.device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
self.num_bboxes = cfg_rpn.num_bboxes
|
||||
self.slice_index = ()
|
||||
self.feature_anchor_shape = ()
|
||||
|
@ -204,9 +206,12 @@ class RPN(nn.Cell):
|
|||
weight_reg = initializer('Normal', shape=shp_weight_reg, dtype=mstype.float32).to_tensor()
|
||||
bias_reg = initializer(0, shape=shp_bias_reg, dtype=mstype.float32).to_tensor()
|
||||
|
||||
rpn_layer.append(RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
|
||||
weight_conv, bias_conv, weight_cls, \
|
||||
bias_cls, weight_reg, bias_reg))
|
||||
rpn_reg_cls_block = RpnRegClsBlock(in_channels, feat_channels, num_anchors, cls_out_channels, \
|
||||
weight_conv, bias_conv, weight_cls, \
|
||||
bias_cls, weight_reg, bias_reg)
|
||||
if self.device_type == "Ascend":
|
||||
rpn_reg_cls_block.to_float(mstype.float16)
|
||||
rpn_layer.append(rpn_reg_cls_block)
|
||||
|
||||
rpn_layer[0].rpn_conv.weight = rpn_layer[0].rpn_conv.weight
|
||||
rpn_layer[0].rpn_cls.weight = rpn_layer[0].rpn_cls.weight
|
||||
|
@ -271,6 +276,7 @@ class RPN(nn.Cell):
|
|||
mstype.bool_),
|
||||
anchor_using_list, gt_valids_i)
|
||||
|
||||
bbox_target = self.cast(bbox_target, mstype.float32)
|
||||
bbox_weight = self.cast(bbox_weight, mstype.float32)
|
||||
label = self.cast(label, mstype.float32)
|
||||
label_weight = self.cast(label_weight, mstype.float32)
|
||||
|
@ -305,8 +311,8 @@ class RPN(nn.Cell):
|
|||
label_ = F.stop_gradient(label_with_batchsize)
|
||||
label_weight_ = F.stop_gradient(label_weight_with_batchsize)
|
||||
|
||||
cls_score_i = rpn_cls_score[0]
|
||||
reg_score_i = rpn_bbox_pred[0]
|
||||
cls_score_i = self.cast(rpn_cls_score[0], self.ms_type)
|
||||
reg_score_i = self.cast(rpn_bbox_pred[0], self.ms_type)
|
||||
|
||||
loss_cls = self.loss_cls(cls_score_i, label_)
|
||||
loss_cls_item = loss_cls * label_weight_
|
||||
|
|
|
@ -153,6 +153,10 @@ def run_train():
|
|||
param_dict = load_checkpoint(load_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
device_type = "Ascend" if context.get_context("device_target") == "Ascend" else "Others"
|
||||
if device_type == "Ascend":
|
||||
net.to_float(mstype.float16)
|
||||
|
||||
loss = LossNet()
|
||||
lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32)
|
||||
|
||||
|
|
Loading…
Reference in New Issue