forked from mindspore-Ecosystem/mindspore
[JIT Fallback] Support interpreted node.
This commit is contained in:
parent
66400d742d
commit
4cfa4ac94c
|
@ -285,6 +285,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
<< ", shape: " << tensor_shape->ToString() << ", val: " << tensor_val->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Made dict getitem node: " << dict_getitem_node->DebugString();
|
||||
dict_getitem_node->set_debug_info(node->debug_info());
|
||||
return dict_getitem_node;
|
||||
}
|
||||
|
||||
|
@ -410,6 +411,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
const auto dict_setitem_node = func_graph->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
MS_LOG(DEBUG) << "Made dict setitem node: " << dict_setitem_node->DebugString();
|
||||
dict_setitem_node->set_debug_info(node->debug_info());
|
||||
return dict_setitem_node;
|
||||
}
|
||||
|
||||
|
@ -444,16 +446,18 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
|
||||
// Local parameters values.
|
||||
// Pack the key tuple.
|
||||
constexpr size_t values_input_index = 2;
|
||||
const auto script_key_tuple_str = std::make_shared<StringImm>(internal_tuple_keys_str);
|
||||
const auto make_key_tuple_node =
|
||||
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_key_tuple_str),
|
||||
NewValueNode(script_key_tuple_str), node->input(values_input_index)});
|
||||
NewValueNode(script_key_tuple_str), node->input(1)});
|
||||
make_key_tuple_node->set_debug_info(node->input(1)->debug_info());
|
||||
// Pack the value tuple.
|
||||
constexpr size_t values_input_index = 2;
|
||||
const auto script_value_tuple_str = std::make_shared<StringImm>(internal_tuple_values_str);
|
||||
const auto make_value_tuple_node =
|
||||
fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_value_tuple_str),
|
||||
NewValueNode(script_value_tuple_str), node->input(1)});
|
||||
NewValueNode(script_value_tuple_str), node->input(values_input_index)});
|
||||
make_value_tuple_node->set_debug_info(node->input(values_input_index)->debug_info());
|
||||
// Pack the local parameters values
|
||||
std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)key_value_list.emplace_back(make_key_tuple_node);
|
||||
|
@ -478,6 +482,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
const auto make_dict_node = fg->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
|
||||
make_dict_node->set_debug_info(node->debug_info());
|
||||
return make_dict_node;
|
||||
}
|
||||
|
||||
|
@ -568,7 +573,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
}
|
||||
|
||||
// dict(k0:v0, k1:v1, ...) --> PyExecute('dict(zip(keys, values))', ...)
|
||||
AnfNodePtr RebuildValueDict(const ValueDictionaryPtr &dict) const {
|
||||
AnfNodePtr RebuildValueDict(const ValueNodePtr &value_node, const ValueDictionaryPtr &dict) const {
|
||||
constexpr auto internal_tuple_keys_str = "__internal_tuple_keys__";
|
||||
constexpr auto internal_tuple_values_str = "__internal_tuple_values__";
|
||||
constexpr auto internal_dict_zip_keys_str = "__internal_dict_zip_keys__";
|
||||
|
@ -621,6 +626,7 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
const auto make_dict_node = root_graph_->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str), key_value_name_tuple, key_value_tuple});
|
||||
MS_LOG(DEBUG) << "Made dict node: " << make_dict_node->DebugString();
|
||||
make_dict_node->set_debug_info(value_node->debug_info());
|
||||
return make_dict_node;
|
||||
}
|
||||
|
||||
|
@ -646,12 +652,12 @@ class SimplifyDataStructuresRewriter : public BaseRewriter {
|
|||
return (this->*(iter->second))(cnode);
|
||||
}
|
||||
|
||||
AnfNodePtr ConvertValueNode(const ValueNodePtr &, const ValuePtr &value) override {
|
||||
AnfNodePtr ConvertValueNode(const ValueNodePtr &value_node, const ValuePtr &value) override {
|
||||
// Convert Dictionary value node.
|
||||
if (value->isa<ValueDictionary>()) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
if (support_fallback_runtime && is_dict_output_) {
|
||||
return RebuildValueDict(value->cast<ValueDictionaryPtr>());
|
||||
return RebuildValueDict(value_node, value->cast<ValueDictionaryPtr>());
|
||||
}
|
||||
return DictToTuple(value->cast<ValueDictionaryPtr>());
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,7 +35,6 @@ namespace mindspore {
|
|||
namespace trace {
|
||||
using TraceGraphEvalStack = std::deque<std::pair<abstract::AnalysisContextPtr, abstract::AnfNodeConfigPtr>>;
|
||||
using TraceCNodeEvalStack = std::vector<abstract::AnfNodeConfigPtr>;
|
||||
DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
|
||||
void TraceGraphEval();
|
||||
void GetEvalStackInfo(std::ostringstream &oss);
|
||||
void TraceGraphEvalEnter(const abstract::AnalysisContextPtr &context, const abstract::AnfNodeConfigPtr &node);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -297,7 +297,8 @@ AnfNodePtr ConvertObjectToNode(const AnfNodePtr &origin_node, const py::object &
|
|||
bool interpret_without_internal =
|
||||
(IsPrimitiveCNode(origin_node, prim::kPrimPyInterpret) && !origin_node->interpret_internal_type()) ||
|
||||
origin_node->interpret();
|
||||
if (!interpret_without_internal && convert_result->isa<InterpretedObject>()) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
if (!support_fallback_runtime && !interpret_without_internal && convert_result->isa<InterpretedObject>()) {
|
||||
auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj);
|
||||
MS_EXCEPTION(TypeError) << "Do not support to convert " << py::str(type_str) << " object into graph node."
|
||||
<< ".\nFor more details, please refer to "
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -34,15 +34,16 @@
|
|||
#include "abstract/utils.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "pipeline/jit/parse/resolve.h"
|
||||
#include "pipeline/jit/pipeline.h"
|
||||
#include "pipeline/jit/resource.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "pipeline/jit/debug/trace.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "include/common/utils/convert_utils_py.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "abstract/param_validator.h"
|
||||
#include "utils/ms_utils.h"
|
||||
|
@ -1307,6 +1308,119 @@ inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr fun
|
|||
|
||||
enum class REQUIRE_TYPE { ATTR, METHOD };
|
||||
|
||||
AnfNodePtr ConvertInterpretedNodeToCNode(const FuncGraphPtr &fg, const ValuePtr &value, const AnfNodePtr &node) {
|
||||
const auto &interpreted_value = dyn_cast<parse::InterpretedObject>(value);
|
||||
if (interpreted_value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto &value_node_value = interpreted_value->obj();
|
||||
|
||||
auto value_node_key = interpreted_value->name();
|
||||
(void)value_node_key.erase(
|
||||
std::remove_if(value_node_key.begin(), value_node_key.end(),
|
||||
[](char c) { return std::isspace(c) || c == ':' || c == '\'' || c == '<' || c == '>' || c == '.'; }),
|
||||
value_node_key.end());
|
||||
|
||||
// Set the value node into dict firstly.
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
constexpr auto set_local_variable = "set_local_variable";
|
||||
(void)python_adapter::CallPyModFn(mod, set_local_variable, value_node_key, value_node_value);
|
||||
|
||||
// Get the value node from the dict in IR.
|
||||
std::stringstream script_buffer;
|
||||
script_buffer << "__import__('mindspore')._extends.parse.get_local_variable(" << value_node_key << ")";
|
||||
const std::string &script = script_buffer.str();
|
||||
const auto script_str = std::make_shared<StringImm>(script);
|
||||
|
||||
// Build new CNode for value node.
|
||||
ValuePtrList keys({std::make_shared<StringImm>(value_node_key)});
|
||||
ValuePtrList values({std::make_shared<StringImm>(value_node_key)});
|
||||
const auto interpreted_cnode = fg->NewCNode({NewValueNode(prim::kPrimPyExecute), NewValueNode(script_str),
|
||||
NewValueNode(std::make_shared<ValueTuple>(keys)),
|
||||
NewValueNode(std::make_shared<ValueTuple>(values))});
|
||||
constexpr auto debug_recursive_level = 2;
|
||||
MS_LOG(DEBUG) << "interpreted_cnode: " << interpreted_cnode->DebugString(debug_recursive_level);
|
||||
interpreted_cnode->set_debug_info(node->debug_info());
|
||||
return interpreted_cnode;
|
||||
}
|
||||
|
||||
EvalResultPtr InterpretGetAttrNode(const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf) {
|
||||
auto out_node = out_conf->node();
|
||||
const auto cnode = dyn_cast<CNode>(out_node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto fg = cnode->func_graph();
|
||||
|
||||
const auto &debug_info = trace::GetSourceCodeDebugInfo(out_conf->node()->debug_info());
|
||||
const auto &location = debug_info->location();
|
||||
if (location == nullptr) {
|
||||
MS_LOG(WARNING) << "Location info is null, node: " << out_conf->node()->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
const auto expr = location->expr_src();
|
||||
if (expr.empty()) {
|
||||
MS_LOG(WARNING) << "Location's expr is empty, node: " << out_conf->node()->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
auto owner_abs = args_abs_list[0];
|
||||
auto owner_value = owner_abs->BuildValue();
|
||||
auto owner_node = cnode->input(1);
|
||||
constexpr auto debug_recursive_level = 2;
|
||||
MS_LOG(DEBUG) << "expr: " << expr << ", for node: " << out_conf->node()->DebugString(debug_recursive_level)
|
||||
<< ", owner_value: " << owner_value->ToString();
|
||||
if (owner_value->isa<parse::InterpretedObject>()) {
|
||||
owner_node = ConvertInterpretedNodeToCNode(fg, owner_value, owner_node);
|
||||
}
|
||||
|
||||
constexpr auto internal_getattr_owner_str = "__internal_getattr_owner__";
|
||||
std::stringstream script_buffer;
|
||||
script_buffer << internal_getattr_owner_str;
|
||||
// Check "x.xxx"
|
||||
auto pos = expr.rfind('.');
|
||||
if (pos == std::string::npos) {
|
||||
// Check "getattr(x, 'xxx')"
|
||||
constexpr auto get_attr_expr = "getattr";
|
||||
pos = expr.find(get_attr_expr);
|
||||
if (pos == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The expression is wrong: " << expr;
|
||||
}
|
||||
pos = expr.find(", ", pos);
|
||||
if (pos == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The expression is wrong: " << expr;
|
||||
}
|
||||
constexpr auto get_attr_call_input_sep_num = 3;
|
||||
pos += get_attr_call_input_sep_num;
|
||||
auto end_pos = expr.find(")", pos);
|
||||
if (end_pos == std::string::npos) {
|
||||
MS_LOG(EXCEPTION) << "The expression is wrong: " << expr;
|
||||
}
|
||||
script_buffer << "." << expr.substr(pos, end_pos - pos - 1);
|
||||
} else {
|
||||
script_buffer << expr.substr(pos);
|
||||
}
|
||||
MS_LOG(DEBUG) << "attr: " << script_buffer.str();
|
||||
|
||||
const auto script_getattr_str = std::make_shared<StringImm>(script_buffer.str());
|
||||
std::vector<ValuePtr> key_list;
|
||||
const auto owner_str = std::make_shared<StringImm>(internal_getattr_owner_str);
|
||||
(void)key_list.emplace_back(owner_str);
|
||||
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
|
||||
|
||||
std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)value_list.emplace_back(owner_node);
|
||||
const auto value_tuple_node = fg->NewCNode(value_list);
|
||||
|
||||
const auto getattr_node = fg->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_getattr_str), NewValueNode(key_tuple), value_tuple_node});
|
||||
getattr_node->set_debug_info(cnode->debug_info());
|
||||
MS_LOG(DEBUG) << "getattr_node: " << getattr_node->DebugString();
|
||||
|
||||
fg->ReplaceInOrder(cnode, getattr_node);
|
||||
auto eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
auto fn_conf = eng->MakeConfig(getattr_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
|
||||
REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
|
||||
MS_EXCEPTION_IF_NULL(old_conf);
|
||||
|
@ -1413,18 +1527,31 @@ EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_abs_
|
|||
auto data_value = data->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(data_value);
|
||||
if (!data_value->isa<parse::NameSpace>()) {
|
||||
auto item_value = item->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(item_value);
|
||||
if (data_value->isa<parse::ClassType>()) {
|
||||
auto class_val = dyn_cast_ptr<parse::ClassType>(data_value);
|
||||
const auto &class_name = class_val->name();
|
||||
auto item_value = item->BuildValue();
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "Can not get attribute '" << item_value->ToString() << "' from " << class_name
|
||||
<< " in graph mode. Try using jit_class to decorate the class? "
|
||||
<< ".\nFor more details, please refer to "
|
||||
<< "https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.jit_class.html \n";
|
||||
}
|
||||
MS_EXCEPTION(TypeError) << "Not supported to get attribute for " << data_value->ToString()
|
||||
<< "\nThe first argument should be a NameSpace, but got " << data->ToString();
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
if (!support_fallback_runtime) {
|
||||
MS_EXCEPTION(TypeError) << "Do not support to get attribute from " << data_value->ToString()
|
||||
<< "\nThe first argument should be a NameSpace, but got " << data->ToString();
|
||||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Evaluate " << data_value->ToString() << " attribute: " << item_value->ToString()
|
||||
<< ".\nnode: " << out_conf->node()->DebugString() << "\n"
|
||||
<< trace::GetDebugInfo(out_conf->node()->debug_info());
|
||||
auto res = InterpretGetAttrNode(args_abs_list, out_conf);
|
||||
if (res == nullptr) {
|
||||
MS_EXCEPTION(AttributeError) << data_value->ToString() << " object has no attribute: " << item_value->ToString();
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
auto item_value = item->BuildValue();
|
||||
|
@ -1581,46 +1708,6 @@ EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEngine
|
|||
return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
|
||||
}
|
||||
|
||||
EvalResultPtr InterpretGetAttrNode(const AnfNodeConfigPtr &out_conf) {
|
||||
auto out_node = out_conf->node();
|
||||
const auto cnode = dyn_cast<CNode>(out_node);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto fg = cnode->func_graph();
|
||||
|
||||
const auto &location = out_conf->node()->debug_info()->location();
|
||||
if (location == nullptr) {
|
||||
MS_LOG(WARNING) << "Location info is null, node: " << out_conf->node()->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
const auto expr = location->expr_src();
|
||||
if (expr.empty()) {
|
||||
MS_LOG(WARNING) << "Location's expr is empty, node: " << out_conf->node()->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "expr: " << expr << ", for node: " << out_conf->node()->DebugString();
|
||||
const auto script_getattr_str = std::make_shared<StringImm>(expr);
|
||||
std::vector<ValuePtr> key_list;
|
||||
const auto &owner_node = cnode->input(1);
|
||||
const auto owner_str = std::make_shared<StringImm>(owner_node->debug_info()->name());
|
||||
(void)key_list.emplace_back(owner_str);
|
||||
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
|
||||
|
||||
std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)value_list.emplace_back(owner_node);
|
||||
const auto value_tuple_node = fg->NewCNode(value_list);
|
||||
|
||||
const auto getattr_node = fg->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_getattr_str), NewValueNode(key_tuple), value_tuple_node});
|
||||
getattr_node->set_debug_info(cnode->debug_info());
|
||||
MS_LOG(DEBUG) << "getattr_node: " << getattr_node->DebugString();
|
||||
|
||||
fg->ReplaceInOrder(cnode, getattr_node);
|
||||
auto eng = out_conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
auto fn_conf = eng->MakeConfig(getattr_node, out_conf->context(), out_conf->func_graph());
|
||||
return eng->ForwardConfig(out_conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
|
||||
const AbstractBasePtrList &args_abs_list,
|
||||
const ConfigPtr &data_conf,
|
||||
|
@ -1654,9 +1741,9 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
|
|||
}
|
||||
|
||||
MS_LOG(DEBUG) << "Evaluate " << data_type->ToString() << " attribute: " << item_name
|
||||
<< ".\nnode:" << out_conf->node()->DebugString() << "\n"
|
||||
<< ".\nnode: " << out_conf->node()->DebugString() << "\n"
|
||||
<< trace::GetDebugInfo(out_conf->node()->debug_info());
|
||||
auto res = InterpretGetAttrNode(out_conf);
|
||||
auto res = InterpretGetAttrNode(args_abs_list, out_conf);
|
||||
if (res == nullptr) {
|
||||
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
|
||||
}
|
||||
|
@ -1744,11 +1831,13 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
|
||||
}
|
||||
|
||||
if (data_args->isa<abstract::AbstractScalar>()) {
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
if (!support_fallback_runtime && data_args->isa<abstract::AbstractScalar>()) {
|
||||
ValuePtr data_value = data_args->BuildValue();
|
||||
if (data_value->isa<parse::InterpretedObject>()) {
|
||||
auto obj = ValueToPyData(data_value);
|
||||
auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj);
|
||||
|
||||
MS_EXCEPTION(TypeError) << "Do not support to get attribute from " << py::str(type_str) << " object "
|
||||
<< py::str(obj) << ".\nFor more details, please refer to "
|
||||
<< "https://mindspore.cn/docs/zh-CN/master/faq/network_compilation.html?highlight=do"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -327,8 +327,10 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A
|
|||
// Check if the operator input is PyExecute CNode.
|
||||
auto &func_node = inputs[0];
|
||||
MS_EXCEPTION_IF_NULL(func_node);
|
||||
MS_LOG(DEBUG) << "Current CNode function: " << func_node->DebugString();
|
||||
if (!IsPrimitiveCNode(func_node, prim::kPrimGetAttr)) { // Optimize the performance.
|
||||
constexpr auto debug_recursive_level = 2;
|
||||
MS_LOG(DEBUG) << "Current CNode: " << cnode->DebugString(debug_recursive_level);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(func_node);
|
||||
if (!IsPrimitiveEquals(prim, prim::kPrimGetAttr)) { // Optimize the performance.
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodeConfigPtr func_conf = MakeConfig(func_node, conf->context(), conf->func_graph());
|
||||
|
@ -341,21 +343,37 @@ EvalResultPtr AnalysisEngine::InterpretedNodeCall(const CNodePtr &cnode, const A
|
|||
// Forward getattr CNode call to py_execute CNode.
|
||||
constexpr auto internal_getattr_callable_obj_str = "__internal_getattr_callable_obj__";
|
||||
std::stringstream script_buffer;
|
||||
script_buffer << internal_getattr_callable_obj_str << "()";
|
||||
const auto script_call_str = std::make_shared<StringImm>(script_buffer.str());
|
||||
script_buffer << internal_getattr_callable_obj_str << "(";
|
||||
|
||||
std::vector<ValuePtr> key_list;
|
||||
const auto callable_obj_name_str = std::make_shared<StringImm>(internal_getattr_callable_obj_str);
|
||||
(void)key_list.emplace_back(callable_obj_name_str);
|
||||
constexpr auto internal_getattr_callable_input_str = "__internal_getattr_callable_obj_input__";
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
std::stringstream key_input_buffer;
|
||||
key_input_buffer << internal_getattr_callable_input_str << i;
|
||||
(void)key_list.emplace_back(std::make_shared<StringImm>(key_input_buffer.str()));
|
||||
script_buffer << key_input_buffer.str();
|
||||
if (i < inputs.size() - 1) {
|
||||
script_buffer << ", ";
|
||||
}
|
||||
}
|
||||
script_buffer << ")";
|
||||
const auto key_tuple = std::make_shared<ValueTuple>(key_list);
|
||||
const auto script_call_str = std::make_shared<StringImm>(script_buffer.str());
|
||||
|
||||
std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
|
||||
(void)value_list.emplace_back(func_node);
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
const auto &input = inputs[i];
|
||||
(void)value_list.emplace_back(input);
|
||||
}
|
||||
auto fg = cnode->func_graph();
|
||||
const auto value_tuple_node = fg->NewCNode(value_list);
|
||||
|
||||
const auto getattr_obj_call_node = fg->NewCNode(
|
||||
{NewValueNode(prim::kPrimPyExecute), NewValueNode(script_call_str), NewValueNode(key_tuple), value_tuple_node});
|
||||
MS_LOG(DEBUG) << "getattr_obj_call_node: " << getattr_obj_call_node->DebugString();
|
||||
|
||||
getattr_obj_call_node->set_debug_info(cnode->debug_info());
|
||||
fg->ReplaceInOrder(cnode, getattr_obj_call_node);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -166,19 +166,23 @@ py::object PyExecuteCpuKernelMod::BuildLocalTupleParameters(const std::vector<Ad
|
|||
MS_EXCEPTION_IF_NULL(input_abstract);
|
||||
const auto &input_type = input_abstract->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
if (!tuple_input_start && input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
|
||||
const auto &value = input_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &str_value = dyn_cast<StringImm>(value);
|
||||
MS_EXCEPTION_IF_NULL(str_value);
|
||||
const auto &str = str_value->value();
|
||||
if (str != internal_tuple_keys_str && str != internal_tuple_values_str) {
|
||||
if (!tuple_input_start) {
|
||||
if (input_abstract->isa<abstract::AbstractScalar>() && input_type->isa<String>()) {
|
||||
const auto &value = input_abstract->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
const auto &str_value = dyn_cast<StringImm>(value);
|
||||
MS_EXCEPTION_IF_NULL(str_value);
|
||||
std::string str = str_value->value();
|
||||
if (str != internal_tuple_keys_str && str != internal_tuple_values_str) {
|
||||
return py::none();
|
||||
}
|
||||
tuple_key_str = str;
|
||||
tuple_input_start = true;
|
||||
MS_LOG(DEBUG) << "String, key input[" << i << "]: " << input_abstract->ToString();
|
||||
continue;
|
||||
} else {
|
||||
return py::none();
|
||||
}
|
||||
tuple_key_str = str;
|
||||
tuple_input_start = true;
|
||||
MS_LOG(DEBUG) << "String, key input[" << i << "]: " << input_abstract->ToString();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Rebuild the tuple with all left inputs.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -30,6 +30,7 @@ namespace mindspore {
|
|||
namespace trace {
|
||||
constexpr auto kSectionPrefix = " - ";
|
||||
|
||||
MS_CORE_API DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info);
|
||||
MS_CORE_API std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine);
|
||||
MS_CORE_API std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix,
|
||||
SourceLineTip tip = kSourceLineTipNextLine);
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -27,7 +27,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
|
||||
convert_to_ms_cootensor, convert_class_to_function, convert_cell_list_to_sequence, is_cell_list,
|
||||
get_obj_from_sequence, get_type, is_class_member_recursive, merge_global_params, get_global_params,
|
||||
get_adapter_tensor_attr)
|
||||
get_adapter_tensor_attr, get_local_variable, set_local_variable)
|
||||
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
|
@ -38,4 +38,5 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
|
|||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
|
||||
'convert_to_ms_cootensor', 'convert_class_to_function', 'convert_cell_list_to_sequence', 'is_cell_list',
|
||||
'get_obj_from_sequence', 'get_type', 'is_class_member_recursive', 'get_adapter_tensor_attr']
|
||||
'get_obj_from_sequence', 'get_type', 'is_class_member_recursive', 'get_adapter_tensor_attr',
|
||||
'get_local_variable', 'set_local_variable']
|
||||
|
|
|
@ -126,6 +126,7 @@ _unsupported_convert_data_type = (
|
|||
)
|
||||
|
||||
_global_params = {}
|
||||
_local_value_nodes = {}
|
||||
|
||||
|
||||
def _convert_map():
|
||||
|
@ -818,6 +819,14 @@ def get_global_params():
|
|||
return _global_params
|
||||
|
||||
|
||||
def set_local_variable(name, value):
|
||||
_local_value_nodes[name] = value
|
||||
|
||||
|
||||
def get_local_variable(name):
|
||||
return _local_value_nodes.get(name)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser python code to ast tree.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -95,6 +95,29 @@ def test_fallback_np_asnumpy():
|
|||
np.testing.assert_almost_equal(output, const_output, 3)
|
||||
|
||||
|
||||
@ms.jit
|
||||
def tensor_asnumpy():
|
||||
tensor = ms.Tensor(np.arange(0, 6).reshape(2, 3))
|
||||
res = tensor.asnumpy()
|
||||
return res
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_jit_tensor_asnumpy():
|
||||
"""
|
||||
Feature: Support JIT Fallback runtime feature.
|
||||
Description: Support JIT Fallback runtime feature.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
res = tensor_asnumpy()
|
||||
print(res)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -331,3 +354,234 @@ def test_net_dict_2():
|
|||
assert outputs['conv1'].shape == (64, 6, 28, 28)
|
||||
assert outputs['conv2'].shape == (64, 16, 10, 10)
|
||||
assert outputs['fc'].shape == (64, 10)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr_cust_class():
|
||||
"""
|
||||
Feature: getattr for custom class.
|
||||
Description: Support getattr for custom class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
class GetattrClass():
|
||||
def __init__(self):
|
||||
self.attr1 = 99
|
||||
self.attr2 = 1
|
||||
|
||||
def method1(self, x):
|
||||
return x + self.attr2
|
||||
|
||||
class GetattrClassNet(ms.nn.Cell):
|
||||
def __init__(self):
|
||||
super(GetattrClassNet, self).__init__()
|
||||
self.cls = GetattrClass()
|
||||
|
||||
def construct(self):
|
||||
return self.cls.method1(self.cls.attr1)
|
||||
|
||||
net = GetattrClassNet()
|
||||
out = net()
|
||||
print(f'out: {out}')
|
||||
assert out == 100
|
||||
|
||||
|
||||
class ClassTest:
|
||||
""" ClassTest definition """
|
||||
|
||||
def __init__(self, name, value1):
|
||||
self.name = name
|
||||
self.value = value1
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get_name(self):
|
||||
return self.name
|
||||
|
||||
def get_value(self, inc):
|
||||
ret = self.value + inc
|
||||
return ret
|
||||
|
||||
|
||||
class SelfObjectGetattrNet(ms.nn.Cell):
|
||||
""" SelfObjectGetattrNet definition """
|
||||
|
||||
def __init__(self, v1, v2):
|
||||
super(SelfObjectGetattrNet, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.softmax = nn.Softmax(0)
|
||||
self.axis = 0
|
||||
self.test_class = ClassTest("test_class", v1)
|
||||
self.value = v2
|
||||
|
||||
@ms.jit
|
||||
def construct(self, x):
|
||||
x = x + self.test_class.get_value(self.value)
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Stuck by ScopedLongRunning() invocation in forward.cc during JIT Fallback Python running.")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_call_other_object_method_runtime():
|
||||
"""
|
||||
Feature: getattr for custom class.
|
||||
Description: Support getattr for custom class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = ms.Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
|
||||
y = ms.Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
|
||||
y1 = ms.Tensor(np.array([[5, 4, 5], [1, 1, 2]]).astype(np.int32))
|
||||
z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
|
||||
|
||||
net = SelfObjectGetattrNet(y, y1)
|
||||
output = net.construct(x)
|
||||
result = output.asnumpy()
|
||||
print(result)
|
||||
assert np.all(result == z)
|
||||
|
||||
|
||||
# Test: call global object method(not self) on parse graph code
|
||||
value = ms.Tensor(np.array([[3, 4, 5], [1, 1, 2]]).astype(np.int32))
|
||||
test_class = ClassTest("test_class", value)
|
||||
|
||||
|
||||
class GlobalObjectGetattrNet(ms.nn.Cell):
|
||||
""" GlobalObjectGetattrNet definition """
|
||||
|
||||
def __init__(self, value1):
|
||||
super(GlobalObjectGetattrNet, self).__init__()
|
||||
self.value = value1
|
||||
|
||||
@ms.jit
|
||||
def construct(self, x):
|
||||
x = x + test_class.get_value(self.value)
|
||||
return x
|
||||
|
||||
@ms.jit
|
||||
def construct1(self, x):
|
||||
x = x + test_class.value
|
||||
x = x + self.value
|
||||
return x
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Stuck by ScopedLongRunning() invocation in forward.cc during JIT Fallback Python running.")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_call_no_self_other_object_method_runtime():
|
||||
"""
|
||||
Feature: getattr for custom class.
|
||||
Description: Support getattr for custom class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = ms.Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
|
||||
y = ms.Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
|
||||
z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
|
||||
|
||||
net = GlobalObjectGetattrNet(y)
|
||||
output = net.construct(x)
|
||||
result = output.asnumpy()
|
||||
print(result)
|
||||
assert np.all(result == z)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr_tensor_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms.jit
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo(Tensor([-1, -2, -3])) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr_list_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support list input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms.jit
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo([1, 2, 3, 4]) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr_tuple_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms.jit
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "shape")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo((1, 2, 3, 4)) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not supported by now")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_getattr_dict_with_wrong_attr():
|
||||
"""
|
||||
Feature: Syntax getattr.
|
||||
Description: Graph syntax getattr support tensor input.
|
||||
Expectation: AttributeError.
|
||||
"""
|
||||
|
||||
@ms.jit
|
||||
def foo(x):
|
||||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo({"1": 1, "2": 2}) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
"""test graph getattr"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
|
@ -69,9 +70,11 @@ def test_getattr_tensor_with_wrong_attr():
|
|||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo(Tensor([-1, -2, -3]))
|
||||
foo(Tensor([-1, -2, -3])) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
|
||||
def test_getattr_tensor_with_default():
|
||||
|
@ -220,9 +223,11 @@ def test_getattr_list_with_wrong_attr():
|
|||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo([1, 2, 3, 4])
|
||||
foo([1, 2, 3, 4]) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
|
||||
def test_getattr_tuple():
|
||||
|
@ -339,9 +344,11 @@ def test_getattr_tuple_with_wrong_attr():
|
|||
abs_func = getattr(x, "shape")
|
||||
return abs_func()
|
||||
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo((1, 2, 3, 4))
|
||||
foo((1, 2, 3, 4)) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
|
||||
def test_getattr_dict():
|
||||
|
@ -424,9 +431,11 @@ def test_getattr_dict_with_wrong_attr():
|
|||
abs_func = getattr(x, "abs2")
|
||||
return abs_func()
|
||||
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
with pytest.raises(AttributeError) as err:
|
||||
foo({"1": 1, "2": 2})
|
||||
foo({"1": 1, "2": 2}) # Not throw error any more, should move to ST.
|
||||
assert "object has no attribute" in str(err.value)
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
|
||||
def test_getattr_dict_with_default():
|
||||
|
@ -682,9 +691,11 @@ def test_getattr_numpy_array():
|
|||
x = np.array([1, 2, 3, 4])
|
||||
return getattr(x, "shape")[0]
|
||||
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '0'
|
||||
with pytest.raises(TypeError) as err:
|
||||
foo()
|
||||
foo() # Not throw error any more, should move to ST.
|
||||
assert "Do not support to get attribute" in str(err.value)
|
||||
os.environ['MS_DEV_ENABLE_FALLBACK_RUNTIME'] = '1'
|
||||
|
||||
|
||||
def test_getattr_numpy_array_2():
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -167,7 +167,11 @@ class Net1(nn.Cell):
|
|||
|
||||
@non_graph_engine
|
||||
def test_call_other_object_method():
|
||||
""" test_call_other_object_method """
|
||||
"""
|
||||
Feature: getattr for custom class.
|
||||
Description: Support getattr for custom class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
log.debug("begin test_call_other_object_method")
|
||||
|
||||
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
|
||||
|
@ -176,7 +180,7 @@ def test_call_other_object_method():
|
|||
z = np.array([[8, 9, 12], [3, 4, 7]]).astype(np.int32)
|
||||
|
||||
net = Net1(y, y1)
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(NotImplementedError): # NotImplementedError: PyExecute, should move to ST.
|
||||
output = net.construct(x)
|
||||
result = output.asnumpy()
|
||||
print(result)
|
||||
|
@ -211,14 +215,18 @@ class Net2(nn.Cell):
|
|||
|
||||
@non_graph_engine
|
||||
def test_call_no_self_other_object_method():
|
||||
""" test_call_no_self_other_object_method """
|
||||
"""
|
||||
Feature: getattr for custom class.
|
||||
Description: Support getattr for custom class.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
log.debug("begin test_call_other_object_method")
|
||||
x = Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.int32))
|
||||
y = Tensor(np.array([[2, 3, 4], [1, 1, 2]]).astype(np.int32))
|
||||
z = np.array([[6, 9, 12], [3, 4, 7]]).astype(np.int32)
|
||||
|
||||
net = Net2(y)
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(NotImplementedError): # NotImplementedError: PyExecute, should move to ST.
|
||||
output = net.construct(x)
|
||||
result = output.asnumpy()
|
||||
print(result)
|
||||
|
|
Loading…
Reference in New Issue