forked from mindspore-Ecosystem/mindspore
fix some bug in quant debug
This commit is contained in:
parent
3d3b9d5474
commit
2a695cfe24
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
"""grad impl."""
|
"""grad impl."""
|
||||||
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
||||||
grad_math_ops, grad_nn_ops, grad_other_ops
|
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
|
||||||
from .grad_base import get_bprop_fn
|
from .grad_base import get_bprop_fn
|
||||||
|
|
||||||
__all__ = ['get_bprop_fn']
|
__all__ = ['get_bprop_fn']
|
||||||
|
|
|
@ -223,8 +223,8 @@ class BatchNormFold(PrimitiveWithInfer):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
momentum (float): Momentum value should be [0, 1]. Default: 0.1.
|
||||||
epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in
|
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
|
||||||
float32 else 1e-3. Default: 1e-12.
|
float32 else 1e-3. Default: 1e-5.
|
||||||
is_training (bool): In training mode set True, else set False. Default: True.
|
is_training (bool): In training mode set True, else set False. Default: True.
|
||||||
freeze_bn (int): Delay in steps at which computation switches from regular batch
|
freeze_bn (int): Delay in steps at which computation switches from regular batch
|
||||||
norm to frozen mean and std. Default: 0.
|
norm to frozen mean and std. Default: 0.
|
||||||
|
@ -247,7 +247,7 @@ class BatchNormFold(PrimitiveWithInfer):
|
||||||
channel = 1
|
channel = 1
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0):
|
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||||
"""init batch norm fold layer"""
|
"""init batch norm fold layer"""
|
||||||
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
|
||||||
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
|
||||||
|
@ -277,7 +277,7 @@ class BatchNormFoldGrad(PrimitiveWithInfer):
|
||||||
channel = 1
|
channel = 1
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0):
|
def __init__(self, epsilon=1e-5, is_training=True, freeze_bn=0):
|
||||||
"""init BatchNormGrad layer"""
|
"""init BatchNormGrad layer"""
|
||||||
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||||
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name)
|
||||||
|
|
|
@ -32,6 +32,7 @@ __all__ = ["build_train_network"]
|
||||||
|
|
||||||
class OutputTo16(nn.Cell):
|
class OutputTo16(nn.Cell):
|
||||||
"Wrap cell for amp. Cast network output back to float16"
|
"Wrap cell for amp. Cast network output back to float16"
|
||||||
|
|
||||||
def __init__(self, op):
|
def __init__(self, op):
|
||||||
super(OutputTo16, self).__init__(auto_prefix=False)
|
super(OutputTo16, self).__init__(auto_prefix=False)
|
||||||
self._op = op
|
self._op = op
|
||||||
|
@ -53,7 +54,7 @@ def _do_keep_batchnorm_fp32(network):
|
||||||
change = True
|
change = True
|
||||||
else:
|
else:
|
||||||
_do_keep_batchnorm_fp32(subcell)
|
_do_keep_batchnorm_fp32(subcell)
|
||||||
if isinstance(network, nn.SequentialCell) and change:
|
if isinstance(network, nn.SequentialCell) and change:
|
||||||
network.cell_list = list(network.cells())
|
network.cell_list = list(network.cells())
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,7 +73,7 @@ def _check_kwargs(key_words):
|
||||||
"""Check kwargs."""
|
"""Check kwargs."""
|
||||||
for arg in key_words:
|
for arg in key_words:
|
||||||
if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
|
if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']:
|
||||||
raise ValueError(f"Unsupported arg '{arg}'")
|
raise ValueError(f"Unsupported arg '{arg}'")
|
||||||
|
|
||||||
if 'cast_model_type' in key_words:
|
if 'cast_model_type' in key_words:
|
||||||
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
|
validator.check_type_name('cast_model_type', key_words['cast_model_type'],
|
||||||
|
|
|
@ -18,4 +18,16 @@ set -e
|
||||||
|
|
||||||
# Usage : get_shape_from_ir.sh ir_file
|
# Usage : get_shape_from_ir.sh ir_file
|
||||||
|
|
||||||
cat "$1" | perl -p -e 's/\n/NEWLINE/' | sed 's/NEWLINE :/:/g' | sed 's/Tensor NEWLINEshape//g' | perl -p -e 's/NEWLINE/\n/g' | perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' | perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' | perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' | perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' | tr -d '()' | awk '/subgraph/{p=1;next}{if(p){print}}'| awk '/return/{p=1;next}{if(!p){print}}' | sed '/^$/d' | awk -F'\t' '{print $1"\t"$2"\t"$4"\t"$3}'
|
cat "$1" | perl -p -e 's/\n/NEWLINE/' \
|
||||||
|
| sed 's/NEWLINE :/:/g' \
|
||||||
|
| sed 's/Tensor NEWLINEshape//g' \
|
||||||
|
| perl -p -e 's/NEWLINE/\n/g' \
|
||||||
|
| perl -p -e 's/<Array\[([\d\w]+)\]x\[[\w ]+\](\[[\d, ]*\])>/\2/g' \
|
||||||
|
| perl -p -e 's/<Tuple\[([\[\]\d\w\.\*]*)\]>/Tuple/g' \
|
||||||
|
| perl -p -e 's/ \%(\d+)\(.*= /\1\t/g' \
|
||||||
|
| perl -p -e 's/\(.*\)( \{.*\})*:/\t\1\t/g' \
|
||||||
|
| tr -d '()' \
|
||||||
|
| awk '/subgraph/{p=1;next}{if(p){print}}'\
|
||||||
|
| awk '/return/{p=1;next}{if(!p){print}}' \
|
||||||
|
| sed '/^$/d' \
|
||||||
|
| awk -F'\t' '{print $1"\t"$2"\t"$4}'
|
||||||
|
|
Loading…
Reference in New Issue