forked from mindspore-Ecosystem/mindspore
fix some bug in quant debug
This commit is contained in:
parent
3d3b9d5474
commit
2a695cfe24
|
@ -15,7 +15,7 @@
|
|||
|
||||
"""grad impl."""
|
||||
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
||||
grad_math_ops, grad_nn_ops, grad_other_ops
|
||||
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
|
||||
from .grad_base import get_bprop_fn
|
||||
|
||||
__all__ = ['get_bprop_fn']
|
||||
|
|
|
@ -223,8 +223,8 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||
epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in
|
||||
float32 else 1e-3. Default: 1e-12.
|
||||
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
|
||||
float32 else 1e-3. Default: 1e-5.
|
||||
is_training (bool): In training mode set True, else set False. Default: True.
|
||||
freeze_bn (int): Delay in steps at which computation switches from regular batch
|
||||
norm to frozen mean and std. Default: 0.
|
||||
|
@ -247,7 +247,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
|||
channel = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0):
|
||||
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""init batch norm fold layer"""
|
||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||
|
@ -277,7 +277,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
|
|||
channel = 1
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0):
|
||||
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||
"""init BatchNormGrad layer"""
|
||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||
|
|
|
@ -32,6 +32,7 @@ __all__ = ["build_train_network"]
|
|||
|
||||
class OutputTo16(nn.Cell):
|
||||
"Wrap cell for amp. Cast network output back to float16"
|
||||
|
||||
def __init__(self, op):
|
||||
super(OutputTo16, self).__init__(auto_prefix=False)
|
||||
self._op = op
|
||||
|
@ -53,7 +54,7 @@ def _do_keep_batchnorm_fp32(network):
|
|||
change = True
|
||||
else:
|
||||
_do_keep_batchnorm_fp32(subcell)
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
if isinstance(network, nn.SequentialCell) and change:
|
||||
network.cell_list = list(network.cells())
|
||||
|
||||
|
||||
|
@ -72,7 +73,7 @@ def _check_kwargs(key_words):
|
|||
"""Check kwargs."""
|
||||
for arg in key_words:
|
||||
if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
|
||||
raise ValueError(f"Unsupported arg '{arg}'")
|
||||
raise ValueError(f"Unsupported arg '{arg}'")
|
||||
|
||||
if 'cast_model_type' in key_words:
|
||||
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
|
||||
|
|
|
@ -18,4 +18,16 @@ set -e
|
|||
|
||||
# Usage : get_shape_from_ir.sh ir_file
|
||||
|
||||
cat "$1" | perl -p -e 's/\n/NEWLINE/' | sed 's/NEWLINE :/:/g' | sed 's/Tensor NEWLINEshape//g' | perl -p -e 's/NEWLINE/\n/g' | perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' | perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' | tr -d '()' | awk '/subgraph/{p=1;next}{if(p){print}}'| awk '/return/{p=1;next}{if(!p){print}}' | sed '/^$/d' | awk -F'\t' '{print $1"\t"$2"\t"$4"\t"$3}'
|
||||
cat "$1" | perl -p -e 's/\n/NEWLINE/' \
|
||||
| sed 's/NEWLINE :/:/g' \
|
||||
| sed 's/Tensor NEWLINEshape//g' \
|
||||
| perl -p -e 's/NEWLINE/\n/g' \
|
||||
| perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' \
|
||||
| perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' \
|
||||
| perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' \
|
||||
| perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' \
|
||||
| tr -d '()' \
|
||||
| awk '/subgraph/{p=1;next}{if(p){print}}'\
|
||||
| awk '/return/{p=1;next}{if(!p){print}}' \
|
||||
| sed '/^$/d' \
|
||||
| awk -F'\t' '{print $1"\t"$2"\t"$4}'
|
||||
|
|
Loading…
Reference in New Issue