diff --git a/mindspore/ops/_grad/__init__.py b/mindspore/ops/_grad/__init__.py index 9cf4104e5a1..c2c837b89c2 100644 --- a/mindspore/ops/_grad/__init__.py +++ b/mindspore/ops/_grad/__init__.py @@ -15,7 +15,7 @@ """grad impl.""" from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \ - grad_math_ops, grad_nn_ops, grad_other_ops + grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops from .grad_base import get_bprop_fn __all__ = ['get_bprop_fn'] diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 4c7d64b581b..2d07549fd05 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -223,8 +223,8 @@ class BatchNormFold(PrimitiveWithInfer): Args: momentum (float): Momentum value should be [0, 1]. Default: 0.1. - epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in - float32 else 1e-3. Default: 1e-12. + epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in + float32 else 1e-3. Default: 1e-5. is_training (bool): In training mode set True, else set False. Default: True. freeze_bn (int): Delay in steps at which computation switches from regular batch norm to frozen mean and std. Default: 0. @@ -247,7 +247,7 @@ class BatchNormFold(PrimitiveWithInfer): channel = 1 @prim_attr_register - def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0): + def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): """init batch norm fold layer""" self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) @@ -277,7 +277,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer): channel = 1 @prim_attr_register - def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0): + def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0): """init BatchNormGrad layer""" self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 2e758b0e9dd..f83f9606ead 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -32,6 +32,7 @@ __all__ = ["build_train_network"] class OutputTo16(nn.Cell): "Wrap cell for amp. Cast network output back to float16" + def __init__(self, op): super(OutputTo16, self).__init__(auto_prefix=False) self._op = op @@ -53,7 +54,7 @@ def _do_keep_batchnorm_fp32(network): change = True else: _do_keep_batchnorm_fp32(subcell) - if isinstance(network, nn.SequentialCell) and change: + if isinstance(network, nn.SequentialCell) and change: network.cell_list = list(network.cells()) @@ -72,7 +73,7 @@ def _check_kwargs(key_words): """Check kwargs.""" for arg in key_words: if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: - raise ValueError(f"Unsupported arg '{arg}'") + raise ValueError(f"Unsupported arg '{arg}'") if 'cast_model_type' in key_words: validator.check_type_name('cast_model_type', key_words['cast_model_type'], diff --git a/scripts/get_shape_from_ir.sh b/scripts/get_shape_from_ir.sh index 8e603add4ab..739ffca56e3 100755 --- a/scripts/get_shape_from_ir.sh +++ b/scripts/get_shape_from_ir.sh @@ -18,4 +18,16 @@ set -e # Usage : get_shape_from_ir.sh ir_file -cat "$1" | perl -p -e 's/\n/NEWLINE/' | sed 's/NEWLINE :/:/g' | sed 's/Tensor NEWLINEshape//g' | perl -p -e 's/NEWLINE/\n/g' | perl -p -e 's//\2/g' | perl -p -e 's//Tuple/g' | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' | tr -d '()' | awk '/subgraph/{p=1;next}{if(p){print}}'| awk '/return/{p=1;next}{if(!p){print}}' | sed '/^$/d' | awk -F'\t' '{print $1"\t"$2"\t"$4"\t"$3}' +cat "$1" | perl -p -e 's/\n/NEWLINE/' \ + | sed 's/NEWLINE :/:/g' \ + | sed 's/Tensor NEWLINEshape//g' \ + | perl -p -e 's/NEWLINE/\n/g' \ + | perl -p -e 's//\2/g' \ + | perl -p -e 's//Tuple/g' \ + | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' \ + | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' \ + | tr -d '()' \ + | awk '/subgraph/{p=1;next}{if(p){print}}'\ + | awk '/return/{p=1;next}{if(!p){print}}' \ + | sed '/^$/d' \ + | awk -F'\t' '{print $1"\t"$2"\t"$4}'