forked from mindspore-Ecosystem/mindspore
!1382 support create new item when dictionary set item key not exist
Merge pull request !1382 from amongo/DictionarySupportSetNewItem
This commit is contained in:
commit
906389b3a7
|
@ -24,6 +24,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "./common.h"
|
#include "./common.h"
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
|
#include "operator/composite/composite.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
/* namespace to support opt */
|
/* namespace to support opt */
|
||||||
|
@ -171,8 +172,10 @@ AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
|
||||||
count++;
|
count++;
|
||||||
}
|
}
|
||||||
if (IntToSize(count) >= cmap.size()) {
|
if (IntToSize(count) >= cmap.size()) {
|
||||||
MS_LOG(EXCEPTION) << "dictionary assignment key " << cons_str
|
// for dictionary set, if the key does not exist, we should create a new item
|
||||||
<< " does not exist, can not create new dictionary item for now.";
|
auto tuple_add_op = std::make_shared<prim::TupleAdd>("tuple_add");
|
||||||
|
auto tuple_new_item = node->func_graph()->NewCNode({NewValueNode(prim::kPrimMakeTuple), item_value});
|
||||||
|
return node->func_graph()->NewCNode({NewValueNode(tuple_add_op), data, tuple_new_item});
|
||||||
}
|
}
|
||||||
auto idx_c = NewValueNode(count);
|
auto idx_c = NewValueNode(count);
|
||||||
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
|
AbstractBasePtr aptr = std::make_shared<AbstractScalar>(std::make_shared<Int32Imm>(count));
|
||||||
|
|
|
@ -153,3 +153,19 @@ def test_dict_set_item():
|
||||||
x = Tensor(np.ones([2, 2, 3], np.float32))
|
x = Tensor(np.ones([2, 2, 3], np.float32))
|
||||||
net = DictSetNet()
|
net = DictSetNet()
|
||||||
out = net(x)
|
out = net(x)
|
||||||
|
|
||||||
|
|
||||||
|
# if the dictionary item does not exist, create a new one
|
||||||
|
def test_dict_set_item_create_new():
|
||||||
|
class DictSetNet(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(DictSetNet, self).__init__()
|
||||||
|
self.attrs = ("abc", "edf", "ghi", "jkl")
|
||||||
|
def construct(self, x):
|
||||||
|
my_dict = {"def": x}
|
||||||
|
for i in range(len(self.attrs)):
|
||||||
|
my_dict[self.attrs[i]] = x - i
|
||||||
|
return my_dict
|
||||||
|
x = Tensor(np.ones([2, 2, 3], np.float32))
|
||||||
|
net = DictSetNet()
|
||||||
|
out = net(x)
|
||||||
|
|
Loading…
Reference in New Issue