forked from mindspore-Ecosystem/mindspore
fix bn cast bug
This commit is contained in:
parent
00f25c8409
commit
a6959c2a13
|
@ -75,7 +75,7 @@ PrimitivePy::~PrimitivePy() {
|
|||
void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; }
|
||||
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
|
||||
signatures_ = signatures;
|
||||
set_has_signature(true);
|
||||
set_has_signature(!signatures.empty());
|
||||
}
|
||||
|
||||
py::function PrimitivePy::GetBpropFunction() {
|
||||
|
|
|
@ -1303,8 +1303,18 @@ class BatchNorm(PrimitiveWithInfer):
|
|||
[ 1.00000000e+00, 1.00000000e+00]))
|
||||
"""
|
||||
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3)
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
|
||||
if is_training is False:
|
||||
self.set_signatures(tuple())
|
||||
validator.check_value_type('is_training', is_training, (bool,), self.name)
|
||||
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
||||
validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
||||
|
|
|
@ -129,7 +129,7 @@ def test_sit_auto_mix_precision_model_o0():
|
|||
model.train(1, dataset1, dataset_sink_mode=False)
|
||||
contend = read_validateir_file('./test_amp_o0')
|
||||
castnum = re.findall("Cast", contend)
|
||||
assert len(castnum) == 17
|
||||
assert len(castnum) == 5
|
||||
model.predict(Tensor(input_data))
|
||||
contend = read_validateir_file('./test_amp_o0')
|
||||
castnum = re.findall("Cast", contend)
|
||||
|
|
|
@ -109,8 +109,8 @@ class FusedBatchNorm(nn.Cell):
|
|||
self.bn_train(x,
|
||||
self.gamma,
|
||||
self.beta,
|
||||
None,
|
||||
None)
|
||||
self.moving_mean,
|
||||
self.moving_variance)
|
||||
|
||||
mean_sub = self.sub_mean(self.moving_mean, batch_mean)
|
||||
temp_mean = self.mul_mean(mean_sub, self.momentum)
|
||||
|
|
Loading…
Reference in New Issue