diff --git a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc index 949b64bfa6e..726f4a28b01 100644 --- a/mindspore/ccsrc/optimizer/irpass/branch_culling.cc +++ b/mindspore/ccsrc/optimizer/irpass/branch_culling.cc @@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { {prim::kPrimApplyRMSProp, {6, 7, 8}}, {prim::kPrimCumSum, {2}}, {prim::kPrimTile, {2}}, + {prim::kPrimExpandDims, {2}}, {prim::kPrimHistogramSummary, {1}}}); for (auto &item : white_list) { auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) { diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index 3e980ce5703..7ef6da2e913 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -30,6 +30,8 @@ class ControlDepend(Primitive): tells the engine that the destination operations should depend on the source operation which means the source operations should be executed before the destination. + Note: + This operation does not work in `PYNATIVE_MODE`. Args: depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0. diff --git a/tests/ut/python/pipeline/parse/test_fix_bug.py b/tests/ut/python/pipeline/parse/test_fix_bug.py index 9b013f95a4e..59e5fdd5de6 100644 --- a/tests/ut/python/pipeline/parse/test_fix_bug.py +++ b/tests/ut/python/pipeline/parse/test_fix_bug.py @@ -19,6 +19,8 @@ import pytest import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.common import dtype as ms from mindspore.common.api import _executor @@ -116,3 +118,28 @@ def test_parser_map_0002(): net = NetMap0002() with pytest.raises(TypeError): net(input_me_x) + + +def test_fix_expanddims_loss_scale(): + class ControlOneIfOneScaleOneScale(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.ExpandDims() + + def construct(self, x, y, data): + if x > y: + out = 1 + else: + out = 2 + if x > y: + out = self.op(data, out) + else: + out = self.op(data, out) + return out + net = ControlOneIfOneScaleOneScale() + x = Tensor(1, ms.float32) + y = Tensor(0, ms.float32) + input_shape = (1024, 512, 7, 7) + input_data = np.random.randn(*input_shape).astype(np.float32) + net = ControlOneIfOneScaleOneScale() + net(x, y, Tensor(input_data))