!2631 Add ExpandDims to whitelist

Merge pull request !2631 from amongo/FixExpandDimsOps
This commit is contained in:
mindspore-ci-bot 2020-06-28 17:40:20 +08:00 committed by Gitee
commit bc30576ac9
3 changed files with 30 additions and 0 deletions

View File

@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
{prim::kPrimApplyRMSProp, {6, 7, 8}},
{prim::kPrimCumSum, {2}},
{prim::kPrimTile, {2}},
{prim::kPrimExpandDims, {2}},
{prim::kPrimHistogramSummary, {1}}});
for (auto &item : white_list) {
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {

View File

@ -30,6 +30,8 @@ class ControlDepend(Primitive):
tells the engine that the destination operations should depend on the source operation which means the source
operations should be executed before the destination.
Note:
This operation does not work in `PYNATIVE_MODE`.
Args:
depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0.

View File

@ -19,6 +19,8 @@ import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common import dtype as ms
from mindspore.common.api import _executor
@ -116,3 +118,28 @@ def test_parser_map_0002():
net = NetMap0002()
with pytest.raises(TypeError):
net(input_me_x)
def test_fix_expanddims_loss_scale():
class ControlOneIfOneScaleOneScale(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.ExpandDims()
def construct(self, x, y, data):
if x > y:
out = 1
else:
out = 2
if x > y:
out = self.op(data, out)
else:
out = self.op(data, out)
return out
net = ControlOneIfOneScaleOneScale()
x = Tensor(1, ms.float32)
y = Tensor(0, ms.float32)
input_shape = (1024, 512, 7, 7)
input_data = np.random.randn(*input_shape).astype(np.float32)
net = ControlOneIfOneScaleOneScale()
net(x, y, Tensor(input_data))