!46865 [JIT Fallback] Support return Python dict in top func graph.

Merge pull request !46865 from 张清华/opt_jit_fallback
This commit is contained in:
i-robot 2022-12-23 12:09:26 +00:00 committed by Gitee
commit 904d8e6873
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
30 changed files with 855 additions and 123 deletions

View File

@ -127,7 +127,13 @@ tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args); auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
if (i >= input_device_address->size() || input_device_address->at(i) == nullptr) { if (i >= input_device_address->size() || input_device_address->at(i) == nullptr) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->fullname_with_scope(); if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
MS_LOG(INFO) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
<< node->fullname_with_scope();
return out_tensor;
}
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
<< node->fullname_with_scope();
} }
out_tensor->data_sync_directly(input_device_address->at(i)); out_tensor->data_sync_directly(input_device_address->at(i));

View File

@ -848,10 +848,10 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
return rectify_abs_list; return rectify_abs_list;
} }
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive, AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
const AbstractBasePtrList &input_abstract) { const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(prim);
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes); auto dynamic_inputs_list = prim->GetAttr(kAttrDynInputSizes);
if (dynamic_inputs_list == nullptr) { if (dynamic_inputs_list == nullptr) {
return input_abstract; return input_abstract;
} }
@ -873,6 +873,10 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitiv
AbstractBasePtrList dynamic_inputs_abs; AbstractBasePtrList dynamic_inputs_abs;
for (auto index = item; index > 0; --index) { for (auto index = item; index > 0; --index) {
if (input_index >= input_abstract.size()) { if (input_index >= input_abstract.size()) {
// Not to check for PyExecute.
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
continue;
}
MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract " MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract "
<< input_abstract.size(); << input_abstract.size();
} }

View File

@ -41,11 +41,11 @@ namespace mindspore {
namespace prim { namespace prim {
constexpr auto kStepDefault = 1; constexpr auto kStepDefault = 1;
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractBasePtr; using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr; using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractEllipsis; using mindspore::abstract::AbstractEllipsis;
using mindspore::abstract::AbstractEllipsisPtr; using mindspore::abstract::AbstractEllipsisPtr;
using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunction;

View File

@ -49,12 +49,12 @@ FuncGraphPtr DictHasKey::GenerateFuncGraph(const abstract::AbstractBasePtrList &
ValuePtr key_value = args_list[1]->BuildValue(); ValuePtr key_value = args_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(dict); MS_EXCEPTION_IF_NULL(dict);
MS_EXCEPTION_IF_NULL(key_value); MS_EXCEPTION_IF_NULL(key_value);
auto dict_elems = dict->elements(); auto elems = dict->elements();
bool has_key = false; bool has_key = false;
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const abstract::AbstractAttribute &item) { auto it = std::find_if(elems.cbegin(), elems.cend(), [&key_value](const abstract::AbstractElementPair &item) {
return *key_value == *item.first->BuildValue(); return *key_value == *item.first->BuildValue();
}); });
if (it != dict_elems.cend()) { if (it != elems.cend()) {
has_key = true; has_key = true;
} }
@ -153,7 +153,7 @@ abstract::AbstractBasePtrList DictFromKeys::ParseIterableObject(const abstract::
AbstractBasePtrList keys; AbstractBasePtrList keys;
auto &dict_elems = dict_keys->elements(); auto &dict_elems = dict_keys->elements();
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys), std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
[](const abstract::AbstractAttribute &item) { return item.first; }); [](const abstract::AbstractElementPair &item) { return item.first; });
return keys; return keys;
} }
if (key_type->IsSameTypeId(String::kTypeId)) { if (key_type->IsSameTypeId(String::kTypeId)) {

View File

@ -28,10 +28,10 @@
namespace mindspore { namespace mindspore {
// namespace to support composite operators definition // namespace to support composite operators definition
namespace prim { namespace prim {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr; using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractKeywordArg; using mindspore::abstract::AbstractKeywordArg;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
@ -78,7 +78,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_l
auto dict_elems = arg_dict->elements(); auto dict_elems = arg_dict->elements();
(void)std::transform( (void)std::transform(
dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(elems), dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(elems),
[res_graph, para_dict](const AbstractAttribute &item) { [res_graph, para_dict](const AbstractElementPair &item) {
// Dict_elems's first element represents parameter names, which should be string type. // Dict_elems's first element represents parameter names, which should be string type.
auto key_value = GetValue<std::string>(item.first->BuildValue()); auto key_value = GetValue<std::string>(item.first->BuildValue());
auto dict_get_item = auto dict_get_item =

View File

@ -38,11 +38,11 @@
namespace mindspore { namespace mindspore {
/* namespace to support opt */ /* namespace to support opt */
namespace opt { namespace opt {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractBasePtr; using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr; using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr; using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractRowTensor; using mindspore::abstract::AbstractRowTensor;
@ -164,7 +164,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
public: public:
using ThisClass = SimplifyDataStructuresRewriter; using ThisClass = SimplifyDataStructuresRewriter;
SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager) SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
: BaseRewriter(root_graph, manager) {} : BaseRewriter(root_graph, manager), is_dict_output_{IsDictOutput()} {}
~SimplifyDataStructuresRewriter() override = default; ~SimplifyDataStructuresRewriter() override = default;
protected: protected:
@ -176,7 +176,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return str->value(); return str->value();
} }
static int64_t GetAttrIndex(const std::vector<AbstractAttribute> &attrs, const AnfNodePtr &name) { static int64_t GetElementIndex(const std::vector<AbstractElementPair> &attrs, const AnfNodePtr &name) {
auto n_attrs = attrs.size(); auto n_attrs = attrs.size();
auto name_abstract = GetAbstract<AbstractBase>(name); auto name_abstract = GetAbstract<AbstractBase>(name);
MS_EXCEPTION_IF_NULL(name_abstract); MS_EXCEPTION_IF_NULL(name_abstract);
@ -191,15 +191,15 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
} }
static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node, static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,
const std::vector<AbstractAttribute> &attributes, const AnfNodePtr &name_node) { const std::vector<AbstractElementPair> &elements, const AnfNodePtr &name_node) {
int64_t index = GetAttrIndex(attributes, name_node); int64_t index = GetElementIndex(elements, name_node);
auto index_node = NewValueNode(index); auto index_node = NewValueNode(index);
auto prim_node = NewValueNode(prim::kPrimTupleGetItem); auto prim_node = NewValueNode(prim::kPrimTupleGetItem);
return cnode->func_graph()->NewCNode({prim_node, data_node, index_node}); return cnode->func_graph()->NewCNode({prim_node, data_node, index_node});
} }
// From: // From:
// DictGetItem(data:AbstractDictionary, cons:AbstractBase) // DictGetItem(data:AbstractDictionary, key:AbstractBase)
// To: // To:
// TupleGetItem(data, index:Int64Imm) // TupleGetItem(data, index:Int64Imm)
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
@ -211,27 +211,98 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
CheckInputsSize(node, expect_inputs_size); CheckInputsSize(node, expect_inputs_size);
constexpr size_t data_index = 1; constexpr size_t data_index = 1;
constexpr size_t attr_index = 2; constexpr size_t key_index = 2;
const auto &inputs = node->inputs(); const auto &inputs = node->inputs();
auto &data = inputs[data_index]; auto &data = inputs[data_index];
auto &attr = inputs[attr_index]; auto &key = inputs[key_index];
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(attr); MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data); auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) { if (abs_dict == nullptr) {
return nullptr; return nullptr;
} }
return NewTupleGetCNode(node, data, abs_dict->elements(), attr); return NewTupleGetCNode(node, data, abs_dict->elements(), key);
}
// DictGetItem --> PyExecute()
AnfNodePtr RebuidDictGetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
// Inputs should be [dict_setitem, dict, item]
const size_t expect_inputs_size = 3;
CheckInputsSize(node, expect_inputs_size);
const size_t data_index = 1;
const size_t item_key_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[item_key_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// Script
constexpr auto internal_dict_self_str = "__internal_dict_self__";
constexpr auto internal_dict_key_str = "__internal_dict_key__";
std::stringstream script_buffer;
script_buffer << internal_dict_self_str << "[" << internal_dict_key_str << "]";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Pack local parameters keys.
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
// Pack the local parameters values, not support list, tuple, or dict.
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(data);
(void)key_value_list.emplace_back(key);
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
// Build the new dict node.
const auto dict_getitem_node = func_graph->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
int64_t index = GetElementIndex(abs_dict->elements(), key);
const auto &val = abs_dict->elements()[index].second;
const auto &tensor_val = dyn_cast<abstract::AbstractTensor>(val);
if (tensor_val != nullptr) {
const auto &tensor_type = tensor_val->element()->BuildType();
dict_getitem_node->set_user_data<Type>("__py_execute_tensor_type__", tensor_type);
const auto &tensor_shape = dyn_cast<abstract::Shape>(tensor_val->BuildShape());
MS_EXCEPTION_IF_NULL(tensor_shape);
dict_getitem_node->set_user_data<abstract::Shape>("__py_execute_tensor_shape__", tensor_shape);
MS_LOG(DEBUG) << "key: " << key->abstract()->BuildValue()->ToString() << ", type: " << tensor_type->ToString()
<< ", shape: " << tensor_shape->ToString() << ", val: " << tensor_val->ToString();
}
MS_LOG(DEBUG) << "Made dict getitem node: " << dict_getitem_node->DebugString();
return dict_getitem_node;
}
AnfNodePtr ConvertDictGetItem(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuidDictGetItem(node);
}
return ConvertDictGetItemToTupleGetItem(node);
} }
// From: // From:
// DictSetItem(data:AbstractDictionary, cons:AbstractBase, value) // DictSetItem(data:AbstractDictionary, key:AbstractBase, value)
// To: // To:
// TupleSetItem(data, index:Int64Imm, value) // TupleSetItem(data, index:Int64Imm, value)
// Or: // Or:
// tuple_add(data, value) // tuple_add(data, value)
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) { AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
@ -244,16 +315,16 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
const size_t item_value_index = 3; const size_t item_value_index = 3;
const auto &inputs = node->inputs(); const auto &inputs = node->inputs();
auto &data = inputs[data_index]; auto &data = inputs[data_index];
auto &cons = inputs[cons_index]; auto &key = inputs[cons_index];
auto &item_value = inputs[item_value_index]; auto &item_value = inputs[item_value_index];
MS_EXCEPTION_IF_NULL(data); MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(cons); MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data); auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) { if (abs_dict == nullptr) {
return nullptr; return nullptr;
} }
int64_t index = GetAttrIndex(abs_dict->elements(), cons); int64_t index = GetElementIndex(abs_dict->elements(), key);
auto func_graph = node->func_graph(); auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
if (index >= static_cast<int64_t>(abs_dict->elements().size())) { if (index >= static_cast<int64_t>(abs_dict->elements().size())) {
@ -275,11 +346,86 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return new_node; return new_node;
} }
bool IsDictOutput() const {
const AnfNodePtr &output = root_graph_->output();
auto abs_dict = GetAbstract<AbstractDictionary>(output);
if (abs_dict != nullptr) {
return true;
}
return false;
}
// DictSetItem --> PyExecute()
AnfNodePtr RebuidDictSetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
// Inputs should be [dict_setitem, dict, item, value]
const size_t expect_inputs_size = 4;
CheckInputsSize(node, expect_inputs_size);
const size_t data_index = 1;
const size_t item_key_index = 2;
const size_t item_value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[item_key_index];
auto &item_value = inputs[item_value_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// Script
constexpr auto internal_dict_self_str = "__internal_dict_self__";
constexpr auto internal_dict_key_str = "__internal_dict_key__";
constexpr auto internal_dict_value_str = "__internal_dict_value__";
std::stringstream script_buffer;
script_buffer << "__import__('mindspore').update_and_return_dict(" << internal_dict_self_str << ", "
<< internal_dict_key_str << ", " << internal_dict_value_str << ")";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Pack local parameters keys.
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_value_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
// Pack the local parameters values, not support list, tuple, or dict.
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(data);
(void)key_value_list.emplace_back(key);
(void)key_value_list.emplace_back(item_value);
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
// Build the new dict node.
const auto dict_setitem_node = func_graph->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict setitem node: " << dict_setitem_node->DebugString();
return dict_setitem_node;
}
AnfNodePtr ConvertDictSetItem(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuidDictSetItem(node);
}
return ConvertDictSetItemToTupleSetItem(node);
}
// From: // From:
// MakeDict(name, input) // MakeDict(name, input)
// To: // To:
// input // input
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { AnfNodePtr EraseMakeDictNode(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
constexpr size_t expect_inputs_size = 3; constexpr size_t expect_inputs_size = 3;
constexpr size_t input_index = 2; constexpr size_t input_index = 2;
@ -287,6 +433,62 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return node->input(input_index); return node->input(input_index);
} }
// MakeDict(keys, values) --> PyExecute('dict(zip(keys, values))', ...)
AnfNodePtr RebuildMakeDictNode(const CNodePtr &node) const {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
const auto &fg = node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
// Local parameters values.
// Pack the key tuple.
constexpr size_t values_input_index = 2;
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
const auto make_key_tuple_node =
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
NewValueNode(script_key_tuple_str), node->input(values_input_index)});
// Pack the value tuple.
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
const auto make_value_tuple_node =
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
NewValueNode(script_value_tuple_str), node->input(1)});
// Pack the local parameters values
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(make_key_tuple_node);
(void)key_value_list.emplace_back(make_value_tuple_node);
const auto key_value_tuple = fg->NewCNode(key_value_list);
// Pack local parameters keys.
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
// Script
std::stringstream script_buffer;
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Build the new dict node.
const auto make_dict_node = fg->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
return make_dict_node;
}
AnfNodePtr ConvertMakeDict(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuildMakeDictNode(node);
}
return EraseMakeDictNode(node);
}
// From: // From:
// DictGetValues(dict:AbstractDictionary) // DictGetValues(dict:AbstractDictionary)
// To: // To:
@ -356,22 +558,79 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
} }
// dict(k0:v0, k1:v1, ...) --> tuple(v0, v1, ...) // dict(k0:v0, k1:v1, ...) --> tuple(v0, v1, ...)
ValueTuplePtr DictToTuple(const ValueDictionaryPtr &dict) const { AnfNodePtr DictToTuple(const ValueDictionaryPtr &dict) const {
const auto &elements = dict->value(); const auto &keys_values = dict->value();
std::vector<ValuePtr> values; std::vector<ValuePtr> value_list;
values.reserve(elements.size()); value_list.reserve(keys_values.size());
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(values), (void)std::transform(keys_values.begin(), keys_values.end(), std::back_inserter(value_list),
[](const auto &element) { return element.second; }); [](const auto &value) { return value.second; });
return std::make_shared<ValueTuple>(values); return NewValueNode(std::make_shared<ValueTuple>(value_list));
}
// dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
AnfNodePtr RebuildValueDict(const ValueDictionaryPtr &dict) const {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
const auto &keys_values = dict->value();
std::vector<ValuePtr> key_list;
key_list.reserve(keys_values.size());
std::vector<ValuePtr> value_list;
value_list.reserve(keys_values.size());
for (const auto &key_value : keys_values) {
(void)key_list.emplace_back(key_value.first);
(void)value_list.emplace_back(key_value.second);
}
// Local parameters values.
// Pack the key tuple.
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
const auto make_key_tuple_node =
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
NewValueNode(script_key_tuple_str), NewValueNode(key_tuple)});
// Pack the value tuple.
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
const auto value_tuple = std::make_shared<ValueTuple>(value_list);
const auto make_value_tuple_node =
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
NewValueNode(script_value_tuple_str), NewValueNode(value_tuple)});
// Pack the local parameters values
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(make_key_tuple_node);
(void)key_value_list.emplace_back(make_value_tuple_node);
const auto key_value_tuple = root_graph_->NewCNode(key_value_list);
// Pack local parameters keys.
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = root_graph_->NewCNode(key_value_names_list);
// Script
std::stringstream script_buffer;
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Build the new dict node.
const auto make_dict_node = root_graph_->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
return make_dict_node;
} }
using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &); using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &);
using ConverterMap = mindspore::HashMap<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>; using ConverterMap = mindspore::HashMap<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
static inline const ConverterMap converters_{ static inline const ConverterMap converters_{
{prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItemToTupleGetItem}, {prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
{prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItemToTupleSetItem}, {prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
{prim::kPrimDictGetValues, &ThisClass::EraseDictGetValues}, {prim::kPrimDictGetValues, &ThisClass::EraseDictGetValues},
{prim::kPrimMakeDict, &ThisClass::EraseMakeDictNode}, {prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
{prim::kPrimMakeKeywordArg, &ThisClass::EraseMakeKeywordArgNode}, {prim::kPrimMakeKeywordArg, &ThisClass::EraseMakeKeywordArgNode},
{prim::kPrimExtractKeywordArg, &ThisClass::EraseExtractKeywordArg}, {prim::kPrimExtractKeywordArg, &ThisClass::EraseExtractKeywordArg},
{prim::kPrimDictItems, &ThisClass::EraseDictItems}, {prim::kPrimDictItems, &ThisClass::EraseDictItems},
@ -390,12 +649,16 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
AnfNodePtr ConvertValueNode(const ValueNodePtr &, const ValuePtr &value) override { AnfNodePtr ConvertValueNode(const ValueNodePtr &, const ValuePtr &value) override {
// Convert Dictionary value node. // Convert Dictionary value node.
if (value->isa<ValueDictionary>()) { if (value->isa<ValueDictionary>()) {
return NewValueNode(DictToTuple(value->cast<ValueDictionaryPtr>())); static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuildValueDict(value->cast<ValueDictionaryPtr>());
}
return DictToTuple(value->cast<ValueDictionaryPtr>());
} }
return nullptr; return nullptr;
} }
static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractAttribute> &attrs) { static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractElementPair> &attrs) {
std::vector<AbstractBasePtr> elements; std::vector<AbstractBasePtr> elements;
elements.reserve(attrs.size()); elements.reserve(attrs.size());
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(elements), (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(elements),
@ -465,6 +728,9 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
// AbstractDictionary --> AbstractSequence. // AbstractDictionary --> AbstractSequence.
return ConvertToAbstractSequence(abs, 0); return ConvertToAbstractSequence(abs, 0);
} }
private:
bool is_dict_output_{false};
}; };
// ================================================================== // ==================================================================
@ -495,9 +761,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
} }
// From: // From:
// ListGetItem(list, cons) // ListGetItem(list, key)
// To: // To:
// TupleGetItem(list, cons) // TupleGetItem(list, key)
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph()); MS_EXCEPTION_IF_NULL(node->func_graph());
@ -509,8 +775,8 @@ class CleanAfterOptARewriter : public BaseRewriter {
constexpr size_t cons_index = 2; constexpr size_t cons_index = 2;
const auto &inputs = node->inputs(); const auto &inputs = node->inputs();
auto &data = inputs[data_index]; auto &data = inputs[data_index];
auto &cons = inputs[cons_index]; auto &key = inputs[cons_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons}); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, key});
} }
// From: // From:
@ -530,9 +796,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
const size_t value_index = 3; const size_t value_index = 3;
const auto &inputs = node->inputs(); const auto &inputs = node->inputs();
auto &data = inputs[data_index]; auto &data = inputs[data_index];
auto &cons = inputs[cons_index]; auto &key = inputs[cons_index];
auto &value = inputs[value_index]; auto &value = inputs[value_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, key, value});
} }
// From: // From:

View File

@ -33,13 +33,20 @@ namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
py::object CallPythonPushGlobalParams(const py::object &dict) { py::object CallPythonPushGlobalParams(const py::object &dict) {
constexpr auto python_mod_parse = "mindspore._extends.parse"; constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
py::module mod = python_adapter::GetPyModule(python_mod_parse); // The same as PYTHON_MOD_PARSE_MODULE[] py::module mod = python_adapter::GetPyModule(python_mod_parse);
constexpr auto python_merge_dict = "merge_global_params"; constexpr auto python_merge_dict = "merge_global_params";
return python_adapter::CallPyModFn(mod, python_merge_dict, dict); return python_adapter::CallPyModFn(mod, python_merge_dict, dict);
} }
} // namespace } // namespace
// Convert PyInterpret into PyExecute:
// PyInterpret(script, global_dict, local_dict)
// -->
// PyExecute(script, local_dict_keys, local_dict_values),
// with side-effect operation:
// Push global_dict into global parameters list.
// (So it requires no same key name.)
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) { bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
auto manager = resource->manager(); auto manager = resource->manager();
const auto &all_nodes = manager->all_nodes(); const auto &all_nodes = manager->all_nodes();

View File

@ -1221,7 +1221,8 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
MS_LOG(DEBUG) << "call_cnode: " << call_cnode->DebugString() MS_LOG(DEBUG) << "call_cnode: " << call_cnode->DebugString()
<< ", call_function_node: " << call_function_node->DebugString(); << ", call_function_node: " << call_function_node->DebugString();
// Support tensor.asnumpy() in runtime by JIT Fallback. // Support tensor.asnumpy() in runtime by JIT Fallback.
if (IsPrimitiveCNode(call_function_node, prim::kPrimGetAttr)) { static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && IsPrimitiveCNode(call_function_node, prim::kPrimGetAttr)) {
constexpr size_t index_two = 2; constexpr size_t index_two = 2;
const auto &attr_node = call_function_node->cast<CNodePtr>()->input(index_two); const auto &attr_node = call_function_node->cast<CNodePtr>()->input(index_two);
const auto &attr_str = GetValueNode<StringImmPtr>(attr_node); const auto &attr_str = GetValueNode<StringImmPtr>(attr_node);
@ -1474,8 +1475,9 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body)); auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));
py::bool_ is_const_value = py::bool_ is_const_value =
ast()->CallParserObjMethod(PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE, value_str, common::SafeCStr(attr_str)); ast()->CallParserObjMethod(PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE, value_str, common::SafeCStr(attr_str));
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
auto is_constant = py::cast<bool>(is_const_value); auto is_constant = py::cast<bool>(is_const_value);
if (!is_constant || attr_str == "asnumpy") { if (!is_constant || (support_fallback_runtime && attr_str == "asnumpy")) {
UpdateInterpretForUserNode(attr_cnode, value_node); UpdateInterpretForUserNode(attr_cnode, value_node);
} }
} }

View File

@ -88,6 +88,10 @@ void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource)
} // namespace } // namespace
bool PyInterpretToExecutePass(const ResourcePtr &resource) { bool PyInterpretToExecutePass(const ResourcePtr &resource) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (!support_fallback_runtime) {
return true;
}
MS_EXCEPTION_IF_NULL(resource); MS_EXCEPTION_IF_NULL(resource);
FuncGraphPtr func_graph = resource->func_graph(); FuncGraphPtr func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -270,6 +270,33 @@ std::string ToOrdinal(const size_t &i) {
} }
return std::to_string(i) + suffix; return std::to_string(i) + suffix;
} }
py::object GetPyExecuteOutput(const AnfNodePtr &output) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime) {
std::function<AnfNodePtr(const AnfNodePtr &)> get_real_output = [&get_real_output](const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
const auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
return get_real_output(cnode->input(1));
}
return node;
};
const auto &real_output = get_real_output(output);
MS_LOG(INFO) << "Real output: " << real_output << ", " << real_output->DebugString()
<< ", has \'PyExecuteOutputData\': " << real_output->has_user_data<kernel::PyExecuteOutputData>();
if (real_output->has_user_data<kernel::PyExecuteOutputData>()) {
py::gil_scoped_acquire gil_acquire;
const auto &output_data = real_output->user_data<kernel::PyExecuteOutputData>();
py::object res_obj = output_data->obj;
MS_LOG(INFO) << "Has \'PyExecuteOutputData\', just return it. res_obj: " << res_obj;
if (!py::isinstance<py::none>(res_obj)) {
return res_obj;
}
}
}
return py::none();
}
} // namespace } // namespace
std::string GetObjDesc(const py::object &source_obj) { std::string GetObjDesc(const py::object &source_obj) {
@ -1313,14 +1340,9 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
} }
// Replace the output if it's not Tensor, but Python data. // Replace the output if it's not Tensor, but Python data.
if (output->has_user_data<kernel::PyExecuteOutputData>()) { const auto &py_res = GetPyExecuteOutput(output);
py::gil_scoped_acquire gil_acquire; if (py_res != py::none()) {
const auto &output_data = output->user_data<kernel::PyExecuteOutputData>(); return py_res;
py::object res_obj = output_data->obj;
MS_LOG(INFO) << "Has \'PyExecuteOutputData\', just return it. res_obj: " << res_obj;
if (!py::isinstance<py::none>(res_obj)) {
return res_obj;
}
} }
MS_LOG(DEBUG) << "Run end"; MS_LOG(DEBUG) << "Run end";

View File

@ -177,7 +177,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_a
auto arg_dict = specialize_args_before_unpack[index]->cast_ptr<AbstractDictionary>(); auto arg_dict = specialize_args_before_unpack[index]->cast_ptr<AbstractDictionary>();
auto dict_elems = arg_dict->elements(); auto dict_elems = arg_dict->elements();
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(graph_specialize_args), (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(graph_specialize_args),
[](const AbstractAttribute &item) { [](const AbstractElementPair &item) {
// Dict_elems's first element represents parameter names, which should be string type. // Dict_elems's first element represents parameter names, which should be string type.
return std::make_shared<AbstractKeywordArg>( return std::make_shared<AbstractKeywordArg>(
GetValue<std::string>(item.first->BuildValue()), item.second); GetValue<std::string>(item.first->BuildValue()), item.second);
@ -1817,9 +1817,21 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst
// Call python script string. // Call python script string.
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list; MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
ShapeVector shp; TypePtr type = kFloat32;
(void)shp.emplace_back(Shape::kShapeRankAny); if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
AbstractBasePtr res = std::make_shared<AbstractTensor>(kFloat64, std::make_shared<Shape>(shp)); type = current_interpret_node->user_data<Type>("__py_execute_tensor_type__");
MS_LOG(DEBUG) << "type: " << type->ToString();
}
BaseShapePtr shape;
if (current_interpret_node->has_user_data("__py_execute_tensor_shape__")) {
shape = current_interpret_node->user_data<BaseShape>("__py_execute_tensor_shape__");
MS_LOG(DEBUG) << "shape: " << shape->ToString();
} else {
ShapeVector shp;
(void)shp.emplace_back(Shape::kShapeRankAny);
shape = std::make_shared<Shape>(shp);
}
AbstractBasePtr res = std::make_shared<AbstractTensor>(type, shape);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>()); auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result); evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result; return infer_result;
@ -2246,10 +2258,17 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
MS_EXCEPTION_IF_NULL(local_abs_val); MS_EXCEPTION_IF_NULL(local_abs_val);
auto py_data_name = py::str(ValueToPyData(name->BuildValue())); auto py_data_name = py::str(ValueToPyData(name->BuildValue()));
if (local_abs_val == kAnyValue) { if (local_abs_val == kAnyValue) {
MS_LOG(INFO) << "When using JIT Fallback to handle script '" << script static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
<< "', the inputs should be constant, but found variable '" << py_data_name if (support_fallback_runtime) {
<< "' to be nonconstant. To convert to PyExecute() afterwards"; MS_LOG(INFO) << "When using JIT Fallback to handle script '" << script
non_const_err_ = true; << "', the inputs should be constant, but found variable '" << py_data_name
<< "' to be nonconstant. To convert to PyExecute() afterwards";
non_const_err_ = true;
} else {
MS_EXCEPTION(ValueError) << "When using JIT Fallback to handle script '" << script
<< "', the inputs should be constant, but found variable '" << py_data_name
<< "' to be nonconstant.";
}
} }
if (local_abs->isa<abstract::AbstractTensor>()) { if (local_abs->isa<abstract::AbstractTensor>()) {
MS_LOG(WARNING) << "When using JIT Fallback to handle script '" << script << "', found variable '" MS_LOG(WARNING) << "When using JIT Fallback to handle script '" << script << "', found variable '"
@ -2341,11 +2360,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const { AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
MS_EXCEPTION_IF_NULL(abstract_dict); MS_EXCEPTION_IF_NULL(abstract_dict);
std::vector<AbstractAttribute> kv; std::vector<AbstractElementPair> kv;
const auto &keys_values = abstract_dict->elements(); const auto &keys_values = abstract_dict->elements();
// Filter out the element of Function type. // Filter out the element of Function type.
(void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv), (void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) { [](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.second); MS_EXCEPTION_IF_NULL(item.second);
return (!item.second->isa<abstract::AbstractFunction>()); return (!item.second->isa<abstract::AbstractFunction>());
}); });

View File

@ -132,20 +132,106 @@ void PyExecuteCpuKernelMod::AttachPyOutputData(const py::object &py_res) {
const auto &iter = graph_output_map.find(anf_index); const auto &iter = graph_output_map.find(anf_index);
if (iter != graph_output_map.cend()) { if (iter != graph_output_map.cend()) {
const auto &front_node = iter->second.first; const auto &front_node = iter->second.first;
MS_LOG(INFO) << "Found front output for " << kernel_node_->DebugString(); MS_LOG(INFO) << "Found front output for " << kernel_node_ << ", " << kernel_node_->DebugString();
front_node->set_user_data<PyExecuteOutputData>(py_output); front_node->set_user_data<PyExecuteOutputData>(py_output);
} else { } else {
MS_LOG(DEBUG) << "Not found, kernel node is not output, " << kernel_node_->DebugString(); MS_LOG(DEBUG) << "Not found, kernel node is not output, " << kernel_node_ << ", " << kernel_node_->DebugString();
if (!IS_OUTPUT_ON(mindspore::kDebug)) {
return;
}
for (const auto &output_pair : graph_output_map) {
MS_EXCEPTION_IF_NULL(output_pair.first.first);
MS_EXCEPTION_IF_NULL(output_pair.second.first);
MS_LOG(DEBUG) << "backend node: " << output_pair.first.first << ", " << output_pair.first.first->DebugString()
<< ", front node: " << output_pair.second.first << ", " << output_pair.second.first->DebugString();
}
} }
} }
// Notice: Remove here after BE kernel supports tuple input.
py::object PyExecuteCpuKernelMod::BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs) {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto number_two = 2;
std::string tuple_key_str;
py::tuple local_tuple_inputs(inputs_info_.size() - number_two); // Exclude the script and key.
MS_LOG(DEBUG) << "Local parameter tuple size: " << (inputs_info_.size() - number_two);
bool tuple_input_start = false;
size_t tuple_index = 0;
py::dict local_tuple_dict;
for (size_t i = 1; i < inputs.size() && i < inputs_info_.size(); ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
const auto &input_abstract = input_info.abstract;
MS_EXCEPTION_IF_NULL(input_abstract);
const auto &input_type = input_abstract->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (!tuple_input_start && input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
if (str != internal_tuple_keys_str && str != internal_tuple_values_str) {
return py::none();
}
tuple_key_str = str;
tuple_input_start = true;
MS_LOG(DEBUG) << "String, key input[" << i << "]: " << input_abstract->ToString();
continue;
}
// Rebuild the tuple with all left inputs.
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
local_tuple_inputs[tuple_index++] = py::str(str);
MS_LOG(DEBUG) << "String, value input[" << i << "]: " << input_abstract->ToString();
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
const auto &py_array_value = input_info.py_obj_output;
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
MS_LOG(DEBUG) << "Tensor, value input[" << i << "]: " << input_abstract->ToString()
<< ", type: " << input_info.type << ", shape: " << input_info.shape << ", addr: " << inputs[i]->addr
<< ", size: " << inputs[i]->size << ", py_array_value: " << py_array_value
<< ", is_py_middle_data: " << is_py_middle_data;
if (!is_py_middle_data) {
const auto tensor =
std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
MS_EXCEPTION_IF_NULL(tensor);
local_tuple_inputs[tuple_index++] = tensor;
} else {
local_tuple_inputs[tuple_index++] = py_array_value;
}
} else {
MS_LOG(ERROR) << "Unsupported value type.";
}
}
local_tuple_dict[py::str(tuple_key_str)] = local_tuple_inputs;
return local_tuple_dict;
}
py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<AddressPtr> &inputs) { py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<AddressPtr> &inputs) {
const auto &local_tuple_params = BuildLocalTupleParameters(inputs);
if (local_tuple_params != py::none()) {
return local_tuple_params;
}
MS_LOG(DEBUG) << "Build normal local parameters.";
// Build local parameters dict. // Build local parameters dict.
std::vector<std::string> keys; std::vector<std::string> keys;
std::vector<tensor::TensorPtr> tensor_values; std::vector<tensor::TensorPtr> tensor_values;
std::vector<py::array> array_values; std::vector<py::object> py_object_values;
std::vector<bool> py_array_flags; std::vector<bool> py_array_flags;
for (size_t i = 1; i < inputs.size() && i < inputs_info_.size(); ++i) { constexpr auto number_two = 2;
size_t pair_size = (inputs_info_.size() - 1) / number_two;
// Handle the keys.
size_t i = 1;
for (; i < inputs.size() && i < pair_size + 1; ++i) {
const auto &input = inputs[i]; const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i]; const auto &input_info = inputs_info_[i];
@ -161,6 +247,29 @@ py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<Address
const auto &str = str_value->value(); const auto &str = str_value->value();
(void)keys.emplace_back(str); (void)keys.emplace_back(str);
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString(); MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
} else {
MS_LOG(EXCEPTION) << "Other, input[" << i << "]: " << input_abstract->ToString();
}
}
// Handle the values.
for (; i < inputs.size() && i < inputs_info_.size(); ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
const auto &input_abstract = input_info.abstract;
MS_EXCEPTION_IF_NULL(input_abstract);
const auto &input_type = input_abstract->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
(void)py_object_values.emplace_back(py::str(str));
(void)tensor_values.emplace_back(nullptr);
(void)py_array_flags.emplace_back(true);
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
} else if (input_abstract->isa<abstract::AbstractTensor>()) { } else if (input_abstract->isa<abstract::AbstractTensor>()) {
const auto &py_array_value = input_info.py_obj_output; const auto &py_array_value = input_info.py_obj_output;
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value); bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
@ -172,7 +281,7 @@ py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<Address
tensor = std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size); tensor = std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
} }
(void)array_values.emplace_back(py_array_value); (void)py_object_values.emplace_back(py_array_value);
(void)tensor_values.emplace_back(tensor); (void)tensor_values.emplace_back(tensor);
(void)py_array_flags.emplace_back(is_py_middle_data); (void)py_array_flags.emplace_back(is_py_middle_data);
} else if (input_abstract->isa<abstract::AbstractRefTensor>()) { } else if (input_abstract->isa<abstract::AbstractRefTensor>()) {
@ -181,21 +290,22 @@ py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<Address
MS_LOG(DEBUG) << "Other, input[" << i << "]: " << input_abstract->ToString(); MS_LOG(DEBUG) << "Other, input[" << i << "]: " << input_abstract->ToString();
} }
} }
constexpr auto number_two = 2;
if (keys.size() != tensor_values.size() || keys.size() != (inputs_info_.size() - 1) / number_two) { if (keys.size() != tensor_values.size() || keys.size() != pair_size) {
MS_LOG(EXCEPTION) << "The local dict input is invalid, " << keys.size() << ", " << tensor_values.size() << ", " MS_LOG(EXCEPTION) << "The local dict input is invalid, " << keys.size() << ", " << tensor_values.size() << ", "
<< inputs_info_.size(); << inputs_info_.size();
} }
// To call the script with global and local parameters. // To call the script with global and local parameters.
py::dict local_dict; py::dict local_dict;
for (size_t i = 0; i < keys.size(); ++i) { for (i = 0; i < keys.size(); ++i) {
if (py_array_flags[i]) { if (py_array_flags[i]) {
local_dict[py::str(keys[i])] = array_values[i]; local_dict[py::str(keys[i])] = py_object_values[i];
} else { } else {
local_dict[py::str(keys[i])] = tensor_values[i]; local_dict[py::str(keys[i])] = tensor_values[i];
} }
} }
MS_LOG(DEBUG) << "local_dict: " << local_dict;
return local_dict; return local_dict;
} }
@ -249,6 +359,14 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
MS_LOG(DEBUG) << "Real output is py::bool_, py_res: " << py_res; MS_LOG(DEBUG) << "Real output is py::bool_, py_res: " << py_res;
} else if (py::isinstance<py::str>(py_res)) { } else if (py::isinstance<py::str>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::str, py_res: " << py_res; MS_LOG(DEBUG) << "Real output is py::str, py_res: " << py_res;
} else if (py::isinstance<py::tuple>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::tuple, py_res: " << py_res;
} else if (py::isinstance<py::list>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::list, py_res: " << py_res;
} else if (py::isinstance<py::dict>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::dict, py_res: " << py_res;
} else if (py::isinstance<py::set>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::set, py_res: " << py_res;
} else { } else {
MS_LOG(EXCEPTION) << "The output is invalid, py_res: " << py_res; MS_LOG(EXCEPTION) << "The output is invalid, py_res: " << py_res;
} }

View File

@ -52,6 +52,7 @@ class PyExecuteCpuKernelMod : public DeprecatedNativeCpuKernelMod {
private: private:
void AttachPyOutputData(const py::object &py_res); void AttachPyOutputData(const py::object &py_res);
py::object BuildLocalParameters(const std::vector<AddressPtr> &inputs); py::object BuildLocalParameters(const std::vector<AddressPtr> &inputs);
py::object BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs);
CNodePtr kernel_node_{nullptr}; CNodePtr kernel_node_{nullptr};
std::vector<PyExecuteInputInfo> inputs_info_; std::vector<PyExecuteInputInfo> inputs_info_;

View File

@ -55,7 +55,7 @@ class PyExecuteInitializer {
MS_LOG(EXCEPTION) << "Value tuple should not be anyvalue."; MS_LOG(EXCEPTION) << "Value tuple should not be anyvalue.";
} }
const auto &values = dyn_cast<ValueSequence>(values_tuple); const auto &values = dyn_cast<ValueSequence>(values_tuple);
MS_LOG(ERROR) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString() MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
<< ", values_tuple: " << values_tuple->ToString(); << ", values_tuple: " << values_tuple->ToString();
py::gil_scoped_acquire gil_acquire; py::gil_scoped_acquire gil_acquire;
@ -90,13 +90,21 @@ class PyExecuteInitializer {
const auto &res_tensor = tensor::TensorPy::MakeTensorOfNumpy(py_res); const auto &res_tensor = tensor::TensorPy::MakeTensorOfNumpy(py_res);
MS_LOG(DEBUG) << "res_tensor: " << res_tensor->ToString(); MS_LOG(DEBUG) << "res_tensor: " << res_tensor->ToString();
} else if (py::isinstance<py::float_>(py_res)) { } else if (py::isinstance<py::float_>(py_res)) {
MS_LOG(ERROR) << "is py::float_, py_res: " << py_res; MS_LOG(DEBUG) << "is py::float_, py_res: " << py_res;
} else if (py::isinstance<py::int_>(py_res)) { } else if (py::isinstance<py::int_>(py_res)) {
MS_LOG(ERROR) << "is py::int_, py_res: " << py_res; MS_LOG(DEBUG) << "is py::int_, py_res: " << py_res;
} else if (py::isinstance<py::bool_>(py_res)) { } else if (py::isinstance<py::bool_>(py_res)) {
MS_LOG(ERROR) << "is py::bool_, py_res: " << py_res; MS_LOG(DEBUG) << "is py::bool_, py_res: " << py_res;
} else if (py::isinstance<py::str>(py_res)) { } else if (py::isinstance<py::str>(py_res)) {
MS_LOG(ERROR) << "is py::str, py_res: " << py_res; MS_LOG(DEBUG) << "is py::str, py_res: " << py_res;
} else if (py::isinstance<py::tuple>(py_res)) {
MS_LOG(DEBUG) << "is py::tuple, py_res: " << py_res;
} else if (py::isinstance<py::list>(py_res)) {
MS_LOG(DEBUG) << "is py::list, py_res: " << py_res;
} else if (py::isinstance<py::dict>(py_res)) {
MS_LOG(DEBUG) << "is py::dict, py_res: " << py_res;
} else if (py::isinstance<py::set>(py_res)) {
MS_LOG(DEBUG) << "is py::set, py_res: " << py_res;
} else { } else {
MS_LOG(EXCEPTION) << "py_res is invalid, py_res: " << py_res; MS_LOG(EXCEPTION) << "py_res is invalid, py_res: " << py_res;
} }

View File

@ -1350,9 +1350,9 @@ bool AbstractDictionary::operator==(const AbstractBase &other) const {
} }
AbstractBasePtr AbstractDictionary::Clone() const { AbstractBasePtr AbstractDictionary::Clone() const {
std::vector<AbstractAttribute> kv; std::vector<AbstractElementPair> kv;
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv), (void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) { [](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.first); MS_EXCEPTION_IF_NULL(item.first);
MS_EXCEPTION_IF_NULL(item.second); MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first->Clone(), item.second->Clone()); return std::make_pair(item.first->Clone(), item.second->Clone());
@ -1361,9 +1361,9 @@ AbstractBasePtr AbstractDictionary::Clone() const {
} }
AbstractBasePtr AbstractDictionary::Broaden() const { AbstractBasePtr AbstractDictionary::Broaden() const {
std::vector<AbstractAttribute> kv; std::vector<AbstractElementPair> kv;
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv), (void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) { [](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.second); MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first, item.second->Broaden()); return std::make_pair(item.first, item.second->Broaden());
}); });
@ -1384,7 +1384,7 @@ std::string AbstractDictionary::ToString() const {
std::size_t AbstractDictionary::hash() const { std::size_t AbstractDictionary::hash() const {
std::size_t hash_sum = std::accumulate(key_values_.cbegin(), key_values_.cend(), tid(), std::size_t hash_sum = std::accumulate(key_values_.cbegin(), key_values_.cend(), tid(),
[](std::size_t hash_sum, const AbstractAttribute &item) { [](std::size_t hash_sum, const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.first); MS_EXCEPTION_IF_NULL(item.first);
MS_EXCEPTION_IF_NULL(item.second); MS_EXCEPTION_IF_NULL(item.second);
hash_sum = hash_combine(hash_sum, item.first->hash()); hash_sum = hash_combine(hash_sum, item.first->hash());

View File

@ -1025,8 +1025,8 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
public: public:
/// \brief Constructor of AbstractDictionary. /// \brief Constructor of AbstractDictionary.
/// ///
/// \param[in] key_values The vector of AbstractAttribute. /// \param[in] key_values The vector of AbstractElementPair.
explicit AbstractDictionary(const std::vector<AbstractAttribute> &key_values) : key_values_(key_values) {} explicit AbstractDictionary(const std::vector<AbstractElementPair> &key_values) : key_values_(key_values) {}
/// \brief Destructor of AbstractDictionary. /// \brief Destructor of AbstractDictionary.
~AbstractDictionary() override = default; ~AbstractDictionary() override = default;
@ -1051,12 +1051,12 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
/// \brief Get the key values. /// \brief Get the key values.
/// ///
/// \return A vector of AbstractAttribute. /// \return A vector of AbstractElementPair.
const std::vector<AbstractAttribute> &elements() const { return key_values_; } const std::vector<AbstractElementPair> &elements() const { return key_values_; }
protected: protected:
ValuePtr RealBuildValue() const override; ValuePtr RealBuildValue() const override;
std::vector<AbstractAttribute> key_values_; std::vector<AbstractElementPair> key_values_;
}; };
using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>; using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>;

View File

@ -196,8 +196,8 @@ bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spe
ValuePtr key_value = key->BuildValue(); ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value); MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) { auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue(); return *key_value == *item.first->BuildValue();
}); });
return it != dict_elems.end(); return it != dict_elems.end();

View File

@ -69,7 +69,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
const auto &key = key_list[index]; const auto &key = key_list[index];
CheckDictKey(key, op_name); CheckDictKey(key, op_name);
} }
std::vector<AbstractAttribute> key_value; std::vector<AbstractElementPair> key_value;
AbstractBasePtrList value_list = values->elements(); AbstractBasePtrList value_list = values->elements();
for (size_t index = 0; index < keys_size; index++) { for (size_t index = 0; index < keys_size; index++) {
(void)key_value.emplace_back(key_list[index], value_list[index]); (void)key_value.emplace_back(key_list[index], value_list[index]);
@ -277,8 +277,8 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue(); ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value); MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) { auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue(); return *key_value == *item.first->BuildValue();
}); });
if (it == dict_elems.end()) { if (it == dict_elems.end()) {
@ -302,8 +302,8 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue(); ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value); MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) { auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue(); return *key_value == *item.first->BuildValue();
}); });
@ -325,10 +325,10 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP
constexpr int args_spec_size = 1; constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size); CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList keys; AbstractBasePtrList keys;
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys), std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
[](const AbstractAttribute &item) { return item.first; }); [](const AbstractElementPair &item) { return item.first; });
return std::make_shared<AbstractTuple>(keys); return std::make_shared<AbstractTuple>(keys);
} }
@ -339,10 +339,10 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv
constexpr int args_spec_size = 1; constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size); CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList values; AbstractBasePtrList values;
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(values), std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(values),
[](const AbstractAttribute &item) { return item.second; }); [](const AbstractElementPair &item) { return item.second; });
return std::make_shared<AbstractTuple>(values); return std::make_shared<AbstractTuple>(values);
} }
@ -353,10 +353,10 @@ AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr
constexpr int args_spec_size = 1; constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size); CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0); AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements(); std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList items; AbstractBasePtrList items;
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(items), (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(items),
[](const AbstractAttribute &item) { [](const AbstractElementPair &item) {
return std::make_shared<AbstractTuple>(AbstractBasePtrList{item.first, item.second}); return std::make_shared<AbstractTuple>(AbstractBasePtrList{item.first, item.second});
}); });
return std::make_shared<AbstractList>(items); return std::make_shared<AbstractList>(items);

View File

@ -231,7 +231,7 @@ using FuncGraphWeakPtr = std::weak_ptr<FuncGraph>;
namespace abstract { namespace abstract {
class AbstractBase; class AbstractBase;
using AbstractBasePtr = std::shared_ptr<AbstractBase>; using AbstractBasePtr = std::shared_ptr<AbstractBase>;
using AbstractAttribute = std::pair<AbstractBasePtr, AbstractBasePtr>; using AbstractElementPair = std::pair<AbstractBasePtr, AbstractBasePtr>;
class AnalysisContext; class AnalysisContext;
using AnalysisContextPtr = std::shared_ptr<AnalysisContext>; using AnalysisContextPtr = std::shared_ptr<AnalysisContext>;
} // namespace abstract } // namespace abstract

View File

@ -28,30 +28,28 @@ MIND_API_OPERATOR_IMPL(PyExecute, BaseOperator);
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive, BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const { const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
if (infer_handler_ == nullptr) {
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
}
// TODO(zh_qh): Will call 'infer_handler_(input_args)' check the abstract shape and type later.
ShapeVector out_shape = {1}; ShapeVector out_shape = {1};
return std::make_shared<abstract::Shape>(out_shape); return std::make_shared<abstract::Shape>(out_shape);
} }
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const { TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
// TODO(zh_qh): Will call 'infer_handler_(input_args)' check the abstract shape and type later.
return kFloat64; return kFloat64;
} }
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine, AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
const PrimitivePtr &primitive, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const { const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(DEBUG) << "item: " << item->ToString();
}
if (infer_handler_ == nullptr) {
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
}
infer_handler_(input_args);
const auto &type = InferType(primitive, input_args); const auto &type = InferType(primitive, input_args);
const auto &shape = InferShape(primitive, input_args); const auto &shape = InferShape(primitive, input_args);
const auto &abstract = MakeAbstract(shape, type); const auto &abstract = MakeAbstract(shape, type);

View File

@ -29,6 +29,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.sparse_tensor import RowTensor, RowTensorInner, SparseTensor, COOTensor, CSRTensor from mindspore.common.sparse_tensor import RowTensor, RowTensorInner, SparseTensor, COOTensor, CSRTensor
from mindspore.common.mutable import mutable from mindspore.common.mutable import mutable
from mindspore.common.jit_config import JitConfig from mindspore.common.jit_config import JitConfig
from mindspore.common._utils import update_and_return_dict
# symbols from dtype # symbols from dtype
__all__ = [ __all__ = [
@ -66,4 +67,5 @@ __all__.extend([
"set_dump", "set_dump",
"ms_memory_recycle", "ms_memory_recycle",
"mutable", "JitConfig", "mutable", "JitConfig",
"update_and_return_dict",
]) ])

View File

@ -53,3 +53,8 @@ def split_to_slice_if_need(dtype, shape):
return slice_num return slice_num
slice_num = math.ceil(data_size / emb_cache_size) slice_num = math.ceil(data_size / emb_cache_size)
return slice_num return slice_num
def update_and_return_dict(dic, key, val):
dic.__setitem__(key, val)
return dic

View File

@ -225,3 +225,9 @@ def bprop_scalar_not(x, out, dout):
def bprop_tensor_move(x, out, dout): def bprop_tensor_move(x, out, dout):
"""Backpropagator for primitive `TensorMove`.""" """Backpropagator for primitive `TensorMove`."""
return (dout,) return (dout,)
@bprops.register("PyExecute")
def get_bprop_py_execute(x, y, z, out, dout):
"""Generate bprop for PyExecute"""
return x, y, z

View File

@ -17,6 +17,7 @@ import pytest
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore.common.initializer import TruncatedNormal
ms.set_context(mode=ms.GRAPH_MODE) ms.set_context(mode=ms.GRAPH_MODE)
@ -92,3 +93,217 @@ def test_fallback_np_asnumpy():
const_output = ConstNet()() const_output = ConstNet()()
print(f'const_output: {const_output}') print(f'const_output: {const_output}')
np.testing.assert_almost_equal(output, const_output, 3) np.testing.assert_almost_equal(output, const_output, 3)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_return_1():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_1():
x = {'a': 'a', 'b': 'b'}
y = x.get('a')
z = dict(y=y)
return z
out = dict_net_1()
print(f'out: {out}')
assert out == {'y': 'a'}
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_get_1():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_1():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(a=y_tensor)
return z
out = dict_net_1()
print(f'out: {out}')
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_get_2():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_2():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(a=y_tensor, b='hello', c='world')
return z
out = dict_net_2()
print(f'out: {out}')
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_get_3():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_3():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(y=y_tensor, a='a', b='c')
return z
out = dict_net_3()
print(f'out: {out}')
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return ms.nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return ms.nn.Dense(input_channels, out_channels, weight, bias)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net_dict_1():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
class DictLeNetNet(ms.nn.Cell):
def __init__(self, num_class=10):
super(DictLeNetNet, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = ms.nn.ReLU()
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = ms.nn.Flatten()
def construct(self, x):
x = self.conv1(x)
conv1_x = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
conv2_x = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
fc_x = x
outputs = dict(conv1=conv1_x, conv2=conv2_x, fc=fc_x)
return outputs
net = DictLeNetNet()
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
outputs = net(x)
print(f'outputs: {outputs}')
assert outputs['conv1'].shape == (64, 6, 28, 28)
assert outputs['conv2'].shape == (64, 16, 10, 10)
assert outputs['fc'].shape == (64, 10)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net_dict_2():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
class DictLeNetNet(ms.nn.Cell):
def __init__(self, num_class=10):
super(DictLeNetNet, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = ms.nn.ReLU()
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = ms.nn.Flatten()
def construct(self, x):
outputs = dict()
x = self.conv1(x)
outputs['conv1'] = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
outputs['conv2'] = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
outputs['fc'] = x
return outputs
net = DictLeNetNet()
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
outputs = net(x)
print(f'outputs: {outputs}')
assert outputs['conv1'].shape == (64, 6, 28, 28)
assert outputs['conv2'].shape == (64, 16, 10, 10)
assert outputs['fc'].shape == (64, 10)

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test getting gradient of mutable input""" """test getting gradient of mutable input"""
import os
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
@ -160,10 +161,12 @@ def test_grad_mutable_dict_tensor():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(z) return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
t = mutable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), t = mutable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}) 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t) output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6], expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), [1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -350,11 +353,13 @@ def test_grad_mutable_tuple_dict_tensor():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(z) return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
t = mutable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), t = mutable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)}, 'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32))) Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t) output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6], expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -398,11 +403,13 @@ def test_grad_mutable_dict_tuple_tensor():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(z) return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
t = mutable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), t = mutable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)), Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}) 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t) output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6], expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -446,11 +453,13 @@ def test_grad_mutable_list_dict_tensor():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(z) return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
t = mutable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), t = mutable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)}, 'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]) Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t) output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6], expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -494,11 +503,13 @@ def test_grad_mutable_dict_list_tensor():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(z) return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
t = mutable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), t = mutable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)], Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}) 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t) output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6], expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -699,11 +710,13 @@ def test_grad_mutable_unused_dict_tensor():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(z) return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
t = mutable({'x1': Tensor([[4.0, 6.0, 6.0], [4.0, 6.0, 6.0]], dtype=mstype.float32), t = mutable({'x1': Tensor([[4.0, 6.0, 6.0], [4.0, 6.0, 6.0]], dtype=mstype.float32),
'x2': Tensor([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], dtype=mstype.float32), 'x2': Tensor([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], dtype=mstype.float32),
'x3': Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=mstype.float32)}) 'x3': Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t) output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [np.array([[3., 3., 3.], expect = [np.array([[3., 3., 3.],
[3., 3., 3.]]).astype(np.float32), [3., 3., 3.]]).astype(np.float32),
@ -746,10 +759,12 @@ def test_grad_mutable_single_element_dict_tensor():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(x, t) return gradient_function(x, t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32) x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = mutable({'a': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}) y = mutable({'a': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(x, y) output = GradNetWrtX(Net())(x, y)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6], expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), [1.4100001, 1.5999999, 6.6]]).astype(np.float32),

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""test the feature of mutable in graph""" """test the feature of mutable in graph"""
import os
import numpy as np import numpy as np
import pytest import pytest
import mindspore.nn as nn import mindspore.nn as nn
@ -456,8 +457,10 @@ def test_grad_const_dict_tensor_to_mutable():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(self.x) return gradient_function(self.x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
grad_net = GradNetWrtX(Net()) grad_net = GradNetWrtX(Net())
output = grad_net() output = grad_net()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6], expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), [1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -502,10 +505,12 @@ def test_grad_const_dict_tensor_arg_to_mutable():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(mutable(x)) return gradient_function(mutable(x))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
x = {'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), x = {'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)} 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}
grad_net = GradNetWrtX(Net()) grad_net = GradNetWrtX(Net())
output = grad_net(x) output = grad_net(x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6], expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), [1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -563,8 +568,10 @@ def test_grad_const_dict_and_tuple_tensor_to_mutable():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(self.x) return gradient_function(self.x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
grad_net = GradNetWrtX(Net()) grad_net = GradNetWrtX(Net())
output = grad_net() output = grad_net()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [(np.array([[1.4100001, 1.5999999, 6.6], expect = [(np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), [1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -574,8 +581,10 @@ def test_grad_const_dict_and_tuple_tensor_to_mutable():
[1.9, 1.9, 1.9], [1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)] [1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect) assert compare(output, expect)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
grad_net = GradNetWrtX1(Net()) grad_net = GradNetWrtX1(Net())
output = grad_net() output = grad_net()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
assert compare(output, expect) assert compare(output, expect)
@ -612,11 +621,13 @@ def test_grad_const_dict_and_tuple_tensor_arg_to_mutable():
gradient_function = self.grad_op(self.net) gradient_function = self.grad_op(self.net)
return gradient_function(mutable(x)) return gradient_function(mutable(x))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
x = {'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), x = {'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 4.0], [1.2, 1.3, 1.1]], dtype=mstype.float32)), Tensor([[0.5, 0.6, 4.0], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)} 'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}
grad_net = GradNetWrtX(Net()) grad_net = GradNetWrtX(Net())
output = grad_net(x) output = grad_net(x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple) assert isinstance(output, tuple)
expect = [(np.array([[1.4100001, 1.5999999, 6.6], expect = [(np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), [1.4100001, 1.5999999, 6.6]]).astype(np.float32),

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""st for scipy.optimize.""" """st for scipy.optimize."""
import os
import pytest import pytest
import numpy as onp import numpy as onp
import scipy as osp import scipy as osp
@ -210,6 +211,7 @@ def test_bfgs_graph(dtype, func_x0):
Description: test cases for bfgs in GRAPH mode Description: test cases for bfgs in GRAPH mode
Expectation: the result match scipy Expectation: the result match scipy
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
@ -218,6 +220,7 @@ def test_bfgs_graph(dtype, func_x0):
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS') scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS')
match_array(ms_res.x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
def _scalar_func_1(np): def _scalar_func_1(np):
@ -349,6 +352,7 @@ def test_line_search_graph(maxiter, func, x, p):
Description: test cases for n-d function in GRAPH mode Description: test cases for n-d function in GRAPH mode
Expectation: the result match scipy Expectation: the result match scipy
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
A = [[1.76405235, 0.40015721, 0.97873798, 2.2408932, 1.86755799], A = [[1.76405235, 0.40015721, 0.97873798, 2.2408932, 1.86755799],
[-0.97727788, 0.95008842, -0.15135721, -0.10321885, 0.4105985], [-0.97727788, 0.95008842, -0.15135721, -0.10321885, 0.4105985],
@ -356,8 +360,8 @@ def test_line_search_graph(maxiter, func, x, p):
[0.33367433, 1.49407907, -0.20515826, 0.3130677, -0.85409574], [0.33367433, 1.49407907, -0.20515826, 0.3130677, -0.85409574],
[-2.55298982, 0.6536186, 0.8644362, -0.74216502, 2.26975462]] [-2.55298982, 0.6536186, 0.8644362, -0.74216502, 2.26975462]]
osp_x, osp_p, osp_A = onp.array(x), onp.array(p), onp.array(A) osp_x, osp_p, osp_a = onp.array(x), onp.array(p), onp.array(A)
osp_f, osp_fp = func(onp, osp_A) osp_f, osp_fp = func(onp, osp_a)
osp_res = osp_line_search(osp_f, osp_fp, osp_x, osp_p, maxiter=maxiter) osp_res = osp_line_search(osp_f, osp_fp, osp_x, osp_p, maxiter=maxiter)
msp_x, msp_p, msp_A = mnp.array(x), mnp.array(p), mnp.array(A) msp_x, msp_p, msp_A = mnp.array(x), mnp.array(p), mnp.array(A)
@ -366,6 +370,7 @@ def test_line_search_graph(maxiter, func, x, p):
match_array(msp_res.a_k, osp_res[0], error=5) match_array(msp_res.a_k, osp_res[0], error=5)
match_array(msp_res.f_k, osp_res[3], error=5) match_array(msp_res.f_k, osp_res[3], error=5)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0 @pytest.mark.level0
@ -380,6 +385,7 @@ def test_lbfgs1(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
x0_tensor = Tensor(x0) x0_tensor = Tensor(x0)
@ -388,6 +394,7 @@ def test_lbfgs1(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0 @pytest.mark.level0
@ -402,6 +409,7 @@ def test_lbfgs2(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
x0_tensor = Tensor(x0) x0_tensor = Tensor(x0)
@ -410,6 +418,7 @@ def test_lbfgs2(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0 @pytest.mark.level0
@ -424,6 +433,7 @@ def test_lbfgs3(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
x0_tensor = Tensor(x0) x0_tensor = Tensor(x0)
@ -432,6 +442,7 @@ def test_lbfgs3(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0 @pytest.mark.level0
@ -446,6 +457,7 @@ def test_lbfgs4(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
x0_tensor = Tensor(x0) x0_tensor = Tensor(x0)
@ -454,6 +466,7 @@ def test_lbfgs4(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0 @pytest.mark.level0
@ -468,6 +481,7 @@ def test_lbfgs5(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
x0_tensor = Tensor(x0) x0_tensor = Tensor(x0)
@ -476,6 +490,7 @@ def test_lbfgs5(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0 @pytest.mark.level0
@ -490,6 +505,7 @@ def test_lbfgs6(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
x0_tensor = Tensor(x0) x0_tensor = Tensor(x0)
@ -498,6 +514,7 @@ def test_lbfgs6(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0 @pytest.mark.level0
@ -511,6 +528,7 @@ def test_lbfgs_fixes4594(dtype):
Description: test cases for lbfgs in PYNATIVE mode Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
n = 2 n = 2
a = Tensor(onp.eye(n, dtype=dtype)) * 1e4 a = Tensor(onp.eye(n, dtype=dtype)) * 1e4
@ -520,6 +538,7 @@ def test_lbfgs_fixes4594(dtype):
results = msp.optimize.minimize(func, Tensor(onp.ones(n, dtype=dtype)), method='LBFGS', results = msp.optimize.minimize(func, Tensor(onp.ones(n, dtype=dtype)), method='LBFGS',
options=dict(maxiter=None, gtol=1e-6)).x options=dict(maxiter=None, gtol=1e-6)).x
onp.testing.assert_allclose(results.asnumpy(), onp.zeros(n, dtype=dtype), rtol=1e-6, atol=1e-6) onp.testing.assert_allclose(results.asnumpy(), onp.zeros(n, dtype=dtype), rtol=1e-6, atol=1e-6)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level1 @pytest.mark.level1
@ -534,6 +553,7 @@ def test_lbfgs_graph(dtype, func_x0):
Description: test cases for lbfgs in GRAPH mode Description: test cases for lbfgs in GRAPH mode
Expectation: the result match bfgs Expectation: the result match bfgs
""" """
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
func, x0 = func_x0 func, x0 = func_x0
x0 = x0.astype(dtype) x0 = x0.astype(dtype)
@ -543,3 +563,4 @@ def test_lbfgs_graph(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS', ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6)) options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res)) match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'

View File

@ -42,7 +42,7 @@ using AbstractTensor = abstract::AbstractTensor;
using AbstractTensorPtr = abstract::AbstractTensorPtr; using AbstractTensorPtr = abstract::AbstractTensorPtr;
using AbstractNone = abstract::AbstractNone; using AbstractNone = abstract::AbstractNone;
using AbstractAttribute = abstract::AbstractAttribute; using AbstractAttribute = abstract::AbstractElementPair;
using AnalysisEngine = abstract::AnalysisEngine; using AnalysisEngine = abstract::AnalysisEngine;
using AnalysisEnginePtr = abstract::AnalysisEnginePtr; using AnalysisEnginePtr = abstract::AnalysisEnginePtr;

View File

@ -1002,7 +1002,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5}); AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5}); AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5}); AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}}; std::vector<AbstractElementPair> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map); AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
AbstractBasePtr key = abstract::FromValue("x"); AbstractBasePtr key = abstract::FromValue("x");
AbstractBasePtrList args_spec_list = {array_dict, key}; AbstractBasePtrList args_spec_list = {array_dict, key};

View File

@ -155,6 +155,7 @@ def test_dict_set_item():
_ = net(x) _ = net(x)
@pytest.mark.skip(reason="Do not support dict value for dict set item yet.")
def test_dict_set_item_2(): def test_dict_set_item_2():
""" """
Description: test dict in dict set item. Description: test dict in dict set item.
@ -184,6 +185,7 @@ def test_dict_set_item_2():
assert second[1] == 1 assert second[1] == 1
@pytest.mark.skip(reason="Do not support dict value for dict set item yet.")
def test_dict_set_item_3(): def test_dict_set_item_3():
""" """
Description: test dict in dict set item. Description: test dict in dict set item.
@ -207,7 +209,7 @@ def test_dict_set_item_3():
assert first[0][1] == 3 assert first[0][1] == 3
# if the dictionary item does not exist, create a new one # If the dictionary item does not exist, create a new one
def test_dict_set_item_create_new(): def test_dict_set_item_create_new():
class DictSetNet(Cell): class DictSetNet(Cell):
def __init__(self): def __init__(self):