forked from mindspore-Ecosystem/mindspore
!6653 fix stream sync error for mixed precesion on pynative mode
Merge pull request !6653 from chujinjin/fix_stream_sync_error_for_mixed_precision
This commit is contained in:
commit
f72f2c22fb
|
@ -260,7 +260,7 @@ py::object DoAutoCast(const py::object &arg, const TypeId &type_id) {
|
||||||
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
|
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) {
|
||||||
auto tensor = py::cast<tensor::TensorPtr>(obj);
|
auto tensor = py::cast<tensor::TensorPtr>(obj);
|
||||||
auto cast_type = tensor->cast_dtype();
|
auto cast_type = tensor->cast_dtype();
|
||||||
py::object cast_output;
|
py::object cast_output = obj;
|
||||||
if (cast_type != nullptr) {
|
if (cast_type != nullptr) {
|
||||||
auto source_element = tensor->Dtype();
|
auto source_element = tensor->Dtype();
|
||||||
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
|
||||||
|
@ -282,6 +282,8 @@ py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) {
|
||||||
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
|
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]);
|
||||||
} else if (py::isinstance<py::tuple>(tuple[i])) {
|
} else if (py::isinstance<py::tuple>(tuple[i])) {
|
||||||
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
|
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]);
|
||||||
|
} else {
|
||||||
|
result[i] = tuple[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
|
|
|
@ -609,6 +609,13 @@ class FusedBatchNorm(Primitive):
|
||||||
>>> op = P.FusedBatchNorm()
|
>>> op = P.FusedBatchNorm()
|
||||||
>>> output = op(input_x, scale, bias, mean, variance)
|
>>> output = op(input_x, scale, bias, mean, variance)
|
||||||
"""
|
"""
|
||||||
|
__mindspore_signature__ = (
|
||||||
|
sig.make_sig('input_x', dtype=sig.sig_dtype.T2),
|
||||||
|
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||||
|
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||||
|
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||||
|
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T),
|
||||||
|
)
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
|
def __init__(self, mode=0, epsilon=1e-5, momentum=0.1):
|
||||||
|
|
|
@ -1394,11 +1394,6 @@ test_case_nn_ops = [
|
||||||
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
|
'desc_inputs': [[2, 16], [2, 16], [2, 16], [2, 16], [16]],
|
||||||
'desc_bprop': [[2, 16], [16], [16]],
|
'desc_bprop': [[2, 16], [16], [16]],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
('FusedBatchNorm', {
|
|
||||||
'block': P.FusedBatchNorm(),
|
|
||||||
'desc_inputs': [[128, 64, 32, 64], [64], [64], [64], [64]],
|
|
||||||
'desc_bprop': [[128, 64, 32, 64], [64], [64], [64], [64]],
|
|
||||||
'skip': []}),
|
|
||||||
('FusedBatchNormGrad', {
|
('FusedBatchNormGrad', {
|
||||||
'block': G.FusedBatchNormGrad(),
|
'block': G.FusedBatchNormGrad(),
|
||||||
'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],
|
'desc_inputs': [[128, 64, 32, 64], [128, 64, 32, 64], [64], [64], [64]],
|
||||||
|
|
Loading…
Reference in New Issue