modified built-in function problem in the network

This commit is contained in:
wsq3 2021-05-29 16:15:53 +08:00
parent d8c943fd29
commit cd2ea6e435
4 changed files with 31 additions and 12 deletions

View File

@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -21,7 +20,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
class CrossEntropy(nn.Cell):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
super(CrossEntropy, self).__init__()
@ -43,7 +42,7 @@ class CrossEntropy(_Loss):
return loss_logit + self.factor*loss_aux
class CrossEntropy_Val(_Loss):
class CrossEntropy_Val(nn.Cell):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
def __init__(self, smooth_factor=0, num_classes=1000):
super(CrossEntropy_Val, self).__init__()

View File

@ -24,7 +24,7 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
from mindspore.train.callback import CheckpointConfig
from mindspore import amp
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
@ -113,6 +113,16 @@ def parallel_init(args):
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
class InternalCallbackParam(dict):
"""Internal callback object's parameters."""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
def modelarts_pre_process():
'''modelarts pre process function.'''
def unzip(zip_file, save_dir):
@ -231,7 +241,7 @@ def run_train():
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix='{}'.format(config.rank))
cb_params = _InternalCallbackParam()
cb_params = InternalCallbackParam()
cb_params.train_network = network
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1

View File

@ -25,7 +25,7 @@ from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
from mindspore.train.callback import CheckpointConfig
import mindspore as ms
from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.common import set_seed
@ -149,6 +149,16 @@ def build_quant_network(network):
return network
class InternalCallbackParam(dict):
"""Internal callback object's parameters."""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
def train():
"""Train function."""
args = parse_args()
@ -215,7 +225,7 @@ def train():
keep_checkpoint_max=ckpt_max_num)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=save_ckpt_path, prefix='{}'.format(args.rank))
cb_params = _InternalCallbackParam()
cb_params = InternalCallbackParam()
cb_params.train_network = network
cb_params.epoch_num = ckpt_max_num
cb_params.cur_epoch_num = 1

View File

@ -28,8 +28,8 @@ from mindspore.nn.metrics import Metric
from mindspore import nn, Tensor, ParameterTuple, Parameter
from mindspore.common.initializer import Uniform, initializer
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
from mindspore.context import ParallelMode
from mindspore.context import ParallelMode, get_auto_parallel_context
from mindspore.communication.management import get_group_size
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from src.callback import EvalCallBack, LossCallBack
@ -276,12 +276,12 @@ class TrainStepWrap(nn.Cell):
self.reducer_flag = False
self.grad_reducer = None
parallel_mode = _get_parallel_mode()
parallel_mode = get_auto_parallel_context("parallel_mode")
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
mean = get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
def construct(self, batch_ids, batch_wts, label):