From 1cf8f3b77762e62b64fdae33be4ecb26ca520ffd Mon Sep 17 00:00:00 2001 From: chujinjin Date: Mon, 21 Sep 2020 17:32:34 +0800 Subject: [PATCH] fix stream sync error for mixed precision --- mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 4 +++- mindspore/ops/operations/nn_ops.py | 7 +++++++ tests/ut/python/ops/test_ops.py | 5 ----- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 50684c5d544..8363be959c2 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -260,7 +260,7 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { auto tensor = py::cast(obj); auto cast_type = tensor->cast_dtype(); - py::object cast_output; + py::object cast_output = obj; if (cast_type != nullptr) { auto source_element = tensor->Dtype(); if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { @@ -282,6 +282,8 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); } else if (py::isinstance(tuple[i])) { result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); + } else { + result[i] = tuple[i]; } } return result; diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 1c59328c748..76bb558e014 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -609,6 +609,13 @@ class FusedBatchNorm(Primitive): >>> op = P.FusedBatchNorm() >>> output = op(input_x, scale, bias, mean, variance) """ + __mindspore_signature__ = ( + sig.make_sig('input_x', dtype=sig.sig_dtype.T2), + sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + ) @prim_attr_register def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 767539dbfc9..1182840ddc6 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -1394,11 +1394,6 @@ test_case_nn_ops = [ 'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]], 'desc_bprop': [[2, 16], [16], [16]], 'skip': ['backward']}), - ('FusedBatchNorm', { - 'block': P.FusedBatchNorm(), - 'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]], - 'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]], - 'skip': []}), ('FusedBatchNormGrad', { 'block': G.FusedBatchNormGrad(), 'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],