DDE issue:set dictionary element flag

This commit is contained in:
lanzhineng 2022-05-18 17:37:00 +08:00
parent dafae9e76e
commit 56e478e4f9
2 changed files with 50 additions and 4 deletions

View File

@ -730,13 +730,21 @@ void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new
SetSequenceElementsUseFlags(abs, new_flag);
// Check its elements if it's sequence node.
// Check its elements if it's a sequence node.
auto sequence_abs = dyn_cast<abstract::AbstractSequence>(abs);
if (sequence_abs == nullptr) {
if (sequence_abs != nullptr) {
for (auto &element : sequence_abs->elements()) {
SetSequenceElementsUseFlagsRecursively(element, new_flag);
}
return;
}
for (auto &element : sequence_abs->elements()) {
SetSequenceElementsUseFlagsRecursively(element, new_flag);
// Check its elements if it's a dictionary node.
auto dictionary_abs = dyn_cast<abstract::AbstractDictionary>(abs);
if (dictionary_abs != nullptr) {
for (auto &element : dictionary_abs->elements()) {
SetSequenceElementsUseFlagsRecursively(element.second, new_flag);
}
}
}
} // namespace mindspore

View File

@ -18,6 +18,7 @@ import mindspore.context as context
from mindspore import Tensor, dtype
from mindspore.nn import Cell
import mindspore.ops.operations as P
import numpy as np
@pytest.mark.skip(reason='Not support list as parameter in while function yet')
@ -83,3 +84,40 @@ def test_for_list():
x2 = Tensor([[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], dtype.float32)
net = Net()
print(net(x1, x2))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dictionary_list():
"""
Feature: dictionary list.
Description: Infer list in dictionary.
Expectation: Null.
"""
class D3rNet(Cell):
def __init__(self):
super().__init__()
self.a = Tensor(np.random.randn(300, 9).astype(np.float32))
self.b = Tensor(np.random.randn(300).astype(np.float32))
self.c = Tensor(np.ones([300]).astype(np.int32))
def construct(self, a, b, c):
a_o = a * self.a
b_o = b * self.b
c_o = c * self.c
pts = [[a_o, b_o, c_o]]
bbox = []
for i in pts:
bbox.append({"ptx:": i})
return bbox
a = Tensor(np.random.randn(300, 9).astype(np.float32))
b = Tensor(np.random.randn(300).astype(np.float32))
c = Tensor(np.ones([300]).astype(np.int32))
net = D3rNet()
bbox = net(a, b, c)
print(bbox)