get keys and values from dictionary & set tuple to dictionary

This commit is contained in:
simson 2020-10-29 15:13:49 +08:00
parent d2b1e783e7
commit 3b21822824
8 changed files with 140 additions and 0 deletions

View File

@ -304,6 +304,13 @@ AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
return inputs[2];
}
AnfNodePtr EraseDictGetValues(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &inputs = node->inputs();
MS_ASSERT(inputs.size() == 2 && "DictGetValues should have two inputs");
return inputs[1];
}
AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
const auto &inputs = node->inputs();
@ -374,6 +381,8 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
new_node = ConvertDictGetItemToTupleGetItem(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimDictSetItem)) {
new_node = ConvertDictSetItemToTupleSetItem(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimDictGetValues)) {
new_node = EraseDictGetValues(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeDict)) {
new_node = EraseMakeDictNode(cnode);
} else if (IsPrimitiveCNode(node, prim::kPrimMakeKeywordArg)) {

View File

@ -141,6 +141,8 @@ BuiltInTypeMap &GetMethodMap() {
{"__len__", prim::kPrimDictLen}, // P.dict_len
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"__bool__", std::string("dict_bool")} // C.dict_bool
}},
{kObjectTypeTensorType,

View File

@ -131,6 +131,10 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -249,6 +249,32 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
return std::make_shared<AbstractDictionary>(dict_elems);
}
AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList keys;
std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(keys),
[](const AbstractAttribute &item) { return std::make_shared<AbstractScalar>(item.first); });
return std::make_shared<AbstractTuple>(keys);
}
AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a dict.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
AbstractBasePtrList values;
std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(values),
[](const AbstractAttribute &item) { return item.second; });
return std::make_shared<AbstractTuple>(values);
}
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a list and an object of a subclass of AbstractBase.

View File

@ -72,6 +72,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimListSetItem, {InferImplListSetItem, true}},
{prim::kPrimDictGetItem, {InferImplDictGetItem, true}},
{prim::kPrimDictSetItem, {InferImplDictSetItem, true}},
{prim::kPrimDictGetKeys, {InferImplDictGetKeys, true}},
{prim::kPrimDictGetValues, {InferImplDictGetValues, true}},
{prim::kPrimListAppend, {InferImplListAppend, true}},
{prim::kPrimTupleLen, {InferImplTupleLen, true}},
{prim::kPrimListLen, {InferImplListLen, true}},

View File

@ -279,6 +279,8 @@ inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_g
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
inline const PrimitivePtr kPrimDictGetKeys = std::make_shared<Primitive>("dict_getkeys");
inline const PrimitivePtr kPrimDictGetValues = std::make_shared<Primitive>("dict_getvalues");
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");

View File

@ -132,6 +132,20 @@ def _dict_setitem_with_number(data, key, value):
"""
return F.dict_setitem(data, key, value)
@setitem.register("Dictionary", "String", "Tuple")
def _dict_setitem_with_tuple(data, key, value):
"""
Assigns value to dictionary.
Inputs:
data (dict): Data of type dict.
key (str): Key of the data.
value (Tuple): Value given.
Outputs:
dict, type is as same as the element type of data.
"""
return F.dict_setitem(data, key, value)
@setitem.register("Tensor", "Tensor", "Tensor")
def _tensor_setitem_by_tensor_with_tensor(data, index, value_tensor):

View File

@ -0,0 +1,81 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_dictionary """
import numpy as np
from mindspore import Tensor
from mindspore.nn import Cell
class Net1(Cell):
def __init__(self):
super().__init__()
def construct(self, x):
dic = {'x': 0, 'y': 1}
output = []
for i in dic.keys():
output.append(i)
for j in dic.values():
output.append(j)
return output
class Net2(Cell):
def __init__(self):
super().__init__()
def construct(self, x):
dic = {'x': x, 'y': 1}
output = []
for i in dic.keys():
output.append(i)
for j in dic.values():
output.append(j)
return output
class Net3(Cell):
def __init__(self):
super().__init__()
def construct(self, x):
dic = {'x': 0}
dic['y'] = (0, 1)
output = []
for i in dic.keys():
output.append(i)
for j in dic.values():
output.append(j)
return output
def test_dict1():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net1()
out_me = net(input_me)
assert out_me == ('x', 'y', 0, 1)
def test_dict2():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net2()
net(input_me)
def test_dict3():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
net = Net3()
out_me = net(input_me)
assert out_me == ('x', 'y', 0, (0, 1))