forked from mindspore-Ecosystem/mindspore
DDE issue:set dictionary element flag
This commit is contained in:
parent
dafae9e76e
commit
56e478e4f9
|
@ -730,13 +730,21 @@ void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new
|
|||
|
||||
SetSequenceElementsUseFlags(abs, new_flag);
|
||||
|
||||
// Check its elements if it's sequence node.
|
||||
// Check its elements if it's a sequence node.
|
||||
auto sequence_abs = dyn_cast<abstract::AbstractSequence>(abs);
|
||||
if (sequence_abs == nullptr) {
|
||||
if (sequence_abs != nullptr) {
|
||||
for (auto &element : sequence_abs->elements()) {
|
||||
SetSequenceElementsUseFlagsRecursively(element, new_flag);
|
||||
}
|
||||
return;
|
||||
}
|
||||
for (auto &element : sequence_abs->elements()) {
|
||||
SetSequenceElementsUseFlagsRecursively(element, new_flag);
|
||||
|
||||
// Check its elements if it's a dictionary node.
|
||||
auto dictionary_abs = dyn_cast<abstract::AbstractDictionary>(abs);
|
||||
if (dictionary_abs != nullptr) {
|
||||
for (auto &element : dictionary_abs->elements()) {
|
||||
SetSequenceElementsUseFlagsRecursively(element.second, new_flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@ import mindspore.context as context
|
|||
from mindspore import Tensor, dtype
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations as P
|
||||
import numpy as np
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support list as parameter in while function yet')
|
||||
|
@ -83,3 +84,40 @@ def test_for_list():
|
|||
x2 = Tensor([[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], dtype.float32)
|
||||
net = Net()
|
||||
print(net(x1, x2))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dictionary_list():
|
||||
"""
|
||||
Feature: dictionary list.
|
||||
Description: Infer list in dictionary.
|
||||
Expectation: Null.
|
||||
"""
|
||||
|
||||
class D3rNet(Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = Tensor(np.random.randn(300, 9).astype(np.float32))
|
||||
self.b = Tensor(np.random.randn(300).astype(np.float32))
|
||||
self.c = Tensor(np.ones([300]).astype(np.int32))
|
||||
|
||||
def construct(self, a, b, c):
|
||||
a_o = a * self.a
|
||||
b_o = b * self.b
|
||||
c_o = c * self.c
|
||||
pts = [[a_o, b_o, c_o]]
|
||||
bbox = []
|
||||
for i in pts:
|
||||
bbox.append({"ptx:": i})
|
||||
return bbox
|
||||
|
||||
a = Tensor(np.random.randn(300, 9).astype(np.float32))
|
||||
b = Tensor(np.random.randn(300).astype(np.float32))
|
||||
c = Tensor(np.ones([300]).astype(np.int32))
|
||||
net = D3rNet()
|
||||
bbox = net(a, b, c)
|
||||
print(bbox)
|
||||
|
|
Loading…
Reference in New Issue