fix stream sync error for mixed precision

This commit is contained in:
chujinjin 2020-09-21 17:32:34 +08:00
parent 0037afee74
commit 1cf8f3b777
3 changed files with 10 additions and 6 deletions

View File

@ -260,7 +260,7 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
auto tensor = py::cast<tensor::TensorPtr>(obj);
auto cast_type = tensor->cast_dtype();
py::object cast_output;
py::object cast_output = obj;
if (cast_type != nullptr) {
auto source_element = tensor->Dtype();
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
@ -282,6 +282,8 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
} else if (py::isinstance<py::tuple>(tuple[i])) {
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
} else {
result[i] = tuple[i];
}
}
return result;

View File

@ -609,6 +609,13 @@ class FusedBatchNorm(Primitive):
>>> op = P.FusedBatchNorm()
>>> output = op(input_x, scale, bias, mean, variance)
"""
__mindspore_signature__ = (
sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
)
@prim_attr_register
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):

View File

@ -1394,11 +1394,6 @@ test_case_nn_ops = [
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
'desc_bprop': [[2, 16], [16], [16]],
'skip': ['backward']}),
('FusedBatchNorm', {
'block': P.FusedBatchNorm(),
'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
'skip': []}),
('FusedBatchNormGrad', {
'block': G.FusedBatchNormGrad(),
'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],