forked from mindspore-Ecosystem/mindspore
modified built-in function problem in the network
This commit is contained in:
parent
d8c943fd29
commit
cd2ea6e435
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network."""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
|
@ -21,7 +20,7 @@ from mindspore.common import dtype as mstype
|
|||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
class CrossEntropy(nn.Cell):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
|
||||
super(CrossEntropy, self).__init__()
|
||||
|
@ -43,7 +42,7 @@ class CrossEntropy(_Loss):
|
|||
return loss_logit + self.factor*loss_aux
|
||||
|
||||
|
||||
class CrossEntropy_Val(_Loss):
|
||||
class CrossEntropy_Val(nn.Cell):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
|
||||
def __init__(self, smooth_factor=0, num_classes=1000):
|
||||
super(CrossEntropy_Val, self).__init__()
|
||||
|
|
|
@ -24,7 +24,7 @@ import mindspore.nn as nn
|
|||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext
|
||||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||
from mindspore.train.callback import CheckpointConfig
|
||||
from mindspore import amp
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
|
@ -113,6 +113,16 @@ def parallel_init(args):
|
|||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=degree)
|
||||
|
||||
|
||||
class InternalCallbackParam(dict):
|
||||
"""Internal callback object's parameters."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
|
||||
def modelarts_pre_process():
|
||||
'''modelarts pre process function.'''
|
||||
def unzip(zip_file, save_dir):
|
||||
|
@ -231,7 +241,7 @@ def run_train():
|
|||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix='{}'.format(config.rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params = InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
cb_params.cur_epoch_num = 1
|
||||
|
|
|
@ -25,7 +25,7 @@ from mindspore import Tensor
|
|||
from mindspore import context
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import ModelCheckpoint, RunContext
|
||||
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
|
||||
from mindspore.train.callback import CheckpointConfig
|
||||
import mindspore as ms
|
||||
from mindspore.compression.quant import QuantizationAwareTraining
|
||||
from mindspore.common import set_seed
|
||||
|
@ -149,6 +149,16 @@ def build_quant_network(network):
|
|||
return network
|
||||
|
||||
|
||||
class InternalCallbackParam(dict):
|
||||
"""Internal callback object's parameters."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
|
@ -215,7 +225,7 @@ def train():
|
|||
keep_checkpoint_max=ckpt_max_num)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(args.rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=save_ckpt_path, prefix='{}'.format(args.rank))
|
||||
cb_params = _InternalCallbackParam()
|
||||
cb_params = InternalCallbackParam()
|
||||
cb_params.train_network = network
|
||||
cb_params.epoch_num = ckpt_max_num
|
||||
cb_params.cur_epoch_num = 1
|
||||
|
|
|
@ -28,8 +28,8 @@ from mindspore.nn.metrics import Metric
|
|||
from mindspore import nn, Tensor, ParameterTuple, Parameter
|
||||
from mindspore.common.initializer import Uniform, initializer
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.context import ParallelMode, get_auto_parallel_context
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
|
||||
from src.callback import EvalCallBack, LossCallBack
|
||||
|
@ -276,12 +276,12 @@ class TrainStepWrap(nn.Cell):
|
|||
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
parallel_mode = _get_parallel_mode()
|
||||
parallel_mode = get_auto_parallel_context("parallel_mode")
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
mean = get_auto_parallel_context("gradients_mean")
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, batch_ids, batch_wts, label):
|
||||
|
|
Loading…
Reference in New Issue