forked from mindspore-Ecosystem/mindspore
!49589 [1.10] Fix dic.get error
Merge pull request !49589 from huanghui/r1.10-fix-dict-get
This commit is contained in:
commit
cce56aa486
|
@ -0,0 +1,179 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/operator/composite/dict_operation.h"
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
|
||||
#include "abstract/param_validator.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "include/common/pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
FuncGraphPtr DictClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
constexpr size_t dict_clear_args_size = 1;
|
||||
abstract::CheckArgsSize("DictClear", args_list, dict_clear_args_size);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("clear");
|
||||
(void)ret->add_parameter();
|
||||
|
||||
auto empty_dict = std::vector<std::pair<std::string, ValuePtr>>();
|
||||
ret->set_output(NewValueNode(std::make_shared<ValueDictionary>(empty_dict)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr DictHasKey::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
constexpr size_t dict_has_key_args_size = 2;
|
||||
abstract::CheckArgsSize("DictHasKey", args_list, dict_has_key_args_size);
|
||||
|
||||
auto dict = dyn_cast<abstract::AbstractDictionary>(args_list[0]);
|
||||
ValuePtr key_value = args_list[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dict);
|
||||
MS_EXCEPTION_IF_NULL(key_value);
|
||||
if (!key_value->isa<StringImm>()) {
|
||||
MS_LOG(EXCEPTION) << "The key should be string, but got " << key_value->ToString();
|
||||
}
|
||||
|
||||
auto key_str = GetValue<std::string>(key_value);
|
||||
auto dict_elems = dict->elements();
|
||||
bool has_key = false;
|
||||
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||
[&key_str](const abstract::AbstractAttribute &item) { return item.first == key_str; });
|
||||
if (it != dict_elems.end()) {
|
||||
has_key = true;
|
||||
}
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("has_key");
|
||||
(void)ret->add_parameter();
|
||||
(void)ret->add_parameter();
|
||||
|
||||
auto out = NewValueNode(MakeValue(has_key));
|
||||
ret->set_output(out);
|
||||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr DictUpdate::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
constexpr size_t dict_update_args_size = 2;
|
||||
abstract::CheckArgsSize("DictUpdate", args_list, dict_update_args_size);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("update");
|
||||
|
||||
AnfNodePtrList key_inputs;
|
||||
AnfNodePtrList value_inputs;
|
||||
(void)key_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
(void)value_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
|
||||
std::unordered_map<std::string, size_t> hash_map;
|
||||
AddNodeToLists(args_list[0], ret, &key_inputs, &value_inputs, &hash_map);
|
||||
AddNodeToLists(args_list[1], ret, &key_inputs, &value_inputs, &hash_map);
|
||||
|
||||
ret->set_output(ret->NewCNode(
|
||||
{NewValueNode(prim::kPrimMakeDict), ret->NewCNode(std::move(key_inputs)), ret->NewCNode(std::move(value_inputs))}));
|
||||
return ret;
|
||||
}
|
||||
|
||||
void DictUpdate::AddNodeToLists(const AbstractBasePtr &arg, const FuncGraphPtr &ret, AnfNodePtrList *keys,
|
||||
AnfNodePtrList *values, std::unordered_map<std::string, size_t> *hash_map) {
|
||||
auto dict = dyn_cast<abstract::AbstractDictionary>(arg);
|
||||
MS_EXCEPTION_IF_NULL(dict);
|
||||
auto &dict_elems = dict->elements();
|
||||
auto arg_node = ret->add_parameter();
|
||||
|
||||
for (const auto &elem : dict_elems) {
|
||||
AnfNodePtr dict_value = ret->NewCNode({NewValueNode(prim::kPrimDictGetItem), arg_node, NewValueNode(elem.first)});
|
||||
|
||||
auto map_find = hash_map->find(elem.first);
|
||||
if (map_find == hash_map->end()) {
|
||||
hash_map->insert(std::make_pair(elem.first, values->size()));
|
||||
(void)keys->emplace_back(NewValueNode(elem.first));
|
||||
(void)values->emplace_back(dict_value);
|
||||
} else {
|
||||
values->at(map_find->second) = dict_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr DictFromKeys::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
constexpr size_t dict_fromkeys_args_size = 3;
|
||||
abstract::CheckArgsSize("DictFromKeys", args_list, dict_fromkeys_args_size);
|
||||
const auto &values = ParseIterableObject(args_list[1]);
|
||||
auto value_node = args_list[2]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value_node);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("fromkeys");
|
||||
(void)ret->add_parameter();
|
||||
(void)ret->add_parameter();
|
||||
(void)ret->add_parameter();
|
||||
|
||||
std::vector<std::pair<std::string, ValuePtr>> key_values;
|
||||
for (auto &value : values) {
|
||||
auto key = value->BuildValue();
|
||||
if (!key->IsSameTypeId(StringImm::kTypeId)) {
|
||||
MS_LOG(EXCEPTION) << "The key should be string, but got " << key->type_name();
|
||||
}
|
||||
|
||||
std::string key_node = GetValue<std::string>(key);
|
||||
(void)key_values.emplace_back(std::make_pair(key_node, value_node));
|
||||
}
|
||||
|
||||
ret->set_output(NewValueNode(std::make_shared<ValueDictionary>(key_values)));
|
||||
return ret;
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtrList DictFromKeys::ParseIterableObject(const abstract::AbstractBasePtr &arg_key) {
|
||||
auto key_type = arg_key->BuildType();
|
||||
if (key_type->IsSameTypeId(List::kTypeId) || key_type->IsSameTypeId(Tuple::kTypeId)) {
|
||||
abstract::AbstractSequencePtr dict_keys = dyn_cast<abstract::AbstractSequence>(arg_key);
|
||||
MS_EXCEPTION_IF_NULL(dict_keys);
|
||||
return dict_keys->elements();
|
||||
}
|
||||
if (key_type->IsSameTypeId(Dictionary::kTypeId)) {
|
||||
auto dict_keys = dyn_cast<abstract::AbstractDictionary>(arg_key);
|
||||
MS_EXCEPTION_IF_NULL(dict_keys);
|
||||
AbstractBasePtrList keys;
|
||||
auto &dict_elems = dict_keys->elements();
|
||||
(void)std::transform(
|
||||
dict_elems.begin(), dict_elems.end(), std::back_inserter(keys),
|
||||
[](const abstract::AbstractAttribute &item) { return std::make_shared<abstract::AbstractScalar>(item.first); });
|
||||
return keys;
|
||||
}
|
||||
if (key_type->IsSameTypeId(String::kTypeId)) {
|
||||
auto value = arg_key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
string dict_keys = value->ToString();
|
||||
AbstractBasePtrList keys;
|
||||
(void)std::transform(dict_keys.begin(), dict_keys.end(), std::back_inserter(keys), [](const char &item) {
|
||||
return std::make_shared<abstract::AbstractScalar>(std::string(1, item));
|
||||
});
|
||||
return keys;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << key_type->ToString() << " object is not iterable";
|
||||
}
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DICT_OPERATION_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DICT_OPERATION_H_
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
||||
#include "ir/meta_func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
class DictClear : public MetaFuncGraph {
|
||||
public:
|
||||
explicit DictClear(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~DictClear() override = default;
|
||||
MS_DECLARE_PARENT(DictClear, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const DictClear &dict_clear) {
|
||||
os << dict_clear.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const DictClear &lhs, const DictClear &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using DictClearPtr = std::shared_ptr<DictClear>;
|
||||
|
||||
class DictHasKey : public MetaFuncGraph {
|
||||
public:
|
||||
explicit DictHasKey(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~DictHasKey() override = default;
|
||||
MS_DECLARE_PARENT(DictHasKey, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const DictHasKey &dict_has_key) {
|
||||
os << dict_has_key.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const DictHasKey &lhs, const DictHasKey &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using DictHasKeyPtr = std::shared_ptr<DictHasKey>;
|
||||
|
||||
class DictUpdate : public MetaFuncGraph {
|
||||
public:
|
||||
explicit DictUpdate(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~DictUpdate() override = default;
|
||||
MS_DECLARE_PARENT(DictUpdate, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const DictUpdate &dict_update) {
|
||||
os << dict_update.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const DictUpdate &lhs, const DictUpdate &rhs) { return lhs.name_ == rhs.name_; }
|
||||
void AddNodeToLists(const AbstractBasePtr &arg, const FuncGraphPtr &ret, AnfNodePtrList *keys, AnfNodePtrList *values,
|
||||
std::unordered_map<std::string, size_t> *hash_map);
|
||||
};
|
||||
using DictUpdatePtr = std::shared_ptr<DictUpdate>;
|
||||
|
||||
class DictFromKeys : public MetaFuncGraph {
|
||||
public:
|
||||
explicit DictFromKeys(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~DictFromKeys() override = default;
|
||||
MS_DECLARE_PARENT(DictFromKeys, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const DictFromKeys &dict_from_keys) {
|
||||
os << dict_from_keys.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const DictFromKeys &lhs, const DictFromKeys &rhs) { return lhs.name_ == rhs.name_; }
|
||||
abstract::AbstractBasePtrList ParseIterableObject(const abstract::AbstractBasePtr &arg_key);
|
||||
};
|
||||
using DictFromKeysPtr = std::shared_ptr<DictFromKeys>;
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_DICT_OPERATION_H_
|
|
@ -18,6 +18,7 @@
|
|||
#include "frontend/operator/composite/composite.h"
|
||||
#include "include/common/pybind_api/api_register.h"
|
||||
#include "frontend/operator/composite/list_operation.h"
|
||||
#include "frontend/operator/composite/dict_operation.h"
|
||||
#include "frontend/operator/composite/map.h"
|
||||
#include "frontend/operator/composite/unpack_call.h"
|
||||
#include "frontend/operator/composite/vmap.h"
|
||||
|
@ -103,6 +104,22 @@ void RegCompositeOpsGroup(const py::module *m) {
|
|||
(void)py::class_<ListCount, MetaFuncGraph, std::shared_ptr<ListCount>>(*m, "ListCount_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg DictClear
|
||||
(void)py::class_<DictClear, MetaFuncGraph, std::shared_ptr<DictClear>>(*m, "DictClear_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg DictHasKey
|
||||
(void)py::class_<DictHasKey, MetaFuncGraph, std::shared_ptr<DictHasKey>>(*m, "DictHasKey_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg DictUpdate
|
||||
(void)py::class_<DictUpdate, MetaFuncGraph, std::shared_ptr<DictUpdate>>(*m, "DictUpdate_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg DictFromKeys
|
||||
(void)py::class_<DictFromKeys, MetaFuncGraph, std::shared_ptr<DictFromKeys>>(*m, "DictFromKeys_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg MapPy
|
||||
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
|
||||
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"), py::arg("ops"))
|
||||
|
|
|
@ -146,15 +146,19 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"__ms_iter__", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
|
||||
{"items", prim::kPrimDictItems}, // P.dict_items
|
||||
{"__bool__", std::string("dict_bool")}, // C.dict_bool
|
||||
{"get", std::string("dict_get")} // C.dict_get
|
||||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"__ms_iter__", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
|
||||
{"items", prim::kPrimDictItems}, // P.dict_items
|
||||
{"__bool__", std::string("dict_bool")}, // C.dict_bool
|
||||
{"get", std::string("dict_get")}, // C.dict_get
|
||||
{"has_key", std::string("dict_haskey")}, // C.dict_haskey
|
||||
{"clear", std::string("dict_clear")}, // C.dict_clear
|
||||
{"update", std::string("dict_update")}, // C.dict_update
|
||||
{"fromkeys", std::string("dict_fromkeys")} // C.dict_fromkeys
|
||||
}},
|
||||
{kObjectTypeTensorType,
|
||||
{
|
||||
|
|
|
@ -190,15 +190,13 @@ AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
|
||||
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a dict and a scalar whose value is a string.
|
||||
const std::string op_name = primitive->name();
|
||||
// dict[key] mean the size of args_spec_list is 2.
|
||||
// dict.get('key', default=None) mean the size of args_spec_list is 3.
|
||||
// dict.get('key', default_value=None) mean the size of args_spec_list is 2 too, the key will check in dict_get.
|
||||
constexpr int subscript_args_size = 2;
|
||||
constexpr int dict_get_arg_size = 3;
|
||||
if (args_spec_list.size() != subscript_args_size && args_spec_list.size() != dict_get_arg_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of input should be " << subscript_args_size << " or "
|
||||
<< dict_get_arg_size << ", but got " << args_spec_list.size();
|
||||
if (args_spec_list.size() != subscript_args_size) {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name << "', the number of input should be " << subscript_args_size
|
||||
<< ", but got " << args_spec_list.size();
|
||||
}
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
|
@ -214,13 +212,9 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||
if (it == dict_elems.end()) {
|
||||
// For dict[key], if key is not exist, will raise a KeyError exception.
|
||||
if (args_spec_list.size() == subscript_args_size) {
|
||||
MS_EXCEPTION(KeyError) << "The key " << key_str
|
||||
<< " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||
}
|
||||
// For dict.get('key', default=None), if key is not exist, will return the default value.
|
||||
constexpr int default_value_index = 2;
|
||||
return args_spec_list[default_value_index];
|
||||
// For dict.get('key', default=None), if key is not exist, will return the default value during dict_get.
|
||||
MS_EXCEPTION(KeyError) << "The key " << key_value->ToString()
|
||||
<< " does not exist in the dict:" << args_spec_list[0]->BuildValue()->ToString();
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
|
|
@ -93,7 +93,7 @@ SYNTAX_UNSUPPORTED_NAMESPACE = 4 # Unsupported namespace
|
|||
# Process expr statement white list
|
||||
# Add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
|
||||
parse_expr_statement_white_list = (
|
||||
"append", "insert", "clear", "reverse", "extend",
|
||||
"append", "insert", "clear", "reverse", "extend", "update",
|
||||
)
|
||||
|
||||
_builtin_function_or_method_type = type(abs)
|
||||
|
|
|
@ -30,6 +30,8 @@ from ...ops import functional as F
|
|||
from ...ops import operations as P
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
zeros_like, ones_like, repeat_elements
|
||||
from ...ops.composite.base import _append, _insert, _pop, _list_clear, _reverse, \
|
||||
_count, _extend, _dict_clear, _haskey, _update, _fromkeys
|
||||
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from ...ops.operations.math_ops import Median
|
||||
|
@ -2823,7 +2825,29 @@ def list_count(self_, value):
|
|||
|
||||
def dict_get(self_, key_index, default_value=None):
|
||||
"""Get value by key from dict"""
|
||||
return F.dict_getitem(self_, key_index, default_value)
|
||||
if not _haskey(self_, key_index):
|
||||
return default_value
|
||||
return F.dict_getitem(self_, key_index)
|
||||
|
||||
|
||||
def dict_clear(self_):
|
||||
"""Clear the dict"""
|
||||
return _dict_clear(self_)
|
||||
|
||||
|
||||
def dict_haskey(self_, key_index):
|
||||
"""Check if key is in dict"""
|
||||
return _haskey(self_, key_index)
|
||||
|
||||
|
||||
def dict_update(self_, dict_obj):
|
||||
"""Update the dict"""
|
||||
return _update(self_, dict_obj)
|
||||
|
||||
|
||||
def dict_fromkeys(self_, seq, value=None):
|
||||
"""Check if key is in dict"""
|
||||
return _fromkeys(self_, seq, value)
|
||||
|
||||
|
||||
#################
|
||||
|
|
|
@ -27,7 +27,8 @@ from mindspore import log as logger
|
|||
from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
||||
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
||||
ListClear_, ListReverse_, ListExtend_, ListCount_
|
||||
ListClear_, ListReverse_, ListExtend_, ListCount_, DictClear_, DictHasKey_, DictUpdate_, \
|
||||
DictFromKeys_
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import ms_function, _pynative_executor, _wrap_func
|
||||
from mindspore.ops.primitive import Primitive
|
||||
|
@ -1086,6 +1087,82 @@ class _ListCount(ListCount_):
|
|||
_count = _ListCount("count")
|
||||
|
||||
|
||||
class _DictClear(DictClear_):
|
||||
"""
|
||||
A metafuncgraph class that clear the dict.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _DictClear."""
|
||||
DictClear_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_dict_clear = _DictClear("clear")
|
||||
|
||||
|
||||
class _DictHasKey(DictHasKey_):
|
||||
"""
|
||||
A metafuncgraph class that Check if key is in dict.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _DictHasKey."""
|
||||
DictHasKey_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_haskey = _DictHasKey("has_key")
|
||||
|
||||
|
||||
class _DictUpdate(DictUpdate_):
|
||||
"""
|
||||
A metafuncgraph class that append another dict to the end of the dict.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _DictUpdate."""
|
||||
DictUpdate_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_update = _DictUpdate("update")
|
||||
|
||||
|
||||
class _DictFromKeys(DictFromKeys_):
|
||||
"""
|
||||
A metafuncgraph class that creates a new dict from the given sequence and value.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _DictFromKeys."""
|
||||
DictFromKeys_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_fromkeys = _DictFromKeys("fromkeys")
|
||||
|
||||
|
||||
class _Tail(Tail_):
|
||||
"""
|
||||
A metafuncgraph class that generates tail elements of the tuple.
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_dict_clear """
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_dict_clear_1():
|
||||
"""
|
||||
Feature: dict clear.
|
||||
Description: support dict clear.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_1():
|
||||
x = {'a': 1, 'b': 2}
|
||||
x.clear()
|
||||
return x
|
||||
out = dict_net_1()
|
||||
assert dict(out) == {}
|
||||
|
||||
|
||||
def test_dict_clear_2():
|
||||
"""
|
||||
Feature: dict clear.
|
||||
Description: support dict clear.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_2():
|
||||
x = {'a': [1, 2, 'aa'], 'b': 2, 'c': Tensor(1)}
|
||||
x.clear()
|
||||
return x
|
||||
out = dict_net_2()
|
||||
assert dict(out) == {}
|
||||
|
||||
|
||||
def test_dict_clear_3():
|
||||
"""
|
||||
Feature: dict clear.
|
||||
Description: support dict clear.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_3():
|
||||
x = {}
|
||||
x.clear()
|
||||
return x
|
||||
out = dict_net_3()
|
||||
assert dict(out) == {}
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_dict_fromkeys """
|
||||
import ast
|
||||
import pytest
|
||||
from mindspore import ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_dict_fromkeys_1():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_1():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = ['1', '2', '3']
|
||||
new_dict = x.fromkeys(y)
|
||||
return str(new_dict)
|
||||
out = dict_net_1()
|
||||
assert ast.literal_eval(out) == {'1': None, '2': None, '3': None}
|
||||
|
||||
|
||||
def test_dict_fromkeys_2():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_2():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = ('1', '2', '3')
|
||||
new_dict = x.fromkeys(y)
|
||||
return str(new_dict)
|
||||
out = dict_net_2()
|
||||
assert ast.literal_eval(out) == {'1': None, '2': None, '3': None}
|
||||
|
||||
|
||||
def test_dict_fromkeys_3():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_3():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
|
||||
new_dict = x.fromkeys(y.keys())
|
||||
return str(new_dict)
|
||||
out = dict_net_3()
|
||||
assert ast.literal_eval(out) == {'a': None, 'b': None, 'c': None, 'd': None}
|
||||
|
||||
|
||||
def test_dict_fromkeys_4():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_4():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = ['1', '2', "3"]
|
||||
new_dict = x.fromkeys(y, 123)
|
||||
return str(new_dict)
|
||||
out = dict_net_4()
|
||||
assert ast.literal_eval(out) == {'1': 123, '2': 123, '3': 123}
|
||||
|
||||
|
||||
def test_dict_fromkeys_5():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_5():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = ('1', '2', '3')
|
||||
new_dict = x.fromkeys(y, 123)
|
||||
return str(new_dict)
|
||||
out = dict_net_5()
|
||||
assert ast.literal_eval(out) == {'1': 123, '2': 123, '3': 123}
|
||||
|
||||
|
||||
def test_dict_fromkeys_6():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_6():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
|
||||
new_dict = x.fromkeys(y.keys(), 123)
|
||||
return str(new_dict)
|
||||
out = dict_net_6()
|
||||
assert ast.literal_eval(out) == {'a': 123, 'b': 123, 'c': 123, 'd': 123}
|
||||
|
||||
|
||||
def test_dict_fromkeys_7():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_7():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
|
||||
new_dict = x.fromkeys(y, 123)
|
||||
return str(new_dict)
|
||||
out = dict_net_7()
|
||||
assert ast.literal_eval(out) == {'a': 123, 'b': 123, 'c': 123, 'd': 123}
|
||||
|
||||
|
||||
def test_dict_fromkeys_8():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_8():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
|
||||
new_dict = x.fromkeys(y)
|
||||
return str(new_dict)
|
||||
out = dict_net_8()
|
||||
assert ast.literal_eval(out) == {'a': None, 'b': None, 'c': None, 'd': None}
|
||||
|
||||
|
||||
def test_dict_fromkeys_9():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_9():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = "abcd"
|
||||
new_dict = x.fromkeys(y)
|
||||
return str(new_dict)
|
||||
out = dict_net_9()
|
||||
assert ast.literal_eval(out) == {'a': None, 'b': None, 'c': None, 'd': None}
|
||||
|
||||
|
||||
def test_dict_fromkeys_10():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_10():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = "abcd"
|
||||
new_dict = x.fromkeys(y, 111)
|
||||
return str(new_dict)
|
||||
out = dict_net_10()
|
||||
assert ast.literal_eval(out) == {'a': 111, 'b': 111, 'c': 111, 'd': 111}
|
||||
|
||||
|
||||
def test_dict_fromkeys_11():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_11():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = 123
|
||||
new_dict = x.fromkeys(y, 111)
|
||||
return str(new_dict)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
out = dict_net_11()
|
||||
print(out)
|
||||
|
||||
|
||||
def test_dict_fromkeys_12():
|
||||
"""
|
||||
Feature: dict fromkeys.
|
||||
Description: support dict fromkeys.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_12():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = ['b', 1, 'c']
|
||||
new_dict = x.fromkeys(y, 111)
|
||||
return str(new_dict)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
out = dict_net_12()
|
||||
print(out)
|
|
@ -126,3 +126,21 @@ def test_dict_get_7():
|
|||
return Tensor(the_value)
|
||||
out = dict_net_7()
|
||||
assert (out.asnumpy() == (3, 4)).all()
|
||||
|
||||
|
||||
def test_dict_get_8():
|
||||
"""
|
||||
Feature: dict get.
|
||||
Description: support dict get set default value.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_8(x, y, z):
|
||||
dict_x = {"1": x, "2": y}
|
||||
default_value = dict_x.get("3", z)
|
||||
return default_value
|
||||
input_x = Tensor(1)
|
||||
input_y = Tensor(2)
|
||||
input_z = Tensor(3)
|
||||
out = dict_net_8(input_x, input_y, input_z)
|
||||
assert out == 3
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_dict_has_key """
|
||||
from mindspore import ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_dict_haskey_1():
|
||||
"""
|
||||
Feature: dict has_key.
|
||||
Description: support dict has_key.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_1():
|
||||
x = {'a': 1, 'b': 2}
|
||||
res = x.has_key('a')
|
||||
return res
|
||||
out = dict_net_1()
|
||||
assert out is True
|
||||
|
||||
|
||||
def test_dict_haskey_2():
|
||||
"""
|
||||
Feature: dict has_key.
|
||||
Description: support dict has_key.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_2():
|
||||
x = {'a': [2, 3, "123"], 'b': 2}
|
||||
res = x.has_key('a')
|
||||
return res
|
||||
out = dict_net_2()
|
||||
assert out is True
|
||||
|
||||
|
||||
def test_dict_haskey_3():
|
||||
"""
|
||||
Feature: dict has_key.
|
||||
Description: support dict has_key.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_3():
|
||||
x = {'a': 1, 'b': 2}
|
||||
res = x.has_key('c')
|
||||
return res
|
||||
out = dict_net_3()
|
||||
assert out is False
|
||||
|
||||
|
||||
def test_dict_haskey_4():
|
||||
"""
|
||||
Feature: dict has_key.
|
||||
Description: support dict has_key.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_4():
|
||||
x = {"a": 1, "b": 2, "cd": 3, "c": 4}
|
||||
res = x.has_key('c')
|
||||
return res
|
||||
out = dict_net_4()
|
||||
assert out is True
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_dict_update """
|
||||
import ast
|
||||
from mindspore import ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
def test_dict_update_1():
|
||||
"""
|
||||
Feature: dict update.
|
||||
Description: support dict update.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_1():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = {'c': 3, 'b': 4}
|
||||
x.update(y)
|
||||
return str(x)
|
||||
out = dict_net_1()
|
||||
assert ast.literal_eval(out) == {'a': 1, 'b': 4, 'c': 3}
|
||||
|
||||
|
||||
def test_dict_update_2():
|
||||
"""
|
||||
Feature: dict update.
|
||||
Description: support dict update.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_2():
|
||||
x = {'a': 1, 'b': 2, 'aa': 11, 'bb': 22}
|
||||
y = {'dd': {'ab': 12}, 'c': 3, 'b': "aaaa", 'ddd': [1, 2, 3]}
|
||||
x.update(y)
|
||||
return str(x)
|
||||
out = dict_net_2()
|
||||
assert ast.literal_eval(out) == {'a': 1, 'b': 'aaaa', 'aa': 11, 'bb': 22, 'dd': {'ab': 12},
|
||||
'c': 3, 'ddd': [1, 2, 3]}
|
||||
|
||||
|
||||
def test_dict_update_3():
|
||||
"""
|
||||
Feature: dict update.
|
||||
Description: support dict update.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_3():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = {'c': 3}
|
||||
x.update(y)
|
||||
return str(x)
|
||||
out = dict_net_3()
|
||||
assert ast.literal_eval(out) == {'a': 1, 'b': 2, 'c': 3}
|
||||
|
||||
|
||||
def test_dict_update_4():
|
||||
"""
|
||||
Feature: dict update.
|
||||
Description: support dict update.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def dict_net_4():
|
||||
x = {'a': ["aa", "bb"], 'b': 2}
|
||||
y = {'c': 3, "a": {"sub": "test"}}
|
||||
x.update(y)
|
||||
return str(x)
|
||||
out = dict_net_4()
|
||||
assert ast.literal_eval(out) == {'a': {"sub": "test"}, 'b': 2, 'c': 3}
|
Loading…
Reference in New Issue