forked from mindspore-Ecosystem/mindspore
!6653 fix stream sync error for mixed precesion on pynative mode
Merge pull request !6653 from chujinjin/fix_stream_sync_error_for_mixed_precision
This commit is contained in:
commit
f72f2c22fb
|
@ -260,7 +260,7 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
|
|||
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
|
||||
auto tensor = py::cast<tensor::TensorPtr>(obj);
|
||||
auto cast_type = tensor->cast_dtype();
|
||||
py::object cast_output;
|
||||
py::object cast_output = obj;
|
||||
if (cast_type != nullptr) {
|
||||
auto source_element = tensor->Dtype();
|
||||
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
||||
|
@ -282,6 +282,8 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
|
|||
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
|
||||
} else if (py::isinstance<py::tuple>(tuple[i])) {
|
||||
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
|
||||
} else {
|
||||
result[i] = tuple[i];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
|
@ -609,6 +609,13 @@ class FusedBatchNorm(Primitive):
|
|||
>>> op = P.FusedBatchNorm()
|
||||
>>> output = op(input_x, scale, bias, mean, variance)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
|
||||
|
|
|
@ -1394,11 +1394,6 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
|
||||
'desc_bprop': [[2, 16], [16], [16]],
|
||||
'skip': ['backward']}),
|
||||
('FusedBatchNorm', {
|
||||
'block': P.FusedBatchNorm(),
|
||||
'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
|
||||
'skip': []}),
|
||||
('FusedBatchNormGrad', {
|
||||
'block': G.FusedBatchNormGrad(),
|
||||
'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],
|
||||
|
|
Loading…
Reference in New Issue