[JIT Fallback] Support return Python dict in top func graph.

This commit is contained in:
张清华 2022-12-12 17:28:31 +08:00
parent b6d99e4a08
commit bc38782b94
30 changed files with 855 additions and 123 deletions

View File

@ -127,7 +127,13 @@ tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
if (i >= input_device_address->size() || input_device_address->at(i) == nullptr) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->fullname_with_scope();
if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
MS_LOG(INFO) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
<< node->fullname_with_scope();
return out_tensor;
}
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
<< node->fullname_with_scope();
}
out_tensor->data_sync_directly(input_device_address->at(i));

View File

@ -848,10 +848,10 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
return rectify_abs_list;
}
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(primitive);
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
MS_EXCEPTION_IF_NULL(prim);
auto dynamic_inputs_list = prim->GetAttr(kAttrDynInputSizes);
if (dynamic_inputs_list == nullptr) {
return input_abstract;
}
@ -873,6 +873,10 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitiv
AbstractBasePtrList dynamic_inputs_abs;
for (auto index = item; index > 0; --index) {
if (input_index >= input_abstract.size()) {
// Not to check for PyExecute.
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
continue;
}
MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract "
<< input_abstract.size();
}

View File

@ -41,11 +41,11 @@ namespace mindspore {
namespace prim {
constexpr auto kStepDefault = 1;
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractEllipsis;
using mindspore::abstract::AbstractEllipsisPtr;
using mindspore::abstract::AbstractFunction;

View File

@ -49,12 +49,12 @@ FuncGraphPtr DictHasKey::GenerateFuncGraph(const abstract::AbstractBasePtrList &
ValuePtr key_value = args_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(dict);
MS_EXCEPTION_IF_NULL(key_value);
auto dict_elems = dict->elements();
auto elems = dict->elements();
bool has_key = false;
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const abstract::AbstractAttribute &item) {
auto it = std::find_if(elems.cbegin(), elems.cend(), [&key_value](const abstract::AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
if (it != dict_elems.cend()) {
if (it != elems.cend()) {
has_key = true;
}
@ -153,7 +153,7 @@ abstract::AbstractBasePtrList DictFromKeys::ParseIterableObject(const abstract::
AbstractBasePtrList keys;
auto &dict_elems = dict_keys->elements();
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
[](const abstract::AbstractAttribute &item) { return item.first; });
[](const abstract::AbstractElementPair &item) { return item.first; });
return keys;
}
if (key_type->IsSameTypeId(String::kTypeId)) {

View File

@ -28,10 +28,10 @@
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractKeywordArg;
using mindspore::abstract::AbstractList;
@ -78,7 +78,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_l
auto dict_elems = arg_dict->elements();
(void)std::transform(
dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(elems),
[res_graph, para_dict](const AbstractAttribute &item) {
[res_graph, para_dict](const AbstractElementPair &item) {
// Dict_elems's first element represents parameter names, which should be string type.
auto key_value = GetValue<std::string>(item.first->BuildValue());
auto dict_get_item =

View File

@ -38,11 +38,11 @@
namespace mindspore {
/* namespace to support opt */
namespace opt {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractRowTensor;
@ -164,7 +164,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
public:
using ThisClass = SimplifyDataStructuresRewriter;
SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
: BaseRewriter(root_graph, manager) {}
: BaseRewriter(root_graph, manager), is_dict_output_{IsDictOutput()} {}
~SimplifyDataStructuresRewriter() override = default;
protected:
@ -176,7 +176,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return str->value();
}
static int64_t GetAttrIndex(const std::vector<AbstractAttribute> &attrs, const AnfNodePtr &name) {
static int64_t GetElementIndex(const std::vector<AbstractElementPair> &attrs, const AnfNodePtr &name) {
auto n_attrs = attrs.size();
auto name_abstract = GetAbstract<AbstractBase>(name);
MS_EXCEPTION_IF_NULL(name_abstract);
@ -191,15 +191,15 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
}
static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,
const std::vector<AbstractAttribute> &attributes, const AnfNodePtr &name_node) {
int64_t index = GetAttrIndex(attributes, name_node);
const std::vector<AbstractElementPair> &elements, const AnfNodePtr &name_node) {
int64_t index = GetElementIndex(elements, name_node);
auto index_node = NewValueNode(index);
auto prim_node = NewValueNode(prim::kPrimTupleGetItem);
return cnode->func_graph()->NewCNode({prim_node, data_node, index_node});
}
// From:
// DictGetItem(data:AbstractDictionary, cons:AbstractBase)
// DictGetItem(data:AbstractDictionary, key:AbstractBase)
// To:
// TupleGetItem(data, index:Int64Imm)
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
@ -211,27 +211,98 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
CheckInputsSize(node, expect_inputs_size);
constexpr size_t data_index = 1;
constexpr size_t attr_index = 2;
constexpr size_t key_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &attr = inputs[attr_index];
auto &key = inputs[key_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(attr);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
return NewTupleGetCNode(node, data, abs_dict->elements(), attr);
return NewTupleGetCNode(node, data, abs_dict->elements(), key);
}
// DictGetItem --> PyExecute()
AnfNodePtr RebuidDictGetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
// Inputs should be [dict_setitem, dict, item]
const size_t expect_inputs_size = 3;
CheckInputsSize(node, expect_inputs_size);
const size_t data_index = 1;
const size_t item_key_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[item_key_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// Script
constexpr auto internal_dict_self_str = "__internal_dict_self__";
constexpr auto internal_dict_key_str = "__internal_dict_key__";
std::stringstream script_buffer;
script_buffer << internal_dict_self_str << "[" << internal_dict_key_str << "]";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Pack local parameters keys.
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
// Pack the local parameters values, not support list, tuple, or dict.
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(data);
(void)key_value_list.emplace_back(key);
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
// Build the new dict node.
const auto dict_getitem_node = func_graph->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
int64_t index = GetElementIndex(abs_dict->elements(), key);
const auto &val = abs_dict->elements()[index].second;
const auto &tensor_val = dyn_cast<abstract::AbstractTensor>(val);
if (tensor_val != nullptr) {
const auto &tensor_type = tensor_val->element()->BuildType();
dict_getitem_node->set_user_data<Type>("__py_execute_tensor_type__", tensor_type);
const auto &tensor_shape = dyn_cast<abstract::Shape>(tensor_val->BuildShape());
MS_EXCEPTION_IF_NULL(tensor_shape);
dict_getitem_node->set_user_data<abstract::Shape>("__py_execute_tensor_shape__", tensor_shape);
MS_LOG(DEBUG) << "key: " << key->abstract()->BuildValue()->ToString() << ", type: " << tensor_type->ToString()
<< ", shape: " << tensor_shape->ToString() << ", val: " << tensor_val->ToString();
}
MS_LOG(DEBUG) << "Made dict getitem node: " << dict_getitem_node->DebugString();
return dict_getitem_node;
}
AnfNodePtr ConvertDictGetItem(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuidDictGetItem(node);
}
return ConvertDictGetItemToTupleGetItem(node);
}
// From:
// DictSetItem(data:AbstractDictionary, cons:AbstractBase, value)
// DictSetItem(data:AbstractDictionary, key:AbstractBase, value)
// To:
// TupleSetItem(data, index:Int64Imm, value)
// Or:
// tuple_add(data, value)
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
@ -244,16 +315,16 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
const size_t item_value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &cons = inputs[cons_index];
auto &key = inputs[cons_index];
auto &item_value = inputs[item_value_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(cons);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
int64_t index = GetAttrIndex(abs_dict->elements(), cons);
int64_t index = GetElementIndex(abs_dict->elements(), key);
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
if (index >= static_cast<int64_t>(abs_dict->elements().size())) {
@ -275,11 +346,86 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return new_node;
}
bool IsDictOutput() const {
const AnfNodePtr &output = root_graph_->output();
auto abs_dict = GetAbstract<AbstractDictionary>(output);
if (abs_dict != nullptr) {
return true;
}
return false;
}
// DictSetItem --> PyExecute()
AnfNodePtr RebuidDictSetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
// Inputs should be [dict_setitem, dict, item, value]
const size_t expect_inputs_size = 4;
CheckInputsSize(node, expect_inputs_size);
const size_t data_index = 1;
const size_t item_key_index = 2;
const size_t item_value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[item_key_index];
auto &item_value = inputs[item_value_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// Script
constexpr auto internal_dict_self_str = "__internal_dict_self__";
constexpr auto internal_dict_key_str = "__internal_dict_key__";
constexpr auto internal_dict_value_str = "__internal_dict_value__";
std::stringstream script_buffer;
script_buffer << "__import__('mindspore').update_and_return_dict(" << internal_dict_self_str << ", "
<< internal_dict_key_str << ", " << internal_dict_value_str << ")";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Pack local parameters keys.
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_value_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
// Pack the local parameters values, not support list, tuple, or dict.
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(data);
(void)key_value_list.emplace_back(key);
(void)key_value_list.emplace_back(item_value);
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
// Build the new dict node.
const auto dict_setitem_node = func_graph->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict setitem node: " << dict_setitem_node->DebugString();
return dict_setitem_node;
}
AnfNodePtr ConvertDictSetItem(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuidDictSetItem(node);
}
return ConvertDictSetItemToTupleSetItem(node);
}
// From:
// MakeDict(name, input)
// To:
// input
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
constexpr size_t expect_inputs_size = 3;
constexpr size_t input_index = 2;
@ -287,6 +433,62 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return node->input(input_index);
}
// MakeDict(keys, values) --> PyExecute('dict(zip(keys, values))', ...)
AnfNodePtr RebuildMakeDictNode(const CNodePtr &node) const {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
const auto &fg = node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
// Local parameters values.
// Pack the key tuple.
constexpr size_t values_input_index = 2;
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
const auto make_key_tuple_node =
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
NewValueNode(script_key_tuple_str), node->input(values_input_index)});
// Pack the value tuple.
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
const auto make_value_tuple_node =
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
NewValueNode(script_value_tuple_str), node->input(1)});
// Pack the local parameters values
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(make_key_tuple_node);
(void)key_value_list.emplace_back(make_value_tuple_node);
const auto key_value_tuple = fg->NewCNode(key_value_list);
// Pack local parameters keys.
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
// Script
std::stringstream script_buffer;
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Build the new dict node.
const auto make_dict_node = fg->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
return make_dict_node;
}
AnfNodePtr ConvertMakeDict(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuildMakeDictNode(node);
}
return EraseMakeDictNode(node);
}
// From:
// DictGetValues(dict:AbstractDictionary)
// To:
@ -356,22 +558,79 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
}
// dict(k0:v0, k1:v1, ...) --> tuple(v0, v1, ...)
ValueTuplePtr DictToTuple(const ValueDictionaryPtr &dict) const {
const auto &elements = dict->value();
std::vector<ValuePtr> values;
values.reserve(elements.size());
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(values),
[](const auto &element) { return element.second; });
return std::make_shared<ValueTuple>(values);
AnfNodePtr DictToTuple(const ValueDictionaryPtr &dict) const {
const auto &keys_values = dict->value();
std::vector<ValuePtr> value_list;
value_list.reserve(keys_values.size());
(void)std::transform(keys_values.begin(), keys_values.end(), std::back_inserter(value_list),
[](const auto &value) { return value.second; });
return NewValueNode(std::make_shared<ValueTuple>(value_list));
}
// dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
AnfNodePtr RebuildValueDict(const ValueDictionaryPtr &dict) const {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
const auto &keys_values = dict->value();
std::vector<ValuePtr> key_list;
key_list.reserve(keys_values.size());
std::vector<ValuePtr> value_list;
value_list.reserve(keys_values.size());
for (const auto &key_value : keys_values) {
(void)key_list.emplace_back(key_value.first);
(void)value_list.emplace_back(key_value.second);
}
// Local parameters values.
// Pack the key tuple.
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
const auto make_key_tuple_node =
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
NewValueNode(script_key_tuple_str), NewValueNode(key_tuple)});
// Pack the value tuple.
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
const auto value_tuple = std::make_shared<ValueTuple>(value_list);
const auto make_value_tuple_node =
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
NewValueNode(script_value_tuple_str), NewValueNode(value_tuple)});
// Pack the local parameters values
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(make_key_tuple_node);
(void)key_value_list.emplace_back(make_value_tuple_node);
const auto key_value_tuple = root_graph_->NewCNode(key_value_list);
// Pack local parameters keys.
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = root_graph_->NewCNode(key_value_names_list);
// Script
std::stringstream script_buffer;
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Build the new dict node.
const auto make_dict_node = root_graph_->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
return make_dict_node;
}
using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &);
using ConverterMap = mindspore::HashMap<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
static inline const ConverterMap converters_{
{prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItemToTupleGetItem},
{prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItemToTupleSetItem},
{prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
{prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
{prim::kPrimDictGetValues, &ThisClass::EraseDictGetValues},
{prim::kPrimMakeDict, &ThisClass::EraseMakeDictNode},
{prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
{prim::kPrimMakeKeywordArg, &ThisClass::EraseMakeKeywordArgNode},
{prim::kPrimExtractKeywordArg, &ThisClass::EraseExtractKeywordArg},
{prim::kPrimDictItems, &ThisClass::EraseDictItems},
@ -390,12 +649,16 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
AnfNodePtr ConvertValueNode(const ValueNodePtr &, const ValuePtr &value) override {
// Convert Dictionary value node.
if (value->isa<ValueDictionary>()) {
return NewValueNode(DictToTuple(value->cast<ValueDictionaryPtr>()));
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && is_dict_output_) {
return RebuildValueDict(value->cast<ValueDictionaryPtr>());
}
return DictToTuple(value->cast<ValueDictionaryPtr>());
}
return nullptr;
}
static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractAttribute> &attrs) {
static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractElementPair> &attrs) {
std::vector<AbstractBasePtr> elements;
elements.reserve(attrs.size());
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(elements),
@ -465,6 +728,9 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
// AbstractDictionary --> AbstractSequence.
return ConvertToAbstractSequence(abs, 0);
}
private:
bool is_dict_output_{false};
};
// ==================================================================
@ -495,9 +761,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
}
// From:
// ListGetItem(list, cons)
// ListGetItem(list, key)
// To:
// TupleGetItem(list, cons)
// TupleGetItem(list, key)
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
@ -509,8 +775,8 @@ class CleanAfterOptARewriter : public BaseRewriter {
constexpr size_t cons_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &cons = inputs[cons_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons});
auto &key = inputs[cons_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, key});
}
// From:
@ -530,9 +796,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
const size_t value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &cons = inputs[cons_index];
auto &key = inputs[cons_index];
auto &value = inputs[value_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, key, value});
}
// From:

View File

@ -33,13 +33,20 @@ namespace mindspore {
namespace opt {
namespace {
py::object CallPythonPushGlobalParams(const py::object &dict) {
constexpr auto python_mod_parse = "mindspore._extends.parse";
py::module mod = python_adapter::GetPyModule(python_mod_parse); // The same as PYTHON_MOD_PARSE_MODULE[]
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
py::module mod = python_adapter::GetPyModule(python_mod_parse);
constexpr auto python_merge_dict = "merge_global_params";
return python_adapter::CallPyModFn(mod, python_merge_dict, dict);
}
} // namespace
// Convert PyInterpret into PyExecute:
// PyInterpret(script, global_dict, local_dict)
// -->
// PyExecute(script, local_dict_keys, local_dict_values),
// with side-effect operation:
// Push global_dict into global parameters list.
// (So it requires no same key name.)
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
auto manager = resource->manager();
const auto &all_nodes = manager->all_nodes();

View File

@ -1221,7 +1221,8 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
MS_LOG(DEBUG) << "call_cnode: " << call_cnode->DebugString()
<< ", call_function_node: " << call_function_node->DebugString();
// Support tensor.asnumpy() in runtime by JIT Fallback.
if (IsPrimitiveCNode(call_function_node, prim::kPrimGetAttr)) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime && IsPrimitiveCNode(call_function_node, prim::kPrimGetAttr)) {
constexpr size_t index_two = 2;
const auto &attr_node = call_function_node->cast<CNodePtr>()->input(index_two);
const auto &attr_str = GetValueNode<StringImmPtr>(attr_node);
@ -1474,8 +1475,9 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));
py::bool_ is_const_value =
ast()->CallParserObjMethod(PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE, value_str, common::SafeCStr(attr_str));
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
auto is_constant = py::cast<bool>(is_const_value);
if (!is_constant || attr_str == "asnumpy") {
if (!is_constant || (support_fallback_runtime && attr_str == "asnumpy")) {
UpdateInterpretForUserNode(attr_cnode, value_node);
}
}

View File

@ -88,6 +88,10 @@ void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource)
} // namespace
bool PyInterpretToExecutePass(const ResourcePtr &resource) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (!support_fallback_runtime) {
return true;
}
MS_EXCEPTION_IF_NULL(resource);
FuncGraphPtr func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);

View File

@ -270,6 +270,33 @@ std::string ToOrdinal(const size_t &i) {
}
return std::to_string(i) + suffix;
}
py::object GetPyExecuteOutput(const AnfNodePtr &output) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime) {
std::function<AnfNodePtr(const AnfNodePtr &)> get_real_output = [&get_real_output](const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
const auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
return get_real_output(cnode->input(1));
}
return node;
};
const auto &real_output = get_real_output(output);
MS_LOG(INFO) << "Real output: " << real_output << ", " << real_output->DebugString()
<< ", has \'PyExecuteOutputData\': " << real_output->has_user_data<kernel::PyExecuteOutputData>();
if (real_output->has_user_data<kernel::PyExecuteOutputData>()) {
py::gil_scoped_acquire gil_acquire;
const auto &output_data = real_output->user_data<kernel::PyExecuteOutputData>();
py::object res_obj = output_data->obj;
MS_LOG(INFO) << "Has \'PyExecuteOutputData\', just return it. res_obj: " << res_obj;
if (!py::isinstance<py::none>(res_obj)) {
return res_obj;
}
}
}
return py::none();
}
} // namespace
std::string GetObjDesc(const py::object &source_obj) {
@ -1429,14 +1456,9 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
}
// Replace the output if it's not Tensor, but Python data.
if (output->has_user_data<kernel::PyExecuteOutputData>()) {
py::gil_scoped_acquire gil_acquire;
const auto &output_data = output->user_data<kernel::PyExecuteOutputData>();
py::object res_obj = output_data->obj;
MS_LOG(INFO) << "Has \'PyExecuteOutputData\', just return it. res_obj: " << res_obj;
if (!py::isinstance<py::none>(res_obj)) {
return res_obj;
}
const auto &py_res = GetPyExecuteOutput(output);
if (py_res != py::none()) {
return py_res;
}
MS_LOG(DEBUG) << "Run end";

View File

@ -177,7 +177,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_a
auto arg_dict = specialize_args_before_unpack[index]->cast_ptr<AbstractDictionary>();
auto dict_elems = arg_dict->elements();
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(graph_specialize_args),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
// Dict_elems's first element represents parameter names, which should be string type.
return std::make_shared<AbstractKeywordArg>(
GetValue<std::string>(item.first->BuildValue()), item.second);
@ -1817,9 +1817,21 @@ EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abst
// Call python script string.
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
TypePtr type = kFloat32;
if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
type = current_interpret_node->user_data<Type>("__py_execute_tensor_type__");
MS_LOG(DEBUG) << "type: " << type->ToString();
}
BaseShapePtr shape;
if (current_interpret_node->has_user_data("__py_execute_tensor_shape__")) {
shape = current_interpret_node->user_data<BaseShape>("__py_execute_tensor_shape__");
MS_LOG(DEBUG) << "shape: " << shape->ToString();
} else {
ShapeVector shp;
(void)shp.emplace_back(Shape::kShapeRankAny);
AbstractBasePtr res = std::make_shared<AbstractTensor>(kFloat64, std::make_shared<Shape>(shp));
shape = std::make_shared<Shape>(shp);
}
AbstractBasePtr res = std::make_shared<AbstractTensor>(type, shape);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
@ -2246,10 +2258,17 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
MS_EXCEPTION_IF_NULL(local_abs_val);
auto py_data_name = py::str(ValueToPyData(name->BuildValue()));
if (local_abs_val == kAnyValue) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
if (support_fallback_runtime) {
MS_LOG(INFO) << "When using JIT Fallback to handle script '" << script
<< "', the inputs should be constant, but found variable '" << py_data_name
<< "' to be nonconstant. To convert to PyExecute() afterwards";
non_const_err_ = true;
} else {
MS_EXCEPTION(ValueError) << "When using JIT Fallback to handle script '" << script
<< "', the inputs should be constant, but found variable '" << py_data_name
<< "' to be nonconstant.";
}
}
if (local_abs->isa<abstract::AbstractTensor>()) {
MS_LOG(WARNING) << "When using JIT Fallback to handle script '" << script << "', found variable '"
@ -2341,11 +2360,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
MS_EXCEPTION_IF_NULL(abstract_dict);
std::vector<AbstractAttribute> kv;
std::vector<AbstractElementPair> kv;
const auto &keys_values = abstract_dict->elements();
// Filter out the element of Function type.
(void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.second);
return (!item.second->isa<abstract::AbstractFunction>());
});

View File

@ -132,20 +132,106 @@ void PyExecuteCpuKernelMod::AttachPyOutputData(const py::object &py_res) {
const auto &iter = graph_output_map.find(anf_index);
if (iter != graph_output_map.cend()) {
const auto &front_node = iter->second.first;
MS_LOG(INFO) << "Found front output for " << kernel_node_->DebugString();
MS_LOG(INFO) << "Found front output for " << kernel_node_ << ", " << kernel_node_->DebugString();
front_node->set_user_data<PyExecuteOutputData>(py_output);
} else {
MS_LOG(DEBUG) << "Not found, kernel node is not output, " << kernel_node_->DebugString();
MS_LOG(DEBUG) << "Not found, kernel node is not output, " << kernel_node_ << ", " << kernel_node_->DebugString();
if (!IS_OUTPUT_ON(mindspore::kDebug)) {
return;
}
for (const auto &output_pair : graph_output_map) {
MS_EXCEPTION_IF_NULL(output_pair.first.first);
MS_EXCEPTION_IF_NULL(output_pair.second.first);
MS_LOG(DEBUG) << "backend node: " << output_pair.first.first << ", " << output_pair.first.first->DebugString()
<< ", front node: " << output_pair.second.first << ", " << output_pair.second.first->DebugString();
}
}
}
// Notice: Remove here after BE kernel supports tuple input.
py::object PyExecuteCpuKernelMod::BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs) {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto number_two = 2;
std::string tuple_key_str;
py::tuple local_tuple_inputs(inputs_info_.size() - number_two); // Exclude the script and key.
MS_LOG(DEBUG) << "Local parameter tuple size: " << (inputs_info_.size() - number_two);
bool tuple_input_start = false;
size_t tuple_index = 0;
py::dict local_tuple_dict;
for (size_t i = 1; i < inputs.size() && i < inputs_info_.size(); ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
const auto &input_abstract = input_info.abstract;
MS_EXCEPTION_IF_NULL(input_abstract);
const auto &input_type = input_abstract->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (!tuple_input_start && input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
if (str != internal_tuple_keys_str && str != internal_tuple_values_str) {
return py::none();
}
tuple_key_str = str;
tuple_input_start = true;
MS_LOG(DEBUG) << "String, key input[" << i << "]: " << input_abstract->ToString();
continue;
}
// Rebuild the tuple with all left inputs.
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
local_tuple_inputs[tuple_index++] = py::str(str);
MS_LOG(DEBUG) << "String, value input[" << i << "]: " << input_abstract->ToString();
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
const auto &py_array_value = input_info.py_obj_output;
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
MS_LOG(DEBUG) << "Tensor, value input[" << i << "]: " << input_abstract->ToString()
<< ", type: " << input_info.type << ", shape: " << input_info.shape << ", addr: " << inputs[i]->addr
<< ", size: " << inputs[i]->size << ", py_array_value: " << py_array_value
<< ", is_py_middle_data: " << is_py_middle_data;
if (!is_py_middle_data) {
const auto tensor =
std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
MS_EXCEPTION_IF_NULL(tensor);
local_tuple_inputs[tuple_index++] = tensor;
} else {
local_tuple_inputs[tuple_index++] = py_array_value;
}
} else {
MS_LOG(ERROR) << "Unsupported value type.";
}
}
local_tuple_dict[py::str(tuple_key_str)] = local_tuple_inputs;
return local_tuple_dict;
}
py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<AddressPtr> &inputs) {
const auto &local_tuple_params = BuildLocalTupleParameters(inputs);
if (local_tuple_params != py::none()) {
return local_tuple_params;
}
MS_LOG(DEBUG) << "Build normal local parameters.";
// Build local parameters dict.
std::vector<std::string> keys;
std::vector<tensor::TensorPtr> tensor_values;
std::vector<py::array> array_values;
std::vector<py::object> py_object_values;
std::vector<bool> py_array_flags;
for (size_t i = 1; i < inputs.size() && i < inputs_info_.size(); ++i) {
constexpr auto number_two = 2;
size_t pair_size = (inputs_info_.size() - 1) / number_two;
// Handle the keys.
size_t i = 1;
for (; i < inputs.size() && i < pair_size + 1; ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
@ -161,6 +247,29 @@ py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<Address
const auto &str = str_value->value();
(void)keys.emplace_back(str);
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
} else {
MS_LOG(EXCEPTION) << "Other, input[" << i << "]: " << input_abstract->ToString();
}
}
// Handle the values.
for (; i < inputs.size() && i < inputs_info_.size(); ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
const auto &input_abstract = input_info.abstract;
MS_EXCEPTION_IF_NULL(input_abstract);
const auto &input_type = input_abstract->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
(void)py_object_values.emplace_back(py::str(str));
(void)tensor_values.emplace_back(nullptr);
(void)py_array_flags.emplace_back(true);
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
const auto &py_array_value = input_info.py_obj_output;
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
@ -172,7 +281,7 @@ py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<Address
tensor = std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
MS_EXCEPTION_IF_NULL(tensor);
}
(void)array_values.emplace_back(py_array_value);
(void)py_object_values.emplace_back(py_array_value);
(void)tensor_values.emplace_back(tensor);
(void)py_array_flags.emplace_back(is_py_middle_data);
} else if (input_abstract->isa<abstract::AbstractRefTensor>()) {
@ -181,21 +290,22 @@ py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<Address
MS_LOG(DEBUG) << "Other, input[" << i << "]: " << input_abstract->ToString();
}
}
constexpr auto number_two = 2;
if (keys.size() != tensor_values.size() || keys.size() != (inputs_info_.size() - 1) / number_two) {
if (keys.size() != tensor_values.size() || keys.size() != pair_size) {
MS_LOG(EXCEPTION) << "The local dict input is invalid, " << keys.size() << ", " << tensor_values.size() << ", "
<< inputs_info_.size();
}
// To call the script with global and local parameters.
py::dict local_dict;
for (size_t i = 0; i < keys.size(); ++i) {
for (i = 0; i < keys.size(); ++i) {
if (py_array_flags[i]) {
local_dict[py::str(keys[i])] = array_values[i];
local_dict[py::str(keys[i])] = py_object_values[i];
} else {
local_dict[py::str(keys[i])] = tensor_values[i];
}
}
MS_LOG(DEBUG) << "local_dict: " << local_dict;
return local_dict;
}
@ -249,6 +359,14 @@ bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
MS_LOG(DEBUG) << "Real output is py::bool_, py_res: " << py_res;
} else if (py::isinstance<py::str>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::str, py_res: " << py_res;
} else if (py::isinstance<py::tuple>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::tuple, py_res: " << py_res;
} else if (py::isinstance<py::list>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::list, py_res: " << py_res;
} else if (py::isinstance<py::dict>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::dict, py_res: " << py_res;
} else if (py::isinstance<py::set>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::set, py_res: " << py_res;
} else {
MS_LOG(EXCEPTION) << "The output is invalid, py_res: " << py_res;
}

View File

@ -52,6 +52,7 @@ class PyExecuteCpuKernelMod : public DeprecatedNativeCpuKernelMod {
private:
void AttachPyOutputData(const py::object &py_res);
py::object BuildLocalParameters(const std::vector<AddressPtr> &inputs);
py::object BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs);
CNodePtr kernel_node_{nullptr};
std::vector<PyExecuteInputInfo> inputs_info_;

View File

@ -55,7 +55,7 @@ class PyExecuteInitializer {
MS_LOG(EXCEPTION) << "Value tuple should not be anyvalue.";
}
const auto &values = dyn_cast<ValueSequence>(values_tuple);
MS_LOG(ERROR) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
<< ", values_tuple: " << values_tuple->ToString();
py::gil_scoped_acquire gil_acquire;
@ -90,13 +90,21 @@ class PyExecuteInitializer {
const auto &res_tensor = tensor::TensorPy::MakeTensorOfNumpy(py_res);
MS_LOG(DEBUG) << "res_tensor: " << res_tensor->ToString();
} else if (py::isinstance<py::float_>(py_res)) {
MS_LOG(ERROR) << "is py::float_, py_res: " << py_res;
MS_LOG(DEBUG) << "is py::float_, py_res: " << py_res;
} else if (py::isinstance<py::int_>(py_res)) {
MS_LOG(ERROR) << "is py::int_, py_res: " << py_res;
MS_LOG(DEBUG) << "is py::int_, py_res: " << py_res;
} else if (py::isinstance<py::bool_>(py_res)) {
MS_LOG(ERROR) << "is py::bool_, py_res: " << py_res;
MS_LOG(DEBUG) << "is py::bool_, py_res: " << py_res;
} else if (py::isinstance<py::str>(py_res)) {
MS_LOG(ERROR) << "is py::str, py_res: " << py_res;
MS_LOG(DEBUG) << "is py::str, py_res: " << py_res;
} else if (py::isinstance<py::tuple>(py_res)) {
MS_LOG(DEBUG) << "is py::tuple, py_res: " << py_res;
} else if (py::isinstance<py::list>(py_res)) {
MS_LOG(DEBUG) << "is py::list, py_res: " << py_res;
} else if (py::isinstance<py::dict>(py_res)) {
MS_LOG(DEBUG) << "is py::dict, py_res: " << py_res;
} else if (py::isinstance<py::set>(py_res)) {
MS_LOG(DEBUG) << "is py::set, py_res: " << py_res;
} else {
MS_LOG(EXCEPTION) << "py_res is invalid, py_res: " << py_res;
}

View File

@ -1350,9 +1350,9 @@ bool AbstractDictionary::operator==(const AbstractBase &other) const {
}
AbstractBasePtr AbstractDictionary::Clone() const {
std::vector<AbstractAttribute> kv;
std::vector<AbstractElementPair> kv;
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.first);
MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first->Clone(), item.second->Clone());
@ -1361,9 +1361,9 @@ AbstractBasePtr AbstractDictionary::Clone() const {
}
AbstractBasePtr AbstractDictionary::Broaden() const {
std::vector<AbstractAttribute> kv;
std::vector<AbstractElementPair> kv;
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first, item.second->Broaden());
});
@ -1384,7 +1384,7 @@ std::string AbstractDictionary::ToString() const {
std::size_t AbstractDictionary::hash() const {
std::size_t hash_sum = std::accumulate(key_values_.cbegin(), key_values_.cend(), tid(),
[](std::size_t hash_sum, const AbstractAttribute &item) {
[](std::size_t hash_sum, const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.first);
MS_EXCEPTION_IF_NULL(item.second);
hash_sum = hash_combine(hash_sum, item.first->hash());

View File

@ -1025,8 +1025,8 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
public:
/// \brief Constructor of AbstractDictionary.
///
/// \param[in] key_values The vector of AbstractAttribute.
explicit AbstractDictionary(const std::vector<AbstractAttribute> &key_values) : key_values_(key_values) {}
/// \param[in] key_values The vector of AbstractElementPair.
explicit AbstractDictionary(const std::vector<AbstractElementPair> &key_values) : key_values_(key_values) {}
/// \brief Destructor of AbstractDictionary.
~AbstractDictionary() override = default;
@ -1051,12 +1051,12 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
/// \brief Get the key values.
///
/// \return A vector of AbstractAttribute.
const std::vector<AbstractAttribute> &elements() const { return key_values_; }
/// \return A vector of AbstractElementPair.
const std::vector<AbstractElementPair> &elements() const { return key_values_; }
protected:
ValuePtr RealBuildValue() const override;
std::vector<AbstractAttribute> key_values_;
std::vector<AbstractElementPair> key_values_;
};
using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>;

View File

@ -196,8 +196,8 @@ bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spe
ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
return it != dict_elems.end();

View File

@ -69,7 +69,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
const auto &key = key_list[index];
CheckDictKey(key, op_name);
}
std::vector<AbstractAttribute> key_value;
std::vector<AbstractElementPair> key_value;
AbstractBasePtrList value_list = values->elements();
for (size_t index = 0; index < keys_size; index++) {
(void)key_value.emplace_back(key_list[index], value_list[index]);
@ -277,8 +277,8 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
if (it == dict_elems.end()) {
@ -302,8 +302,8 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
@ -325,10 +325,10 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList keys;
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
[](const AbstractAttribute &item) { return item.first; });
[](const AbstractElementPair &item) { return item.first; });
return std::make_shared<AbstractTuple>(keys);
}
@ -339,10 +339,10 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList values;
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(values),
[](const AbstractAttribute &item) { return item.second; });
[](const AbstractElementPair &item) { return item.second; });
return std::make_shared<AbstractTuple>(values);
}
@ -353,10 +353,10 @@ AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList items;
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(items),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
return std::make_shared<AbstractTuple>(AbstractBasePtrList{item.first, item.second});
});
return std::make_shared<AbstractList>(items);

View File

@ -231,7 +231,7 @@ using FuncGraphWeakPtr = std::weak_ptr<FuncGraph>;
namespace abstract {
class AbstractBase;
using AbstractBasePtr = std::shared_ptr<AbstractBase>;
using AbstractAttribute = std::pair<AbstractBasePtr, AbstractBasePtr>;
using AbstractElementPair = std::pair<AbstractBasePtr, AbstractBasePtr>;
class AnalysisContext;
using AnalysisContextPtr = std::shared_ptr<AnalysisContext>;
} // namespace abstract

View File

@ -28,30 +28,28 @@ MIND_API_OPERATOR_IMPL(PyExecute, BaseOperator);
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
if (infer_handler_ == nullptr) {
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
}
// TODO(zh_qh): Will call 'infer_handler_(input_args)' check the abstract shape and type later.
ShapeVector out_shape = {1};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
// TODO(zh_qh): Will call 'infer_handler_(input_args)' check the abstract shape and type later.
return kFloat64;
}
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(DEBUG) << "item: " << item->ToString();
}
if (infer_handler_ == nullptr) {
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
}
infer_handler_(input_args);
const auto &type = InferType(primitive, input_args);
const auto &shape = InferShape(primitive, input_args);
const auto &abstract = MakeAbstract(shape, type);

View File

@ -29,6 +29,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.sparse_tensor import RowTensor, RowTensorInner, SparseTensor, COOTensor, CSRTensor
from mindspore.common.mutable import mutable
from mindspore.common.jit_config import JitConfig
from mindspore.common._utils import update_and_return_dict
# symbols from dtype
__all__ = [
@ -66,4 +67,5 @@ __all__.extend([
"set_dump",
"ms_memory_recycle",
"mutable", "JitConfig",
"update_and_return_dict",
])

View File

@ -53,3 +53,8 @@ def split_to_slice_if_need(dtype, shape):
return slice_num
slice_num = math.ceil(data_size / emb_cache_size)
return slice_num
def update_and_return_dict(dic, key, val):
dic.__setitem__(key, val)
return dic

View File

@ -225,3 +225,9 @@ def bprop_scalar_not(x, out, dout):
def bprop_tensor_move(x, out, dout):
"""Backpropagator for primitive `TensorMove`."""
return (dout,)
@bprops.register("PyExecute")
def get_bprop_py_execute(x, y, z, out, dout):
"""Generate bprop for PyExecute"""
return x, y, z

View File

@ -17,6 +17,7 @@ import pytest
import numpy as np
import mindspore as ms
from mindspore.common.initializer import TruncatedNormal
ms.set_context(mode=ms.GRAPH_MODE)
@ -92,3 +93,217 @@ def test_fallback_np_asnumpy():
const_output = ConstNet()()
print(f'const_output: {const_output}')
np.testing.assert_almost_equal(output, const_output, 3)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_return_1():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_1():
x = {'a': 'a', 'b': 'b'}
y = x.get('a')
z = dict(y=y)
return z
out = dict_net_1()
print(f'out: {out}')
assert out == {'y': 'a'}
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_get_1():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_1():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(a=y_tensor)
return z
out = dict_net_1()
print(f'out: {out}')
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_get_2():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_2():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(a=y_tensor, b='hello', c='world')
return z
out = dict_net_2()
print(f'out: {out}')
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_get_3():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_3():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(y=y_tensor, a='a', b='c')
return z
out = dict_net_3()
print(f'out: {out}')
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return ms.nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return ms.nn.Dense(input_channels, out_channels, weight, bias)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net_dict_1():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
class DictLeNetNet(ms.nn.Cell):
def __init__(self, num_class=10):
super(DictLeNetNet, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = ms.nn.ReLU()
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = ms.nn.Flatten()
def construct(self, x):
x = self.conv1(x)
conv1_x = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
conv2_x = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
fc_x = x
outputs = dict(conv1=conv1_x, conv2=conv2_x, fc=fc_x)
return outputs
net = DictLeNetNet()
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
outputs = net(x)
print(f'outputs: {outputs}')
assert outputs['conv1'].shape == (64, 6, 28, 28)
assert outputs['conv2'].shape == (64, 16, 10, 10)
assert outputs['fc'].shape == (64, 10)
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_net_dict_2():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
class DictLeNetNet(ms.nn.Cell):
def __init__(self, num_class=10):
super(DictLeNetNet, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.relu = ms.nn.ReLU()
self.max_pool2d = ms.nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = ms.nn.Flatten()
def construct(self, x):
outputs = dict()
x = self.conv1(x)
outputs['conv1'] = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
outputs['conv2'] = x
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
outputs['fc'] = x
return outputs
net = DictLeNetNet()
x = ms.Tensor(np.random.rand(64, 1, 32, 32).astype(np.float32))
outputs = net(x)
print(f'outputs: {outputs}')
assert outputs['conv1'].shape == (64, 6, 28, 28)
assert outputs['conv2'].shape == (64, 16, 10, 10)
assert outputs['fc'].shape == (64, 10)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""test getting gradient of mutable input"""
import os
import numpy as np
import pytest
import mindspore.nn as nn
@ -160,10 +161,12 @@ def test_grad_mutable_dict_tensor():
gradient_function = self.grad_op(self.net)
return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
t = mutable({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -350,11 +353,13 @@ def test_grad_mutable_tuple_dict_tensor():
gradient_function = self.grad_op(self.net)
return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
t = mutable(({'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)))
output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -398,11 +403,13 @@ def test_grad_mutable_dict_tuple_tensor():
gradient_function = self.grad_op(self.net)
return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
t = mutable({'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -446,11 +453,13 @@ def test_grad_mutable_list_dict_tensor():
gradient_function = self.grad_op(self.net)
return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
t = mutable([{'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)},
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -494,11 +503,13 @@ def test_grad_mutable_dict_list_tensor():
gradient_function = self.grad_op(self.net)
return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
t = mutable({'a': [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)],
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
@ -699,11 +710,13 @@ def test_grad_mutable_unused_dict_tensor():
gradient_function = self.grad_op(self.net)
return gradient_function(z)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
t = mutable({'x1': Tensor([[4.0, 6.0, 6.0], [4.0, 6.0, 6.0]], dtype=mstype.float32),
'x2': Tensor([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], dtype=mstype.float32),
'x3': Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [np.array([[3., 3., 3.],
[3., 3., 3.]]).astype(np.float32),
@ -746,10 +759,12 @@ def test_grad_mutable_single_element_dict_tensor():
gradient_function = self.grad_op(self.net)
return gradient_function(x, t)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
y = mutable({'a': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)})
output = GradNetWrtX(Net())(x, y)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""test the feature of mutable in graph"""
import os
import numpy as np
import pytest
import mindspore.nn as nn
@ -456,8 +457,10 @@ def test_grad_const_dict_tensor_to_mutable():
gradient_function = self.grad_op(self.net)
return gradient_function(self.x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
grad_net = GradNetWrtX(Net())
output = grad_net()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -502,10 +505,12 @@ def test_grad_const_dict_tensor_arg_to_mutable():
gradient_function = self.grad_op(self.net)
return gradient_function(mutable(x))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
x = {'a': Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}
grad_net = GradNetWrtX(Net())
output = grad_net(x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -563,8 +568,10 @@ def test_grad_const_dict_and_tuple_tensor_to_mutable():
gradient_function = self.grad_op(self.net)
return gradient_function(self.x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
grad_net = GradNetWrtX(Net())
output = grad_net()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [(np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
@ -574,8 +581,10 @@ def test_grad_const_dict_and_tuple_tensor_to_mutable():
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
grad_net = GradNetWrtX1(Net())
output = grad_net()
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
assert compare(output, expect)
@ -612,11 +621,13 @@ def test_grad_const_dict_and_tuple_tensor_arg_to_mutable():
gradient_function = self.grad_op(self.net)
return gradient_function(mutable(x))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
x = {'a': (Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.5, 0.6, 4.0], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
'b': Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)}
grad_net = GradNetWrtX(Net())
output = grad_net(x)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
assert isinstance(output, tuple)
expect = [(np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),

View File

@ -14,6 +14,7 @@
# ============================================================================
"""st for scipy.optimize."""
import os
import pytest
import numpy as onp
import scipy as osp
@ -210,6 +211,7 @@ def test_bfgs_graph(dtype, func_x0):
Description: test cases for bfgs in GRAPH mode
Expectation: the result match scipy
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
func, x0 = func_x0
x0 = x0.astype(dtype)
@ -218,6 +220,7 @@ def test_bfgs_graph(dtype, func_x0):
options=dict(maxiter=None, gtol=1e-6))
scipy_res = osp.optimize.minimize(func(onp), x0, method='BFGS')
match_array(ms_res.x.asnumpy(), scipy_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
def _scalar_func_1(np):
@ -349,6 +352,7 @@ def test_line_search_graph(maxiter, func, x, p):
Description: test cases for n-d function in GRAPH mode
Expectation: the result match scipy
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
A = [[1.76405235, 0.40015721, 0.97873798, 2.2408932, 1.86755799],
[-0.97727788, 0.95008842, -0.15135721, -0.10321885, 0.4105985],
@ -356,8 +360,8 @@ def test_line_search_graph(maxiter, func, x, p):
[0.33367433, 1.49407907, -0.20515826, 0.3130677, -0.85409574],
[-2.55298982, 0.6536186, 0.8644362, -0.74216502, 2.26975462]]
osp_x, osp_p, osp_A = onp.array(x), onp.array(p), onp.array(A)
osp_f, osp_fp = func(onp, osp_A)
osp_x, osp_p, osp_a = onp.array(x), onp.array(p), onp.array(A)
osp_f, osp_fp = func(onp, osp_a)
osp_res = osp_line_search(osp_f, osp_fp, osp_x, osp_p, maxiter=maxiter)
msp_x, msp_p, msp_A = mnp.array(x), mnp.array(p), mnp.array(A)
@ -366,6 +370,7 @@ def test_line_search_graph(maxiter, func, x, p):
match_array(msp_res.a_k, osp_res[0], error=5)
match_array(msp_res.f_k, osp_res[3], error=5)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0
@ -380,6 +385,7 @@ def test_lbfgs1(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
@ -388,6 +394,7 @@ def test_lbfgs1(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0
@ -402,6 +409,7 @@ def test_lbfgs2(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
@ -410,6 +418,7 @@ def test_lbfgs2(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0
@ -424,6 +433,7 @@ def test_lbfgs3(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
@ -432,6 +442,7 @@ def test_lbfgs3(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0
@ -446,6 +457,7 @@ def test_lbfgs4(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
@ -454,6 +466,7 @@ def test_lbfgs4(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0
@ -468,6 +481,7 @@ def test_lbfgs5(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
@ -476,6 +490,7 @@ def test_lbfgs5(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0
@ -490,6 +505,7 @@ def test_lbfgs6(dtype, func_x0):
Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
func, x0 = func_x0
x0 = x0.astype(dtype)
x0_tensor = Tensor(x0)
@ -498,6 +514,7 @@ def test_lbfgs6(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level0
@ -511,6 +528,7 @@ def test_lbfgs_fixes4594(dtype):
Description: test cases for lbfgs in PYNATIVE mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
n = 2
a = Tensor(onp.eye(n, dtype=dtype)) * 1e4
@ -520,6 +538,7 @@ def test_lbfgs_fixes4594(dtype):
results = msp.optimize.minimize(func, Tensor(onp.ones(n, dtype=dtype)), method='LBFGS',
options=dict(maxiter=None, gtol=1e-6)).x
onp.testing.assert_allclose(results.asnumpy(), onp.zeros(n, dtype=dtype), rtol=1e-6, atol=1e-6)
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
@pytest.mark.level1
@ -534,6 +553,7 @@ def test_lbfgs_graph(dtype, func_x0):
Description: test cases for lbfgs in GRAPH mode
Expectation: the result match bfgs
"""
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
context.set_context(mode=context.GRAPH_MODE)
func, x0 = func_x0
x0 = x0.astype(dtype)
@ -543,3 +563,4 @@ def test_lbfgs_graph(dtype, func_x0):
ma_res = msp.optimize.minimize(func(mnp), x0_tensor, method='BFGS',
options=dict(maxiter=None, gtol=1e-6))
match_array(ms_res.x.asnumpy(), ma_res.x, error=5, err_msg=str(ms_res))
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'

View File

@ -42,7 +42,7 @@ using AbstractTensor = abstract::AbstractTensor;
using AbstractTensorPtr = abstract::AbstractTensorPtr;
using AbstractNone = abstract::AbstractNone;
using AbstractAttribute = abstract::AbstractAttribute;
using AbstractAttribute = abstract::AbstractElementPair;
using AnalysisEngine = abstract::AnalysisEngine;
using AnalysisEnginePtr = abstract::AnalysisEnginePtr;

View File

@ -1002,7 +1002,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
std::vector<AbstractElementPair> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
AbstractBasePtr key = abstract::FromValue("x");
AbstractBasePtrList args_spec_list = {array_dict, key};

View File

@ -155,6 +155,7 @@ def test_dict_set_item():
_ = net(x)
@pytest.mark.skip(reason="Do not support dict value for dict set item yet.")
def test_dict_set_item_2():
"""
Description: test dict in dict set item.
@ -184,6 +185,7 @@ def test_dict_set_item_2():
assert second[1] == 1
@pytest.mark.skip(reason="Do not support dict value for dict set item yet.")
def test_dict_set_item_3():
"""
Description: test dict in dict set item.
@ -207,7 +209,7 @@ def test_dict_set_item_3():
assert first[0][1] == 3
# if the dictionary item does not exist, create a new one
# If the dictionary item does not exist, create a new one
def test_dict_set_item_create_new():
class DictSetNet(Cell):
def __init__(self):