forked from mindspore-Ecosystem/mindspore
[JIT Fallback] Supports tensor.asnumpy() and return dictionary features in construct() for GraphMode.
This commit is contained in:
parent
a1d5aebc72
commit
00a000c5ce
|
@ -127,7 +127,13 @@ tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
|
|||
auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
|
||||
if (i >= input_device_address->size() || input_device_address->at(i) == nullptr) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->fullname_with_scope();
|
||||
if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
|
||||
MS_LOG(INFO) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
|
||||
<< node->fullname_with_scope();
|
||||
return out_tensor;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
|
||||
<< node->fullname_with_scope();
|
||||
}
|
||||
|
||||
out_tensor->data_sync_directly(input_device_address->at(i));
|
||||
|
|
|
@ -846,10 +846,10 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
|
|||
return rectify_abs_list;
|
||||
}
|
||||
|
||||
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
|
||||
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
|
||||
const AbstractBasePtrList &input_abstract) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto dynamic_inputs_list = prim->GetAttr(kAttrDynInputSizes);
|
||||
if (dynamic_inputs_list == nullptr) {
|
||||
return input_abstract;
|
||||
}
|
||||
|
@ -871,6 +871,10 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitiv
|
|||
AbstractBasePtrList dynamic_inputs_abs;
|
||||
for (auto index = item; index > 0; --index) {
|
||||
if (input_index >= input_abstract.size()) {
|
||||
// Not to check for PyExecute.
|
||||
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
|
||||
continue;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract "
|
||||
<< input_abstract.size();
|
||||
}
|
||||
|
|
|
@ -39,11 +39,11 @@ namespace mindspore {
|
|||
namespace prim {
|
||||
constexpr auto kStepDefault = 1;
|
||||
|
||||
using mindspore::abstract::AbstractAttribute;
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractBasePtr;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractElementPair;
|
||||
using mindspore::abstract::AbstractEllipsis;
|
||||
using mindspore::abstract::AbstractEllipsisPtr;
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
|
|
|
@ -49,12 +49,12 @@ FuncGraphPtr DictHasKey::GenerateFuncGraph(const abstract::AbstractBasePtrList &
|
|||
ValuePtr key_value = args_list[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(dict);
|
||||
MS_EXCEPTION_IF_NULL(key_value);
|
||||
auto dict_elems = dict->elements();
|
||||
auto elems = dict->elements();
|
||||
bool has_key = false;
|
||||
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const abstract::AbstractAttribute &item) {
|
||||
auto it = std::find_if(elems.cbegin(), elems.cend(), [&key_value](const abstract::AbstractElementPair &item) {
|
||||
return *key_value == *item.first->BuildValue();
|
||||
});
|
||||
if (it != dict_elems.cend()) {
|
||||
if (it != elems.cend()) {
|
||||
has_key = true;
|
||||
}
|
||||
|
||||
|
@ -153,7 +153,7 @@ abstract::AbstractBasePtrList DictFromKeys::ParseIterableObject(const abstract::
|
|||
AbstractBasePtrList keys;
|
||||
auto &dict_elems = dict_keys->elements();
|
||||
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
|
||||
[](const abstract::AbstractAttribute &item) { return item.first; });
|
||||
[](const abstract::AbstractElementPair &item) { return item.first; });
|
||||
return keys;
|
||||
}
|
||||
if (key_type->IsSameTypeId(String::kTypeId)) {
|
||||
|
|
|
@ -28,10 +28,10 @@
|
|||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
using mindspore::abstract::AbstractAttribute;
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractElementPair;
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractKeywordArg;
|
||||
using mindspore::abstract::AbstractList;
|
||||
|
@ -78,7 +78,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_l
|
|||
auto dict_elems = arg_dict->elements();
|
||||
(void)std::transform(
|
||||
dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(elems),
|
||||
[res_graph, para_dict](const AbstractAttribute &item) {
|
||||
[res_graph, para_dict](const AbstractElementPair &item) {
|
||||
// Dict_elems's first element represents parameter names, which should be string type.
|
||||
auto key_value = GetValue<std::string>(item.first->BuildValue());
|
||||
auto dict_get_item =
|
||||
|
|
|
@ -38,11 +38,11 @@
|
|||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
using mindspore::abstract::AbstractAttribute;
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractBasePtr;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractElementPair;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractListPtr;
|
||||
using mindspore::abstract::AbstractRowTensor;
|
||||
|
@ -164,7 +164,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
public:
|
||||
using ThisClass = SimplifyDataStructuresRewriter;
|
||||
SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
|
||||
: BaseRewriter(root_graph, manager) {}
|
||||
: BaseRewriter(root_graph, manager), is_dict_output_{IsDictOutput()} {}
|
||||
~SimplifyDataStructuresRewriter() override = default;
|
||||
|
||||
protected:
|
||||
|
@ -176,7 +176,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
return str->value();
|
||||
}
|
||||
|
||||
static int64_t GetAttrIndex(const std::vector<AbstractAttribute> &attrs, const AnfNodePtr &name) {
|
||||
static int64_t GetElementIndex(const std::vector<AbstractElementPair> &attrs, const AnfNodePtr &name) {
|
||||
auto n_attrs = attrs.size();
|
||||
auto name_abstract = GetAbstract<AbstractBase>(name);
|
||||
MS_EXCEPTION_IF_NULL(name_abstract);
|
||||
|
@ -191,15 +191,15 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
}
|
||||
|
||||
static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,
|
||||
const std::vector<AbstractAttribute> &attributes, const AnfNodePtr &name_node) {
|
||||
int64_t index = GetAttrIndex(attributes, name_node);
|
||||
const std::vector<AbstractElementPair> &elements, const AnfNodePtr &name_node) {
|
||||
int64_t index = GetElementIndex(elements, name_node);
|
||||
auto index_node = NewValueNode(index);
|
||||
auto prim_node = NewValueNode(prim::kPrimTupleGetItem);
|
||||
return cnode->func_graph()->NewCNode({prim_node, data_node, index_node});
|
||||
}
|
||||
|
||||
// From:
|
||||
// DictGetItem(data:AbstractDictionary, cons:AbstractBase)
|
||||
// DictGetItem(data:AbstractDictionary, key:AbstractBase)
|
||||
// To:
|
||||
// TupleGetItem(data, index:Int64Imm)
|
||||
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
|
||||
|
@ -211,27 +211,98 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
CheckInputsSize(node, expect_inputs_size);
|
||||
|
||||
constexpr size_t data_index = 1;
|
||||
constexpr size_t attr_index = 2;
|
||||
constexpr size_t key_index = 2;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &attr = inputs[attr_index];
|
||||
auto &key = inputs[key_index];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_EXCEPTION_IF_NULL(attr);
|
||||
MS_EXCEPTION_IF_NULL(key);
|
||||
|
||||
auto abs_dict = GetAbstract<AbstractDictionary>(data);
|
||||
if (abs_dict == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return NewTupleGetCNode(node, data, abs_dict->elements(), attr);
|
||||
return NewTupleGetCNode(node, data, abs_dict->elements(), key);
|
||||
}
|
||||
|
||||
// DictGetItem --> PyExecute()
|
||||
AnfNodePtr RebuidDictGetItem(const CNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Inputs should be [dict_setitem, dict, item]
|
||||
const size_t expect_inputs_size = 3;
|
||||
CheckInputsSize(node, expect_inputs_size);
|
||||
|
||||
const size_t data_index = 1;
|
||||
const size_t item_key_index = 2;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &key = inputs[item_key_index];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_EXCEPTION_IF_NULL(key);
|
||||
|
||||
auto abs_dict = GetAbstract<AbstractDictionary>(data);
|
||||
if (abs_dict == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
// Script
|
||||
constexpr auto internal_dict_self_str = "__internal_dict_self__";
|
||||
constexpr auto internal_dict_key_str = "__internal_dict_key__";
|
||||
std::stringstream script_buffer;
|
||||
script_buffer << internal_dict_self_str << "[" << internal_dict_key_str << "]";
|
||||
const std::string &script = script_buffer.str();
|
||||
const auto script_str = std::make_shared<StringImm>(script);
|
||||
|
||||
// Pack local parameters keys.
|
||||
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
|
||||
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
|
||||
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
|
||||
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
|
||||
|
||||
// Pack the local parameters values, not support list, tuple, or dict.
|
||||
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_list.emplace_back(data);
|
||||
(void)key_value_list.emplace_back(key);
|
||||
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
|
||||
|
||||
// Build the new dict node.
|
||||
const auto dict_getitem_node = func_graph->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
int64_t index = GetElementIndex(abs_dict->elements(), key);
|
||||
const auto &val = abs_dict->elements()[index].second;
|
||||
const auto &tensor_val = dyn_cast<abstract::AbstractTensor>(val);
|
||||
if (tensor_val != nullptr) {
|
||||
const auto &tensor_type = tensor_val->element()->BuildType();
|
||||
dict_getitem_node->set_user_data<Type>("__py_execute_tensor_type__", tensor_type);
|
||||
const auto &tensor_shape = dyn_cast<abstract::Shape>(tensor_val->BuildShape());
|
||||
MS_EXCEPTION_IF_NULL(tensor_shape);
|
||||
dict_getitem_node->set_user_data<abstract::Shape>("__py_execute_tensor_shape__", tensor_shape);
|
||||
MS_LOG(DEBUG) << "key: " << key->abstract()->BuildValue()->ToString() << ", type: " << tensor_type->ToString()
|
||||
<< ", shape: " << tensor_shape->ToString() << ", val: " << tensor_val->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Made dict getitem node: " << dict_getitem_node->DebugString();
|
||||
return dict_getitem_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertDictGetItem(const CNodePtr &node) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (support_fallback_runtime && is_dict_output_) {
|
||||
return RebuidDictGetItem(node);
|
||||
}
|
||||
return ConvertDictGetItemToTupleGetItem(node);
|
||||
}
|
||||
|
||||
// From:
|
||||
// DictSetItem(data:AbstractDictionary, cons:AbstractBase, value)
|
||||
// DictSetItem(data:AbstractDictionary, key:AbstractBase, value)
|
||||
// To:
|
||||
// TupleSetItem(data, index:Int64Imm, value)
|
||||
// Or:
|
||||
// tuple_add(data, value)
|
||||
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
|
||||
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
||||
|
@ -244,16 +315,18 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
const size_t item_value_index = 3;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &cons = inputs[cons_index];
|
||||
auto &key = inputs[cons_index];
|
||||
auto &item_value = inputs[item_value_index];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_EXCEPTION_IF_NULL(cons);
|
||||
MS_EXCEPTION_IF_NULL(key);
|
||||
|
||||
auto abs_dict = GetAbstract<AbstractDictionary>(data);
|
||||
if (abs_dict == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
int64_t index = GetAttrIndex(abs_dict->elements(), cons);
|
||||
int64_t index = GetElementIndex(abs_dict->elements(), key);
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
if (index >= static_cast<int64_t>(abs_dict->elements().size())) {
|
||||
// For dictionary set, if the key does not exist, we should create a new item.
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
|
@ -269,11 +342,86 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, index_node, item_value});
|
||||
}
|
||||
|
||||
bool IsDictOutput() const {
|
||||
const AnfNodePtr &output = root_graph_->output();
|
||||
auto abs_dict = GetAbstract<AbstractDictionary>(output);
|
||||
if (abs_dict != nullptr) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// DictSetItem --> PyExecute()
|
||||
AnfNodePtr RebuidDictSetItem(const CNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Inputs should be [dict_setitem, dict, item, value]
|
||||
const size_t expect_inputs_size = 4;
|
||||
CheckInputsSize(node, expect_inputs_size);
|
||||
|
||||
const size_t data_index = 1;
|
||||
const size_t item_key_index = 2;
|
||||
const size_t item_value_index = 3;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &key = inputs[item_key_index];
|
||||
auto &item_value = inputs[item_value_index];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_EXCEPTION_IF_NULL(key);
|
||||
|
||||
auto abs_dict = GetAbstract<AbstractDictionary>(data);
|
||||
if (abs_dict == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto func_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
// Script
|
||||
constexpr auto internal_dict_self_str = "__internal_dict_self__";
|
||||
constexpr auto internal_dict_key_str = "__internal_dict_key__";
|
||||
constexpr auto internal_dict_value_str = "__internal_dict_value__";
|
||||
std::stringstream script_buffer;
|
||||
script_buffer << "__import__('mindspore').update_and_return_dict(" << internal_dict_self_str << ", "
|
||||
<< internal_dict_key_str << ", " << internal_dict_value_str << ")";
|
||||
const std::string &script = script_buffer.str();
|
||||
const auto script_str = std::make_shared<StringImm>(script);
|
||||
|
||||
// Pack local parameters keys.
|
||||
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
|
||||
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
|
||||
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_value_str);
|
||||
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
|
||||
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
|
||||
|
||||
// Pack the local parameters values, not support list, tuple, or dict.
|
||||
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_list.emplace_back(data);
|
||||
(void)key_value_list.emplace_back(key);
|
||||
(void)key_value_list.emplace_back(item_value);
|
||||
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
|
||||
|
||||
// Build the new dict node.
|
||||
const auto dict_setitem_node = func_graph->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
MS_LOG(DEBUG) << "Made dict setitem node: " << dict_setitem_node->DebugString();
|
||||
return dict_setitem_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertDictSetItem(const CNodePtr &node) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (support_fallback_runtime && is_dict_output_) {
|
||||
return RebuidDictSetItem(node);
|
||||
}
|
||||
return ConvertDictSetItemToTupleSetItem(node);
|
||||
}
|
||||
|
||||
// From:
|
||||
// MakeDict(name, input)
|
||||
// To:
|
||||
// input
|
||||
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
|
||||
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
constexpr size_t expect_inputs_size = 3;
|
||||
constexpr size_t input_index = 2;
|
||||
|
@ -281,6 +429,62 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
return node->input(input_index);
|
||||
}
|
||||
|
||||
// MakeDict(keys, values) --> PyExecute('dict(zip(keys, values))', ...)
|
||||
AnfNodePtr RebuildMakeDictNode(const CNodePtr &node) const {
|
||||
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
|
||||
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
|
||||
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
|
||||
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
|
||||
const auto &fg = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
|
||||
// Local parameters values.
|
||||
// Pack the key tuple.
|
||||
constexpr size_t values_input_index = 2;
|
||||
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
|
||||
const auto make_key_tuple_node =
|
||||
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
|
||||
NewValueNode(script_key_tuple_str), node->input(values_input_index)});
|
||||
// Pack the value tuple.
|
||||
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
|
||||
const auto make_value_tuple_node =
|
||||
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
|
||||
NewValueNode(script_value_tuple_str), node->input(1)});
|
||||
// Pack the local parameters values
|
||||
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_list.emplace_back(make_key_tuple_node);
|
||||
(void)key_value_list.emplace_back(make_value_tuple_node);
|
||||
const auto key_value_tuple = fg->NewCNode(key_value_list);
|
||||
|
||||
// Pack local parameters keys.
|
||||
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
|
||||
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
|
||||
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
|
||||
const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
|
||||
|
||||
// Script
|
||||
std::stringstream script_buffer;
|
||||
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
|
||||
const std::string &script = script_buffer.str();
|
||||
const auto script_str = std::make_shared<StringImm>(script);
|
||||
|
||||
// Build the new dict node.
|
||||
const auto make_dict_node = fg->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
|
||||
return make_dict_node;
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertMakeDict(const CNodePtr &node) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (support_fallback_runtime && is_dict_output_) {
|
||||
return RebuildMakeDictNode(node);
|
||||
}
|
||||
return EraseMakeDictNode(node);
|
||||
}
|
||||
|
||||
// From:
|
||||
// DictGetValues(dict:AbstractDictionary)
|
||||
// To:
|
||||
|
@ -350,22 +554,79 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
}
|
||||
|
||||
// dict(k0:v0, k1:v1, ...) --> tuple(v0, v1, ...)
|
||||
ValueTuplePtr DictToTuple(const ValueDictionaryPtr &dict) const {
|
||||
const auto &elements = dict->value();
|
||||
std::vector<ValuePtr> values;
|
||||
values.reserve(elements.size());
|
||||
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(values),
|
||||
[](const auto &element) { return element.second; });
|
||||
return std::make_shared<ValueTuple>(values);
|
||||
AnfNodePtr DictToTuple(const ValueDictionaryPtr &dict) const {
|
||||
const auto &keys_values = dict->value();
|
||||
std::vector<ValuePtr> value_list;
|
||||
value_list.reserve(keys_values.size());
|
||||
(void)std::transform(keys_values.begin(), keys_values.end(), std::back_inserter(value_list),
|
||||
[](const auto &value) { return value.second; });
|
||||
return NewValueNode(std::make_shared<ValueTuple>(value_list));
|
||||
}
|
||||
|
||||
// dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
|
||||
AnfNodePtr RebuildValueDict(const ValueDictionaryPtr &dict) const {
|
||||
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
|
||||
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
|
||||
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
|
||||
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
|
||||
|
||||
const auto &keys_values = dict->value();
|
||||
std::vector<ValuePtr> key_list;
|
||||
key_list.reserve(keys_values.size());
|
||||
std::vector<ValuePtr> value_list;
|
||||
value_list.reserve(keys_values.size());
|
||||
for (const auto &key_value : keys_values) {
|
||||
(void)key_list.emplace_back(key_value.first);
|
||||
(void)value_list.emplace_back(key_value.second);
|
||||
}
|
||||
|
||||
// Local parameters values.
|
||||
// Pack the key tuple.
|
||||
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
|
||||
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
|
||||
const auto make_key_tuple_node =
|
||||
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
|
||||
NewValueNode(script_key_tuple_str), NewValueNode(key_tuple)});
|
||||
// Pack the value tuple.
|
||||
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
|
||||
const auto value_tuple = std::make_shared<ValueTuple>(value_list);
|
||||
const auto make_value_tuple_node =
|
||||
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
|
||||
NewValueNode(script_value_tuple_str), NewValueNode(value_tuple)});
|
||||
// Pack the local parameters values
|
||||
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_list.emplace_back(make_key_tuple_node);
|
||||
(void)key_value_list.emplace_back(make_value_tuple_node);
|
||||
const auto key_value_tuple = root_graph_->NewCNode(key_value_list);
|
||||
|
||||
// Pack local parameters keys.
|
||||
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
|
||||
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
|
||||
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
|
||||
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
|
||||
const auto key_value_name_tuple = root_graph_->NewCNode(key_value_names_list);
|
||||
|
||||
// Script
|
||||
std::stringstream script_buffer;
|
||||
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
|
||||
const std::string &script = script_buffer.str();
|
||||
const auto script_str = std::make_shared<StringImm>(script);
|
||||
|
||||
// Build the new dict node.
|
||||
const auto make_dict_node = root_graph_->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
|
||||
return make_dict_node;
|
||||
}
|
||||
|
||||
using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &);
|
||||
using ConverterMap = mindspore::HashMap<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
|
||||
static inline const ConverterMap converters_{
|
||||
{prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItemToTupleGetItem},
|
||||
{prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItemToTupleSetItem},
|
||||
{prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
|
||||
{prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
|
||||
{prim::kPrimDictGetValues, &ThisClass::EraseDictGetValues},
|
||||
{prim::kPrimMakeDict, &ThisClass::EraseMakeDictNode},
|
||||
{prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
|
||||
{prim::kPrimMakeKeywordArg, &ThisClass::EraseMakeKeywordArgNode},
|
||||
{prim::kPrimExtractKeywordArg, &ThisClass::EraseExtractKeywordArg},
|
||||
{prim::kPrimDictItems, &ThisClass::EraseDictItems},
|
||||
|
@ -384,12 +645,16 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
AnfNodePtr ConvertValueNode(const ValueNodePtr &, const ValuePtr &value) override {
|
||||
// Convert Dictionary value node.
|
||||
if (value->isa<ValueDictionary>()) {
|
||||
return NewValueNode(DictToTuple(value->cast<ValueDictionaryPtr>()));
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (support_fallback_runtime && is_dict_output_) {
|
||||
return RebuildValueDict(value->cast<ValueDictionaryPtr>());
|
||||
}
|
||||
return DictToTuple(value->cast<ValueDictionaryPtr>());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractAttribute> &attrs) {
|
||||
static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractElementPair> &attrs) {
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(attrs.size());
|
||||
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(elements),
|
||||
|
@ -459,6 +724,9 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
// AbstractDictionary --> AbstractSequence.
|
||||
return ConvertToAbstractSequence(abs, 0);
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_dict_output_{false};
|
||||
};
|
||||
|
||||
// ==================================================================
|
||||
|
@ -489,9 +757,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
}
|
||||
|
||||
// From:
|
||||
// ListGetItem(list, cons)
|
||||
// ListGetItem(list, key)
|
||||
// To:
|
||||
// TupleGetItem(list, cons)
|
||||
// TupleGetItem(list, key)
|
||||
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
@ -503,8 +771,8 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
constexpr size_t cons_index = 2;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &cons = inputs[cons_index];
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons});
|
||||
auto &key = inputs[cons_index];
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, key});
|
||||
}
|
||||
|
||||
// From:
|
||||
|
@ -524,9 +792,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
const size_t value_index = 3;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &cons = inputs[cons_index];
|
||||
auto &key = inputs[cons_index];
|
||||
auto &value = inputs[value_index];
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, key, value});
|
||||
}
|
||||
|
||||
// From:
|
||||
|
|
|
@ -51,7 +51,7 @@ class ResolveNodeResolve : public AnfVisitor {
|
|||
if (IsValueNode<parse::NameSpace>(vnode)) {
|
||||
auto name_space = GetValueNode<parse::NameSpacePtr>(vnode);
|
||||
MS_EXCEPTION_IF_NULL(name_space);
|
||||
obj_ = name_space->obj();
|
||||
obj_ = name_space->namespace_obj();
|
||||
} else if (IsValueNode<parse::Symbol>(vnode)) {
|
||||
auto symbol_value = GetValueNode<parse::SymbolPtr>(vnode);
|
||||
MS_EXCEPTION_IF_NULL(symbol_value);
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/optimizer/py_interpret_to_execute.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/anf_utils.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
namespace {
|
||||
py::object CallPythonPushGlobalParams(const py::object &dict) {
|
||||
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
|
||||
py::module mod = python_adapter::GetPyModule(python_mod_parse);
|
||||
constexpr auto python_merge_dict = "merge_global_params";
|
||||
return python_adapter::CallPyModFn(mod, python_merge_dict, dict);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Convert PyInterpret into PyExecute:
|
||||
// PyInterpret(script, global_dict, local_dict)
|
||||
// -->
|
||||
// PyExecute(script, local_dict_keys, local_dict_values),
|
||||
// with side-effect operation:
|
||||
// Push global_dict into global parameters list.
|
||||
// (So it requires no same key name.)
|
||||
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
|
||||
auto manager = resource->manager();
|
||||
const auto &all_nodes = manager->all_nodes();
|
||||
auto transact = manager->Transact();
|
||||
constexpr auto input_index_one = 1;
|
||||
constexpr auto input_index_two = 2;
|
||||
constexpr auto input_index_three = 3;
|
||||
for (const auto &node : all_nodes) {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
|
||||
continue;
|
||||
}
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
MS_LOG(DEBUG) << "cnode: " << cnode->DebugString();
|
||||
auto new_cnode = std::make_shared<CNode>(*cnode);
|
||||
new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
|
||||
|
||||
if (!IsValueNode<parse::Script>(cnode->input(input_index_one))) {
|
||||
MS_LOG(EXCEPTION) << "The first input should be a Script, but got "
|
||||
<< cnode->input(input_index_one)->DebugString();
|
||||
}
|
||||
const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(cnode->input(input_index_one));
|
||||
const auto &script_str = script->script();
|
||||
const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
|
||||
new_cnode->set_input(input_index_one, script_strimm_node);
|
||||
|
||||
if (!IsValueNode<ValueDictionary>(cnode->input(input_index_two))) {
|
||||
MS_LOG(EXCEPTION) << "The second input should be a dictionary, but got "
|
||||
<< cnode->input(input_index_two)->DebugString();
|
||||
}
|
||||
const auto &global_dict = GetValueNode<ValueDictionaryPtr>(cnode->input(input_index_two));
|
||||
py::object py_global_dict = ValueToPyData(global_dict);
|
||||
MS_LOG(DEBUG) << "py_global_dict: " << py::str(py_global_dict);
|
||||
CallPythonPushGlobalParams(py_global_dict);
|
||||
|
||||
if (!IsPrimitiveCNode(cnode->input(input_index_three), prim::kPrimMakeDict)) {
|
||||
MS_LOG(EXCEPTION) << "The 3rd input should be a dictionary, but got "
|
||||
<< cnode->input(input_index_three)->DebugString();
|
||||
}
|
||||
const auto &local_dict_cnode = dyn_cast<CNode>(cnode->input(input_index_three));
|
||||
MS_EXCEPTION_IF_NULL(local_dict_cnode);
|
||||
const auto &local_dict_keys = local_dict_cnode->input(input_index_one);
|
||||
const auto &local_dict_values = local_dict_cnode->input(input_index_two);
|
||||
if (!IsValueNode<ValueTuple>(local_dict_keys) || !IsPrimitiveCNode(local_dict_values, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "The dictionary's keys and values should be a tuple, but got "
|
||||
<< local_dict_cnode->DebugString();
|
||||
}
|
||||
new_cnode->set_input(input_index_two, local_dict_keys);
|
||||
new_cnode->set_input(input_index_three, local_dict_values);
|
||||
(void)transact.Replace(cnode, new_cnode);
|
||||
}
|
||||
transact.Commit();
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_INTERPRET_TO_EXECUTE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_INTERPRET_TO_EXECUTE_H_
|
||||
|
||||
#include "pipeline/jit/resource.h"
|
||||
|
||||
namespace mindspore {
|
||||
/* namespace to support opt */
|
||||
namespace opt {
|
||||
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_INTERPRET_TO_EXECUTE_H_
|
|
@ -1432,7 +1432,7 @@ tensor::TensorPtr GetDependValueByConstTensor(const AnfNodePtr &input_node, cons
|
|||
auto value = value_node->value();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (!value->isa<tensor::Tensor>()) {
|
||||
MS_EXCEPTION(ValueError) << "the cnode " << cnode_name << "'s input[" << i << "], must be tensor, but got "
|
||||
MS_EXCEPTION(ValueError) << "The CNode " << cnode_name << "'s input[" << i << "], must be tensor, but got "
|
||||
<< value->ToString();
|
||||
}
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
|
|
|
@ -239,8 +239,7 @@ ValuePtr ConvertModuleNameSpace(const py::object &obj) {
|
|||
MS_LOG(DEBUG) << "Converting python module";
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
|
||||
auto converted =
|
||||
std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace), obj);
|
||||
auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, module_namespace, obj);
|
||||
MS_LOG(DEBUG) << "name_space: " << converted->ToString();
|
||||
return converted;
|
||||
}
|
||||
|
|
|
@ -376,8 +376,9 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
|
|||
}
|
||||
|
||||
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
|
||||
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
|
||||
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
|
||||
MS_LOG(DEBUG) << "MakeResolve for "
|
||||
<< (name_space ? (std::string)py::str(name_space->namespace_obj()) : "null namespace") << " , "
|
||||
<< (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
|
||||
ValueNodePtr module_node = NewValueNode(name_space);
|
||||
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
|
||||
auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
|
||||
|
@ -388,7 +389,7 @@ AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const An
|
|||
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
|
||||
MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
|
||||
MS_EXCEPTION_IF_NULL(orig_node);
|
||||
ScriptPtr script = std::make_shared<Script>(script_text);
|
||||
auto script = std::make_shared<Script>(script_text);
|
||||
auto script_node = NewValueNode(script);
|
||||
auto node = func_graph_->NewCNodeInOrder(
|
||||
{NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
|
||||
|
|
|
@ -1218,6 +1218,22 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
|
|||
UpdateInterpretForUserNode(call_cnode, call_function_node);
|
||||
MS_EXCEPTION_IF_NULL(call_cnode);
|
||||
|
||||
MS_LOG(DEBUG) << "call_cnode: " << call_cnode->DebugString()
|
||||
<< ", call_function_node: " << call_function_node->DebugString();
|
||||
// Support tensor.asnumpy() in runtime by JIT Fallback.
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (support_fallback_runtime && IsPrimitiveCNode(call_function_node, prim::kPrimGetAttr)) {
|
||||
constexpr size_t index_two = 2;
|
||||
const auto &attr_node = call_function_node->cast<CNodePtr>()->input(index_two);
|
||||
const auto &attr_str = GetValueNode<StringImmPtr>(attr_node);
|
||||
MS_EXCEPTION_IF_NULL(attr_str);
|
||||
if (attr_str->value() == "asnumpy") {
|
||||
call_cnode->set_interpret(true);
|
||||
call_cnode = HandleInterpret(block, call_cnode, node);
|
||||
return call_cnode;
|
||||
}
|
||||
}
|
||||
|
||||
// Process bulitin function, for example, sum(np.array(xx))
|
||||
py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, name_id);
|
||||
constexpr size_t namespace_info_size = 4;
|
||||
|
@ -1432,7 +1448,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
|
||||
// Process the node attr
|
||||
auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
|
||||
MS_LOG(DEBUG) << "Attr = " << attr_str;
|
||||
MS_LOG(DEBUG) << "node: " << node << ", attr: " << attr_str << ", value: " << value_body;
|
||||
// The fallback feature is enabled in default.
|
||||
static const auto use_fallback = (support_fallback() != "0");
|
||||
if (use_fallback && attr_str == "Tensor") {
|
||||
|
@ -1449,21 +1465,19 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
|
|||
return ret;
|
||||
}
|
||||
}
|
||||
AnfNodePtr attr_node = nullptr;
|
||||
{
|
||||
TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
|
||||
attr_node = NewValueNode(attr_str);
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
// Create the apply node
|
||||
AnfNodePtr attr_node = NewValueNode(attr_str);
|
||||
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
|
||||
if (use_fallback) {
|
||||
// Check whether it is constant, constant does not need interpret.
|
||||
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));
|
||||
py::bool_ is_const_value =
|
||||
ast()->CallParserObjMethod(PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE, value_str, common::SafeCStr(attr_str));
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
auto is_constant = py::cast<bool>(is_const_value);
|
||||
if (!is_constant) {
|
||||
if (!is_constant || (support_fallback_runtime && attr_str == "asnumpy")) {
|
||||
UpdateInterpretForUserNode(attr_cnode, value_node);
|
||||
}
|
||||
}
|
||||
|
@ -2114,7 +2128,7 @@ CNodePtr GenerateInterpretGetItem(const FuncGraphPtr &fg, const AnfNodePtr &iter
|
|||
auto local_dict_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), local_dict_key, local_dict_value});
|
||||
|
||||
// Construct script text node.
|
||||
parse::ScriptPtr script = std::make_shared<parse::Script>("x[i]");
|
||||
auto script = std::make_shared<Script>("x[i]");
|
||||
auto script_node = NewValueNode(script);
|
||||
|
||||
return fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
|
||||
|
|
|
@ -481,7 +481,7 @@ py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symb
|
|||
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
|
||||
}
|
||||
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
|
||||
auto &obj = name_space->obj();
|
||||
auto &obj = name_space->namespace_obj();
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
MS_EXCEPTION(NameError) << "The name \'" << symbol << "\' is not defined.";
|
||||
}
|
||||
|
|
|
@ -40,15 +40,15 @@ namespace parse {
|
|||
// NameSpace class for resolving python code.
|
||||
class NameSpace final : public Named {
|
||||
public:
|
||||
NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object())
|
||||
: Named(module + ": \'" + std::string(py::str(obj)) + "\'"),
|
||||
NameSpace(const std::string &module, const py::object &namespace_obj, const py::object &module_obj = py::object())
|
||||
: Named(module + ": \'" + std::string(py::str(namespace_obj)) + "\'"),
|
||||
module_(module),
|
||||
obj_(obj),
|
||||
namespace_obj_(namespace_obj),
|
||||
module_obj_(module_obj) {}
|
||||
~NameSpace() override = default;
|
||||
MS_DECLARE_PARENT(NameSpace, Named);
|
||||
|
||||
const py::object &obj() const { return obj_; }
|
||||
const py::object &namespace_obj() const { return namespace_obj_; }
|
||||
const py::object &module_obj() const { return module_obj_; }
|
||||
const std::string &module() const { return module_; }
|
||||
abstract::AbstractBasePtr ToAbstract() override {
|
||||
|
@ -59,7 +59,7 @@ class NameSpace final : public Named {
|
|||
// namespace of the module
|
||||
std::string module_;
|
||||
// namespace object
|
||||
py::object obj_;
|
||||
py::object namespace_obj_;
|
||||
// module object
|
||||
py::object module_obj_;
|
||||
};
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#include "frontend/optimizer/environ_conversion.h"
|
||||
#include "frontend/optimizer/comm_op_reuse_tag.h"
|
||||
#include "frontend/optimizer/overlap_opt_shard_in_pipeline.h"
|
||||
#include "frontend/optimizer/py_interpret_to_execute.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "pipeline/jit/pipeline_split.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
|
@ -86,6 +87,19 @@ void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource)
|
|||
}
|
||||
} // namespace
|
||||
|
||||
bool PyInterpretToExecutePass(const ResourcePtr &resource) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (!support_fallback_runtime) {
|
||||
return true;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
FuncGraphPtr func_graph = resource->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
(void)opt::PyInterpretToExecute(resource);
|
||||
UpdateArgsSpec(func_graph, resource);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SimplifyDataStructuresPass(const ResourcePtr &resource) {
|
||||
MS_EXCEPTION_IF_NULL(resource);
|
||||
FuncGraphPtr func_graph = resource->func_graph();
|
||||
|
@ -840,6 +854,7 @@ bool AddEmbeddingCachePass(const ResourcePtr &resource) {
|
|||
}
|
||||
|
||||
std::vector<PassItem> kVmPasses = {
|
||||
{"py_interpret_to_execute", PyInterpretToExecutePass},
|
||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_after_opta", CleanAfterOptAPass},
|
||||
|
@ -859,7 +874,8 @@ std::vector<PassItem> kVmPasses = {
|
|||
{"overlap_opt_shard_in_pipeline", OverlapOptShardInPipelinePass},
|
||||
};
|
||||
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
std::vector<PassItem> kGePasses = {{"py_interpret_to_execute", PyInterpretToExecutePass},
|
||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_after_opta", CleanAfterOptAPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
|
|
|
@ -69,6 +69,7 @@
|
|||
#include "distributed/collective/collective_manager.h"
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h"
|
||||
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
|
||||
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
|
||||
#include "distributed/init.h"
|
||||
#include "profiler/device/profiling.h"
|
||||
#include "kernel/akg/akg_kernel_build_manager.h"
|
||||
|
@ -269,6 +270,33 @@ std::string ToOrdinal(const size_t &i) {
|
|||
}
|
||||
return std::to_string(i) + suffix;
|
||||
}
|
||||
|
||||
py::object GetPyExecuteOutput(const AnfNodePtr &output) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (support_fallback_runtime) {
|
||||
std::function<AnfNodePtr(const AnfNodePtr &)> get_real_output = [&get_real_output](const AnfNodePtr &node) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
const auto cnode = dyn_cast<CNode>(node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return get_real_output(cnode->input(1));
|
||||
}
|
||||
return node;
|
||||
};
|
||||
const auto &real_output = get_real_output(output);
|
||||
MS_LOG(INFO) << "Real output: " << real_output << ", " << real_output->DebugString()
|
||||
<< ", has \'PyExecuteOutputData\': " << real_output->has_user_data<kernel::PyExecuteOutputData>();
|
||||
if (real_output->has_user_data<kernel::PyExecuteOutputData>()) {
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
const auto &output_data = real_output->user_data<kernel::PyExecuteOutputData>();
|
||||
py::object res_obj = output_data->obj;
|
||||
MS_LOG(INFO) << "Has \'PyExecuteOutputData\', just return it. res_obj: " << res_obj;
|
||||
if (!py::isinstance<py::none>(res_obj)) {
|
||||
return res_obj;
|
||||
}
|
||||
}
|
||||
}
|
||||
return py::none();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::string GetObjDesc(const py::object &source_obj) {
|
||||
|
@ -1409,16 +1437,25 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
|
|||
ConfigManager::GetInstance().set_gpu_loopsink_size(loop_size);
|
||||
}
|
||||
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
|
||||
py::object ret;
|
||||
py::object res;
|
||||
MS_LOG(DEBUG) << "Eval run" << ms_context->backend_policy();
|
||||
auto output = execute_info->func_graph->output()->abstract();
|
||||
const auto &output = execute_info->func_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
const auto &output_abs = output->abstract();
|
||||
MS_EXCEPTION_IF_NULL(output_abs);
|
||||
for (int64_t i = 0; i < vm_loop; i++) {
|
||||
BaseRef value = (*run)(execute_info->arg_list);
|
||||
ret = BaseRefToPyData(value, output);
|
||||
res = BaseRefToPyData(value, output_abs);
|
||||
}
|
||||
|
||||
// Replace the output if it's not Tensor, but Python data.
|
||||
const auto &py_res = GetPyExecuteOutput(output);
|
||||
if (py_res != py::none()) {
|
||||
return py_res;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Run end";
|
||||
return ret;
|
||||
return res;
|
||||
} // namespace pipeline
|
||||
|
||||
FuncGraphPtr GraphExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase) const {
|
||||
|
|
|
@ -106,8 +106,8 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
AbstractBasePtrList args_abs_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
|
||||
[](const ConfigPtr &config) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(config);
|
||||
const auto &eval_result = config->ObtainEvalResult();
|
||||
|
@ -122,7 +122,7 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
auto do_signature_func = dyn_cast_ptr<Primitive>(func);
|
||||
if (do_signature_func != nullptr) {
|
||||
if (prims_to_skip_undetermined_infer.find(do_signature_func->name()) == prims_to_skip_undetermined_infer.end()) {
|
||||
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
|
||||
auto ret_abstract = EvalUndeterminedArgs(args_abs_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined for " << do_signature_func->name()
|
||||
<< ", ret_abstract: " << ret_abstract->ToString();
|
||||
|
@ -148,9 +148,9 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
ScopeGuard scope_guard(scope);
|
||||
if (bound_node() != nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
|
||||
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
|
||||
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_abs_list, args_inputs);
|
||||
} else {
|
||||
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
|
||||
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_abs_list, args_inputs);
|
||||
}
|
||||
// Update new CNode info.
|
||||
auto new_cnode = dyn_cast<CNode>(new_node);
|
||||
|
@ -162,9 +162,9 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
return engine->ForwardConfig(out_conf, new_conf);
|
||||
}
|
||||
|
||||
static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
|
||||
static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_abs_list, bool need_unpack) {
|
||||
// arg[0] is the func graph to unpack, ignore it
|
||||
AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
|
||||
AbstractBasePtrList specialize_args_before_unpack(args_abs_list.begin() + 1, args_abs_list.end());
|
||||
AbstractBasePtrList graph_specialize_args;
|
||||
if (need_unpack) {
|
||||
for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
|
||||
|
@ -177,7 +177,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
|
|||
auto arg_dict = specialize_args_before_unpack[index]->cast_ptr<AbstractDictionary>();
|
||||
auto dict_elems = arg_dict->elements();
|
||||
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(graph_specialize_args),
|
||||
[](const AbstractAttribute &item) {
|
||||
[](const AbstractElementPair &item) {
|
||||
// Dict_elems's first element represents parameter names, which should be string type.
|
||||
return std::make_shared<AbstractKeywordArg>(
|
||||
GetValue<std::string>(item.first->BuildValue()), item.second);
|
||||
|
@ -212,8 +212,8 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
|
||||
<< ", inputs size " << out_node_inputs.size();
|
||||
}
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
AbstractBasePtrList args_abs_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(ref);
|
||||
const auto &eval_result = ref->ObtainEvalResult();
|
||||
|
@ -221,20 +221,20 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
|
|||
return eval_result->abstract();
|
||||
});
|
||||
// get the forward graph
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_abs_list can't be empty.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
auto fn = args_spec_list[0]->cast_ptr<AbstractFunction>();
|
||||
MS_EXCEPTION_IF_NULL(args_abs_list[0]);
|
||||
auto fn = args_abs_list[0]->cast_ptr<AbstractFunction>();
|
||||
if (fn == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
|
||||
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_abs_list[0]->ToString();
|
||||
}
|
||||
auto real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
|
||||
MS_EXCEPTION_IF_NULL(real_fn);
|
||||
FuncGraphPtr forward_graph = real_fn->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(forward_graph);
|
||||
AbstractBasePtrList graph_specialize_args =
|
||||
GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
|
||||
GetUnpackGraphSpecArgsList(args_abs_list, unpack_graph->need_unpack_args());
|
||||
AbstractBasePtrList graph_specialize_args_without_sens;
|
||||
if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
|
||||
MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
|
||||
|
@ -316,7 +316,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
|
|||
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_EXCEPTION_IF_NULL(engine);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
AbstractBasePtrList args_abs_list;
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
|
||||
|
@ -329,7 +329,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
|
|||
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
|
||||
<< ", inputs size " << out_node_inputs.size();
|
||||
}
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
|
||||
[](const ConfigPtr &ref) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(ref);
|
||||
const auto &eval_result = ref->ObtainEvalResult();
|
||||
|
@ -347,7 +347,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
|
|||
}
|
||||
|
||||
AnfNodePtr new_node =
|
||||
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
|
||||
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_abs_list[1], out_node_inputs[1], func_graph);
|
||||
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
|
||||
|
||||
if (new_node->isa<CNode>()) {
|
||||
|
@ -1434,10 +1434,10 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
|
|||
return eng->ForwardConfig(old_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_spec_list, const ValuePtr &data_value,
|
||||
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_abs_list, const ValuePtr &data_value,
|
||||
const AnfNodeConfigPtr &out_conf, const std::string &data) {
|
||||
constexpr size_t item_index = 1;
|
||||
auto item_args = args_spec_list[item_index];
|
||||
auto item_args = args_abs_list[item_index];
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(data_value);
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
|
@ -1463,7 +1463,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &arg
|
|||
if (IsValueNode<TypeNull>(new_node)) {
|
||||
// Do not find the attribute.
|
||||
constexpr auto max_args_len = 3;
|
||||
bool has_default = (args_spec_list.size() == max_args_len);
|
||||
bool has_default = (args_abs_list.size() == max_args_len);
|
||||
if (!has_default) {
|
||||
MS_EXCEPTION(AttributeError) << data << " object has no attribute " << symbol->symbol();
|
||||
}
|
||||
|
@ -1490,19 +1490,19 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &arg
|
|||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_abs_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
// args_spec_list: same as StaticGetter
|
||||
// args_abs_list: same as StaticGetter
|
||||
constexpr size_t args_min_size = 2;
|
||||
if (args_spec_list.size() < args_min_size) {
|
||||
MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
|
||||
if (args_abs_list.size() < args_min_size) {
|
||||
MS_LOG(EXCEPTION) << "Size of args_abs_list is less than 2";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(out_conf);
|
||||
// An external type.
|
||||
constexpr auto data_index = 0;
|
||||
constexpr auto item_index = 1;
|
||||
auto data = args_spec_list[data_index];
|
||||
auto item = args_spec_list[item_index];
|
||||
auto data = args_abs_list[data_index];
|
||||
auto item = args_abs_list[item_index];
|
||||
MS_EXCEPTION_IF_NULL(data);
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
auto data_value = data->BuildValue();
|
||||
|
@ -1527,13 +1527,13 @@ EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_spec
|
|||
auto data_type = data->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(data_type);
|
||||
const auto &data_id_str = TypeIdToString(data_type->type_id());
|
||||
return GetEvaluatedValueForNameSpaceString(args_spec_list, data_value, out_conf, data_id_str);
|
||||
return GetEvaluatedValueForNameSpaceString(args_abs_list, data_value, out_conf, data_id_str);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList &args_abs_list,
|
||||
const ValuePtr &data_value, const AnfNodeConfigPtr &out_conf) {
|
||||
constexpr size_t item_index = 1;
|
||||
auto item_args = args_spec_list[item_index];
|
||||
auto item_args = args_abs_list[item_index];
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
|
@ -1558,7 +1558,7 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList
|
|||
auto new_node = parse::ResolveMsClassWithAttr(func_graph->manager(), ms_class->obj(), item_name, out_node);
|
||||
if (new_node == nullptr) {
|
||||
constexpr auto max_args_len = 3;
|
||||
bool has_default = (args_spec_list.size() == max_args_len);
|
||||
bool has_default = (args_abs_list.size() == max_args_len);
|
||||
if (!has_default) {
|
||||
MS_EXCEPTION(AttributeError) << py::str(ms_class->obj()) << " object has no attribute: " << item_name << ".";
|
||||
}
|
||||
|
@ -1575,10 +1575,10 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList
|
|||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &args_abs_list,
|
||||
const FuncGraphPtr &func_value, const AnfNodeConfigPtr &out_conf) {
|
||||
constexpr size_t item_index = 1;
|
||||
auto item_args = args_spec_list[item_index];
|
||||
auto item_args = args_abs_list[item_index];
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
|
@ -1601,19 +1601,19 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &ar
|
|||
py::object ns_obj =
|
||||
python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, real_python_obj);
|
||||
auto ns = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
|
||||
return GetEvaluatedValueForNameSpaceString(args_spec_list, ns, out_conf, py_obj_str);
|
||||
return GetEvaluatedValueForNameSpaceString(args_abs_list, ns, out_conf, py_obj_str);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_spec_list,
|
||||
const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &data_conf,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
constexpr size_t data_index = 0;
|
||||
constexpr size_t item_index = 1;
|
||||
auto data_args = args_spec_list[data_index];
|
||||
auto item_args = args_spec_list[item_index];
|
||||
auto data_args = args_abs_list[data_index];
|
||||
auto item_args = args_abs_list[item_index];
|
||||
MS_EXCEPTION_IF_NULL(data_args);
|
||||
MS_EXCEPTION_IF_NULL(item_args);
|
||||
ValuePtr item_value = item_args->BuildValue();
|
||||
|
@ -1631,9 +1631,23 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
|
|||
require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
|
||||
if (require.empty()) {
|
||||
constexpr auto max_args_len = 3;
|
||||
bool has_default = (args_spec_list.size() == max_args_len);
|
||||
bool has_default = (args_abs_list.size() == max_args_len);
|
||||
if (!has_default) {
|
||||
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
|
||||
constexpr auto tensor_asnumpy_attr_name = "asnumpy";
|
||||
if (item_name != tensor_asnumpy_attr_name) {
|
||||
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
|
||||
}
|
||||
auto out_node = out_conf->node();
|
||||
auto out_cnode = out_node->cast_ptr<CNode>();
|
||||
MS_EXCEPTION_IF_NULL(out_cnode);
|
||||
auto fg = out_cnode->func_graph();
|
||||
auto py_interpret_node =
|
||||
fg->NewCNode({NewValueNode(prim::kPrimPyInterpret), NewValueNode(out_node->debug_info()->debug_name())});
|
||||
fg->ReplaceInOrder(out_node, py_interpret_node);
|
||||
auto eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
auto fn_conf = eng->MakeConfig(py_interpret_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
auto out_node = out_conf->node();
|
||||
auto out_cnode = out_node->cast_ptr<CNode>();
|
||||
|
@ -1694,14 +1708,14 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
|
||||
// Inputs: namespace and its static function; or class and its member function
|
||||
|
||||
constexpr size_t data_index = 0;
|
||||
constexpr size_t item_index = 1;
|
||||
auto data_args = args_spec_list[data_index];
|
||||
auto item_args = args_spec_list[item_index];
|
||||
auto data_args = args_abs_list[data_index];
|
||||
auto item_args = args_abs_list[item_index];
|
||||
MS_EXCEPTION_IF_NULL(data_args);
|
||||
MS_EXCEPTION_IF_NULL(item_args);
|
||||
MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString();
|
||||
|
@ -1730,9 +1744,9 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
}
|
||||
|
||||
constexpr auto max_args_size = 3;
|
||||
if (args_spec_list.size() == max_args_size) {
|
||||
if (args_abs_list.size() == max_args_size) {
|
||||
constexpr size_t default_index = 2;
|
||||
auto default_args = args_spec_list[default_index];
|
||||
auto default_args = args_abs_list[default_index];
|
||||
if (default_args->isa<abstract::AbstractScalar>()) {
|
||||
ValuePtr default_value = default_args->BuildValue();
|
||||
if (default_value->isa<parse::InterpretedObject>()) {
|
||||
|
@ -1747,12 +1761,12 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
// Get attribute or method of class object decorated with 'jit_class'.
|
||||
auto class_value = GetMsClassObject(data_args);
|
||||
if (class_value != nullptr) {
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
|
||||
return GetEvaluatedValueForMsClassAttrOrMethod(args_abs_list, class_value, out_conf);
|
||||
}
|
||||
// Get attribute or method of nn.Cell object.
|
||||
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
|
||||
if (data_func_graph != nullptr) {
|
||||
auto res = GetEvaluatedValueForCellAttrOrMethod(args_spec_list, data_func_graph->func_graph(), out_conf);
|
||||
auto res = GetEvaluatedValueForCellAttrOrMethod(args_abs_list, data_func_graph->func_graph(), out_conf);
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
|
@ -1760,58 +1774,112 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
// Try to search method map, if not found, the data_type should be External type.
|
||||
TypePtr data_type = data_args->BuildType();
|
||||
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
|
||||
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_spec_list, data_conf, out_conf);
|
||||
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_abs_list, data_conf, out_conf);
|
||||
}
|
||||
return GetEvaluatedValueForNameSpace(args_spec_list, out_conf);
|
||||
return GetEvaluatedValueForNameSpace(args_abs_list, out_conf);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
|
||||
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
|
||||
if (out_conf != nullptr) { // 'out_conf' maybe nullptr in PyNative mode.
|
||||
if (args_spec_list.empty()) {
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(INFO) << "For MakeTuple, the inputs should not be empty. node: " << out_conf->node()->DebugString();
|
||||
}
|
||||
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
|
||||
if (enable_eliminate_unused_element) {
|
||||
auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
|
||||
if (flags == nullptr) {
|
||||
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
|
||||
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
|
||||
}
|
||||
|
||||
(void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
|
||||
}
|
||||
}
|
||||
auto abs = std::make_shared<AbstractTuple>(args_spec_list, sequence_nodes);
|
||||
auto abs = std::make_shared<AbstractTuple>(args_abs_list, sequence_nodes);
|
||||
auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr MakeListEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr MakeListEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
|
||||
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
|
||||
if (out_conf != nullptr) { // 'out_conf' maybe nullptr in PyNative mode.
|
||||
if (args_spec_list.empty()) {
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(INFO) << "For MakeList, the inputs should not be empty. node: " << out_conf->node()->DebugString();
|
||||
}
|
||||
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
|
||||
if (enable_eliminate_unused_element) {
|
||||
auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
|
||||
if (flags == nullptr) {
|
||||
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
|
||||
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
|
||||
}
|
||||
|
||||
(void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
|
||||
}
|
||||
}
|
||||
auto abs = std::make_shared<AbstractList>(args_spec_list, sequence_nodes);
|
||||
auto abs = std::make_shared<AbstractList>(args_abs_list, sequence_nodes);
|
||||
auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, res);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, res);
|
||||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "'args_abs_list' should not be empty";
|
||||
}
|
||||
|
||||
// Handle for DDE.
|
||||
for (size_t i = 0; i < args_abs_list.size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL(args_abs_list[i]);
|
||||
if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
|
||||
MS_LOG(DEBUG) << "Primitive \'PyExecute\' is consuming tuple/list arguments[" << i
|
||||
<< "]: " << args_abs_list[i]->ToString();
|
||||
SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
|
||||
}
|
||||
}
|
||||
|
||||
auto current_interpret_node = out_conf->node();
|
||||
MS_EXCEPTION_IF_NULL(current_interpret_node);
|
||||
MS_LOG(DEBUG) << "The current interpret node: " << current_interpret_node->DebugString();
|
||||
// Get the type parameter.
|
||||
MS_EXCEPTION_IF_NULL(args_abs_list[0]);
|
||||
ValuePtr value_track = args_abs_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
|
||||
auto script_obj = dyn_cast_ptr<StringImm>(value_track);
|
||||
if (script_obj == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
|
||||
}
|
||||
|
||||
// Make global and local parameters.
|
||||
const std::string &script = script_obj->value();
|
||||
// Call python script string.
|
||||
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
|
||||
|
||||
TypePtr type = kFloat32;
|
||||
if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
|
||||
type = current_interpret_node->user_data<Type>("__py_execute_tensor_type__");
|
||||
MS_LOG(DEBUG) << "type: " << type->ToString();
|
||||
}
|
||||
BaseShapePtr shape;
|
||||
if (current_interpret_node->has_user_data("__py_execute_tensor_shape__")) {
|
||||
shape = current_interpret_node->user_data<BaseShape>("__py_execute_tensor_shape__");
|
||||
MS_LOG(DEBUG) << "shape: " << shape->ToString();
|
||||
} else {
|
||||
ShapeVector shp;
|
||||
(void)shp.emplace_back(Shape::kShapeRankAny);
|
||||
shape = std::make_shared<Shape>(shp);
|
||||
}
|
||||
AbstractBasePtr res = std::make_shared<AbstractTensor>(type, shape);
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class EmbedEvaluator : public SymbolicPrimEvaluator {
|
||||
public:
|
||||
|
@ -1917,23 +1985,23 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
|||
GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
|
||||
~GetAttrEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
constexpr auto args_min_size = 2;
|
||||
constexpr auto args_max_size = 3;
|
||||
constexpr auto attr_index = 1;
|
||||
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
|
||||
auto ret_abstract = EvalUndeterminedArgs(args_abs_list);
|
||||
if (ret_abstract != nullptr) {
|
||||
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
|
||||
return ret_abstract;
|
||||
}
|
||||
// Inputs: data, item
|
||||
const auto args_size = args_spec_list.size();
|
||||
const auto args_size = args_abs_list.size();
|
||||
if (args_size != args_min_size && args_size != args_max_size) {
|
||||
MS_LOG(EXCEPTION) << "For Primitive GetAttr, the input size should be " << args_min_size << " or "
|
||||
<< args_max_size << ", but got size:" << args_size;
|
||||
}
|
||||
auto attr_abs = args_spec_list[attr_index];
|
||||
auto attr_abs = args_abs_list[attr_index];
|
||||
auto attr_abs_type = attr_abs->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(attr_abs_type);
|
||||
auto type_id = attr_abs_type->type_id();
|
||||
|
@ -1943,13 +2011,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
|
|||
EvalResultPtr ret = nullptr;
|
||||
if (bound_node() != nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
|
||||
} else {
|
||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
|
||||
}
|
||||
// don't lookup from cache, as different out_conf with same node but different context
|
||||
// may add different entry to anfnode_config_map, like getattr primitive;
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, ret);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, ret);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
@ -1959,19 +2027,19 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
|
|||
ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
|
||||
~ResolveEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
|
||||
constexpr auto resolve_args_size = 2;
|
||||
// Inputs: namespace, symbol
|
||||
if (args_spec_list.size() != resolve_args_size) {
|
||||
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
|
||||
if (args_abs_list.size() != resolve_args_size) {
|
||||
MS_LOG(EXCEPTION) << "Expected args_abs_list size = 2, but has size:" << args_abs_list.size();
|
||||
}
|
||||
EvalResultPtr ret = nullptr;
|
||||
if (bound_node() != nullptr) {
|
||||
TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
|
||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
|
||||
} else {
|
||||
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
|
||||
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -1997,14 +2065,14 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
|
||||
~CreateInstanceEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
// Check the type parameter.
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "'args_abs_list' should not be empty";
|
||||
}
|
||||
constexpr size_t type_index = 0;
|
||||
auto arg_class_type = args_spec_list[type_index];
|
||||
auto arg_class_type = args_abs_list[type_index];
|
||||
MS_EXCEPTION_IF_NULL(arg_class_type);
|
||||
TypePtr type = arg_class_type->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
@ -2029,7 +2097,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
|
||||
|
||||
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
|
||||
py::tuple params = GetParameters(args_spec_list);
|
||||
py::tuple params = GetParameters(args_abs_list);
|
||||
// Create class instance.
|
||||
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
|
@ -2069,24 +2137,24 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
AbstractBasePtr ret = ToAbstract(converted_res, AnalysisContext::DummyContext(), out_conf);
|
||||
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
|
||||
if (args_spec_list.empty()) {
|
||||
py::tuple GetParameters(const AbstractBasePtrList &args_abs_list) const {
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected arguments num, the min arguments num must be 1, but got 0.";
|
||||
}
|
||||
// Exclude class type by minus 1;
|
||||
std::size_t params_size = args_spec_list.size() - 1;
|
||||
std::size_t params_size = args_abs_list.size() - 1;
|
||||
auto params = py::tuple(params_size);
|
||||
for (size_t i = 0; i < params_size; i++) {
|
||||
// Only support the Scalar parameters type. Bypass class type by offset with 1.
|
||||
auto arg = args_spec_list[i + 1];
|
||||
auto arg = args_abs_list[i + 1];
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
if (IsContainUndetermined(arg)) {
|
||||
MS_EXCEPTION(TypeError) << "The " << i << "th initializing input to create instance for "
|
||||
<< args_spec_list[0]->BuildValue()->ToString()
|
||||
<< args_abs_list[0]->BuildValue()->ToString()
|
||||
<< " should be a constant, but got: " << arg->ToString();
|
||||
}
|
||||
// Because the Tensor's AbstractTensor can't get value from GetValueTrack.
|
||||
|
@ -2103,13 +2171,13 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator {
|
|||
CallInstanceEvaluator() : TransitionPrimEvaluator("CallInstanceEvaluator") {}
|
||||
~CallInstanceEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(CallInstanceEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_spec_list should not be empty.";
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "args_abs_list should not be empty.";
|
||||
}
|
||||
constexpr size_t cls_index = 0;
|
||||
auto arg_cls = args_spec_list[cls_index];
|
||||
auto arg_cls = args_abs_list[cls_index];
|
||||
MS_EXCEPTION_IF_NULL(arg_cls);
|
||||
TypePtr type = arg_cls->GetTypeTrack();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
|
@ -2160,18 +2228,18 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
|
||||
~PyInterpretEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
if (args_spec_list.empty()) {
|
||||
MS_LOG(ERROR) << "'args_spec_list' should not be empty";
|
||||
if (args_abs_list.empty()) {
|
||||
MS_LOG(EXCEPTION) << "'args_abs_list' should not be empty";
|
||||
}
|
||||
|
||||
auto current_interpret_node = out_conf->node();
|
||||
MS_EXCEPTION_IF_NULL(current_interpret_node);
|
||||
MS_LOG(DEBUG) << "The current interpret node: " << current_interpret_node->DebugString();
|
||||
// Get the type parameter.
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(args_abs_list[0]);
|
||||
ValuePtr value_track = args_abs_list[0]->GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
|
||||
auto script_obj = dyn_cast_ptr<parse::Script>(value_track);
|
||||
|
@ -2180,8 +2248,17 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
|
||||
// Make global and local parameters.
|
||||
non_const_err_ = false;
|
||||
const std::string &script = script_obj->script();
|
||||
py::tuple params = MakeParameters(args_spec_list, script);
|
||||
py::tuple params = MakeParameters(args_abs_list, script);
|
||||
if (non_const_err_) { // Would convert PyInterpret to PyExecute then.
|
||||
ShapeVector shp;
|
||||
(void)shp.emplace_back(Shape::kShapeDimAny);
|
||||
AbstractBasePtr res = std::make_shared<AbstractTensor>(kInt32, std::make_shared<Shape>(shp));
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
// Call python script string.
|
||||
MS_LOG(DEBUG) << "Call script: " << script << ", params: " << py::str(params);
|
||||
|
@ -2189,7 +2266,7 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
if (py::isinstance<py::none>(obj)) {
|
||||
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
|
||||
auto infer_result = std::make_shared<EvalResult>(res, nullptr);
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
|
@ -2206,7 +2283,7 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
|
||||
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
|
||||
return infer_result;
|
||||
}
|
||||
|
||||
|
@ -2223,9 +2300,17 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
const auto &local_abs_val = local_abs->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(local_abs_val);
|
||||
if (local_abs_val == kAnyValue) {
|
||||
MS_EXCEPTION(ValueError) << "When using JIT Fallback to handle script '" << script
|
||||
<< "', the inputs should be constant, but found variable '" << name
|
||||
<< "' to be nonconstant.";
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
|
||||
if (support_fallback_runtime) {
|
||||
MS_LOG(INFO) << "When using JIT Fallback to handle script '" << script
|
||||
<< "', the inputs should be constant, but found variable '" << name
|
||||
<< "' to be nonconstant. To convert to PyExecute() afterwards";
|
||||
non_const_err_ = true;
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "When using JIT Fallback to handle script '" << script
|
||||
<< "', the inputs should be constant, but found variable '" << name
|
||||
<< "' to be nonconstant.";
|
||||
}
|
||||
}
|
||||
if (local_abs->isa<abstract::AbstractTensor>()) {
|
||||
MS_LOG(WARNING) << "When using JIT Fallback to handle script '" << script << "', found variable '" << name
|
||||
|
@ -2256,18 +2341,20 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
return;
|
||||
}
|
||||
|
||||
py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list, const std::string &script) const {
|
||||
py::tuple MakeParameters(const AbstractBasePtrList &args_abs_list, const std::string &script) const {
|
||||
constexpr int params_size = 3;
|
||||
if (params_size != args_spec_list.size()) {
|
||||
if (params_size != args_abs_list.size()) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size
|
||||
<< ", not equal to arguments.size:" << args_spec_list.size();
|
||||
<< ", not equal to arguments.size:" << args_abs_list.size();
|
||||
}
|
||||
// The first argument is script string, ignore it.
|
||||
auto params = py::tuple(params_size - 1);
|
||||
|
||||
// Make the global parameters.
|
||||
auto global_dict = dyn_cast<AbstractDictionary>(args_spec_list[1]); // Global parameters dict.
|
||||
MS_EXCEPTION_IF_NULL(global_dict);
|
||||
auto global_dict = dyn_cast<AbstractDictionary>(args_abs_list[1]); // Global parameters dict.
|
||||
if (global_dict == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The second argument should be a dictionary, but got " << args_abs_list[1]->ToString();
|
||||
}
|
||||
auto filtered_global_dict = FilterParameters(global_dict);
|
||||
MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString()
|
||||
<< ", filtered_global_dict: " << filtered_global_dict->ToString();
|
||||
|
@ -2282,8 +2369,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
|
||||
// Make the local parameters.
|
||||
constexpr size_t local_index = 2;
|
||||
auto local_dict = dyn_cast<AbstractDictionary>(args_spec_list[local_index]); // Local parameters dict.
|
||||
MS_EXCEPTION_IF_NULL(local_dict);
|
||||
auto local_dict = dyn_cast<AbstractDictionary>(args_abs_list[local_index]); // Local parameters dict.
|
||||
if (local_dict == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "The third argument should be a dictionary, but got "
|
||||
<< args_abs_list[local_index]->ToString();
|
||||
}
|
||||
auto filtered_local_dict = FilterParameters(local_dict);
|
||||
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
|
||||
<< ", filtered_local_dict: " << filtered_local_dict->ToString();
|
||||
|
@ -2312,11 +2402,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
|
||||
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
|
||||
MS_EXCEPTION_IF_NULL(abstract_dict);
|
||||
std::vector<AbstractAttribute> kv;
|
||||
std::vector<AbstractElementPair> kv;
|
||||
const auto &keys_values = abstract_dict->elements();
|
||||
// Filter out the element of Function type.
|
||||
(void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv),
|
||||
[](const AbstractAttribute &item) {
|
||||
[](const AbstractElementPair &item) {
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
return (!item.second->isa<abstract::AbstractFunction>());
|
||||
});
|
||||
|
@ -2327,6 +2417,9 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
|
|||
constexpr char const_arg_attr[] = "const_arg";
|
||||
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
|
||||
}
|
||||
|
||||
private:
|
||||
mutable bool non_const_err_{false};
|
||||
};
|
||||
|
||||
class PartialEvaluator : public Evaluator {
|
||||
|
@ -2346,7 +2439,7 @@ class PartialEvaluator : public Evaluator {
|
|||
MS_EXCEPTION_IF_NULL(arg0_eval_result);
|
||||
auto arg0_value = arg0_eval_result->abstract();
|
||||
MS_EXCEPTION_IF_NULL(arg0_value);
|
||||
AbstractBasePtrList args_spec_list{arg0_value};
|
||||
AbstractBasePtrList args_abs_list{arg0_value};
|
||||
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
|
||||
if (arg0_value->isa<AbstractError>()) {
|
||||
MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
|
||||
|
@ -2354,10 +2447,10 @@ class PartialEvaluator : public Evaluator {
|
|||
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
|
||||
<< " as func is: " << arg0_value->ToString();
|
||||
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
|
||||
return eval_result;
|
||||
}
|
||||
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
|
||||
auto func = CheckArg<AbstractFunction>("partial", args_abs_list, 0);
|
||||
// Sometimes, node[0] in out_conf becomes phi0;
|
||||
if (func->isa<PrimitiveAbstractClosure>()) {
|
||||
auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func);
|
||||
|
@ -2369,14 +2462,14 @@ class PartialEvaluator : public Evaluator {
|
|||
}
|
||||
}
|
||||
|
||||
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_abs_list),
|
||||
[](const ConfigPtr &config) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(config);
|
||||
const auto &eval_result = config->ObtainEvalResult();
|
||||
MS_EXCEPTION_IF_NULL(eval_result);
|
||||
return eval_result->abstract();
|
||||
});
|
||||
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
|
||||
AbstractBasePtrList args(args_abs_list.begin() + 1, args_abs_list.end());
|
||||
|
||||
auto cnode = out_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -2393,7 +2486,7 @@ class PartialEvaluator : public Evaluator {
|
|||
|
||||
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
|
||||
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
|
||||
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
|
||||
evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
|
||||
return eval_result;
|
||||
}
|
||||
|
||||
|
@ -2429,7 +2522,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
RaiseEvaluator() : TransitionPrimEvaluator("RaiseEvaluator") {}
|
||||
~RaiseEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(RaiseEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
auto node = out_conf->node();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -2441,12 +2534,12 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
<< "Please check your conditions which raise node is located at: "
|
||||
<< trace::GetDebugInfo(node->debug_info());
|
||||
}
|
||||
if (args_spec_list.empty()) {
|
||||
if (args_abs_list.empty()) {
|
||||
// process raise
|
||||
MS_LOG(EXCEPTION) << "No active exception to reraise.";
|
||||
}
|
||||
|
||||
std::string exception_type = GetExceptionType(args_spec_list[0]);
|
||||
std::string exception_type = GetExceptionType(args_abs_list[0]);
|
||||
auto iter = exception_types_map.find(exception_type);
|
||||
if (iter == exception_types_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type
|
||||
|
@ -2454,7 +2547,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
<< SupportedExceptionsToString();
|
||||
}
|
||||
ExceptionType type = iter->second;
|
||||
if (args_spec_list.size() == 1) {
|
||||
if (args_abs_list.size() == 1) {
|
||||
// Process raise ValueError()
|
||||
MS_EXCEPTION(type);
|
||||
}
|
||||
|
@ -2470,7 +2563,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
|
|||
}
|
||||
for (size_t index = index_begin; index < inputs.size(); ++index) {
|
||||
const auto input = inputs[index];
|
||||
auto input_abs = args_spec_list[index - 1];
|
||||
auto input_abs = args_abs_list[index - 1];
|
||||
MS_EXCEPTION_IF_NULL(input_abs);
|
||||
bool need_symbol = CheckNeedSymbol(input, input_abs);
|
||||
if (need_symbol) {
|
||||
|
@ -2636,19 +2729,19 @@ class WithEnterEvaluator : public TransitionPrimEvaluator {
|
|||
WithEnterEvaluator() : TransitionPrimEvaluator("WithEnterEvaluator") {}
|
||||
~WithEnterEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(WithEnterEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
auto node = out_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cur_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cur_graph);
|
||||
|
||||
if (args_spec_list.size() != 1) {
|
||||
if (args_abs_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "The enter node has wrong input." << node->debug_info();
|
||||
}
|
||||
|
||||
// Check class object
|
||||
auto partial_abs = args_spec_list[0]->cast<PartialAbstractClosurePtr>();
|
||||
auto partial_abs = args_abs_list[0]->cast<PartialAbstractClosurePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_abs);
|
||||
if (!IsCallInstance(partial_abs)) {
|
||||
MS_LOG(EXCEPTION) << "The enter node has wrong input." << node->debug_info();
|
||||
|
@ -2693,19 +2786,19 @@ class WithExitEvaluator : public TransitionPrimEvaluator {
|
|||
WithExitEvaluator() : TransitionPrimEvaluator("WithExitEvaluator") {}
|
||||
~WithExitEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(WithExitEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
auto node = out_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cur_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cur_graph);
|
||||
|
||||
if (args_spec_list.size() != 1) {
|
||||
if (args_abs_list.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "The exit node has wrong input." << node->debug_info();
|
||||
}
|
||||
|
||||
// Check class object
|
||||
auto partial_abs = args_spec_list[0]->cast<PartialAbstractClosurePtr>();
|
||||
auto partial_abs = args_abs_list[0]->cast<PartialAbstractClosurePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_abs);
|
||||
if (!IsCallInstance(partial_abs)) {
|
||||
MS_LOG(EXCEPTION) << "The exit node has wrong input." << node->debug_info();
|
||||
|
@ -2755,13 +2848,13 @@ class JoinedStrEvaluator : public TransitionPrimEvaluator {
|
|||
JoinedStrEvaluator() : TransitionPrimEvaluator("JoinedStrEvaluator") {}
|
||||
~JoinedStrEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(JoinedStrEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override {
|
||||
auto node = out_conf->node()->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cur_graph = node->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(cur_graph);
|
||||
bool exist_tensor = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &arg) {
|
||||
bool exist_tensor = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](const AbstractBasePtr &arg) {
|
||||
auto arg_value = arg->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(arg_value);
|
||||
return arg_value->isa<AnyValue>();
|
||||
|
@ -2777,7 +2870,7 @@ class JoinedStrEvaluator : public TransitionPrimEvaluator {
|
|||
new_node = cur_graph->NewCNode(new_inputs);
|
||||
} else {
|
||||
std::string ret;
|
||||
for (const auto &arg : args_spec_list) {
|
||||
for (const auto &arg : args_abs_list) {
|
||||
auto arg_value = arg->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(arg_value);
|
||||
ret += arg_value->ToString();
|
||||
|
|
|
@ -178,6 +178,15 @@ class MakeListEvaluator : public TransitionPrimEvaluator {
|
|||
const AnfNodeConfigPtr &out_conf) override;
|
||||
};
|
||||
|
||||
class PyExecuteEvaluator : public TransitionPrimEvaluator {
|
||||
public:
|
||||
PyExecuteEvaluator() : TransitionPrimEvaluator("PyExecuteEvaluator") {}
|
||||
~PyExecuteEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(PyExecuteEvaluator, TransitionPrimEvaluator);
|
||||
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
};
|
||||
|
||||
bool IsInWhiteList(const PrimitivePtr &primitive);
|
||||
|
||||
PrimEvaluatorMap &GetPrimEvaluatorConstructors();
|
||||
|
|
|
@ -534,6 +534,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
|
|||
if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
|
||||
return std::make_shared<MixedPrecisionCastEvaluator>(prim);
|
||||
}
|
||||
if (prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name()) {
|
||||
prim::kPrimPyExecute->AddAttr("primitive_target", MakeValue("CPU"));
|
||||
return std::make_shared<PyExecuteEvaluator>();
|
||||
}
|
||||
|
||||
// Find prim infer function in the prim function map return a standard evaluator
|
||||
auto eval_impl = GetPrimitiveInferImpl(prim);
|
||||
|
|
|
@ -58,6 +58,9 @@ void ValidateOperation(const AnfNodePtr &node) {
|
|||
if (prim->HasAttr("is_load")) {
|
||||
return;
|
||||
}
|
||||
if (prim->name() == "PyExecute") {
|
||||
return;
|
||||
}
|
||||
if (prim->name() == "TensorMove") {
|
||||
return;
|
||||
}
|
||||
|
@ -127,7 +130,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
|
||||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
|
||||
abstract->isa<AbstractRefTensor>() || abstract->isa<AbstractMapTensor>() ||
|
||||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
|
||||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>() ||
|
||||
abstract->isa<abstract::AbstractScript>();
|
||||
if (is_legal_abstract) {
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ if(ENABLE_CPU)
|
|||
"eigen/*.cc"
|
||||
"mkldnn/*.cc"
|
||||
"ps/*.cc"
|
||||
"pyexecute/*.cc"
|
||||
"pyfunc/*.cc"
|
||||
"rl/*.cc"
|
||||
"custom/*.cc"
|
||||
|
|
|
@ -0,0 +1,379 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "Eigen/Core"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_common.h"
|
||||
#include "include/common/utils/python_adapter.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
py::object CallPythonGetGlobalParams() {
|
||||
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
|
||||
py::module mod = python_adapter::GetPyModule(python_mod_parse);
|
||||
constexpr auto python_get_dict = "get_global_params";
|
||||
return python_adapter::CallPyModFn(mod, python_get_dict);
|
||||
}
|
||||
|
||||
// Call the python script string. The same codes as parse/data_converter.h, we must copy it here.
|
||||
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
|
||||
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
|
||||
py::module mod = python_adapter::GetPyModule(python_mod_parse);
|
||||
constexpr auto python_mode_eval = "eval_script";
|
||||
// The `args_kwargs` is a tuple(dict(global), dict(local)).
|
||||
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, python_mode_eval, script)
|
||||
: python_adapter::CallPyModFn(mod, python_mode_eval, script, args_kwargs);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void PyExecuteCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_LOG(DEBUG) << "kernel_node: " << kernel_node << ", " << kernel_node->DebugString();
|
||||
inputs_info_.clear();
|
||||
kernel_node_ = kernel_node;
|
||||
for (size_t i = 1; i < kernel_node->size(); ++i) {
|
||||
const auto &input = kernel_node->inputs()[i];
|
||||
|
||||
// Check if PyExecuteOutputData exists.
|
||||
py::object obj = py::none();
|
||||
if (input->has_user_data<PyExecuteOutputData>()) {
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
const auto &output_data = input->user_data<PyExecuteOutputData>();
|
||||
obj = output_data->obj;
|
||||
MS_LOG(DEBUG) << "Has \'PyExecuteOutputData\', obj: " << obj;
|
||||
}
|
||||
|
||||
// Record the inputs' information by their abstract types.
|
||||
const auto &input_abstract = input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(input_abstract);
|
||||
if (input_abstract->isa<abstract::AbstractRefTensor>()) {
|
||||
const auto ¶m = dyn_cast<Parameter>(input);
|
||||
MS_EXCEPTION_IF_NULL(param);
|
||||
MS_LOG(DEBUG) << "AbstractRefTensor, input[" << i << "]: " << param->default_param()->ToString();
|
||||
(void)inputs_info_.emplace_back(PyExecuteInputInfo({obj, input_abstract, kTypeUnknown, {}}));
|
||||
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
|
||||
const auto &tensor_abstract = dyn_cast<abstract::AbstractTensor>(input_abstract);
|
||||
MS_EXCEPTION_IF_NULL(tensor_abstract);
|
||||
MS_LOG(DEBUG) << "AbstractTensor, input[" << i << "]: " << tensor_abstract->BuildType()->ToString() << ", "
|
||||
<< tensor_abstract->BuildShape()->ToString();
|
||||
const auto &in_type = AnfAlgo::GetInputDeviceDataType(kernel_node, i - 1);
|
||||
const auto &in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i - 1);
|
||||
(void)inputs_info_.emplace_back(PyExecuteInputInfo({obj, input_abstract, in_type, in_shape}));
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Other, input[" << i << "]: " << input->DebugString() << ", " << input_abstract->ToString();
|
||||
(void)inputs_info_.emplace_back(PyExecuteInputInfo({obj, input_abstract, kTypeUnknown, {}}));
|
||||
}
|
||||
MS_LOG(DEBUG) << "Kernel node's input[" << i << "]: " << input->DebugString() << ", " << input_abstract->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
void ArrayToRawMemory(const py::array &array, const AddressPtr &address) {
|
||||
MS_EXCEPTION_IF_NULL(address);
|
||||
if (static_cast<unsigned int>(array.flags()) &
|
||||
static_cast<unsigned int>(pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_)) {
|
||||
const py::buffer_info &buf_info = array.request();
|
||||
const auto &res =
|
||||
memcpy_s(address->addr, address->size, buf_info.ptr, LongToSize(buf_info.size * buf_info.itemsize));
|
||||
if (res != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed. res: " << res << ", address->size: " << address->size
|
||||
<< ", size: " << LongToSize(buf_info.size * buf_info.itemsize);
|
||||
}
|
||||
} else {
|
||||
// Transform numpy array to contiguous data.
|
||||
Py_buffer pybuf;
|
||||
if (PyObject_GetBuffer(array.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS) != 0) {
|
||||
MS_LOG(EXCEPTION) << "Failed to get buffer from the input!";
|
||||
}
|
||||
auto buffer = std::make_unique<char[]>(LongToSize(pybuf.len));
|
||||
if (PyBuffer_ToContiguous(buffer.get(), &pybuf, pybuf.len, 'C')) {
|
||||
PyBuffer_Release(&pybuf);
|
||||
MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer.";
|
||||
}
|
||||
PyBuffer_Release(&pybuf);
|
||||
const auto &res = memcpy_s(address->addr, address->size, buffer.get(), LongToSize(pybuf.len));
|
||||
if (res != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed. res: " << res;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PyExecuteCpuKernelMod::AttachPyOutputData(const py::object &py_res) {
|
||||
const auto &py_output = std::make_shared<PyExecuteOutputData>();
|
||||
py_output->obj = py_res;
|
||||
// Set Python data for kernel node.
|
||||
kernel_node_->set_user_data<PyExecuteOutputData>(py_output);
|
||||
|
||||
// Set Python data for front node.
|
||||
const auto &kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(kernel_node_->func_graph());
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
const auto &graph_output_map = kernel_graph->graph_output_map();
|
||||
session::AnfWithOutIndex anf_index = std::make_pair(kernel_node_, 0);
|
||||
const auto &iter = graph_output_map.find(anf_index);
|
||||
if (iter != graph_output_map.cend()) {
|
||||
const auto &front_node = iter->second.first;
|
||||
MS_LOG(INFO) << "Found front output for " << kernel_node_ << ", " << kernel_node_->DebugString();
|
||||
front_node->set_user_data<PyExecuteOutputData>(py_output);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Not found, kernel node is not output, " << kernel_node_ << ", " << kernel_node_->DebugString();
|
||||
if (!IS_OUTPUT_ON(mindspore::kDebug)) {
|
||||
return;
|
||||
}
|
||||
for (const auto &output_pair : graph_output_map) {
|
||||
MS_EXCEPTION_IF_NULL(output_pair.first.first);
|
||||
MS_EXCEPTION_IF_NULL(output_pair.second.first);
|
||||
MS_LOG(DEBUG) << "backend node: " << output_pair.first.first << ", " << output_pair.first.first->DebugString()
|
||||
<< ", front node: " << output_pair.second.first << ", " << output_pair.second.first->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notice: Remove here after BE kernel supports tuple input.
|
||||
py::object PyExecuteCpuKernelMod::BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs) {
|
||||
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
|
||||
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
|
||||
constexpr auto number_two = 2;
|
||||
std::string tuple_key_str;
|
||||
py::tuple local_tuple_inputs(inputs_info_.size() - number_two); // Exclude the script and key.
|
||||
MS_LOG(DEBUG) << "Local parameter tuple size: " << (inputs_info_.size() - number_two);
|
||||
bool tuple_input_start = false;
|
||||
size_t tuple_index = 0;
|
||||
py::dict local_tuple_dict;
|
||||
for (size_t i = 1; i < inputs.size() && i < inputs_info_.size(); ++i) {
|
||||
const auto &input = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
const auto &input_info = inputs_info_[i];
|
||||
const auto &input_abstract = input_info.abstract;
|
||||
MS_EXCEPTION_IF_NULL(input_abstract);
|
||||
const auto &input_type = input_abstract->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
if (!tuple_input_start && input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
|
||||
const auto &value = input_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &str_value = dyn_cast<StringImm>(value);
|
||||
MS_EXCEPTION_IF_NULL(str_value);
|
||||
const auto &str = str_value->value();
|
||||
if (str != internal_tuple_keys_str && str != internal_tuple_values_str) {
|
||||
return py::none();
|
||||
}
|
||||
tuple_key_str = str;
|
||||
tuple_input_start = true;
|
||||
MS_LOG(DEBUG) << "String, key input[" << i << "]: " << input_abstract->ToString();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Rebuild the tuple with all left inputs.
|
||||
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
|
||||
const auto &value = input_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &str_value = dyn_cast<StringImm>(value);
|
||||
MS_EXCEPTION_IF_NULL(str_value);
|
||||
const auto &str = str_value->value();
|
||||
local_tuple_inputs[tuple_index++] = py::str(str);
|
||||
MS_LOG(DEBUG) << "String, value input[" << i << "]: " << input_abstract->ToString();
|
||||
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
|
||||
const auto &py_array_value = input_info.py_obj_output;
|
||||
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
|
||||
MS_LOG(DEBUG) << "Tensor, value input[" << i << "]: " << input_abstract->ToString()
|
||||
<< ", type: " << input_info.type << ", shape: " << input_info.shape << ", addr: " << inputs[i]->addr
|
||||
<< ", size: " << inputs[i]->size << ", py_array_value: " << py_array_value
|
||||
<< ", is_py_middle_data: " << is_py_middle_data;
|
||||
if (!is_py_middle_data) {
|
||||
const auto tensor =
|
||||
std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
local_tuple_inputs[tuple_index++] = tensor;
|
||||
} else {
|
||||
local_tuple_inputs[tuple_index++] = py_array_value;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupported value type.";
|
||||
}
|
||||
}
|
||||
local_tuple_dict[py::str(tuple_key_str)] = local_tuple_inputs;
|
||||
return local_tuple_dict;
|
||||
}
|
||||
|
||||
py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<AddressPtr> &inputs) {
|
||||
const auto &local_tuple_params = BuildLocalTupleParameters(inputs);
|
||||
if (local_tuple_params != py::none()) {
|
||||
return local_tuple_params;
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Build normal local parameters.";
|
||||
// Build local parameters dict.
|
||||
std::vector<std::string> keys;
|
||||
std::vector<tensor::TensorPtr> tensor_values;
|
||||
std::vector<py::object> py_object_values;
|
||||
std::vector<bool> py_array_flags;
|
||||
constexpr auto number_two = 2;
|
||||
size_t pair_size = (inputs_info_.size() - 1) / number_two;
|
||||
|
||||
// Handle the keys.
|
||||
size_t i = 1;
|
||||
for (; i < inputs.size() && i < pair_size + 1; ++i) {
|
||||
const auto &input = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
const auto &input_info = inputs_info_[i];
|
||||
const auto &input_abstract = input_info.abstract;
|
||||
MS_EXCEPTION_IF_NULL(input_abstract);
|
||||
const auto &input_type = input_abstract->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
|
||||
const auto &value = input_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &str_value = dyn_cast<StringImm>(value);
|
||||
MS_EXCEPTION_IF_NULL(str_value);
|
||||
const auto &str = str_value->value();
|
||||
(void)keys.emplace_back(str);
|
||||
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Other, input[" << i << "]: " << input_abstract->ToString();
|
||||
}
|
||||
}
|
||||
// Handle the values.
|
||||
for (; i < inputs.size() && i < inputs_info_.size(); ++i) {
|
||||
const auto &input = inputs[i];
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
const auto &input_info = inputs_info_[i];
|
||||
const auto &input_abstract = input_info.abstract;
|
||||
MS_EXCEPTION_IF_NULL(input_abstract);
|
||||
const auto &input_type = input_abstract->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
|
||||
const auto &value = input_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &str_value = dyn_cast<StringImm>(value);
|
||||
MS_EXCEPTION_IF_NULL(str_value);
|
||||
const auto &str = str_value->value();
|
||||
(void)py_object_values.emplace_back(py::str(str));
|
||||
(void)tensor_values.emplace_back(nullptr);
|
||||
(void)py_array_flags.emplace_back(true);
|
||||
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
|
||||
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
|
||||
const auto &py_array_value = input_info.py_obj_output;
|
||||
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
|
||||
MS_LOG(DEBUG) << "Tensor, input[" << i << "]: " << input_abstract->ToString() << ", type: " << input_info.type
|
||||
<< ", shape: " << input_info.shape << ", addr: " << inputs[i]->addr << ", size: " << inputs[i]->size
|
||||
<< ", py_array_value: " << py_array_value << ", is_py_middle_data: " << is_py_middle_data;
|
||||
tensor::TensorPtr tensor = nullptr;
|
||||
if (!is_py_middle_data) {
|
||||
tensor = std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
}
|
||||
(void)py_object_values.emplace_back(py_array_value);
|
||||
(void)tensor_values.emplace_back(tensor);
|
||||
(void)py_array_flags.emplace_back(is_py_middle_data);
|
||||
} else if (input_abstract->isa<abstract::AbstractRefTensor>()) {
|
||||
MS_LOG(DEBUG) << "Parameter, input[" << i << "]: " << input_abstract->ToString();
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "Other, input[" << i << "]: " << input_abstract->ToString();
|
||||
}
|
||||
}
|
||||
|
||||
if (keys.size() != tensor_values.size() || keys.size() != pair_size) {
|
||||
MS_LOG(EXCEPTION) << "The local dict input is invalid, " << keys.size() << ", " << tensor_values.size() << ", "
|
||||
<< inputs_info_.size();
|
||||
}
|
||||
|
||||
// To call the script with global and local parameters.
|
||||
py::dict local_dict;
|
||||
for (i = 0; i < keys.size(); ++i) {
|
||||
if (py_array_flags[i]) {
|
||||
local_dict[py::str(keys[i])] = py_object_values[i];
|
||||
} else {
|
||||
local_dict[py::str(keys[i])] = tensor_values[i];
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "local_dict: " << local_dict;
|
||||
return local_dict;
|
||||
}
|
||||
|
||||
bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
MS_LOG(DEBUG) << "Launch PyExecute(), inputs.size: " << inputs.size() << ", outputs: " << outputs.size();
|
||||
if (Py_IsInitialized() != true) {
|
||||
MS_LOG(ERROR) << "Py_IsInitialized failed.";
|
||||
return false;
|
||||
}
|
||||
if (outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "The output num is 1, but got " << outputs.size();
|
||||
}
|
||||
|
||||
// Build the script.
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
const auto &input0_info = inputs_info_[0];
|
||||
const auto &input0_abstract = input0_info.abstract;
|
||||
const auto &input0_abstract_scalar = dyn_cast<abstract::AbstractScalar>(input0_abstract);
|
||||
MS_EXCEPTION_IF_NULL(input0_abstract_scalar);
|
||||
if (!input0_abstract_scalar->BuildType()->isa<String>()) {
|
||||
MS_LOG(EXCEPTION) << "Should be a string, but got " << input0_abstract_scalar->ToString();
|
||||
}
|
||||
const auto &input0_value = input0_abstract_scalar->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(input0_value);
|
||||
const auto &input0_str = dyn_cast<StringImm>(input0_value);
|
||||
MS_LOG(DEBUG) << "Script: " << input0_str->ToString();
|
||||
const std::string &script = input0_str->value();
|
||||
|
||||
// Build local parameters dict.
|
||||
const auto &local_dict = BuildLocalParameters(inputs);
|
||||
// To call the script with global and local parameters.
|
||||
const auto &global_dict = CallPythonGetGlobalParams();
|
||||
const auto &py_script = py::str(script);
|
||||
auto params = py::tuple(2);
|
||||
params[0] = global_dict;
|
||||
params[1] = local_dict;
|
||||
MS_LOG(DEBUG) << "py_script: " << py_script << ", params: " << params;
|
||||
const auto &py_res = CallPythonScript(py_script, params);
|
||||
// Check Python result.
|
||||
if (py::isinstance<py::none>(py_res)) {
|
||||
MS_LOG(EXCEPTION) << "Real output is None.";
|
||||
} else if (py::isinstance<py::array>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::array, py_res: " << py_res;
|
||||
ArrayToRawMemory(py_res.cast<py::array>(), outputs[0]);
|
||||
} else if (py::isinstance<py::float_>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::float_, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::int_>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::int_, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::bool_>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::bool_, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::str>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::str, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::tuple>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::tuple, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::list>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::list, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::dict>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::dict, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::set>(py_res)) {
|
||||
MS_LOG(DEBUG) << "Real output is py::set, py_res: " << py_res;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "The output is invalid, py_res: " << py_res;
|
||||
}
|
||||
AttachPyOutputData(py_res);
|
||||
return true;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, PyExecute, PyExecuteCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYEXECUTE_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYEXECUTE_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <Python.h>
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/numpy.h"
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
struct PyExecuteInputInfo {
|
||||
py::object py_obj_output;
|
||||
abstract::AbstractBasePtr abstract;
|
||||
TypeId type;
|
||||
std::vector<int64_t> shape;
|
||||
};
|
||||
|
||||
struct PyExecuteOutputData {
|
||||
py::object obj;
|
||||
constexpr static char key[] = "PyExecuteOutputData";
|
||||
};
|
||||
|
||||
class PyExecuteCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
public:
|
||||
PyExecuteCpuKernelMod() : kernel_node_(nullptr) {}
|
||||
~PyExecuteCpuKernelMod() = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void AttachPyOutputData(const py::object &py_res);
|
||||
py::object BuildLocalParameters(const std::vector<AddressPtr> &inputs);
|
||||
py::object BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs);
|
||||
|
||||
CNodePtr kernel_node_{nullptr};
|
||||
std::vector<PyExecuteInputInfo> inputs_info_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYEXECUTE_KERNEL_H_
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind_api/pybind_patch.h"
|
||||
|
||||
#include "mindspore/core/ops/py_execute.h"
|
||||
#include "mindspore/ccsrc/include/common/utils/python_adapter.h"
|
||||
#include "mindspore/ccsrc/pipeline/jit/parse/data_converter.h"
|
||||
#include "mindspore/ccsrc/pybind_api/ir/tensor_py.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace mindspore {
|
||||
namespace {
|
||||
py::object CallPythonGetGlobalParams() {
|
||||
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
|
||||
py::module mod = python_adapter::GetPyModule(python_mod_parse);
|
||||
constexpr auto python_get_dict = "get_global_params";
|
||||
return python_adapter::CallPyModFn(mod, python_get_dict);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class PyExecuteInitializer {
|
||||
public:
|
||||
PyExecuteInitializer() { mindspore::ops::PyExecuteInfer::set_infer_handler(InferPy); }
|
||||
|
||||
~PyExecuteInitializer() = default;
|
||||
|
||||
private:
|
||||
// TODO(zh_qh): Will check the abstract shape and type later.
|
||||
static void InferPy(const std::vector<AbstractBasePtr> &input_args) {
|
||||
const auto &script_abs = input_args[0];
|
||||
const auto &script = script_abs->BuildValue();
|
||||
const auto &script_str = dyn_cast<StringImm>(script);
|
||||
|
||||
const auto &keys_tuple_abs = input_args[1];
|
||||
const auto &keys_tuple = keys_tuple_abs->BuildValue();
|
||||
const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
|
||||
const auto &values_tuple_abs = input_args[2];
|
||||
const auto &values_tuple = values_tuple_abs->BuildValue();
|
||||
if (values_tuple == kAnyValue) {
|
||||
MS_LOG(EXCEPTION) << "Value tuple should not be anyvalue.";
|
||||
}
|
||||
const auto &values = dyn_cast<ValueSequence>(values_tuple);
|
||||
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
|
||||
<< ", values_tuple: " << values_tuple->ToString();
|
||||
|
||||
py::gil_scoped_acquire gil_acquire;
|
||||
py::dict local_dict;
|
||||
for (size_t i = 0; i < keys->size(); ++i) {
|
||||
const auto &key = (*keys)[i];
|
||||
const auto &key_str = dyn_cast<StringImm>(key);
|
||||
MS_EXCEPTION_IF_NULL(key_str);
|
||||
const auto &value = (*values)[i];
|
||||
const auto &tuple_abs = values_tuple_abs->cast<abstract::AbstractSequencePtr>();
|
||||
const auto &value_abs = (*tuple_abs)[i];
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
const auto &tensor = value->cast<tensor::TensorPtr>();
|
||||
const auto &py_array_value = python_adapter::PyAdapterCallback::TensorToNumpy(*tensor);
|
||||
local_dict[py::str(key_str->value())] = py_array_value;
|
||||
continue;
|
||||
}
|
||||
local_dict[py::str(key_str->value())] = value;
|
||||
}
|
||||
const auto &global_dict = CallPythonGetGlobalParams();
|
||||
const auto &py_script = py::str(script_str->value());
|
||||
auto params = py::tuple(2);
|
||||
params[0] = global_dict;
|
||||
params[1] = local_dict;
|
||||
MS_LOG(DEBUG) << "py_script: " << py_script << ", params: " << params;
|
||||
const auto &py_res = parse::data_converter::CallPythonScript(py_script, params);
|
||||
MS_LOG(DEBUG) << "py_res: " << py_res;
|
||||
if (py::isinstance<py::none>(py_res)) {
|
||||
MS_LOG(EXCEPTION) << "py_res is none";
|
||||
} else if (py::isinstance<py::array>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::array, py_res: " << py_res;
|
||||
const auto &res_tensor = tensor::TensorPy::MakeTensorOfNumpy(py_res);
|
||||
MS_LOG(DEBUG) << "res_tensor: " << res_tensor->ToString();
|
||||
} else if (py::isinstance<py::float_>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::float_, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::int_>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::int_, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::bool_>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::bool_, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::str>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::str, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::tuple>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::tuple, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::list>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::list, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::dict>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::dict, py_res: " << py_res;
|
||||
} else if (py::isinstance<py::set>(py_res)) {
|
||||
MS_LOG(DEBUG) << "is py::set, py_res: " << py_res;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "py_res is invalid, py_res: " << py_res;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static PyExecuteInitializer py_execute_initializer;
|
||||
} // namespace mindspore
|
|
@ -886,8 +886,9 @@ AnfNodePtr AnfAlgo::GetInputNode(const CNodePtr &node, size_t index) {
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto get_input_index = index + 1;
|
||||
if (get_input_index >= node->inputs().size()) {
|
||||
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
|
||||
<< node->inputs().size() << "." << trace::DumpSourceLines(node);
|
||||
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << ", but the node input size just "
|
||||
<< node->inputs().size() << ". node: " << node->DebugString() << "."
|
||||
<< trace::DumpSourceLines(node);
|
||||
}
|
||||
// input 0 is primitive node
|
||||
return node->input(get_input_index);
|
||||
|
|
|
@ -1350,9 +1350,9 @@ bool AbstractDictionary::operator==(const AbstractBase &other) const {
|
|||
}
|
||||
|
||||
AbstractBasePtr AbstractDictionary::Clone() const {
|
||||
std::vector<AbstractAttribute> kv;
|
||||
std::vector<AbstractElementPair> kv;
|
||||
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
|
||||
[](const AbstractAttribute &item) {
|
||||
[](const AbstractElementPair &item) {
|
||||
MS_EXCEPTION_IF_NULL(item.first);
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
return std::make_pair(item.first->Clone(), item.second->Clone());
|
||||
|
@ -1361,9 +1361,9 @@ AbstractBasePtr AbstractDictionary::Clone() const {
|
|||
}
|
||||
|
||||
AbstractBasePtr AbstractDictionary::Broaden() const {
|
||||
std::vector<AbstractAttribute> kv;
|
||||
std::vector<AbstractElementPair> kv;
|
||||
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
|
||||
[](const AbstractAttribute &item) {
|
||||
[](const AbstractElementPair &item) {
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
return std::make_pair(item.first, item.second->Broaden());
|
||||
});
|
||||
|
@ -1384,7 +1384,7 @@ std::string AbstractDictionary::ToString() const {
|
|||
|
||||
std::size_t AbstractDictionary::hash() const {
|
||||
std::size_t hash_sum = std::accumulate(key_values_.cbegin(), key_values_.cend(), tid(),
|
||||
[](std::size_t hash_sum, const AbstractAttribute &item) {
|
||||
[](std::size_t hash_sum, const AbstractElementPair &item) {
|
||||
MS_EXCEPTION_IF_NULL(item.first);
|
||||
MS_EXCEPTION_IF_NULL(item.second);
|
||||
hash_sum = hash_combine(hash_sum, item.first->hash());
|
||||
|
|
|
@ -1023,8 +1023,8 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
|
|||
public:
|
||||
/// \brief Constructor of AbstractDictionary.
|
||||
///
|
||||
/// \param[in] key_values The vector of AbstractAttribute.
|
||||
explicit AbstractDictionary(const std::vector<AbstractAttribute> &key_values) : key_values_(key_values) {}
|
||||
/// \param[in] key_values The vector of AbstractElementPair.
|
||||
explicit AbstractDictionary(const std::vector<AbstractElementPair> &key_values) : key_values_(key_values) {}
|
||||
|
||||
/// \brief Destructor of AbstractDictionary.
|
||||
~AbstractDictionary() override = default;
|
||||
|
@ -1056,12 +1056,12 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
|
|||
|
||||
/// \brief Get the key values.
|
||||
///
|
||||
/// \return A vector of AbstractAttribute.
|
||||
const std::vector<AbstractAttribute> &elements() const { return key_values_; }
|
||||
/// \return A vector of AbstractElementPair.
|
||||
const std::vector<AbstractElementPair> &elements() const { return key_values_; }
|
||||
|
||||
protected:
|
||||
ValuePtr RealBuildValue() const override;
|
||||
std::vector<AbstractAttribute> key_values_;
|
||||
std::vector<AbstractElementPair> key_values_;
|
||||
};
|
||||
using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>;
|
||||
|
||||
|
|
|
@ -196,8 +196,8 @@ bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spe
|
|||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
|
||||
std::vector<AbstractElementPair> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
|
||||
return *key_value == *item.first->BuildValue();
|
||||
});
|
||||
return it != dict_elems.end();
|
||||
|
|
|
@ -68,7 +68,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
const auto &key = key_list[index];
|
||||
CheckDictKey(key, op_name);
|
||||
}
|
||||
std::vector<AbstractAttribute> key_value;
|
||||
std::vector<AbstractElementPair> key_value;
|
||||
AbstractBasePtrList value_list = values->elements();
|
||||
for (size_t index = 0; index < keys_size; index++) {
|
||||
(void)key_value.emplace_back(key_list[index], value_list[index]);
|
||||
|
@ -259,8 +259,8 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
|
||||
std::vector<AbstractElementPair> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
|
||||
return *key_value == *item.first->BuildValue();
|
||||
});
|
||||
if (it == dict_elems.end()) {
|
||||
|
@ -284,8 +284,8 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
|
|||
|
||||
ValuePtr key_value = key->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(key_value);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
|
||||
std::vector<AbstractElementPair> dict_elems = dict->elements();
|
||||
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
|
||||
return *key_value == *item.first->BuildValue();
|
||||
});
|
||||
|
||||
|
@ -307,10 +307,10 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP
|
|||
constexpr int args_spec_size = 1;
|
||||
CheckArgsSize(op_name, args_spec_list, args_spec_size);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
std::vector<AbstractElementPair> dict_elems = dict->elements();
|
||||
AbstractBasePtrList keys;
|
||||
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
|
||||
[](const AbstractAttribute &item) { return item.first; });
|
||||
[](const AbstractElementPair &item) { return item.first; });
|
||||
return std::make_shared<AbstractTuple>(keys);
|
||||
}
|
||||
|
||||
|
@ -321,10 +321,10 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv
|
|||
constexpr int args_spec_size = 1;
|
||||
CheckArgsSize(op_name, args_spec_list, args_spec_size);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
std::vector<AbstractElementPair> dict_elems = dict->elements();
|
||||
AbstractBasePtrList values;
|
||||
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(values),
|
||||
[](const AbstractAttribute &item) { return item.second; });
|
||||
[](const AbstractElementPair &item) { return item.second; });
|
||||
return std::make_shared<AbstractTuple>(values);
|
||||
}
|
||||
|
||||
|
@ -335,10 +335,10 @@ AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
constexpr int args_spec_size = 1;
|
||||
CheckArgsSize(op_name, args_spec_list, args_spec_size);
|
||||
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||
std::vector<AbstractElementPair> dict_elems = dict->elements();
|
||||
AbstractBasePtrList items;
|
||||
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(items),
|
||||
[](const AbstractAttribute &item) {
|
||||
[](const AbstractElementPair &item) {
|
||||
return std::make_shared<AbstractTuple>(AbstractBasePtrList{item.first, item.second});
|
||||
});
|
||||
return std::make_shared<AbstractList>(items);
|
||||
|
|
|
@ -217,7 +217,15 @@ PrimShapeDependMap &GetHostDependsMap() {
|
|||
return host_depends;
|
||||
}
|
||||
|
||||
std::set<int64_t> GetValueDependArgIndices(const std::string &prim_name, size_t input_num) {
|
||||
std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid inputs";
|
||||
}
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
|
@ -238,25 +246,27 @@ std::set<int64_t> GetValueDependArgIndices(const std::string &prim_name, size_t
|
|||
ori = op_infer->GetValueDependArgIndices();
|
||||
}
|
||||
|
||||
if (!ori.empty()) {
|
||||
(void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
|
||||
[&](int64_t idx) { return idx < SizeToLong(input_num); });
|
||||
if (ori.empty()) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// To support {-1}, filter all the real tensor input index here.
|
||||
constexpr auto all_tensor_inputs = -1;
|
||||
if (ori.size() == 1 && *(ori.cbegin()) == all_tensor_inputs) {
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
const auto &input = cnode->inputs()[i];
|
||||
const auto &input_abstract = input->abstract();
|
||||
if (input_abstract != nullptr && input_abstract->isa<abstract::AbstractTensor>()) {
|
||||
(void)res.emplace(SizeToLong(i - 1));
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
size_t input_num = cnode->inputs().size() - 1;
|
||||
(void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
|
||||
[&](int64_t idx) { return idx < SizeToLong(input_num); });
|
||||
return res;
|
||||
}
|
||||
|
||||
std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (cnode->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid inputs";
|
||||
}
|
||||
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->ToString();
|
||||
return GetValueDependArgIndices(prim_name, cnode->inputs().size() - 1);
|
||||
}
|
||||
|
||||
void RegisterHostDependsImpl(const std::string &prim_name, const std::set<int64_t> &host_depends) {
|
||||
auto &host_depends_map = GetHostDependsMap();
|
||||
host_depends_map[prim_name] = host_depends;
|
||||
|
|
|
@ -80,8 +80,6 @@ MS_CORE_API PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
|
|||
|
||||
MS_CORE_API StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive);
|
||||
|
||||
MS_CORE_API std::set<int64_t> GetValueDependArgIndices(const std::string &prim_name, size_t input_num);
|
||||
|
||||
MS_CORE_API std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode);
|
||||
|
||||
MS_CORE_API void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);
|
||||
|
|
|
@ -231,7 +231,7 @@ using FuncGraphWeakPtr = std::weak_ptr<FuncGraph>;
|
|||
namespace abstract {
|
||||
class AbstractBase;
|
||||
using AbstractBasePtr = std::shared_ptr<AbstractBase>;
|
||||
using AbstractAttribute = std::pair<AbstractBasePtr, AbstractBasePtr>;
|
||||
using AbstractElementPair = std::pair<AbstractBasePtr, AbstractBasePtr>;
|
||||
class AnalysisContext;
|
||||
using AnalysisContextPtr = std::shared_ptr<AnalysisContext>;
|
||||
} // namespace abstract
|
||||
|
|
|
@ -183,7 +183,7 @@ template <typename T>
|
|||
std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, size_t data_len) {
|
||||
size_t size = SizeOf(shape);
|
||||
if (size * sizeof(T) != data_len) {
|
||||
MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T)
|
||||
MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T)
|
||||
<< " item size " << sizeof(T);
|
||||
}
|
||||
auto buf = static_cast<T *>(data);
|
||||
|
|
|
@ -1447,6 +1447,7 @@ GVAR_DEF(PrimitivePtr, kPrimNPUGetFloatStatus, std::make_shared<Primitive>("NPUG
|
|||
GVAR_DEF(PrimitivePtr, kPrimNPUAllocFloatStatus, std::make_shared<Primitive>("NPUAllocFloatStatus"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimNPUClearFloatStatus, std::make_shared<Primitive>("NPUClearFloatStatus"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPyFunc, std::make_shared<Primitive>("PyFunc"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPyExecute, std::make_shared<Primitive>("PyExecute"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimDynamicLossScale, std::make_shared<Primitive>("_DynamicLossScale"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimScaleGrad, std::make_shared<Primitive>("ScaleGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimPopulationCount, std::make_shared<Primitive>("PopulationCount"));
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/py_execute.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(PyExecute, BaseOperator);
|
||||
|
||||
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
ShapeVector out_shape = {1};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
|
||||
return kFloat64;
|
||||
}
|
||||
|
||||
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
|
||||
const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
MS_LOG(DEBUG) << "item: " << item->ToString();
|
||||
}
|
||||
|
||||
if (infer_handler_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
|
||||
}
|
||||
infer_handler_(input_args);
|
||||
|
||||
const auto &type = InferType(primitive, input_args);
|
||||
const auto &shape = InferShape(primitive, input_args);
|
||||
const auto &abstract = MakeAbstract(shape, type);
|
||||
return abstract;
|
||||
}
|
||||
|
||||
std::set<int64_t> PyExecuteInfer::GetValueDependArgIndices() const { return {-1}; }
|
||||
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(PyExecute, prim::kPrimPyExecute, PyExecuteInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_PY_EXECUTE_H_
|
||||
#define MINDSPORE_CORE_OPS_PY_EXECUTE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/base/types.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNamePyExecute = "PyExecute";
|
||||
/// \brief Implement for JIT Fallback.
|
||||
/// Refer to Python API @ref mindspore.ops.PyExecute for more details.
|
||||
class MIND_API PyExecute : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(PyExecute);
|
||||
/// \brief Constructor.
|
||||
PyExecute() : BaseOperator(kNamePyExecute) { InitIOName({"script", "local_keys", "local_values"}, {"result"}); }
|
||||
};
|
||||
|
||||
class MIND_API PyExecuteInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override;
|
||||
|
||||
std::set<int64_t> GetValueDependArgIndices() const override;
|
||||
|
||||
using InferHandler = void (*)(const std::vector<AbstractBasePtr> &);
|
||||
static void set_infer_handler(const InferHandler &infer_handler) { infer_handler_ = infer_handler; }
|
||||
|
||||
private:
|
||||
inline static InferHandler infer_handler_{nullptr};
|
||||
};
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_PY_EXECUTE_H_
|
|
@ -26,7 +26,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
|
||||
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
|
||||
get_obj_from_sequence, get_type, is_class_member_recursive)
|
||||
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params)
|
||||
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
|
|
|
@ -123,6 +123,8 @@ _unsupported_convert_data_type = (
|
|||
mutable,
|
||||
)
|
||||
|
||||
_global_params = {}
|
||||
|
||||
|
||||
def create_slice_obj(start, end, step):
|
||||
"""Create slice object"""
|
||||
|
@ -755,7 +757,8 @@ def eval_script(exp_str, params):
|
|||
local_params = params[1]
|
||||
try:
|
||||
local_params = _convert_python_data(local_params)
|
||||
obj = eval(exp_str, global_params, local_params)
|
||||
res = eval(exp_str, global_params, local_params)
|
||||
logger.debug(f"eval res: '{res}'")
|
||||
except Exception as e:
|
||||
error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) + \
|
||||
". You can try to turn off JIT Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'."
|
||||
|
@ -763,9 +766,9 @@ def eval_script(exp_str, params):
|
|||
raise e
|
||||
|
||||
# Convert set to tuple.
|
||||
if isinstance(obj, set):
|
||||
return tuple(obj)
|
||||
return obj
|
||||
if isinstance(res, set):
|
||||
return tuple(res)
|
||||
return res
|
||||
|
||||
|
||||
def get_script_ids(script):
|
||||
|
@ -777,6 +780,16 @@ def get_script_ids(script):
|
|||
return set(ids)
|
||||
|
||||
|
||||
def merge_global_params(global_dict):
|
||||
logger.debug(f'merge global_dict: {global_dict}')
|
||||
_global_params.update(global_dict)
|
||||
|
||||
|
||||
def get_global_params():
|
||||
logger.debug(f'get global_dict: {_global_params}')
|
||||
return _global_params
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser python code to ast tree.
|
||||
|
@ -800,6 +813,7 @@ class Parser:
|
|||
# Used to resolve the function's nonlocals.
|
||||
self.closure_namespace = ClosureNamespace(self.fn)
|
||||
self.function_name = self.fn.__qualname__
|
||||
self.lines = []
|
||||
self.col_offset = 0
|
||||
|
||||
@staticmethod
|
||||
|
@ -874,15 +888,15 @@ class Parser:
|
|||
raise OSError(f"Mindspore can not compile temporary source code in terminal. "
|
||||
f"Please write source code to a python file and run the file.")
|
||||
raise e
|
||||
lines, self.line_offset = source
|
||||
original_src = ''.join(lines)
|
||||
self.lines, self.line_offset = source
|
||||
original_src = ''.join(self.lines)
|
||||
hexstr = hashlib.sha256(original_src.encode()).hexdigest()
|
||||
ast_tokens_cache = Parser.ast_cache.get(hexstr)
|
||||
if not ast_tokens_cache:
|
||||
src = dedent(original_src)
|
||||
self.col_offset = \
|
||||
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
|
||||
logger.debug("Get source: %s", src)
|
||||
logger.info("Get source: %s", src)
|
||||
try:
|
||||
ast_tokens = asttokens.ASTTokens(src, parse=True)
|
||||
except IndentationError as idt_err:
|
||||
|
@ -976,6 +990,35 @@ class Parser:
|
|||
f"but got {subclass_instance}.")
|
||||
return super(target_father_class, subclass_instance)
|
||||
|
||||
def get_source_code(self, start_lineno, start_colno, end_lineno, end_colno):
|
||||
"""
|
||||
Get the script source at the location.
|
||||
|
||||
Args:
|
||||
start_lineno: The start line no.
|
||||
start_colno: The start column no.
|
||||
end_lineno: The end line no.
|
||||
end_colno: The end column no.
|
||||
|
||||
Returns:
|
||||
str, the source string.
|
||||
"""
|
||||
if start_lineno == 0:
|
||||
logger.critical('start_lineno should not be 0')
|
||||
|
||||
first_line = self.lines[start_lineno - 1]
|
||||
if start_lineno == end_lineno:
|
||||
src = first_line[self.col_offset + start_colno:self.col_offset + end_colno]
|
||||
return src
|
||||
|
||||
src = first_line[self.col_offset + start_colno:]
|
||||
while start_lineno < end_lineno - 1:
|
||||
src += self.lines[start_lineno]
|
||||
start_lineno += 1
|
||||
last_line = self.lines[end_lineno - 1]
|
||||
src += last_line[:self.col_offset + end_colno]
|
||||
return src
|
||||
|
||||
def get_location(self, node):
|
||||
"""
|
||||
Get location of node start and end line no.
|
||||
|
@ -987,7 +1030,7 @@ class Parser:
|
|||
Returns:
|
||||
List, [fileName, linestart, colstart, lineend, colend].
|
||||
"""
|
||||
ret = [self.filename]
|
||||
res = [self.filename]
|
||||
err_exit = 0
|
||||
if isinstance(node, (list, tuple)):
|
||||
node_size = len(node)
|
||||
|
@ -1009,7 +1052,7 @@ class Parser:
|
|||
start_colno += self.col_offset
|
||||
end_lineno += self.line_offset - 1
|
||||
end_colno += self.col_offset
|
||||
ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
|
||||
res = res + [start_lineno, start_colno, end_lineno, end_colno]
|
||||
else:
|
||||
ret = ret + [0, 0, 0, 0]
|
||||
return ret
|
||||
res = res + [0, 0, 0, 0]
|
||||
return res
|
||||
|
|
|
@ -29,6 +29,7 @@ from mindspore.common.tensor import Tensor
|
|||
from mindspore.common.sparse_tensor import RowTensor, RowTensorInner, SparseTensor, COOTensor, CSRTensor
|
||||
from mindspore.common.mutable import mutable
|
||||
from mindspore.common.jit_config import JitConfig
|
||||
from mindspore.common._utils import update_and_return_dict
|
||||
|
||||
# symbols from dtype
|
||||
__all__ = [
|
||||
|
@ -66,4 +67,5 @@ __all__.extend([
|
|||
"set_dump",
|
||||
"ms_memory_recycle",
|
||||
"mutable", "JitConfig",
|
||||
"update_and_return_dict",
|
||||
])
|
||||
|
|
|
@ -53,3 +53,8 @@ def split_to_slice_if_need(dtype, shape):
|
|||
return slice_num
|
||||
slice_num = math.ceil(data_size / emb_cache_size)
|
||||
return slice_num
|
||||
|
||||
|
||||
def update_and_return_dict(dic, key, val):
|
||||
dic.__setitem__(key, val)
|
||||
return dic
|
||||
|
|
|
@ -231,3 +231,9 @@ def bprop_scalar_not(x, out, dout):
|
|||
def bprop_tensor_move(x, out, dout):
|
||||
"""Backpropagator for primitive `TensorMove`."""
|
||||
return (dout,)
|
||||
|
||||
|
||||
@bprops.register("PyExecute")
|
||||
def get_bprop_py_execute(x, y, z, out, dout):
|
||||
"""Generate bprop for PyExecute"""
|
||||
return x, y, z
|
||||
|
|
|
@ -69,6 +69,7 @@ from .pad import _pad_cpu
|
|||
from .range import _range_cpu
|
||||
from .tensor_copy_slices import _tensor_copy_slices_cpu
|
||||
from .l2loss import _l2loss_cpu
|
||||
from .pyexecute import _pyexecute_cpu
|
||||
from .pyfunc import _pyfunc_cpu
|
||||
from .buffer_append import _buffer_append_cpu
|
||||
from .buffer_get import _buffer_get_cpu
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""PyExecute op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
pyexecute_op_info = CpuRegOp("PyExecute") \
|
||||
.input(0, "x", "dynamic") \
|
||||
.output(0, "y", "dynamic") \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(pyexecute_op_info)
|
||||
def _pyexecute_cpu():
|
||||
"""PyExecute cpu register"""
|
||||
return
|
|
@ -116,7 +116,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
|
|||
GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3)
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
ConfusionMatrix, UpdateState, Load, StopGradient,
|
||||
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale,
|
||||
CheckValid, Partial, Depend, identity, Push, Pull, PyExecute, PyFunc, _DynamicLossScale,
|
||||
SampleDistortedBoundingBoxV2)
|
||||
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
|
||||
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
|
||||
|
@ -501,6 +501,7 @@ __all__ = [
|
|||
"TensorScatterDiv",
|
||||
"SoftShrink",
|
||||
"HShrink",
|
||||
"PyExecute",
|
||||
"PyFunc",
|
||||
"BufferAppend",
|
||||
"BufferGetItem",
|
||||
|
|
|
@ -706,6 +706,39 @@ class identity(Primitive):
|
|||
return x
|
||||
|
||||
|
||||
class PyInterpret(Primitive):
|
||||
r"""
|
||||
Interpret Python expression.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(PyInterpret, self).__init__(self.__class__.__name__)
|
||||
self.add_prim_attr('side_effect_io', True)
|
||||
|
||||
|
||||
class PyExecute(PrimitiveWithInfer):
|
||||
r"""
|
||||
Execute Python expression.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
super(PyExecute, self).__init__(self.__class__.__name__)
|
||||
self.add_prim_attr('side_effect_io', True)
|
||||
self.add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def infer_shape(self, *args):
|
||||
logger.error("The function output are empty tuple. Add a placeholder instead. "
|
||||
"Do not use it as it could be any uninitialized data.")
|
||||
return ((1,),)
|
||||
|
||||
def infer_dtype(self, *args):
|
||||
logger.error("The function output are empty tuple. Add a placeholder instead. "
|
||||
"Do not use it as it could be any uninitialized data.")
|
||||
return (mstype.int32,)
|
||||
|
||||
|
||||
class PyFunc(PrimitiveWithInfer):
|
||||
r"""
|
||||
Execute Python function.
|
||||
|
|
|
@ -42,7 +42,7 @@ using AbstractTensor = abstract::AbstractTensor;
|
|||
using AbstractTensorPtr = abstract::AbstractTensorPtr;
|
||||
|
||||
using AbstractNone = abstract::AbstractNone;
|
||||
using AbstractAttribute = abstract::AbstractAttribute;
|
||||
using AbstractAttribute = abstract::AbstractElementPair;
|
||||
using AnalysisEngine = abstract::AnalysisEngine;
|
||||
using AnalysisEnginePtr = abstract::AnalysisEnginePtr;
|
||||
|
||||
|
|
|
@ -1002,7 +1002,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
|
|||
AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
|
||||
AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
|
||||
AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
|
||||
std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
|
||||
std::vector<AbstractElementPair> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
|
||||
AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
|
||||
AbstractBasePtr key = abstract::FromValue("x");
|
||||
AbstractBasePtrList args_spec_list = {array_dict, key};
|
||||
|
|
Loading…
Reference in New Issue