[JIT Fallback] Supports tensor.asnumpy() and return dictionary features in construct() for GraphMode.

This commit is contained in:
张清华 2022-12-26 09:51:50 +08:00
parent a1d5aebc72
commit 00a000c5ce
48 changed files with 1657 additions and 258 deletions

View File

@ -127,7 +127,13 @@ tensor::TensorPtr GetDependValueTensor(const AnfNodePtr &node, size_t i,
auto input_device_address = reinterpret_cast<std::vector<device::DeviceAddress *> *>(args);
if (i >= input_device_address->size() || input_device_address->at(i) == nullptr) {
MS_EXCEPTION_IF_NULL(node);
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->fullname_with_scope();
if (IsPrimitiveCNode(node, prim::kPrimPyExecute)) {
MS_LOG(INFO) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
<< node->fullname_with_scope();
return out_tensor;
}
MS_LOG(EXCEPTION) << "There is no valid address for " << i << " input of " << node->DebugString() << ", "
<< node->fullname_with_scope();
}
out_tensor->data_sync_directly(input_device_address->at(i));

View File

@ -846,10 +846,10 @@ AbstractBasePtrList RectifyAbstractFromRegAttr(const PrimitivePtr &primitive,
return rectify_abs_list;
}
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitive,
AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &prim,
const AbstractBasePtrList &input_abstract) {
MS_EXCEPTION_IF_NULL(primitive);
auto dynamic_inputs_list = primitive->GetAttr(kAttrDynInputSizes);
MS_EXCEPTION_IF_NULL(prim);
auto dynamic_inputs_list = prim->GetAttr(kAttrDynInputSizes);
if (dynamic_inputs_list == nullptr) {
return input_abstract;
}
@ -871,6 +871,10 @@ AbstractBasePtrList RectifyAbstractFromDynamicInput(const PrimitivePtr &primitiv
AbstractBasePtrList dynamic_inputs_abs;
for (auto index = item; index > 0; --index) {
if (input_index >= input_abstract.size()) {
// Not to check for PyExecute.
if ((prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name())) {
continue;
}
MS_LOG(EXCEPTION) << "Index " << input_index << " is out of range in input abstract "
<< input_abstract.size();
}

View File

@ -39,11 +39,11 @@ namespace mindspore {
namespace prim {
constexpr auto kStepDefault = 1;
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractEllipsis;
using mindspore::abstract::AbstractEllipsisPtr;
using mindspore::abstract::AbstractFunction;

View File

@ -49,12 +49,12 @@ FuncGraphPtr DictHasKey::GenerateFuncGraph(const abstract::AbstractBasePtrList &
ValuePtr key_value = args_list[1]->BuildValue();
MS_EXCEPTION_IF_NULL(dict);
MS_EXCEPTION_IF_NULL(key_value);
auto dict_elems = dict->elements();
auto elems = dict->elements();
bool has_key = false;
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const abstract::AbstractAttribute &item) {
auto it = std::find_if(elems.cbegin(), elems.cend(), [&key_value](const abstract::AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
if (it != dict_elems.cend()) {
if (it != elems.cend()) {
has_key = true;
}
@ -153,7 +153,7 @@ abstract::AbstractBasePtrList DictFromKeys::ParseIterableObject(const abstract::
AbstractBasePtrList keys;
auto &dict_elems = dict_keys->elements();
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
[](const abstract::AbstractAttribute &item) { return item.first; });
[](const abstract::AbstractElementPair &item) { return item.first; });
return keys;
}
if (key_type->IsSameTypeId(String::kTypeId)) {

View File

@ -28,10 +28,10 @@
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractKeywordArg;
using mindspore::abstract::AbstractList;
@ -78,7 +78,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_abs_l
auto dict_elems = arg_dict->elements();
(void)std::transform(
dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(elems),
[res_graph, para_dict](const AbstractAttribute &item) {
[res_graph, para_dict](const AbstractElementPair &item) {
// Dict_elems's first element represents parameter names, which should be string type.
auto key_value = GetValue<std::string>(item.first->BuildValue());
auto dict_get_item =

View File

@ -38,11 +38,11 @@
namespace mindspore {
/* namespace to support opt */
namespace opt {
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractElementPair;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractRowTensor;
@ -164,7 +164,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
public:
using ThisClass = SimplifyDataStructuresRewriter;
SimplifyDataStructuresRewriter(const FuncGraphPtr &root_graph, const FuncGraphManagerPtr &manager)
: BaseRewriter(root_graph, manager) {}
: BaseRewriter(root_graph, manager), is_dict_output_{IsDictOutput()} {}
~SimplifyDataStructuresRewriter() override = default;
protected:
@ -176,7 +176,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return str->value();
}
static int64_t GetAttrIndex(const std::vector<AbstractAttribute> &attrs, const AnfNodePtr &name) {
static int64_t GetElementIndex(const std::vector<AbstractElementPair> &attrs, const AnfNodePtr &name) {
auto n_attrs = attrs.size();
auto name_abstract = GetAbstract<AbstractBase>(name);
MS_EXCEPTION_IF_NULL(name_abstract);
@ -191,15 +191,15 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
}
static CNodePtr NewTupleGetCNode(const AnfNodePtr &cnode, const AnfNodePtr &data_node,
const std::vector<AbstractAttribute> &attributes, const AnfNodePtr &name_node) {
int64_t index = GetAttrIndex(attributes, name_node);
const std::vector<AbstractElementPair> &elements, const AnfNodePtr &name_node) {
int64_t index = GetElementIndex(elements, name_node);
auto index_node = NewValueNode(index);
auto prim_node = NewValueNode(prim::kPrimTupleGetItem);
return cnode->func_graph()->NewCNode({prim_node, data_node, index_node});
}
// From:
// DictGetItem(data:AbstractDictionary, cons:AbstractBase)
// DictGetItem(data:AbstractDictionary, key:AbstractBase)
// To:
// TupleGetItem(data, index:Int64Imm)
AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) {
@ -211,27 +211,98 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
CheckInputsSize(node, expect_inputs_size);
constexpr size_t data_index = 1;
constexpr size_t attr_index = 2;
constexpr size_t key_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &attr = inputs[attr_index];
auto &key = inputs[key_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(attr);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
return NewTupleGetCNode(node, data, abs_dict->elements(), attr);
return NewTupleGetCNode(node, data, abs_dict->elements(), key);
}
// DictGetItem --> PyExecute()
AnfNodePtr RebuidDictGetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
// Inputs should be [dict_setitem, dict, item]
const size_t expect_inputs_size = 3;
CheckInputsSize(node, expect_inputs_size);
const size_t data_index = 1;
const size_t item_key_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[item_key_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// Script
constexpr auto internal_dict_self_str = "__internal_dict_self__";
constexpr auto internal_dict_key_str = "__internal_dict_key__";
std::stringstream script_buffer;
script_buffer << internal_dict_self_str << "[" << internal_dict_key_str << "]";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Pack local parameters keys.
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
// Pack the local parameters values, not support list, tuple, or dict.
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(data);
(void)key_value_list.emplace_back(key);
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
// Build the new dict node.
const auto dict_getitem_node = func_graph->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
int64_t index = GetElementIndex(abs_dict->elements(), key);
const auto &val = abs_dict->elements()[index].second;
const auto &tensor_val = dyn_cast<abstract::AbstractTensor>(val);
if (tensor_val != nullptr) {
const auto &tensor_type = tensor_val->element()->BuildType();
dict_getitem_node->set_user_data<Type>("__py_execute_tensor_type__", tensor_type);
const auto &tensor_shape = dyn_cast<abstract::Shape>(tensor_val->BuildShape());
MS_EXCEPTION_IF_NULL(tensor_shape);
dict_getitem_node->set_user_data<abstract::Shape>("__py_execute_tensor_shape__", tensor_shape);
MS_LOG(DEBUG) << "key: " << key->abstract()->BuildValue()->ToString() << ", type: " << tensor_type->ToString()
<< ", shape: " << tensor_shape->ToString() << ", val: " << tensor_val->ToString();
}
MS_LOG(DEBUG) << "Made dict getitem node: " << dict_getitem_node->DebugString();
return dict_getitem_node;
}
AnfNodePtr ConvertDictGetItem(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (support_fallback_runtime && is_dict_output_) {
return RebuidDictGetItem(node);
}
return ConvertDictGetItemToTupleGetItem(node);
}
// From:
// DictSetItem(data:AbstractDictionary, cons:AbstractBase, value)
// DictSetItem(data:AbstractDictionary, key:AbstractBase, value)
// To:
// TupleSetItem(data, index:Int64Imm, value)
// Or:
// tuple_add(data, value)
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) {
AnfNodePtr ConvertDictSetItemToTupleSetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
@ -244,16 +315,18 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
const size_t item_value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &cons = inputs[cons_index];
auto &key = inputs[cons_index];
auto &item_value = inputs[item_value_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(cons);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
int64_t index = GetAttrIndex(abs_dict->elements(), cons);
int64_t index = GetElementIndex(abs_dict->elements(), key);
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
if (index >= static_cast<int64_t>(abs_dict->elements().size())) {
// For dictionary set, if the key does not exist, we should create a new item.
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
@ -269,11 +342,86 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, index_node, item_value});
}
bool IsDictOutput() const {
const AnfNodePtr &output = root_graph_->output();
auto abs_dict = GetAbstract<AbstractDictionary>(output);
if (abs_dict != nullptr) {
return true;
}
return false;
}
// DictSetItem --> PyExecute()
AnfNodePtr RebuidDictSetItem(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
// Inputs should be [dict_setitem, dict, item, value]
const size_t expect_inputs_size = 4;
CheckInputsSize(node, expect_inputs_size);
const size_t data_index = 1;
const size_t item_key_index = 2;
const size_t item_value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[item_key_index];
auto &item_value = inputs[item_value_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(key);
auto abs_dict = GetAbstract<AbstractDictionary>(data);
if (abs_dict == nullptr) {
return nullptr;
}
auto func_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// Script
constexpr auto internal_dict_self_str = "__internal_dict_self__";
constexpr auto internal_dict_key_str = "__internal_dict_key__";
constexpr auto internal_dict_value_str = "__internal_dict_value__";
std::stringstream script_buffer;
script_buffer << "__import__('mindspore').update_and_return_dict(" << internal_dict_self_str << ", "
<< internal_dict_key_str << ", " << internal_dict_value_str << ")";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Pack local parameters keys.
const auto script_dict_self_name = std::make_shared<StringImm>(internal_dict_self_str);
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_key_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_value_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_self_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = func_graph->NewCNode(key_value_names_list);
// Pack the local parameters values, not support list, tuple, or dict.
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(data);
(void)key_value_list.emplace_back(key);
(void)key_value_list.emplace_back(item_value);
const auto key_value_tuple = func_graph->NewCNode(key_value_list);
// Build the new dict node.
const auto dict_setitem_node = func_graph->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict setitem node: " << dict_setitem_node->DebugString();
return dict_setitem_node;
}
AnfNodePtr ConvertDictSetItem(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (support_fallback_runtime && is_dict_output_) {
return RebuidDictSetItem(node);
}
return ConvertDictSetItemToTupleSetItem(node);
}
// From:
// MakeDict(name, input)
// To:
// input
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) {
AnfNodePtr EraseMakeDictNode(const CNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
constexpr size_t expect_inputs_size = 3;
constexpr size_t input_index = 2;
@ -281,6 +429,62 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
return node->input(input_index);
}
// MakeDict(keys, values) --> PyExecute('dict(zip(keys, values))', ...)
AnfNodePtr RebuildMakeDictNode(const CNodePtr &node) const {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
const auto &fg = node->func_graph();
MS_EXCEPTION_IF_NULL(fg);
// Local parameters values.
// Pack the key tuple.
constexpr size_t values_input_index = 2;
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
const auto make_key_tuple_node =
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
NewValueNode(script_key_tuple_str), node->input(values_input_index)});
// Pack the value tuple.
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
const auto make_value_tuple_node =
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
NewValueNode(script_value_tuple_str), node->input(1)});
// Pack the local parameters values
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(make_key_tuple_node);
(void)key_value_list.emplace_back(make_value_tuple_node);
const auto key_value_tuple = fg->NewCNode(key_value_list);
// Pack local parameters keys.
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = fg->NewCNode(key_value_names_list);
// Script
std::stringstream script_buffer;
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Build the new dict node.
const auto make_dict_node = fg->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
return make_dict_node;
}
AnfNodePtr ConvertMakeDict(const CNodePtr &node) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (support_fallback_runtime && is_dict_output_) {
return RebuildMakeDictNode(node);
}
return EraseMakeDictNode(node);
}
// From:
// DictGetValues(dict:AbstractDictionary)
// To:
@ -350,22 +554,79 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
}
// dict(k0:v0, k1:v1, ...) --> tuple(v0, v1, ...)
ValueTuplePtr DictToTuple(const ValueDictionaryPtr &dict) const {
const auto &elements = dict->value();
std::vector<ValuePtr> values;
values.reserve(elements.size());
(void)std::transform(elements.begin(), elements.end(), std::back_inserter(values),
[](const auto &element) { return element.second; });
return std::make_shared<ValueTuple>(values);
AnfNodePtr DictToTuple(const ValueDictionaryPtr &dict) const {
const auto &keys_values = dict->value();
std::vector<ValuePtr> value_list;
value_list.reserve(keys_values.size());
(void)std::transform(keys_values.begin(), keys_values.end(), std::back_inserter(value_list),
[](const auto &value) { return value.second; });
return NewValueNode(std::make_shared<ValueTuple>(value_list));
}
// dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
AnfNodePtr RebuildValueDict(const ValueDictionaryPtr &dict) const {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
constexpr auto internal_dict_zip_values_str = "__internal_dict_zip_values__";
const auto &keys_values = dict->value();
std::vector<ValuePtr> key_list;
key_list.reserve(keys_values.size());
std::vector<ValuePtr> value_list;
value_list.reserve(keys_values.size());
for (const auto &key_value : keys_values) {
(void)key_list.emplace_back(key_value.first);
(void)value_list.emplace_back(key_value.second);
}
// Local parameters values.
// Pack the key tuple.
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
const auto make_key_tuple_node =
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
NewValueNode(script_key_tuple_str), NewValueNode(key_tuple)});
// Pack the value tuple.
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
const auto value_tuple = std::make_shared<ValueTuple>(value_list);
const auto make_value_tuple_node =
root_graph_->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
NewValueNode(script_value_tuple_str), NewValueNode(value_tuple)});
// Pack the local parameters values
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_list.emplace_back(make_key_tuple_node);
(void)key_value_list.emplace_back(make_value_tuple_node);
const auto key_value_tuple = root_graph_->NewCNode(key_value_list);
// Pack local parameters keys.
const auto script_dict_key_name = std::make_shared<StringImm>(internal_dict_zip_keys_str);
const auto script_dict_value_name = std::make_shared<StringImm>(internal_dict_zip_values_str);
std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_key_name));
(void)key_value_names_list.emplace_back(NewValueNode(script_dict_value_name));
const auto key_value_name_tuple = root_graph_->NewCNode(key_value_names_list);
// Script
std::stringstream script_buffer;
script_buffer << "dict(zip(" << internal_dict_zip_keys_str << "," << internal_dict_zip_values_str << "),)";
const std::string &script = script_buffer.str();
const auto script_str = std::make_shared<StringImm>(script);
// Build the new dict node.
const auto make_dict_node = root_graph_->NewCNode(
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
return make_dict_node;
}
using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &);
using ConverterMap = mindspore::HashMap<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
static inline const ConverterMap converters_{
{prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItemToTupleGetItem},
{prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItemToTupleSetItem},
{prim::kPrimDictGetItem, &ThisClass::ConvertDictGetItem},
{prim::kPrimDictSetItem, &ThisClass::ConvertDictSetItem},
{prim::kPrimDictGetValues, &ThisClass::EraseDictGetValues},
{prim::kPrimMakeDict, &ThisClass::EraseMakeDictNode},
{prim::kPrimMakeDict, &ThisClass::ConvertMakeDict},
{prim::kPrimMakeKeywordArg, &ThisClass::EraseMakeKeywordArgNode},
{prim::kPrimExtractKeywordArg, &ThisClass::EraseExtractKeywordArg},
{prim::kPrimDictItems, &ThisClass::EraseDictItems},
@ -384,12 +645,16 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
AnfNodePtr ConvertValueNode(const ValueNodePtr &, const ValuePtr &value) override {
// Convert Dictionary value node.
if (value->isa<ValueDictionary>()) {
return NewValueNode(DictToTuple(value->cast<ValueDictionaryPtr>()));
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (support_fallback_runtime && is_dict_output_) {
return RebuildValueDict(value->cast<ValueDictionaryPtr>());
}
return DictToTuple(value->cast<ValueDictionaryPtr>());
}
return nullptr;
}
static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractAttribute> &attrs) {
static std::shared_ptr<AbstractTuple> MakeAbstractTuple(const std::vector<AbstractElementPair> &attrs) {
std::vector<AbstractBasePtr> elements;
elements.reserve(attrs.size());
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(elements),
@ -459,6 +724,9 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
// AbstractDictionary --> AbstractSequence.
return ConvertToAbstractSequence(abs, 0);
}
private:
bool is_dict_output_{false};
};
// ==================================================================
@ -489,9 +757,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
}
// From:
// ListGetItem(list, cons)
// ListGetItem(list, key)
// To:
// TupleGetItem(list, cons)
// TupleGetItem(list, key)
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
@ -503,8 +771,8 @@ class CleanAfterOptARewriter : public BaseRewriter {
constexpr size_t cons_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &cons = inputs[cons_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons});
auto &key = inputs[cons_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, key});
}
// From:
@ -524,9 +792,9 @@ class CleanAfterOptARewriter : public BaseRewriter {
const size_t value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &cons = inputs[cons_index];
auto &key = inputs[cons_index];
auto &value = inputs[value_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value});
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, key, value});
}
// From:

View File

@ -51,7 +51,7 @@ class ResolveNodeResolve : public AnfVisitor {
if (IsValueNode<parse::NameSpace>(vnode)) {
auto name_space = GetValueNode<parse::NameSpacePtr>(vnode);
MS_EXCEPTION_IF_NULL(name_space);
obj_ = name_space->obj();
obj_ = name_space->namespace_obj();
} else if (IsValueNode<parse::Symbol>(vnode)) {
auto symbol_value = GetValueNode<parse::SymbolPtr>(vnode);
MS_EXCEPTION_IF_NULL(symbol_value);

View File

@ -0,0 +1,104 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/py_interpret_to_execute.h"
#include <memory>
#include <string>
#include <utility>
#include <unordered_map>
#include "abstract/abstract_function.h"
#include "include/common/utils/convert_utils_py.h"
#include "include/common/utils/utils.h"
#include "utils/anf_utils.h"
#include "utils/symbolic.h"
#include "pipeline/jit/parse/resolve.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
namespace {
py::object CallPythonPushGlobalParams(const py::object &dict) {
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
py::module mod = python_adapter::GetPyModule(python_mod_parse);
constexpr auto python_merge_dict = "merge_global_params";
return python_adapter::CallPyModFn(mod, python_merge_dict, dict);
}
} // namespace
// Convert PyInterpret into PyExecute:
// PyInterpret(script, global_dict, local_dict)
// -->
// PyExecute(script, local_dict_keys, local_dict_values),
// with side-effect operation:
// Push global_dict into global parameters list.
// (So it requires no same key name.)
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource) {
auto manager = resource->manager();
const auto &all_nodes = manager->all_nodes();
auto transact = manager->Transact();
constexpr auto input_index_one = 1;
constexpr auto input_index_two = 2;
constexpr auto input_index_three = 3;
for (const auto &node : all_nodes) {
if (!IsPrimitiveCNode(node, prim::kPrimPyInterpret)) {
continue;
}
const auto &cnode = node->cast<CNodePtr>();
MS_LOG(DEBUG) << "cnode: " << cnode->DebugString();
auto new_cnode = std::make_shared<CNode>(*cnode);
new_cnode->set_input(0, NewValueNode(prim::kPrimPyExecute));
if (!IsValueNode<parse::Script>(cnode->input(input_index_one))) {
MS_LOG(EXCEPTION) << "The first input should be a Script, but got "
<< cnode->input(input_index_one)->DebugString();
}
const auto &script = GetValueNode<std::shared_ptr<parse::Script>>(cnode->input(input_index_one));
const auto &script_str = script->script();
const auto &script_strimm_node = NewValueNode(std::make_shared<StringImm>(script_str));
new_cnode->set_input(input_index_one, script_strimm_node);
if (!IsValueNode<ValueDictionary>(cnode->input(input_index_two))) {
MS_LOG(EXCEPTION) << "The second input should be a dictionary, but got "
<< cnode->input(input_index_two)->DebugString();
}
const auto &global_dict = GetValueNode<ValueDictionaryPtr>(cnode->input(input_index_two));
py::object py_global_dict = ValueToPyData(global_dict);
MS_LOG(DEBUG) << "py_global_dict: " << py::str(py_global_dict);
CallPythonPushGlobalParams(py_global_dict);
if (!IsPrimitiveCNode(cnode->input(input_index_three), prim::kPrimMakeDict)) {
MS_LOG(EXCEPTION) << "The 3rd input should be a dictionary, but got "
<< cnode->input(input_index_three)->DebugString();
}
const auto &local_dict_cnode = dyn_cast<CNode>(cnode->input(input_index_three));
MS_EXCEPTION_IF_NULL(local_dict_cnode);
const auto &local_dict_keys = local_dict_cnode->input(input_index_one);
const auto &local_dict_values = local_dict_cnode->input(input_index_two);
if (!IsValueNode<ValueTuple>(local_dict_keys) || !IsPrimitiveCNode(local_dict_values, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "The dictionary's keys and values should be a tuple, but got "
<< local_dict_cnode->DebugString();
}
new_cnode->set_input(input_index_two, local_dict_keys);
new_cnode->set_input(input_index_three, local_dict_values);
(void)transact.Replace(cnode, new_cnode);
}
transact.Commit();
return true;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,29 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_INTERPRET_TO_EXECUTE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_INTERPRET_TO_EXECUTE_H_
#include "pipeline/jit/resource.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
bool PyInterpretToExecute(const pipeline::ResourcePtr &resource);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_PY_INTERPRET_TO_EXECUTE_H_

View File

@ -1432,7 +1432,7 @@ tensor::TensorPtr GetDependValueByConstTensor(const AnfNodePtr &input_node, cons
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
if (!value->isa<tensor::Tensor>()) {
MS_EXCEPTION(ValueError) << "the cnode " << cnode_name << "'s input[" << i << "], must be tensor, but got "
MS_EXCEPTION(ValueError) << "The CNode " << cnode_name << "'s input[" << i << "], must be tensor, but got "
<< value->ToString();
}
auto tensor = value->cast<tensor::TensorPtr>();

View File

@ -239,8 +239,7 @@ ValuePtr ConvertModuleNameSpace(const py::object &obj) {
MS_LOG(DEBUG) << "Converting python module";
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj);
auto converted =
std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, py::cast<py::module>(module_namespace), obj);
auto converted = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_MODULE, module_namespace, obj);
MS_LOG(DEBUG) << "name_space: " << converted->ToString();
return converted;
}

View File

@ -376,8 +376,9 @@ AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) {
}
AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) {
MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace")
<< " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
MS_LOG(DEBUG) << "MakeResolve for "
<< (name_space ? (std::string)py::str(name_space->namespace_obj()) : "null namespace") << " , "
<< (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol.");
ValueNodePtr module_node = NewValueNode(name_space);
ValueNodePtr symbol_node = NewValueNode(resolve_symbol);
auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
@ -388,7 +389,7 @@ AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const An
const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) {
MS_LOG(DEBUG) << "MakeInterpret for " << script_text;
MS_EXCEPTION_IF_NULL(orig_node);
ScriptPtr script = std::make_shared<Script>(script_text);
auto script = std::make_shared<Script>(script_text);
auto script_node = NewValueNode(script);
auto node = func_graph_->NewCNodeInOrder(
{NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});

View File

@ -1218,6 +1218,22 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
UpdateInterpretForUserNode(call_cnode, call_function_node);
MS_EXCEPTION_IF_NULL(call_cnode);
MS_LOG(DEBUG) << "call_cnode: " << call_cnode->DebugString()
<< ", call_function_node: " << call_function_node->DebugString();
// Support tensor.asnumpy() in runtime by JIT Fallback.
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (support_fallback_runtime && IsPrimitiveCNode(call_function_node, prim::kPrimGetAttr)) {
constexpr size_t index_two = 2;
const auto &attr_node = call_function_node->cast<CNodePtr>()->input(index_two);
const auto &attr_str = GetValueNode<StringImmPtr>(attr_node);
MS_EXCEPTION_IF_NULL(attr_str);
if (attr_str->value() == "asnumpy") {
call_cnode->set_interpret(true);
call_cnode = HandleInterpret(block, call_cnode, node);
return call_cnode;
}
}
// Process bulitin function, for example, sum(np.array(xx))
py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, name_id);
constexpr size_t namespace_info_size = 4;
@ -1432,7 +1448,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
// Process the node attr
auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
MS_LOG(DEBUG) << "Attr = " << attr_str;
MS_LOG(DEBUG) << "node: " << node << ", attr: " << attr_str << ", value: " << value_body;
// The fallback feature is enabled in default.
static const auto use_fallback = (support_fallback() != "0");
if (use_fallback && attr_str == "Tensor") {
@ -1449,21 +1465,19 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec
return ret;
}
}
AnfNodePtr attr_node = nullptr;
{
TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
attr_node = NewValueNode(attr_str);
}
MS_EXCEPTION_IF_NULL(block->func_graph());
// Create the apply node
AnfNodePtr attr_node = NewValueNode(attr_str);
auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
if (use_fallback) {
// Check whether it is constant, constant does not need interpret.
auto value_str = py::cast<std::string>(ast()->GetAstNodeText(value_body));
py::bool_ is_const_value =
ast()->CallParserObjMethod(PYTHON_PARSE_CHECK_IS_CONSTANT_VALUE, value_str, common::SafeCStr(attr_str));
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
auto is_constant = py::cast<bool>(is_const_value);
if (!is_constant) {
if (!is_constant || (support_fallback_runtime && attr_str == "asnumpy")) {
UpdateInterpretForUserNode(attr_cnode, value_node);
}
}
@ -2114,7 +2128,7 @@ CNodePtr GenerateInterpretGetItem(const FuncGraphPtr &fg, const AnfNodePtr &iter
auto local_dict_node = fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), local_dict_key, local_dict_value});
// Construct script text node.
parse::ScriptPtr script = std::make_shared<parse::Script>("x[i]");
auto script = std::make_shared<Script>("x[i]");
auto script_node = NewValueNode(script);
return fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});

View File

@ -481,7 +481,7 @@ py::object GetSymbolObject(const NameSpacePtr &name_space, const SymbolPtr &symb
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph is nullptr.";
}
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
auto &obj = name_space->obj();
auto &obj = name_space->namespace_obj();
if (py::isinstance<py::none>(obj)) {
MS_EXCEPTION(NameError) << "The name \'" << symbol << "\' is not defined.";
}

View File

@ -40,15 +40,15 @@ namespace parse {
// NameSpace class for resolving python code.
class NameSpace final : public Named {
public:
NameSpace(const std::string &module, const py::object &obj, const py::object &module_obj = py::object())
: Named(module + ": \'" + std::string(py::str(obj)) + "\'"),
NameSpace(const std::string &module, const py::object &namespace_obj, const py::object &module_obj = py::object())
: Named(module + ": \'" + std::string(py::str(namespace_obj)) + "\'"),
module_(module),
obj_(obj),
namespace_obj_(namespace_obj),
module_obj_(module_obj) {}
~NameSpace() override = default;
MS_DECLARE_PARENT(NameSpace, Named);
const py::object &obj() const { return obj_; }
const py::object &namespace_obj() const { return namespace_obj_; }
const py::object &module_obj() const { return module_obj_; }
const std::string &module() const { return module_; }
abstract::AbstractBasePtr ToAbstract() override {
@ -59,7 +59,7 @@ class NameSpace final : public Named {
// namespace of the module
std::string module_;
// namespace object
py::object obj_;
py::object namespace_obj_;
// module object
py::object module_obj_;
};

View File

@ -48,6 +48,7 @@
#include "frontend/optimizer/environ_conversion.h"
#include "frontend/optimizer/comm_op_reuse_tag.h"
#include "frontend/optimizer/overlap_opt_shard_in_pipeline.h"
#include "frontend/optimizer/py_interpret_to_execute.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
#include "pipeline/pynative/pynative_execute.h"
@ -86,6 +87,19 @@ void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource)
}
} // namespace
bool PyInterpretToExecutePass(const ResourcePtr &resource) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (!support_fallback_runtime) {
return true;
}
MS_EXCEPTION_IF_NULL(resource);
FuncGraphPtr func_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
(void)opt::PyInterpretToExecute(resource);
UpdateArgsSpec(func_graph, resource);
return true;
}
bool SimplifyDataStructuresPass(const ResourcePtr &resource) {
MS_EXCEPTION_IF_NULL(resource);
FuncGraphPtr func_graph = resource->func_graph();
@ -840,6 +854,7 @@ bool AddEmbeddingCachePass(const ResourcePtr &resource) {
}
std::vector<PassItem> kVmPasses = {
{"py_interpret_to_execute", PyInterpretToExecutePass},
{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
{"clean_after_opta", CleanAfterOptAPass},
@ -859,7 +874,8 @@ std::vector<PassItem> kVmPasses = {
{"overlap_opt_shard_in_pipeline", OverlapOptShardInPipelinePass},
};
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
std::vector<PassItem> kGePasses = {{"py_interpret_to_execute", PyInterpretToExecutePass},
{"simplify_data_structures", SimplifyDataStructuresPass},
{"opt_a", OptPassAGroup},
{"clean_after_opta", CleanAfterOptAPass},
{"opt_b", OptPassBGroup},

View File

@ -69,6 +69,7 @@
#include "distributed/collective/collective_manager.h"
#include "mindspore/ccsrc/utils/dynamic_obfuscation/dynamic_obfuscation.h"
#include "mindspore/ccsrc/utils/dynamic_obfuscation/registry_opaque_predicate.h"
#include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
#include "distributed/init.h"
#include "profiler/device/profiling.h"
#include "kernel/akg/akg_kernel_build_manager.h"
@ -269,6 +270,33 @@ std::string ToOrdinal(const size_t &i) {
}
return std::to_string(i) + suffix;
}
py::object GetPyExecuteOutput(const AnfNodePtr &output) {
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (support_fallback_runtime) {
std::function<AnfNodePtr(const AnfNodePtr &)> get_real_output = [&get_real_output](const AnfNodePtr &node) {
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
const auto cnode = dyn_cast<CNode>(node);
MS_EXCEPTION_IF_NULL(cnode);
return get_real_output(cnode->input(1));
}
return node;
};
const auto &real_output = get_real_output(output);
MS_LOG(INFO) << "Real output: " << real_output << ", " << real_output->DebugString()
<< ", has \'PyExecuteOutputData\': " << real_output->has_user_data<kernel::PyExecuteOutputData>();
if (real_output->has_user_data<kernel::PyExecuteOutputData>()) {
py::gil_scoped_acquire gil_acquire;
const auto &output_data = real_output->user_data<kernel::PyExecuteOutputData>();
py::object res_obj = output_data->obj;
MS_LOG(INFO) << "Has \'PyExecuteOutputData\', just return it. res_obj: " << res_obj;
if (!py::isinstance<py::none>(res_obj)) {
return res_obj;
}
}
}
return py::none();
}
} // namespace
std::string GetObjDesc(const py::object &source_obj) {
@ -1409,16 +1437,25 @@ py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase_o
ConfigManager::GetInstance().set_gpu_loopsink_size(loop_size);
}
MS_LOG(INFO) << "VM loop size " << vm_loop << ", loopsink size " << vm_loop;
py::object ret;
py::object res;
MS_LOG(DEBUG) << "Eval run" << ms_context->backend_policy();
auto output = execute_info->func_graph->output()->abstract();
const auto &output = execute_info->func_graph->output();
MS_EXCEPTION_IF_NULL(output);
const auto &output_abs = output->abstract();
MS_EXCEPTION_IF_NULL(output_abs);
for (int64_t i = 0; i < vm_loop; i++) {
BaseRef value = (*run)(execute_info->arg_list);
ret = BaseRefToPyData(value, output);
res = BaseRefToPyData(value, output_abs);
}
// Replace the output if it's not Tensor, but Python data.
const auto &py_res = GetPyExecuteOutput(output);
if (py_res != py::none()) {
return py_res;
}
MS_LOG(DEBUG) << "Run end";
return ret;
return res;
} // namespace pipeline
FuncGraphPtr GraphExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase) const {

View File

@ -106,8 +106,8 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(engine);
MS_EXCEPTION_IF_NULL(out_conf);
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
AbstractBasePtrList args_abs_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
[](const ConfigPtr &config) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(config);
const auto &eval_result = config->ObtainEvalResult();
@ -122,7 +122,7 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
auto do_signature_func = dyn_cast_ptr<Primitive>(func);
if (do_signature_func != nullptr) {
if (prims_to_skip_undetermined_infer.find(do_signature_func->name()) == prims_to_skip_undetermined_infer.end()) {
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
auto ret_abstract = EvalUndeterminedArgs(args_abs_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined for " << do_signature_func->name()
<< ", ret_abstract: " << ret_abstract->ToString();
@ -148,9 +148,9 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
ScopeGuard scope_guard(scope);
if (bound_node() != nullptr) {
TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_abs_list, args_inputs);
} else {
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
new_node = prim::GenerateCNode(out_cnode->func_graph(), prim_->ToString(), func, args_abs_list, args_inputs);
}
// Update new CNode info.
auto new_cnode = dyn_cast<CNode>(new_node);
@ -162,9 +162,9 @@ EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
return engine->ForwardConfig(out_conf, new_conf);
}
static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_abs_list, bool need_unpack) {
// arg[0] is the func graph to unpack, ignore it
AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
AbstractBasePtrList specialize_args_before_unpack(args_abs_list.begin() + 1, args_abs_list.end());
AbstractBasePtrList graph_specialize_args;
if (need_unpack) {
for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
@ -177,7 +177,7 @@ static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_s
auto arg_dict = specialize_args_before_unpack[index]->cast_ptr<AbstractDictionary>();
auto dict_elems = arg_dict->elements();
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(graph_specialize_args),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
// Dict_elems's first element represents parameter names, which should be string type.
return std::make_shared<AbstractKeywordArg>(
GetValue<std::string>(item.first->BuildValue()), item.second);
@ -212,8 +212,8 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
<< ", inputs size " << out_node_inputs.size();
}
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
AbstractBasePtrList args_abs_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
[](const ConfigPtr &ref) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(ref);
const auto &eval_result = ref->ObtainEvalResult();
@ -221,20 +221,20 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
return eval_result->abstract();
});
// get the forward graph
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
if (args_abs_list.empty()) {
MS_LOG(EXCEPTION) << "args_abs_list can't be empty.";
}
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
auto fn = args_spec_list[0]->cast_ptr<AbstractFunction>();
MS_EXCEPTION_IF_NULL(args_abs_list[0]);
auto fn = args_abs_list[0]->cast_ptr<AbstractFunction>();
if (fn == nullptr) {
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_abs_list[0]->ToString();
}
auto real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
MS_EXCEPTION_IF_NULL(real_fn);
FuncGraphPtr forward_graph = real_fn->func_graph();
MS_EXCEPTION_IF_NULL(forward_graph);
AbstractBasePtrList graph_specialize_args =
GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
GetUnpackGraphSpecArgsList(args_abs_list, unpack_graph->need_unpack_args());
AbstractBasePtrList graph_specialize_args_without_sens;
if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
@ -316,7 +316,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac
EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
MS_EXCEPTION_IF_NULL(engine);
AbstractBasePtrList args_spec_list;
AbstractBasePtrList args_abs_list;
MS_EXCEPTION_IF_NULL(out_conf);
if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
@ -329,7 +329,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
<< " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
<< ", inputs size " << out_node_inputs.size();
}
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
[](const ConfigPtr &ref) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(ref);
const auto &eval_result = ref->ObtainEvalResult();
@ -347,7 +347,7 @@ EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const C
}
AnfNodePtr new_node =
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_abs_list[1], out_node_inputs[1], func_graph);
AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
if (new_node->isa<CNode>()) {
@ -1434,10 +1434,10 @@ EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_
return eng->ForwardConfig(old_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_spec_list, const ValuePtr &data_value,
EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_abs_list, const ValuePtr &data_value,
const AnfNodeConfigPtr &out_conf, const std::string &data) {
constexpr size_t item_index = 1;
auto item_args = args_spec_list[item_index];
auto item_args = args_abs_list[item_index];
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(data_value);
MS_EXCEPTION_IF_NULL(item_value);
@ -1463,7 +1463,7 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &arg
if (IsValueNode<TypeNull>(new_node)) {
// Do not find the attribute.
constexpr auto max_args_len = 3;
bool has_default = (args_spec_list.size() == max_args_len);
bool has_default = (args_abs_list.size() == max_args_len);
if (!has_default) {
MS_EXCEPTION(AttributeError) << data << " object has no attribute " << symbol->symbol();
}
@ -1490,19 +1490,19 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &arg
return eng->ForwardConfig(out_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_spec_list,
EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_abs_list,
const AnfNodeConfigPtr &out_conf) {
// args_spec_list: same as StaticGetter
// args_abs_list: same as StaticGetter
constexpr size_t args_min_size = 2;
if (args_spec_list.size() < args_min_size) {
MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
if (args_abs_list.size() < args_min_size) {
MS_LOG(EXCEPTION) << "Size of args_abs_list is less than 2";
}
MS_EXCEPTION_IF_NULL(out_conf);
// An external type.
constexpr auto data_index = 0;
constexpr auto item_index = 1;
auto data = args_spec_list[data_index];
auto item = args_spec_list[item_index];
auto data = args_abs_list[data_index];
auto item = args_abs_list[item_index];
MS_EXCEPTION_IF_NULL(data);
MS_EXCEPTION_IF_NULL(item);
auto data_value = data->BuildValue();
@ -1527,13 +1527,13 @@ EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_spec
auto data_type = data->BuildType();
MS_EXCEPTION_IF_NULL(data_type);
const auto &data_id_str = TypeIdToString(data_type->type_id());
return GetEvaluatedValueForNameSpaceString(args_spec_list, data_value, out_conf, data_id_str);
return GetEvaluatedValueForNameSpaceString(args_abs_list, data_value, out_conf, data_id_str);
}
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList &args_spec_list,
EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList &args_abs_list,
const ValuePtr &data_value, const AnfNodeConfigPtr &out_conf) {
constexpr size_t item_index = 1;
auto item_args = args_spec_list[item_index];
auto item_args = args_abs_list[item_index];
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
@ -1558,7 +1558,7 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList
auto new_node = parse::ResolveMsClassWithAttr(func_graph->manager(), ms_class->obj(), item_name, out_node);
if (new_node == nullptr) {
constexpr auto max_args_len = 3;
bool has_default = (args_spec_list.size() == max_args_len);
bool has_default = (args_abs_list.size() == max_args_len);
if (!has_default) {
MS_EXCEPTION(AttributeError) << py::str(ms_class->obj()) << " object has no attribute: " << item_name << ".";
}
@ -1575,10 +1575,10 @@ EvalResultPtr GetEvaluatedValueForMsClassAttrOrMethod(const AbstractBasePtrList
return eng->ForwardConfig(out_conf, fn_conf);
}
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &args_spec_list,
EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &args_abs_list,
const FuncGraphPtr &func_value, const AnfNodeConfigPtr &out_conf) {
constexpr size_t item_index = 1;
auto item_args = args_spec_list[item_index];
auto item_args = args_abs_list[item_index];
ValuePtr item_value = item_args->BuildValue();
MS_EXCEPTION_IF_NULL(item_value);
@ -1601,19 +1601,19 @@ EvalResultPtr GetEvaluatedValueForCellAttrOrMethod(const AbstractBasePtrList &ar
py::object ns_obj =
python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, real_python_obj);
auto ns = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
return GetEvaluatedValueForNameSpaceString(args_spec_list, ns, out_conf, py_obj_str);
return GetEvaluatedValueForNameSpaceString(args_abs_list, ns, out_conf, py_obj_str);
}
return nullptr;
}
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
const AbstractBasePtrList &args_spec_list,
const AbstractBasePtrList &args_abs_list,
const ConfigPtr &data_conf,
const AnfNodeConfigPtr &out_conf) {
constexpr size_t data_index = 0;
constexpr size_t item_index = 1;
auto data_args = args_spec_list[data_index];
auto item_args = args_spec_list[item_index];
auto data_args = args_abs_list[data_index];
auto item_args = args_abs_list[item_index];
MS_EXCEPTION_IF_NULL(data_args);
MS_EXCEPTION_IF_NULL(item_args);
ValuePtr item_value = item_args->BuildValue();
@ -1631,9 +1631,23 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
if (require.empty()) {
constexpr auto max_args_len = 3;
bool has_default = (args_spec_list.size() == max_args_len);
bool has_default = (args_abs_list.size() == max_args_len);
if (!has_default) {
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
constexpr auto tensor_asnumpy_attr_name = "asnumpy";
if (item_name != tensor_asnumpy_attr_name) {
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
}
auto out_node = out_conf->node();
auto out_cnode = out_node->cast_ptr<CNode>();
MS_EXCEPTION_IF_NULL(out_cnode);
auto fg = out_cnode->func_graph();
auto py_interpret_node =
fg->NewCNode({NewValueNode(prim::kPrimPyInterpret), NewValueNode(out_node->debug_info()->debug_name())});
fg->ReplaceInOrder(out_node, py_interpret_node);
auto eng = out_conf->engine();
MS_EXCEPTION_IF_NULL(eng);
auto fn_conf = eng->MakeConfig(py_interpret_node, out_conf->context(), out_conf->func_graph());
return eng->ForwardConfig(out_conf, fn_conf);
}
auto out_node = out_conf->node();
auto out_cnode = out_node->cast_ptr<CNode>();
@ -1694,14 +1708,14 @@ ValuePtr GetMsClassObject(const AbstractBasePtr &abs) {
return nullptr;
}
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
// Inputs: namespace and its static function; or class and its member function
constexpr size_t data_index = 0;
constexpr size_t item_index = 1;
auto data_args = args_spec_list[data_index];
auto item_args = args_spec_list[item_index];
auto data_args = args_abs_list[data_index];
auto item_args = args_abs_list[item_index];
MS_EXCEPTION_IF_NULL(data_args);
MS_EXCEPTION_IF_NULL(item_args);
MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString();
@ -1730,9 +1744,9 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
}
constexpr auto max_args_size = 3;
if (args_spec_list.size() == max_args_size) {
if (args_abs_list.size() == max_args_size) {
constexpr size_t default_index = 2;
auto default_args = args_spec_list[default_index];
auto default_args = args_abs_list[default_index];
if (default_args->isa<abstract::AbstractScalar>()) {
ValuePtr default_value = default_args->BuildValue();
if (default_value->isa<parse::InterpretedObject>()) {
@ -1747,12 +1761,12 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
// Get attribute or method of class object decorated with 'jit_class'.
auto class_value = GetMsClassObject(data_args);
if (class_value != nullptr) {
return GetEvaluatedValueForMsClassAttrOrMethod(args_spec_list, class_value, out_conf);
return GetEvaluatedValueForMsClassAttrOrMethod(args_abs_list, class_value, out_conf);
}
// Get attribute or method of nn.Cell object.
auto data_func_graph = dyn_cast_ptr<FuncGraphAbstractClosure>(data_args);
if (data_func_graph != nullptr) {
auto res = GetEvaluatedValueForCellAttrOrMethod(args_spec_list, data_func_graph->func_graph(), out_conf);
auto res = GetEvaluatedValueForCellAttrOrMethod(args_abs_list, data_func_graph->func_graph(), out_conf);
if (res != nullptr) {
return res;
}
@ -1760,58 +1774,112 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
// Try to search method map, if not found, the data_type should be External type.
TypePtr data_type = data_args->BuildType();
if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_spec_list, data_conf, out_conf);
return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_abs_list, data_conf, out_conf);
}
return GetEvaluatedValueForNameSpace(args_spec_list, out_conf);
return GetEvaluatedValueForNameSpace(args_abs_list, out_conf);
}
} // namespace
EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
if (out_conf != nullptr) { // 'out_conf' maybe nullptr in PyNative mode.
if (args_spec_list.empty()) {
if (args_abs_list.empty()) {
MS_LOG(INFO) << "For MakeTuple, the inputs should not be empty. node: " << out_conf->node()->DebugString();
}
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
if (enable_eliminate_unused_element) {
auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
if (flags == nullptr) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
}
(void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
}
}
auto abs = std::make_shared<AbstractTuple>(args_spec_list, sequence_nodes);
auto abs = std::make_shared<AbstractTuple>(args_abs_list, sequence_nodes);
auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, res);
evaluator_cache_mgr_->SetValue(args_abs_list, res);
return res;
}
EvalResultPtr MakeListEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
EvalResultPtr MakeListEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
if (out_conf != nullptr) { // 'out_conf' maybe nullptr in PyNative mode.
if (args_spec_list.empty()) {
if (args_abs_list.empty()) {
MS_LOG(INFO) << "For MakeList, the inputs should not be empty. node: " << out_conf->node()->DebugString();
}
static const auto enable_eliminate_unused_element = (common::GetEnv("MS_DEV_ENABLE_DDE") != "0");
if (enable_eliminate_unused_element) {
auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
if (flags == nullptr) {
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_spec_list.size()));
SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
}
(void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
}
}
auto abs = std::make_shared<AbstractList>(args_spec_list, sequence_nodes);
auto abs = std::make_shared<AbstractList>(args_abs_list, sequence_nodes);
auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, res);
evaluator_cache_mgr_->SetValue(args_abs_list, res);
return res;
}
EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
if (args_abs_list.empty()) {
MS_LOG(EXCEPTION) << "'args_abs_list' should not be empty";
}
// Handle for DDE.
for (size_t i = 0; i < args_abs_list.size(); ++i) {
MS_EXCEPTION_IF_NULL(args_abs_list[i]);
if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
MS_LOG(DEBUG) << "Primitive \'PyExecute\' is consuming tuple/list arguments[" << i
<< "]: " << args_abs_list[i]->ToString();
SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
}
}
auto current_interpret_node = out_conf->node();
MS_EXCEPTION_IF_NULL(current_interpret_node);
MS_LOG(DEBUG) << "The current interpret node: " << current_interpret_node->DebugString();
// Get the type parameter.
MS_EXCEPTION_IF_NULL(args_abs_list[0]);
ValuePtr value_track = args_abs_list[0]->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
auto script_obj = dyn_cast_ptr<StringImm>(value_track);
if (script_obj == nullptr) {
MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
}
// Make global and local parameters.
const std::string &script = script_obj->value();
// Call python script string.
MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
TypePtr type = kFloat32;
if (current_interpret_node->has_user_data("__py_execute_tensor_type__")) {
type = current_interpret_node->user_data<Type>("__py_execute_tensor_type__");
MS_LOG(DEBUG) << "type: " << type->ToString();
}
BaseShapePtr shape;
if (current_interpret_node->has_user_data("__py_execute_tensor_shape__")) {
shape = current_interpret_node->user_data<BaseShape>("__py_execute_tensor_shape__");
MS_LOG(DEBUG) << "shape: " << shape->ToString();
} else {
ShapeVector shp;
(void)shp.emplace_back(Shape::kShapeRankAny);
shape = std::make_shared<Shape>(shp);
}
AbstractBasePtr res = std::make_shared<AbstractTensor>(type, shape);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
}
namespace {
class EmbedEvaluator : public SymbolicPrimEvaluator {
public:
@ -1917,23 +1985,23 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
~GetAttrEvaluator() override = default;
MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
constexpr auto args_min_size = 2;
constexpr auto args_max_size = 3;
constexpr auto attr_index = 1;
auto ret_abstract = EvalUndeterminedArgs(args_spec_list);
auto ret_abstract = EvalUndeterminedArgs(args_abs_list);
if (ret_abstract != nullptr) {
MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
return ret_abstract;
}
// Inputs: data, item
const auto args_size = args_spec_list.size();
const auto args_size = args_abs_list.size();
if (args_size != args_min_size && args_size != args_max_size) {
MS_LOG(EXCEPTION) << "For Primitive GetAttr, the input size should be " << args_min_size << " or "
<< args_max_size << ", but got size:" << args_size;
}
auto attr_abs = args_spec_list[attr_index];
auto attr_abs = args_abs_list[attr_index];
auto attr_abs_type = attr_abs->BuildType();
MS_EXCEPTION_IF_NULL(attr_abs_type);
auto type_id = attr_abs_type->type_id();
@ -1943,13 +2011,13 @@ class GetAttrEvaluator : public TransitionPrimEvaluator {
EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) {
TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
} else {
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
}
// don't lookup from cache, as different out_conf with same node but different context
// may add different entry to anfnode_config_map, like getattr primitive;
evaluator_cache_mgr_->SetValue(args_spec_list, ret);
evaluator_cache_mgr_->SetValue(args_abs_list, ret);
return ret;
}
};
@ -1959,19 +2027,19 @@ class ResolveEvaluator : public TransitionPrimEvaluator {
ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
~ResolveEvaluator() override = default;
MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
constexpr auto resolve_args_size = 2;
// Inputs: namespace, symbol
if (args_spec_list.size() != resolve_args_size) {
MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
if (args_abs_list.size() != resolve_args_size) {
MS_LOG(EXCEPTION) << "Expected args_abs_list size = 2, but has size:" << args_abs_list.size();
}
EvalResultPtr ret = nullptr;
if (bound_node() != nullptr) {
TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
} else {
ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
ret = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
}
return ret;
}
@ -1997,14 +2065,14 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
~CreateInstanceEvaluator() override = default;
MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
// Check the type parameter.
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
if (args_abs_list.empty()) {
MS_LOG(EXCEPTION) << "'args_abs_list' should not be empty";
}
constexpr size_t type_index = 0;
auto arg_class_type = args_spec_list[type_index];
auto arg_class_type = args_abs_list[type_index];
MS_EXCEPTION_IF_NULL(arg_class_type);
TypePtr type = arg_class_type->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
@ -2029,7 +2097,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
// Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
py::tuple params = GetParameters(args_spec_list);
py::tuple params = GetParameters(args_abs_list);
// Create class instance.
auto obj = parse::data_converter::CreatePythonObject(class_type, params);
if (py::isinstance<py::none>(obj)) {
@ -2069,24 +2137,24 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator {
}
AbstractBasePtr ret = ToAbstract(converted_res, AnalysisContext::DummyContext(), out_conf);
auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
}
py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
if (args_spec_list.empty()) {
py::tuple GetParameters(const AbstractBasePtrList &args_abs_list) const {
if (args_abs_list.empty()) {
MS_LOG(EXCEPTION) << "Unexpected arguments num, the min arguments num must be 1, but got 0.";
}
// Exclude class type by minus 1;
std::size_t params_size = args_spec_list.size() - 1;
std::size_t params_size = args_abs_list.size() - 1;
auto params = py::tuple(params_size);
for (size_t i = 0; i < params_size; i++) {
// Only support the Scalar parameters type. Bypass class type by offset with 1.
auto arg = args_spec_list[i + 1];
auto arg = args_abs_list[i + 1];
MS_EXCEPTION_IF_NULL(arg);
if (IsContainUndetermined(arg)) {
MS_EXCEPTION(TypeError) << "The " << i << "th initializing input to create instance for "
<< args_spec_list[0]->BuildValue()->ToString()
<< args_abs_list[0]->BuildValue()->ToString()
<< " should be a constant, but got: " << arg->ToString();
}
// Because the Tensor's AbstractTensor can't get value from GetValueTrack.
@ -2103,13 +2171,13 @@ class CallInstanceEvaluator : public TransitionPrimEvaluator {
CallInstanceEvaluator() : TransitionPrimEvaluator("CallInstanceEvaluator") {}
~CallInstanceEvaluator() override = default;
MS_DECLARE_PARENT(CallInstanceEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(EXCEPTION) << "args_spec_list should not be empty.";
if (args_abs_list.empty()) {
MS_LOG(EXCEPTION) << "args_abs_list should not be empty.";
}
constexpr size_t cls_index = 0;
auto arg_cls = args_spec_list[cls_index];
auto arg_cls = args_abs_list[cls_index];
MS_EXCEPTION_IF_NULL(arg_cls);
TypePtr type = arg_cls->GetTypeTrack();
MS_EXCEPTION_IF_NULL(type);
@ -2160,18 +2228,18 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
~PyInterpretEvaluator() override = default;
MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
if (args_spec_list.empty()) {
MS_LOG(ERROR) << "'args_spec_list' should not be empty";
if (args_abs_list.empty()) {
MS_LOG(EXCEPTION) << "'args_abs_list' should not be empty";
}
auto current_interpret_node = out_conf->node();
MS_EXCEPTION_IF_NULL(current_interpret_node);
MS_LOG(DEBUG) << "The current interpret node: " << current_interpret_node->DebugString();
// Get the type parameter.
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
MS_EXCEPTION_IF_NULL(args_abs_list[0]);
ValuePtr value_track = args_abs_list[0]->GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
auto script_obj = dyn_cast_ptr<parse::Script>(value_track);
@ -2180,8 +2248,17 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
}
// Make global and local parameters.
non_const_err_ = false;
const std::string &script = script_obj->script();
py::tuple params = MakeParameters(args_spec_list, script);
py::tuple params = MakeParameters(args_abs_list, script);
if (non_const_err_) { // Would convert PyInterpret to PyExecute then.
ShapeVector shp;
(void)shp.emplace_back(Shape::kShapeDimAny);
AbstractBasePtr res = std::make_shared<AbstractTensor>(kInt32, std::make_shared<Shape>(shp));
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
}
// Call python script string.
MS_LOG(DEBUG) << "Call script: " << script << ", params: " << py::str(params);
@ -2189,7 +2266,7 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
if (py::isinstance<py::none>(obj)) {
AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
auto infer_result = std::make_shared<EvalResult>(res, nullptr);
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
}
@ -2206,7 +2283,7 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
}
AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
return infer_result;
}
@ -2223,9 +2300,17 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
const auto &local_abs_val = local_abs->BuildValue();
MS_EXCEPTION_IF_NULL(local_abs_val);
if (local_abs_val == kAnyValue) {
MS_EXCEPTION(ValueError) << "When using JIT Fallback to handle script '" << script
<< "', the inputs should be constant, but found variable '" << name
<< "' to be nonconstant.";
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") == "1");
if (support_fallback_runtime) {
MS_LOG(INFO) << "When using JIT Fallback to handle script '" << script
<< "', the inputs should be constant, but found variable '" << name
<< "' to be nonconstant. To convert to PyExecute() afterwards";
non_const_err_ = true;
} else {
MS_EXCEPTION(ValueError) << "When using JIT Fallback to handle script '" << script
<< "', the inputs should be constant, but found variable '" << name
<< "' to be nonconstant.";
}
}
if (local_abs->isa<abstract::AbstractTensor>()) {
MS_LOG(WARNING) << "When using JIT Fallback to handle script '" << script << "', found variable '" << name
@ -2256,18 +2341,20 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
return;
}
py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list, const std::string &script) const {
py::tuple MakeParameters(const AbstractBasePtrList &args_abs_list, const std::string &script) const {
constexpr int params_size = 3;
if (params_size != args_spec_list.size()) {
if (params_size != args_abs_list.size()) {
MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size
<< ", not equal to arguments.size:" << args_spec_list.size();
<< ", not equal to arguments.size:" << args_abs_list.size();
}
// The first argument is script string, ignore it.
auto params = py::tuple(params_size - 1);
// Make the global parameters.
auto global_dict = dyn_cast<AbstractDictionary>(args_spec_list[1]); // Global parameters dict.
MS_EXCEPTION_IF_NULL(global_dict);
auto global_dict = dyn_cast<AbstractDictionary>(args_abs_list[1]); // Global parameters dict.
if (global_dict == nullptr) {
MS_LOG(EXCEPTION) << "The second argument should be a dictionary, but got " << args_abs_list[1]->ToString();
}
auto filtered_global_dict = FilterParameters(global_dict);
MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString()
<< ", filtered_global_dict: " << filtered_global_dict->ToString();
@ -2282,8 +2369,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
// Make the local parameters.
constexpr size_t local_index = 2;
auto local_dict = dyn_cast<AbstractDictionary>(args_spec_list[local_index]); // Local parameters dict.
MS_EXCEPTION_IF_NULL(local_dict);
auto local_dict = dyn_cast<AbstractDictionary>(args_abs_list[local_index]); // Local parameters dict.
if (local_dict == nullptr) {
MS_LOG(EXCEPTION) << "The third argument should be a dictionary, but got "
<< args_abs_list[local_index]->ToString();
}
auto filtered_local_dict = FilterParameters(local_dict);
MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
<< ", filtered_local_dict: " << filtered_local_dict->ToString();
@ -2312,11 +2402,11 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
MS_EXCEPTION_IF_NULL(abstract_dict);
std::vector<AbstractAttribute> kv;
std::vector<AbstractElementPair> kv;
const auto &keys_values = abstract_dict->elements();
// Filter out the element of Function type.
(void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.second);
return (!item.second->isa<abstract::AbstractFunction>());
});
@ -2327,6 +2417,9 @@ class PyInterpretEvaluator : public TransitionPrimEvaluator {
constexpr char const_arg_attr[] = "const_arg";
return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
}
private:
mutable bool non_const_err_{false};
};
class PartialEvaluator : public Evaluator {
@ -2346,7 +2439,7 @@ class PartialEvaluator : public Evaluator {
MS_EXCEPTION_IF_NULL(arg0_eval_result);
auto arg0_value = arg0_eval_result->abstract();
MS_EXCEPTION_IF_NULL(arg0_value);
AbstractBasePtrList args_spec_list{arg0_value};
AbstractBasePtrList args_abs_list{arg0_value};
// Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
if (arg0_value->isa<AbstractError>()) {
MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
@ -2354,10 +2447,10 @@ class PartialEvaluator : public Evaluator {
MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
<< " as func is: " << arg0_value->ToString();
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
return eval_result;
}
auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
auto func = CheckArg<AbstractFunction>("partial", args_abs_list, 0);
// Sometimes, node[0] in out_conf becomes phi0;
if (func->isa<PrimitiveAbstractClosure>()) {
auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func);
@ -2369,14 +2462,14 @@ class PartialEvaluator : public Evaluator {
}
}
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
(void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_abs_list),
[](const ConfigPtr &config) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(config);
const auto &eval_result = config->ObtainEvalResult();
MS_EXCEPTION_IF_NULL(eval_result);
return eval_result->abstract();
});
AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
AbstractBasePtrList args(args_abs_list.begin() + 1, args_abs_list.end());
auto cnode = out_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
@ -2393,7 +2486,7 @@ class PartialEvaluator : public Evaluator {
auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
return eval_result;
}
@ -2429,7 +2522,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
RaiseEvaluator() : TransitionPrimEvaluator("RaiseEvaluator") {}
~RaiseEvaluator() override = default;
MS_DECLARE_PARENT(RaiseEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
auto node = out_conf->node();
MS_EXCEPTION_IF_NULL(node);
@ -2441,12 +2534,12 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
<< "Please check your conditions which raise node is located at: "
<< trace::GetDebugInfo(node->debug_info());
}
if (args_spec_list.empty()) {
if (args_abs_list.empty()) {
// process raise
MS_LOG(EXCEPTION) << "No active exception to reraise.";
}
std::string exception_type = GetExceptionType(args_spec_list[0]);
std::string exception_type = GetExceptionType(args_abs_list[0]);
auto iter = exception_types_map.find(exception_type);
if (iter == exception_types_map.end()) {
MS_LOG(EXCEPTION) << "Unsupported exception type: " << exception_type
@ -2454,7 +2547,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
<< SupportedExceptionsToString();
}
ExceptionType type = iter->second;
if (args_spec_list.size() == 1) {
if (args_abs_list.size() == 1) {
// Process raise ValueError()
MS_EXCEPTION(type);
}
@ -2470,7 +2563,7 @@ class RaiseEvaluator : public TransitionPrimEvaluator {
}
for (size_t index = index_begin; index < inputs.size(); ++index) {
const auto input = inputs[index];
auto input_abs = args_spec_list[index - 1];
auto input_abs = args_abs_list[index - 1];
MS_EXCEPTION_IF_NULL(input_abs);
bool need_symbol = CheckNeedSymbol(input, input_abs);
if (need_symbol) {
@ -2636,19 +2729,19 @@ class WithEnterEvaluator : public TransitionPrimEvaluator {
WithEnterEvaluator() : TransitionPrimEvaluator("WithEnterEvaluator") {}
~WithEnterEvaluator() override = default;
MS_DECLARE_PARENT(WithEnterEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
auto node = out_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto cur_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(cur_graph);
if (args_spec_list.size() != 1) {
if (args_abs_list.size() != 1) {
MS_LOG(EXCEPTION) << "The enter node has wrong input." << node->debug_info();
}
// Check class object
auto partial_abs = args_spec_list[0]->cast<PartialAbstractClosurePtr>();
auto partial_abs = args_abs_list[0]->cast<PartialAbstractClosurePtr>();
MS_EXCEPTION_IF_NULL(partial_abs);
if (!IsCallInstance(partial_abs)) {
MS_LOG(EXCEPTION) << "The enter node has wrong input." << node->debug_info();
@ -2693,19 +2786,19 @@ class WithExitEvaluator : public TransitionPrimEvaluator {
WithExitEvaluator() : TransitionPrimEvaluator("WithExitEvaluator") {}
~WithExitEvaluator() override = default;
MS_DECLARE_PARENT(WithExitEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
auto node = out_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto cur_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(cur_graph);
if (args_spec_list.size() != 1) {
if (args_abs_list.size() != 1) {
MS_LOG(EXCEPTION) << "The exit node has wrong input." << node->debug_info();
}
// Check class object
auto partial_abs = args_spec_list[0]->cast<PartialAbstractClosurePtr>();
auto partial_abs = args_abs_list[0]->cast<PartialAbstractClosurePtr>();
MS_EXCEPTION_IF_NULL(partial_abs);
if (!IsCallInstance(partial_abs)) {
MS_LOG(EXCEPTION) << "The exit node has wrong input." << node->debug_info();
@ -2755,13 +2848,13 @@ class JoinedStrEvaluator : public TransitionPrimEvaluator {
JoinedStrEvaluator() : TransitionPrimEvaluator("JoinedStrEvaluator") {}
~JoinedStrEvaluator() override = default;
MS_DECLARE_PARENT(JoinedStrEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override {
auto node = out_conf->node()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(node);
auto cur_graph = node->func_graph();
MS_EXCEPTION_IF_NULL(cur_graph);
bool exist_tensor = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &arg) {
bool exist_tensor = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](const AbstractBasePtr &arg) {
auto arg_value = arg->BuildValue();
MS_EXCEPTION_IF_NULL(arg_value);
return arg_value->isa<AnyValue>();
@ -2777,7 +2870,7 @@ class JoinedStrEvaluator : public TransitionPrimEvaluator {
new_node = cur_graph->NewCNode(new_inputs);
} else {
std::string ret;
for (const auto &arg : args_spec_list) {
for (const auto &arg : args_abs_list) {
auto arg_value = arg->BuildValue();
MS_EXCEPTION_IF_NULL(arg_value);
ret += arg_value->ToString();

View File

@ -178,6 +178,15 @@ class MakeListEvaluator : public TransitionPrimEvaluator {
const AnfNodeConfigPtr &out_conf) override;
};
class PyExecuteEvaluator : public TransitionPrimEvaluator {
public:
PyExecuteEvaluator() : TransitionPrimEvaluator("PyExecuteEvaluator") {}
~PyExecuteEvaluator() override = default;
MS_DECLARE_PARENT(PyExecuteEvaluator, TransitionPrimEvaluator);
EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
const AnfNodeConfigPtr &out_conf) override;
};
bool IsInWhiteList(const PrimitivePtr &primitive);
PrimEvaluatorMap &GetPrimEvaluatorConstructors();

View File

@ -534,6 +534,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
if (prim->Hash() == prim::kPrimMixedPrecisionCast->Hash() && prim->name() == prim::kPrimMixedPrecisionCast->name()) {
return std::make_shared<MixedPrecisionCastEvaluator>(prim);
}
if (prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name()) {
prim::kPrimPyExecute->AddAttr("primitive_target", MakeValue("CPU"));
return std::make_shared<PyExecuteEvaluator>();
}
// Find prim infer function in the prim function map return a standard evaluator
auto eval_impl = GetPrimitiveInferImpl(prim);

View File

@ -58,6 +58,9 @@ void ValidateOperation(const AnfNodePtr &node) {
if (prim->HasAttr("is_load")) {
return;
}
if (prim->name() == "PyExecute") {
return;
}
if (prim->name() == "TensorMove") {
return;
}
@ -127,7 +130,8 @@ void ValidateAbstract(const AnfNodePtr &node) {
abstract->isa<AbstractTuple>() || abstract->isa<AbstractList>() ||
abstract->isa<AbstractTensor>() || abstract->isa<AbstractRowTensor>() ||
abstract->isa<AbstractRefTensor>() || abstract->isa<AbstractMapTensor>() ||
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>();
abstract->isa<abstract::AbstractNone>() || abstract->isa<abstract::AbstractMonad>() ||
abstract->isa<abstract::AbstractScript>();
if (is_legal_abstract) {
return;
}

View File

@ -8,6 +8,7 @@ if(ENABLE_CPU)
"eigen/*.cc"
"mkldnn/*.cc"
"ps/*.cc"
"pyexecute/*.cc"
"pyfunc/*.cc"
"rl/*.cc"
"custom/*.cc"

View File

@ -0,0 +1,379 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
#include <memory>
#include <vector>
#include <utility>
#include "Eigen/Core"
#include "abstract/utils.h"
#include "plugin/device/cpu/hal/device/cpu_common.h"
#include "include/common/utils/python_adapter.h"
#include "plugin/factory/ms_factory.h"
#include "mindspore/ccsrc/pipeline/jit/parse/resolve.h"
namespace mindspore {
namespace kernel {
namespace {
py::object CallPythonGetGlobalParams() {
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
py::module mod = python_adapter::GetPyModule(python_mod_parse);
constexpr auto python_get_dict = "get_global_params";
return python_adapter::CallPyModFn(mod, python_get_dict);
}
// Call the python script string. The same codes as parse/data_converter.h, we must copy it here.
py::object CallPythonScript(const py::object &script, const py::tuple &args_kwargs) {
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
py::module mod = python_adapter::GetPyModule(python_mod_parse);
constexpr auto python_mode_eval = "eval_script";
// The `args_kwargs` is a tuple(dict(global), dict(local)).
return args_kwargs.empty() ? python_adapter::CallPyModFn(mod, python_mode_eval, script)
: python_adapter::CallPyModFn(mod, python_mode_eval, script, args_kwargs);
}
} // namespace
void PyExecuteCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(DEBUG) << "kernel_node: " << kernel_node << ", " << kernel_node->DebugString();
inputs_info_.clear();
kernel_node_ = kernel_node;
for (size_t i = 1; i < kernel_node->size(); ++i) {
const auto &input = kernel_node->inputs()[i];
// Check if PyExecuteOutputData exists.
py::object obj = py::none();
if (input->has_user_data<PyExecuteOutputData>()) {
py::gil_scoped_acquire gil_acquire;
const auto &output_data = input->user_data<PyExecuteOutputData>();
obj = output_data->obj;
MS_LOG(DEBUG) << "Has \'PyExecuteOutputData\', obj: " << obj;
}
// Record the inputs' information by their abstract types.
const auto &input_abstract = input->abstract();
MS_EXCEPTION_IF_NULL(input_abstract);
if (input_abstract->isa<abstract::AbstractRefTensor>()) {
const auto &param = dyn_cast<Parameter>(input);
MS_EXCEPTION_IF_NULL(param);
MS_LOG(DEBUG) << "AbstractRefTensor, input[" << i << "]: " << param->default_param()->ToString();
(void)inputs_info_.emplace_back(PyExecuteInputInfo({obj, input_abstract, kTypeUnknown, {}}));
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
const auto &tensor_abstract = dyn_cast<abstract::AbstractTensor>(input_abstract);
MS_EXCEPTION_IF_NULL(tensor_abstract);
MS_LOG(DEBUG) << "AbstractTensor, input[" << i << "]: " << tensor_abstract->BuildType()->ToString() << ", "
<< tensor_abstract->BuildShape()->ToString();
const auto &in_type = AnfAlgo::GetInputDeviceDataType(kernel_node, i - 1);
const auto &in_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i - 1);
(void)inputs_info_.emplace_back(PyExecuteInputInfo({obj, input_abstract, in_type, in_shape}));
} else {
MS_LOG(DEBUG) << "Other, input[" << i << "]: " << input->DebugString() << ", " << input_abstract->ToString();
(void)inputs_info_.emplace_back(PyExecuteInputInfo({obj, input_abstract, kTypeUnknown, {}}));
}
MS_LOG(DEBUG) << "Kernel node's input[" << i << "]: " << input->DebugString() << ", " << input_abstract->ToString();
}
}
void ArrayToRawMemory(const py::array &array, const AddressPtr &address) {
MS_EXCEPTION_IF_NULL(address);
if (static_cast<unsigned int>(array.flags()) &
static_cast<unsigned int>(pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_)) {
const py::buffer_info &buf_info = array.request();
const auto &res =
memcpy_s(address->addr, address->size, buf_info.ptr, LongToSize(buf_info.size * buf_info.itemsize));
if (res != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed. res: " << res << ", address->size: " << address->size
<< ", size: " << LongToSize(buf_info.size * buf_info.itemsize);
}
} else {
// Transform numpy array to contiguous data.
Py_buffer pybuf;
if (PyObject_GetBuffer(array.ptr(), &pybuf, PyBUF_ANY_CONTIGUOUS) != 0) {
MS_LOG(EXCEPTION) << "Failed to get buffer from the input!";
}
auto buffer = std::make_unique<char[]>(LongToSize(pybuf.len));
if (PyBuffer_ToContiguous(buffer.get(), &pybuf, pybuf.len, 'C')) {
PyBuffer_Release(&pybuf);
MS_LOG(EXCEPTION) << "Can't copy numpy.ndarray to a contiguous buffer.";
}
PyBuffer_Release(&pybuf);
const auto &res = memcpy_s(address->addr, address->size, buffer.get(), LongToSize(pybuf.len));
if (res != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed. res: " << res;
}
}
}
void PyExecuteCpuKernelMod::AttachPyOutputData(const py::object &py_res) {
const auto &py_output = std::make_shared<PyExecuteOutputData>();
py_output->obj = py_res;
// Set Python data for kernel node.
kernel_node_->set_user_data<PyExecuteOutputData>(py_output);
// Set Python data for front node.
const auto &kernel_graph = std::dynamic_pointer_cast<session::KernelGraph>(kernel_node_->func_graph());
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &graph_output_map = kernel_graph->graph_output_map();
session::AnfWithOutIndex anf_index = std::make_pair(kernel_node_, 0);
const auto &iter = graph_output_map.find(anf_index);
if (iter != graph_output_map.cend()) {
const auto &front_node = iter->second.first;
MS_LOG(INFO) << "Found front output for " << kernel_node_ << ", " << kernel_node_->DebugString();
front_node->set_user_data<PyExecuteOutputData>(py_output);
} else {
MS_LOG(DEBUG) << "Not found, kernel node is not output, " << kernel_node_ << ", " << kernel_node_->DebugString();
if (!IS_OUTPUT_ON(mindspore::kDebug)) {
return;
}
for (const auto &output_pair : graph_output_map) {
MS_EXCEPTION_IF_NULL(output_pair.first.first);
MS_EXCEPTION_IF_NULL(output_pair.second.first);
MS_LOG(DEBUG) << "backend node: " << output_pair.first.first << ", " << output_pair.first.first->DebugString()
<< ", front node: " << output_pair.second.first << ", " << output_pair.second.first->DebugString();
}
}
}
// Notice: Remove here after BE kernel supports tuple input.
py::object PyExecuteCpuKernelMod::BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs) {
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
constexpr auto number_two = 2;
std::string tuple_key_str;
py::tuple local_tuple_inputs(inputs_info_.size() - number_two); // Exclude the script and key.
MS_LOG(DEBUG) << "Local parameter tuple size: " << (inputs_info_.size() - number_two);
bool tuple_input_start = false;
size_t tuple_index = 0;
py::dict local_tuple_dict;
for (size_t i = 1; i < inputs.size() && i < inputs_info_.size(); ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
const auto &input_abstract = input_info.abstract;
MS_EXCEPTION_IF_NULL(input_abstract);
const auto &input_type = input_abstract->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (!tuple_input_start && input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
if (str != internal_tuple_keys_str && str != internal_tuple_values_str) {
return py::none();
}
tuple_key_str = str;
tuple_input_start = true;
MS_LOG(DEBUG) << "String, key input[" << i << "]: " << input_abstract->ToString();
continue;
}
// Rebuild the tuple with all left inputs.
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
local_tuple_inputs[tuple_index++] = py::str(str);
MS_LOG(DEBUG) << "String, value input[" << i << "]: " << input_abstract->ToString();
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
const auto &py_array_value = input_info.py_obj_output;
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
MS_LOG(DEBUG) << "Tensor, value input[" << i << "]: " << input_abstract->ToString()
<< ", type: " << input_info.type << ", shape: " << input_info.shape << ", addr: " << inputs[i]->addr
<< ", size: " << inputs[i]->size << ", py_array_value: " << py_array_value
<< ", is_py_middle_data: " << is_py_middle_data;
if (!is_py_middle_data) {
const auto tensor =
std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
MS_EXCEPTION_IF_NULL(tensor);
local_tuple_inputs[tuple_index++] = tensor;
} else {
local_tuple_inputs[tuple_index++] = py_array_value;
}
} else {
MS_LOG(ERROR) << "Unsupported value type.";
}
}
local_tuple_dict[py::str(tuple_key_str)] = local_tuple_inputs;
return local_tuple_dict;
}
py::object PyExecuteCpuKernelMod::BuildLocalParameters(const std::vector<AddressPtr> &inputs) {
const auto &local_tuple_params = BuildLocalTupleParameters(inputs);
if (local_tuple_params != py::none()) {
return local_tuple_params;
}
MS_LOG(DEBUG) << "Build normal local parameters.";
// Build local parameters dict.
std::vector<std::string> keys;
std::vector<tensor::TensorPtr> tensor_values;
std::vector<py::object> py_object_values;
std::vector<bool> py_array_flags;
constexpr auto number_two = 2;
size_t pair_size = (inputs_info_.size() - 1) / number_two;
// Handle the keys.
size_t i = 1;
for (; i < inputs.size() && i < pair_size + 1; ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
const auto &input_abstract = input_info.abstract;
MS_EXCEPTION_IF_NULL(input_abstract);
const auto &input_type = input_abstract->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
(void)keys.emplace_back(str);
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
} else {
MS_LOG(EXCEPTION) << "Other, input[" << i << "]: " << input_abstract->ToString();
}
}
// Handle the values.
for (; i < inputs.size() && i < inputs_info_.size(); ++i) {
const auto &input = inputs[i];
MS_EXCEPTION_IF_NULL(input);
const auto &input_info = inputs_info_[i];
const auto &input_abstract = input_info.abstract;
MS_EXCEPTION_IF_NULL(input_abstract);
const auto &input_type = input_abstract->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
const auto &value = input_abstract->BuildValue();
MS_EXCEPTION_IF_NULL(value);
const auto &str_value = dyn_cast<StringImm>(value);
MS_EXCEPTION_IF_NULL(str_value);
const auto &str = str_value->value();
(void)py_object_values.emplace_back(py::str(str));
(void)tensor_values.emplace_back(nullptr);
(void)py_array_flags.emplace_back(true);
MS_LOG(DEBUG) << "String, input[" << i << "]: " << input_abstract->ToString();
} else if (input_abstract->isa<abstract::AbstractTensor>()) {
const auto &py_array_value = input_info.py_obj_output;
bool is_py_middle_data = !py::isinstance<py::none>(py_array_value);
MS_LOG(DEBUG) << "Tensor, input[" << i << "]: " << input_abstract->ToString() << ", type: " << input_info.type
<< ", shape: " << input_info.shape << ", addr: " << inputs[i]->addr << ", size: " << inputs[i]->size
<< ", py_array_value: " << py_array_value << ", is_py_middle_data: " << is_py_middle_data;
tensor::TensorPtr tensor = nullptr;
if (!is_py_middle_data) {
tensor = std::make_shared<tensor::Tensor>(input_info.type, input_info.shape, inputs[i]->addr, inputs[i]->size);
MS_EXCEPTION_IF_NULL(tensor);
}
(void)py_object_values.emplace_back(py_array_value);
(void)tensor_values.emplace_back(tensor);
(void)py_array_flags.emplace_back(is_py_middle_data);
} else if (input_abstract->isa<abstract::AbstractRefTensor>()) {
MS_LOG(DEBUG) << "Parameter, input[" << i << "]: " << input_abstract->ToString();
} else {
MS_LOG(DEBUG) << "Other, input[" << i << "]: " << input_abstract->ToString();
}
}
if (keys.size() != tensor_values.size() || keys.size() != pair_size) {
MS_LOG(EXCEPTION) << "The local dict input is invalid, " << keys.size() << ", " << tensor_values.size() << ", "
<< inputs_info_.size();
}
// To call the script with global and local parameters.
py::dict local_dict;
for (i = 0; i < keys.size(); ++i) {
if (py_array_flags[i]) {
local_dict[py::str(keys[i])] = py_object_values[i];
} else {
local_dict[py::str(keys[i])] = tensor_values[i];
}
}
MS_LOG(DEBUG) << "local_dict: " << local_dict;
return local_dict;
}
bool PyExecuteCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
MS_LOG(DEBUG) << "Launch PyExecute(), inputs.size: " << inputs.size() << ", outputs: " << outputs.size();
if (Py_IsInitialized() != true) {
MS_LOG(ERROR) << "Py_IsInitialized failed.";
return false;
}
if (outputs.size() != 1) {
MS_LOG(EXCEPTION) << "The output num is 1, but got " << outputs.size();
}
// Build the script.
py::gil_scoped_acquire gil_acquire;
const auto &input0_info = inputs_info_[0];
const auto &input0_abstract = input0_info.abstract;
const auto &input0_abstract_scalar = dyn_cast<abstract::AbstractScalar>(input0_abstract);
MS_EXCEPTION_IF_NULL(input0_abstract_scalar);
if (!input0_abstract_scalar->BuildType()->isa<String>()) {
MS_LOG(EXCEPTION) << "Should be a string, but got " << input0_abstract_scalar->ToString();
}
const auto &input0_value = input0_abstract_scalar->BuildValue();
MS_EXCEPTION_IF_NULL(input0_value);
const auto &input0_str = dyn_cast<StringImm>(input0_value);
MS_LOG(DEBUG) << "Script: " << input0_str->ToString();
const std::string &script = input0_str->value();
// Build local parameters dict.
const auto &local_dict = BuildLocalParameters(inputs);
// To call the script with global and local parameters.
const auto &global_dict = CallPythonGetGlobalParams();
const auto &py_script = py::str(script);
auto params = py::tuple(2);
params[0] = global_dict;
params[1] = local_dict;
MS_LOG(DEBUG) << "py_script: " << py_script << ", params: " << params;
const auto &py_res = CallPythonScript(py_script, params);
// Check Python result.
if (py::isinstance<py::none>(py_res)) {
MS_LOG(EXCEPTION) << "Real output is None.";
} else if (py::isinstance<py::array>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::array, py_res: " << py_res;
ArrayToRawMemory(py_res.cast<py::array>(), outputs[0]);
} else if (py::isinstance<py::float_>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::float_, py_res: " << py_res;
} else if (py::isinstance<py::int_>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::int_, py_res: " << py_res;
} else if (py::isinstance<py::bool_>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::bool_, py_res: " << py_res;
} else if (py::isinstance<py::str>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::str, py_res: " << py_res;
} else if (py::isinstance<py::tuple>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::tuple, py_res: " << py_res;
} else if (py::isinstance<py::list>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::list, py_res: " << py_res;
} else if (py::isinstance<py::dict>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::dict, py_res: " << py_res;
} else if (py::isinstance<py::set>(py_res)) {
MS_LOG(DEBUG) << "Real output is py::set, py_res: " << py_res;
} else {
MS_LOG(EXCEPTION) << "The output is invalid, py_res: " << py_res;
}
AttachPyOutputData(py_res);
return true;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, PyExecute, PyExecuteCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYEXECUTE_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYEXECUTE_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include <Python.h>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "plugin/device/cpu/kernel/cpu_kernel.h"
namespace py = pybind11;
namespace mindspore {
namespace kernel {
struct PyExecuteInputInfo {
py::object py_obj_output;
abstract::AbstractBasePtr abstract;
TypeId type;
std::vector<int64_t> shape;
};
struct PyExecuteOutputData {
py::object obj;
constexpr static char key[] = "PyExecuteOutputData";
};
class PyExecuteCpuKernelMod : public DeprecatedNativeCpuKernelMod {
public:
PyExecuteCpuKernelMod() : kernel_node_(nullptr) {}
~PyExecuteCpuKernelMod() = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) override;
private:
void AttachPyOutputData(const py::object &py_res);
py::object BuildLocalParameters(const std::vector<AddressPtr> &inputs);
py::object BuildLocalTupleParameters(const std::vector<AddressPtr> &inputs);
CNodePtr kernel_node_{nullptr};
std::vector<PyExecuteInputInfo> inputs_info_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_PYEXECUTE_KERNEL_H_

View File

@ -0,0 +1,115 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind_api/pybind_patch.h"
#include "mindspore/core/ops/py_execute.h"
#include "mindspore/ccsrc/include/common/utils/python_adapter.h"
#include "mindspore/ccsrc/pipeline/jit/parse/data_converter.h"
#include "mindspore/ccsrc/pybind_api/ir/tensor_py.h"
namespace py = pybind11;
namespace mindspore {
namespace {
py::object CallPythonGetGlobalParams() {
constexpr auto python_mod_parse = "mindspore._extends.parse"; // The same as PYTHON_MOD_PARSE_MODULE[]
py::module mod = python_adapter::GetPyModule(python_mod_parse);
constexpr auto python_get_dict = "get_global_params";
return python_adapter::CallPyModFn(mod, python_get_dict);
}
} // namespace
class PyExecuteInitializer {
public:
PyExecuteInitializer() { mindspore::ops::PyExecuteInfer::set_infer_handler(InferPy); }
~PyExecuteInitializer() = default;
private:
// TODO(zh_qh): Will check the abstract shape and type later.
static void InferPy(const std::vector<AbstractBasePtr> &input_args) {
const auto &script_abs = input_args[0];
const auto &script = script_abs->BuildValue();
const auto &script_str = dyn_cast<StringImm>(script);
const auto &keys_tuple_abs = input_args[1];
const auto &keys_tuple = keys_tuple_abs->BuildValue();
const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
const auto &values_tuple_abs = input_args[2];
const auto &values_tuple = values_tuple_abs->BuildValue();
if (values_tuple == kAnyValue) {
MS_LOG(EXCEPTION) << "Value tuple should not be anyvalue.";
}
const auto &values = dyn_cast<ValueSequence>(values_tuple);
MS_LOG(DEBUG) << "script: " << script->ToString() << ", keys_tuple: " << keys_tuple->ToString()
<< ", values_tuple: " << values_tuple->ToString();
py::gil_scoped_acquire gil_acquire;
py::dict local_dict;
for (size_t i = 0; i < keys->size(); ++i) {
const auto &key = (*keys)[i];
const auto &key_str = dyn_cast<StringImm>(key);
MS_EXCEPTION_IF_NULL(key_str);
const auto &value = (*values)[i];
const auto &tuple_abs = values_tuple_abs->cast<abstract::AbstractSequencePtr>();
const auto &value_abs = (*tuple_abs)[i];
if (value->isa<tensor::Tensor>()) {
const auto &tensor = value->cast<tensor::TensorPtr>();
const auto &py_array_value = python_adapter::PyAdapterCallback::TensorToNumpy(*tensor);
local_dict[py::str(key_str->value())] = py_array_value;
continue;
}
local_dict[py::str(key_str->value())] = value;
}
const auto &global_dict = CallPythonGetGlobalParams();
const auto &py_script = py::str(script_str->value());
auto params = py::tuple(2);
params[0] = global_dict;
params[1] = local_dict;
MS_LOG(DEBUG) << "py_script: " << py_script << ", params: " << params;
const auto &py_res = parse::data_converter::CallPythonScript(py_script, params);
MS_LOG(DEBUG) << "py_res: " << py_res;
if (py::isinstance<py::none>(py_res)) {
MS_LOG(EXCEPTION) << "py_res is none";
} else if (py::isinstance<py::array>(py_res)) {
MS_LOG(DEBUG) << "is py::array, py_res: " << py_res;
const auto &res_tensor = tensor::TensorPy::MakeTensorOfNumpy(py_res);
MS_LOG(DEBUG) << "res_tensor: " << res_tensor->ToString();
} else if (py::isinstance<py::float_>(py_res)) {
MS_LOG(DEBUG) << "is py::float_, py_res: " << py_res;
} else if (py::isinstance<py::int_>(py_res)) {
MS_LOG(DEBUG) << "is py::int_, py_res: " << py_res;
} else if (py::isinstance<py::bool_>(py_res)) {
MS_LOG(DEBUG) << "is py::bool_, py_res: " << py_res;
} else if (py::isinstance<py::str>(py_res)) {
MS_LOG(DEBUG) << "is py::str, py_res: " << py_res;
} else if (py::isinstance<py::tuple>(py_res)) {
MS_LOG(DEBUG) << "is py::tuple, py_res: " << py_res;
} else if (py::isinstance<py::list>(py_res)) {
MS_LOG(DEBUG) << "is py::list, py_res: " << py_res;
} else if (py::isinstance<py::dict>(py_res)) {
MS_LOG(DEBUG) << "is py::dict, py_res: " << py_res;
} else if (py::isinstance<py::set>(py_res)) {
MS_LOG(DEBUG) << "is py::set, py_res: " << py_res;
} else {
MS_LOG(EXCEPTION) << "py_res is invalid, py_res: " << py_res;
}
}
};
static PyExecuteInitializer py_execute_initializer;
} // namespace mindspore

View File

@ -886,8 +886,9 @@ AnfNodePtr AnfAlgo::GetInputNode(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
auto get_input_index = index + 1;
if (get_input_index >= node->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
<< node->inputs().size() << "." << trace::DumpSourceLines(node);
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << ", but the node input size just "
<< node->inputs().size() << ". node: " << node->DebugString() << "."
<< trace::DumpSourceLines(node);
}
// input 0 is primitive node
return node->input(get_input_index);

View File

@ -1350,9 +1350,9 @@ bool AbstractDictionary::operator==(const AbstractBase &other) const {
}
AbstractBasePtr AbstractDictionary::Clone() const {
std::vector<AbstractAttribute> kv;
std::vector<AbstractElementPair> kv;
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.first);
MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first->Clone(), item.second->Clone());
@ -1361,9 +1361,9 @@ AbstractBasePtr AbstractDictionary::Clone() const {
}
AbstractBasePtr AbstractDictionary::Broaden() const {
std::vector<AbstractAttribute> kv;
std::vector<AbstractElementPair> kv;
(void)std::transform(key_values_.cbegin(), key_values_.cend(), std::back_inserter(kv),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.second);
return std::make_pair(item.first, item.second->Broaden());
});
@ -1384,7 +1384,7 @@ std::string AbstractDictionary::ToString() const {
std::size_t AbstractDictionary::hash() const {
std::size_t hash_sum = std::accumulate(key_values_.cbegin(), key_values_.cend(), tid(),
[](std::size_t hash_sum, const AbstractAttribute &item) {
[](std::size_t hash_sum, const AbstractElementPair &item) {
MS_EXCEPTION_IF_NULL(item.first);
MS_EXCEPTION_IF_NULL(item.second);
hash_sum = hash_combine(hash_sum, item.first->hash());

View File

@ -1023,8 +1023,8 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
public:
/// \brief Constructor of AbstractDictionary.
///
/// \param[in] key_values The vector of AbstractAttribute.
explicit AbstractDictionary(const std::vector<AbstractAttribute> &key_values) : key_values_(key_values) {}
/// \param[in] key_values The vector of AbstractElementPair.
explicit AbstractDictionary(const std::vector<AbstractElementPair> &key_values) : key_values_(key_values) {}
/// \brief Destructor of AbstractDictionary.
~AbstractDictionary() override = default;
@ -1056,12 +1056,12 @@ class MS_CORE_API AbstractDictionary final : public AbstractBase {
/// \brief Get the key values.
///
/// \return A vector of AbstractAttribute.
const std::vector<AbstractAttribute> &elements() const { return key_values_; }
/// \return A vector of AbstractElementPair.
const std::vector<AbstractElementPair> &elements() const { return key_values_; }
protected:
ValuePtr RealBuildValue() const override;
std::vector<AbstractAttribute> key_values_;
std::vector<AbstractElementPair> key_values_;
};
using AbstractDictionaryPtr = std::shared_ptr<AbstractDictionary>;

View File

@ -196,8 +196,8 @@ bool IsInDict(const PrimitivePtr &primitive, const AbstractBasePtrList &args_spe
ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
return it != dict_elems.end();

View File

@ -68,7 +68,7 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
const auto &key = key_list[index];
CheckDictKey(key, op_name);
}
std::vector<AbstractAttribute> key_value;
std::vector<AbstractElementPair> key_value;
AbstractBasePtrList value_list = values->elements();
for (size_t index = 0; index < keys_size; index++) {
(void)key_value.emplace_back(key_list[index], value_list[index]);
@ -259,8 +259,8 @@ AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
if (it == dict_elems.end()) {
@ -284,8 +284,8 @@ AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitiveP
ValuePtr key_value = key->BuildValue();
MS_EXCEPTION_IF_NULL(key_value);
std::vector<AbstractAttribute> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractAttribute &item) {
std::vector<AbstractElementPair> dict_elems = dict->elements();
auto it = std::find_if(dict_elems.cbegin(), dict_elems.cend(), [&key_value](const AbstractElementPair &item) {
return *key_value == *item.first->BuildValue();
});
@ -307,10 +307,10 @@ AbstractBasePtr InferImplDictGetKeys(const AnalysisEnginePtr &, const PrimitiveP
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList keys;
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(keys),
[](const AbstractAttribute &item) { return item.first; });
[](const AbstractElementPair &item) { return item.first; });
return std::make_shared<AbstractTuple>(keys);
}
@ -321,10 +321,10 @@ AbstractBasePtr InferImplDictGetValues(const AnalysisEnginePtr &, const Primitiv
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList values;
std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(values),
[](const AbstractAttribute &item) { return item.second; });
[](const AbstractElementPair &item) { return item.second; });
return std::make_shared<AbstractTuple>(values);
}
@ -335,10 +335,10 @@ AbstractBasePtr InferImplDictItems(const AnalysisEnginePtr &, const PrimitivePtr
constexpr int args_spec_size = 1;
CheckArgsSize(op_name, args_spec_list, args_spec_size);
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
std::vector<AbstractAttribute> dict_elems = dict->elements();
std::vector<AbstractElementPair> dict_elems = dict->elements();
AbstractBasePtrList items;
(void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(items),
[](const AbstractAttribute &item) {
[](const AbstractElementPair &item) {
return std::make_shared<AbstractTuple>(AbstractBasePtrList{item.first, item.second});
});
return std::make_shared<AbstractList>(items);

View File

@ -217,7 +217,15 @@ PrimShapeDependMap &GetHostDependsMap() {
return host_depends;
}
std::set<int64_t> GetValueDependArgIndices(const std::string &prim_name, size_t input_num) {
std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs";
}
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
@ -238,25 +246,27 @@ std::set<int64_t> GetValueDependArgIndices(const std::string &prim_name, size_t
ori = op_infer->GetValueDependArgIndices();
}
if (!ori.empty()) {
(void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
[&](int64_t idx) { return idx < SizeToLong(input_num); });
if (ori.empty()) {
return res;
}
// To support {-1}, filter all the real tensor input index here.
constexpr auto all_tensor_inputs = -1;
if (ori.size() == 1 && *(ori.cbegin()) == all_tensor_inputs) {
for (size_t i = 1; i < cnode->size(); ++i) {
const auto &input = cnode->inputs()[i];
const auto &input_abstract = input->abstract();
if (input_abstract != nullptr && input_abstract->isa<abstract::AbstractTensor>()) {
(void)res.emplace(SizeToLong(i - 1));
}
}
return res;
}
size_t input_num = cnode->inputs().size() - 1;
(void)std::copy_if(ori.begin(), ori.end(), std::inserter(res, res.begin()),
[&](int64_t idx) { return idx < SizeToLong(input_num); });
return res;
}
std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs";
}
auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->ToString();
return GetValueDependArgIndices(prim_name, cnode->inputs().size() - 1);
}
void RegisterHostDependsImpl(const std::string &prim_name, const std::set<int64_t> &host_depends) {
auto &host_depends_map = GetHostDependsMap();
host_depends_map[prim_name] = host_depends;

View File

@ -80,8 +80,6 @@ MS_CORE_API PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap();
MS_CORE_API StandardPrimitiveImplReg GetPrimitiveInferImpl(const PrimitivePtr &primitive);
MS_CORE_API std::set<int64_t> GetValueDependArgIndices(const std::string &prim_name, size_t input_num);
MS_CORE_API std::set<int64_t> GetValueDependArgIndices(const CNodePtr &cnode);
MS_CORE_API void RegisterStandardPrimitiveImpl(const PrimitivePtr &primitive, const StandardPrimitiveImplReg &impl_reg);

View File

@ -231,7 +231,7 @@ using FuncGraphWeakPtr = std::weak_ptr<FuncGraph>;
namespace abstract {
class AbstractBase;
using AbstractBasePtr = std::shared_ptr<AbstractBase>;
using AbstractAttribute = std::pair<AbstractBasePtr, AbstractBasePtr>;
using AbstractElementPair = std::pair<AbstractBasePtr, AbstractBasePtr>;
class AnalysisContext;
using AnalysisContextPtr = std::shared_ptr<AnalysisContext>;
} // namespace abstract

View File

@ -183,7 +183,7 @@ template <typename T>
std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, size_t data_len) {
size_t size = SizeOf(shape);
if (size * sizeof(T) != data_len) {
MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T)
MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T)
<< " item size " << sizeof(T);
}
auto buf = static_cast<T *>(data);

View File

@ -1447,6 +1447,7 @@ GVAR_DEF(PrimitivePtr, kPrimNPUGetFloatStatus, std::make_shared<Primitive>("NPUG
GVAR_DEF(PrimitivePtr, kPrimNPUAllocFloatStatus, std::make_shared<Primitive>("NPUAllocFloatStatus"));
GVAR_DEF(PrimitivePtr, kPrimNPUClearFloatStatus, std::make_shared<Primitive>("NPUClearFloatStatus"));
GVAR_DEF(PrimitivePtr, kPrimPyFunc, std::make_shared<Primitive>("PyFunc"));
GVAR_DEF(PrimitivePtr, kPrimPyExecute, std::make_shared<Primitive>("PyExecute"));
GVAR_DEF(PrimitivePtr, kPrimDynamicLossScale, std::make_shared<Primitive>("_DynamicLossScale"));
GVAR_DEF(PrimitivePtr, kPrimScaleGrad, std::make_shared<Primitive>("ScaleGrad"));
GVAR_DEF(PrimitivePtr, kPrimPopulationCount, std::make_shared<Primitive>("PopulationCount"));

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/py_execute.h"
#include <algorithm>
#include <memory>
#include <set>
#include <string>
#include <vector>
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(PyExecute, BaseOperator);
BaseShapePtr PyExecuteInfer::InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
ShapeVector out_shape = {1};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr PyExecuteInfer::InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const {
return kFloat64;
}
AbstractBasePtr PyExecuteInfer::InferShapeAndType(const abstract::AnalysisEnginePtr &engine,
const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
MS_LOG(DEBUG) << "item: " << item->ToString();
}
if (infer_handler_ == nullptr) {
MS_LOG(EXCEPTION) << "infer_handler_ should not be null.";
}
infer_handler_(input_args);
const auto &type = InferType(primitive, input_args);
const auto &shape = InferShape(primitive, input_args);
const auto &abstract = MakeAbstract(shape, type);
return abstract;
}
std::set<int64_t> PyExecuteInfer::GetValueDependArgIndices() const { return {-1}; }
REGISTER_PRIMITIVE_OP_INFER_IMPL(PyExecute, prim::kPrimPyExecute, PyExecuteInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_PY_EXECUTE_H_
#define MINDSPORE_CORE_OPS_PY_EXECUTE_H_
#include <vector>
#include <memory>
#include <set>
#include "ops/base_operator.h"
#include "ops/op_utils.h"
#include "mindapi/base/types.h"
#include "mindapi/src/helper.h"
#include "utils/check_convert_utils.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace ops {
constexpr auto kNamePyExecute = "PyExecute";
/// \brief Implement for JIT Fallback.
/// Refer to Python API @ref mindspore.ops.PyExecute for more details.
class MIND_API PyExecute : public BaseOperator {
public:
MIND_API_BASE_MEMBER(PyExecute);
/// \brief Constructor.
PyExecute() : BaseOperator(kNamePyExecute) { InitIOName({"script", "local_keys", "local_values"}, {"result"}); }
};
class MIND_API PyExecuteInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override;
AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override;
std::set<int64_t> GetValueDependArgIndices() const override;
using InferHandler = void (*)(const std::vector<AbstractBasePtr> &);
static void set_infer_handler(const InferHandler &infer_handler) { infer_handler_ = infer_handler; }
private:
inline static InferHandler infer_handler_{nullptr};
};
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_PY_EXECUTE_H_

View File

@ -26,7 +26,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
get_obj_from_sequence, get_type, is_class_member_recursive)
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',

View File

@ -123,6 +123,8 @@ _unsupported_convert_data_type = (
mutable,
)
_global_params = {}
def create_slice_obj(start, end, step):
"""Create slice object"""
@ -755,7 +757,8 @@ def eval_script(exp_str, params):
local_params = params[1]
try:
local_params = _convert_python_data(local_params)
obj = eval(exp_str, global_params, local_params)
res = eval(exp_str, global_params, local_params)
logger.debug(f"eval res: '{res}'")
except Exception as e:
error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) + \
". You can try to turn off JIT Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'."
@ -763,9 +766,9 @@ def eval_script(exp_str, params):
raise e
# Convert set to tuple.
if isinstance(obj, set):
return tuple(obj)
return obj
if isinstance(res, set):
return tuple(res)
return res
def get_script_ids(script):
@ -777,6 +780,16 @@ def get_script_ids(script):
return set(ids)
def merge_global_params(global_dict):
logger.debug(f'merge global_dict: {global_dict}')
_global_params.update(global_dict)
def get_global_params():
logger.debug(f'get global_dict: {_global_params}')
return _global_params
class Parser:
"""
Parser python code to ast tree.
@ -800,6 +813,7 @@ class Parser:
# Used to resolve the function's nonlocals.
self.closure_namespace = ClosureNamespace(self.fn)
self.function_name = self.fn.__qualname__
self.lines = []
self.col_offset = 0
@staticmethod
@ -874,15 +888,15 @@ class Parser:
raise OSError(f"Mindspore can not compile temporary source code in terminal. "
f"Please write source code to a python file and run the file.")
raise e
lines, self.line_offset = source
original_src = ''.join(lines)
self.lines, self.line_offset = source
original_src = ''.join(self.lines)
hexstr = hashlib.sha256(original_src.encode()).hexdigest()
ast_tokens_cache = Parser.ast_cache.get(hexstr)
if not ast_tokens_cache:
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.debug("Get source: %s", src)
logger.info("Get source: %s", src)
try:
ast_tokens = asttokens.ASTTokens(src, parse=True)
except IndentationError as idt_err:
@ -976,6 +990,35 @@ class Parser:
f"but got {subclass_instance}.")
return super(target_father_class, subclass_instance)
def get_source_code(self, start_lineno, start_colno, end_lineno, end_colno):
"""
Get the script source at the location.
Args:
start_lineno: The start line no.
start_colno: The start column no.
end_lineno: The end line no.
end_colno: The end column no.
Returns:
str, the source string.
"""
if start_lineno == 0:
logger.critical('start_lineno should not be 0')
first_line = self.lines[start_lineno - 1]
if start_lineno == end_lineno:
src = first_line[self.col_offset + start_colno:self.col_offset + end_colno]
return src
src = first_line[self.col_offset + start_colno:]
while start_lineno < end_lineno - 1:
src += self.lines[start_lineno]
start_lineno += 1
last_line = self.lines[end_lineno - 1]
src += last_line[:self.col_offset + end_colno]
return src
def get_location(self, node):
"""
Get location of node start and end line no.
@ -987,7 +1030,7 @@ class Parser:
Returns:
List, [fileName, linestart, colstart, lineend, colend].
"""
ret = [self.filename]
res = [self.filename]
err_exit = 0
if isinstance(node, (list, tuple)):
node_size = len(node)
@ -1009,7 +1052,7 @@ class Parser:
start_colno += self.col_offset
end_lineno += self.line_offset - 1
end_colno += self.col_offset
ret = ret + [start_lineno, start_colno, end_lineno, end_colno]
res = res + [start_lineno, start_colno, end_lineno, end_colno]
else:
ret = ret + [0, 0, 0, 0]
return ret
res = res + [0, 0, 0, 0]
return res

View File

@ -29,6 +29,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.sparse_tensor import RowTensor, RowTensorInner, SparseTensor, COOTensor, CSRTensor
from mindspore.common.mutable import mutable
from mindspore.common.jit_config import JitConfig
from mindspore.common._utils import update_and_return_dict
# symbols from dtype
__all__ = [
@ -66,4 +67,5 @@ __all__.extend([
"set_dump",
"ms_memory_recycle",
"mutable", "JitConfig",
"update_and_return_dict",
])

View File

@ -53,3 +53,8 @@ def split_to_slice_if_need(dtype, shape):
return slice_num
slice_num = math.ceil(data_size / emb_cache_size)
return slice_num
def update_and_return_dict(dic, key, val):
dic.__setitem__(key, val)
return dic

View File

@ -231,3 +231,9 @@ def bprop_scalar_not(x, out, dout):
def bprop_tensor_move(x, out, dout):
"""Backpropagator for primitive `TensorMove`."""
return (dout,)
@bprops.register("PyExecute")
def get_bprop_py_execute(x, y, z, out, dout):
"""Generate bprop for PyExecute"""
return x, y, z

View File

@ -69,6 +69,7 @@ from .pad import _pad_cpu
from .range import _range_cpu
from .tensor_copy_slices import _tensor_copy_slices_cpu
from .l2loss import _l2loss_cpu
from .pyexecute import _pyexecute_cpu
from .pyfunc import _pyfunc_cpu
from .buffer_append import _buffer_append_cpu
from .buffer_get import _buffer_get_cpu

View File

@ -0,0 +1,29 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""PyExecute op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
pyexecute_op_info = CpuRegOp("PyExecute") \
.input(0, "x", "dynamic") \
.output(0, "y", "dynamic") \
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(pyexecute_op_info)
def _pyexecute_cpu():
"""PyExecute cpu register"""
return

View File

@ -116,7 +116,7 @@ from .nn_ops import (LSTM, SGD, Adam, AdamWeightDecay, FusedSparseAdam, FusedSpa
GridSampler2D, TripletMarginLoss, UpsampleNearest3D, UpsampleTrilinear3D, PadV3)
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
ConfusionMatrix, UpdateState, Load, StopGradient,
CheckValid, Partial, Depend, identity, Push, Pull, PyFunc, _DynamicLossScale,
CheckValid, Partial, Depend, identity, Push, Pull, PyExecute, PyFunc, _DynamicLossScale,
SampleDistortedBoundingBoxV2)
from .random_ops import (RandomChoiceWithMask, StandardNormal, Gamma, RandomGamma, Poisson, UniformInt, UniformReal,
RandomCategorical, StandardLaplace, Multinomial, UniformCandidateSampler,
@ -501,6 +501,7 @@ __all__ = [
"TensorScatterDiv",
"SoftShrink",
"HShrink",
"PyExecute",
"PyFunc",
"BufferAppend",
"BufferGetItem",

View File

@ -706,6 +706,39 @@ class identity(Primitive):
return x
class PyInterpret(Primitive):
r"""
Interpret Python expression.
"""
@prim_attr_register
def __init__(self):
super(PyInterpret, self).__init__(self.__class__.__name__)
self.add_prim_attr('side_effect_io', True)
class PyExecute(PrimitiveWithInfer):
r"""
Execute Python expression.
"""
@prim_attr_register
def __init__(self):
super(PyExecute, self).__init__(self.__class__.__name__)
self.add_prim_attr('side_effect_io', True)
self.add_prim_attr("primitive_target", "CPU")
def infer_shape(self, *args):
logger.error("The function output are empty tuple. Add a placeholder instead. "
"Do not use it as it could be any uninitialized data.")
return ((1,),)
def infer_dtype(self, *args):
logger.error("The function output are empty tuple. Add a placeholder instead. "
"Do not use it as it could be any uninitialized data.")
return (mstype.int32,)
class PyFunc(PrimitiveWithInfer):
r"""
Execute Python function.

View File

@ -42,7 +42,7 @@ using AbstractTensor = abstract::AbstractTensor;
using AbstractTensorPtr = abstract::AbstractTensorPtr;
using AbstractNone = abstract::AbstractNone;
using AbstractAttribute = abstract::AbstractAttribute;
using AbstractAttribute = abstract::AbstractElementPair;
using AnalysisEngine = abstract::AnalysisEngine;
using AnalysisEnginePtr = abstract::AnalysisEnginePtr;

View File

@ -1002,7 +1002,7 @@ TEST_F(TestPrim, test_DictGetItem2) {
AbstractBasePtr arr_x = ArrayOfTensor(UTPrimUtils::kF64, {3, 4, 5});
AbstractBasePtr arr_y = ArrayOfTensor(UTPrimUtils::kF64, {1, 4, 5});
AbstractBasePtr arr_z = ArrayOfTensor(UTPrimUtils::kF64, {3, 1, 5});
std::vector<AbstractAttribute> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
std::vector<AbstractElementPair> array_map = {{"x", arr_x}, {"y", arr_y}, {"z", arr_z}};
AbstractDictionaryPtr array_dict = std::make_shared<AbstractDictionary>(array_map);
AbstractBasePtr key = abstract::FromValue("x");
AbstractBasePtrList args_spec_list = {array_dict, key};