fix bn cast bug

This commit is contained in:
caifubi 2021-03-02 14:16:45 +08:00
parent 00f25c8409
commit a6959c2a13
4 changed files with 14 additions and 4 deletions

View File

@ -75,7 +75,7 @@ PrimitivePy::~PrimitivePy() {
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures;
set_has_signature(true);
set_has_signature(!signatures.empty());
}
py::function PrimitivePy::GetBpropFunction() {

View File

@ -1303,8 +1303,18 @@ class BatchNorm(PrimitiveWithInfer):
[ 1.00000000e+00, 1.00000000e+00]))
"""
__mindspore_signature__ = (
sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3),
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3)
)
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
if is_training is False:
self.set_signatures(tuple())
validator.check_value_type('is_training', is_training, (bool,), self.name)
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)

View File

@ -129,7 +129,7 @@ def test_sit_auto_mix_precision_model_o0():
model.train(1, dataset1, dataset_sink_mode=False)
contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend)
assert len(castnum) == 17
assert len(castnum) == 5
model.predict(Tensor(input_data))
contend = read_validateir_file('./test_amp_o0')
castnum = re.findall("Cast", contend)

View File

@ -109,8 +109,8 @@ class FusedBatchNorm(nn.Cell):
self.bn_train(x,
self.gamma,
self.beta,
None,
None)
self.moving_mean,
self.moving_variance)
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
temp_mean = self.mul_mean(mean_sub, self.momentum)