forked from mindspore-Ecosystem/mindspore
!2631 Add ExpandDims to whitelist
Merge pull request !2631 from amongo/FixExpandDimsOps
This commit is contained in:
commit
bc30576ac9
|
@ -74,6 +74,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) {
|
|||
{prim::kPrimApplyRMSProp, {6, 7, 8}},
|
||||
{prim::kPrimCumSum, {2}},
|
||||
{prim::kPrimTile, {2}},
|
||||
{prim::kPrimExpandDims, {2}},
|
||||
{prim::kPrimHistogramSummary, {1}}});
|
||||
for (auto &item : white_list) {
|
||||
auto matched = std::any_of(item.second.begin(), item.second.end(), [&item, &node, &index](size_t idx) {
|
||||
|
|
|
@ -30,6 +30,8 @@ class ControlDepend(Primitive):
|
|||
tells the engine that the destination operations should depend on the source operation which means the source
|
||||
operations should be executed before the destination.
|
||||
|
||||
Note:
|
||||
This operation does not work in `PYNATIVE_MODE`.
|
||||
Args:
|
||||
depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0.
|
||||
|
||||
|
|
|
@ -19,6 +19,8 @@ import pytest
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common import dtype as ms
|
||||
from mindspore.common.api import _executor
|
||||
|
||||
|
||||
|
@ -116,3 +118,28 @@ def test_parser_map_0002():
|
|||
net = NetMap0002()
|
||||
with pytest.raises(TypeError):
|
||||
net(input_me_x)
|
||||
|
||||
|
||||
def test_fix_expanddims_loss_scale():
|
||||
class ControlOneIfOneScaleOneScale(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.op = P.ExpandDims()
|
||||
|
||||
def construct(self, x, y, data):
|
||||
if x > y:
|
||||
out = 1
|
||||
else:
|
||||
out = 2
|
||||
if x > y:
|
||||
out = self.op(data, out)
|
||||
else:
|
||||
out = self.op(data, out)
|
||||
return out
|
||||
net = ControlOneIfOneScaleOneScale()
|
||||
x = Tensor(1, ms.float32)
|
||||
y = Tensor(0, ms.float32)
|
||||
input_shape = (1024, 512, 7, 7)
|
||||
input_data = np.random.randn(*input_shape).astype(np.float32)
|
||||
net = ControlOneIfOneScaleOneScale()
|
||||
net(x, y, Tensor(input_data))
|
||||
|
|
Loading…
Reference in New Issue