!61455 Support starred_expression in graph mode.

Merge pull request !61455 from Margaret_wangrui/starred_expression_unpack_r2.3
This commit is contained in:
i-robot 2023-11-11 02:16:15 +00:00 committed by Gitee
commit 1da59bf782
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
19 changed files with 1252 additions and 119 deletions

View File

@ -31,6 +31,8 @@ mindspore/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_dee
mindspore/mindspore/ccsrc/pipeline/jit/ps/resource.cc:mindspore::pipeline::GetMethodMap
mindspore/mindspore/ccsrc/pipeline/jit/ps/fallback.cc:mindspore::fallback::GetJitAnnotationTypeFromComment
mindspore/mindspore/ccsrc/pipeline/jit/ps/pipeline.cc:mindspore::pipeline::GraphExecutorPy::RunInner
mindspore/mindspore/ccsrc/pipeline/jit/ps/debug/anf_ir_utils.cc:mindspore::Skip
mindspore/mindspore/ccsrc/common/debug/anf_ir_dump.cc:mindspore::Skip
mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicing_shape
mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:interpolate

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -81,7 +81,8 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
meta_func_graph->isa<prim::VmapMatchOutAxis>() || meta_func_graph->isa<prim::VmapGeneralPreprocess>() ||
meta_func_graph->isa<prim::GradAux>() || meta_func_graph->isa<prim::PyExecuteGradient>() ||
meta_func_graph->isa<prim::MutableGradient>() || meta_func_graph->isa<prim::ZerosLike>() ||
meta_func_graph->isa<prim::ListAdd>();
meta_func_graph->isa<prim::ListAdd>() || meta_func_graph->isa<prim::StarredGetItem>() ||
meta_func_graph->isa<prim::StarredUnpack>() || meta_func_graph->isa<prim::StarredUnpackMerge>();
}
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {

View File

@ -593,6 +593,7 @@ FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
// Make fprop first result, PyExecute's forward result.
AnfNodePtr out = fg->NewCNodeInOrder(params);
InterpretNodeRecorder::GetInstance().PushPyExecuteNode(out);
// Make fprop second result, PyExecute's backward function.
FuncGraphPtr bprop = std::make_shared<FuncGraph>();

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,6 +31,7 @@
#include "frontend/operator/composite/do_signature.h"
#include "frontend/operator/composite/unpack_call.h"
#include "frontend/operator/composite/multitype_funcgraph.h"
#include "frontend/operator/composite/starred_operation.h"
#include "pipeline/jit/ps/static_analysis/static_analysis.h"
#include "utils/misc.h"
#include "utils/any.h"

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -25,6 +25,7 @@
#include "frontend/operator/composite/multitype_funcgraph.h"
#include "frontend/operator/composite/zip_operation.h"
#include "frontend/operator/composite/tensor_index.h"
#include "frontend/operator/composite/starred_operation.h"
namespace mindspore {
namespace prim {
void RegCompositeOpsGroup(const py::module *m) {
@ -148,6 +149,18 @@ void RegCompositeOpsGroup(const py::module *m) {
(void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m, "ZipOperation_")
.def(py::init<std::string &>());
// Reg StarredUnpack
(void)py::class_<StarredUnpack, MetaFuncGraph, std::shared_ptr<StarredUnpack>>(*m, "StarredUnpack_")
.def(py::init<std::string &>());
// Reg StarredGetItem
(void)py::class_<StarredGetItem, MetaFuncGraph, std::shared_ptr<StarredGetItem>>(*m, "StarredGetItem_")
.def(py::init<std::string &>());
// Reg StarredUnpackMerge
(void)py::class_<StarredUnpackMerge, MetaFuncGraph, std::shared_ptr<StarredUnpackMerge>>(*m, "StarredUnpackMerge_")
.def(py::init<std::string &>());
// Reg VmapGeneralPreprocess
(void)py::class_<VmapGeneralPreprocess, MetaFuncGraph, std::shared_ptr<VmapGeneralPreprocess>>(
*m, "VmapGeneralPreprocess_")

View File

@ -0,0 +1,245 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/operator/composite/starred_operation.h"
#include <algorithm>
#include <vector>
#include <utility>
#include "mindspore/core/ops/sequence_ops.h"
#include "mindspore/core/ops/array_ops.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSequence;
using mindspore::abstract::AbstractSequencePtr;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
// x = (1, 2, 3, 4)
// a, *b, c = x // targets(a, *b, c) = assign(x)
// a = 1, *b = [2, 3], c = 4
// convert:
// StarredGetItem(sequence, position_in_target, targets_num)
// *b: StarredGetItem(x, 1, 3)
// output: *b = makelist(getitem(x, 1), getitem(x, 2))
FuncGraphPtr StarredGetItem::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
// Check inputs
constexpr size_t starred_getitem_args_size = 3;
constexpr size_t sequence_index = 0;
constexpr size_t position_in_target_index = 1;
constexpr size_t targets_num_index = 2;
if (args_abs_list.size() != starred_getitem_args_size) {
MS_LOG(EXCEPTION) << "For 'StarredGetItem', the number of input should be " << starred_getitem_args_size
<< ", but got " << args_abs_list.size();
}
auto first_input__abs = args_abs_list[sequence_index];
MS_EXCEPTION_IF_NULL(first_input__abs);
if (!first_input__abs->isa<AbstractSequence>()) {
MS_LOG(EXCEPTION) << "The first input of StarredGetItem operation must be sequence, but got "
<< first_input__abs->ToString();
}
auto seq_abs = first_input__abs->cast<AbstractSequencePtr>();
auto elements = seq_abs->elements();
size_t elements_size = elements.size();
auto pos_abs = args_abs_list[position_in_target_index];
MS_EXCEPTION_IF_NULL(pos_abs);
int64_t position_in_target = GetValue<int64_t>(pos_abs->GetValueTrack());
auto targets_num_abs = args_abs_list[targets_num_index];
MS_EXCEPTION_IF_NULL(targets_num_abs);
int64_t targets_num = GetValue<int64_t>(targets_num_abs->GetValueTrack());
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
std::vector<AnfNodePtr> make_list_inputs;
make_list_inputs.push_back(NewValueNode(prim::kPrimMakeList));
int64_t list_input_num = elements_size - (targets_num - 1);
auto assign_node = ret_graph->add_parameter();
for (int64_t index = 0; index < list_input_num; ++index) {
auto get_item_prim = NewValueNode(prim::kPrimTupleGetItem);
std::vector<AnfNodePtr> get_item_inputs{get_item_prim, assign_node};
auto index_value = NewValueNode(static_cast<int64_t>(position_in_target + index));
get_item_inputs.push_back(index_value);
auto get_item = ret_graph->NewCNodeInOrder(get_item_inputs);
make_list_inputs.push_back(get_item);
}
for (size_t idx = 0; idx < args_abs_list.size() - 1; idx++) {
(void)ret_graph->add_parameter();
}
auto list_out = ret_graph->NewCNodeInOrder(make_list_inputs);
ret_graph->set_output(list_out);
return ret_graph;
}
// x = [1, 2, 3, 4]
// a = *x, // targets(a) = assign(*x,)
// a = (1, 2, 3, 4)
// convert:
// StarredUnpackMerge(StarredUnpack(sequence))
// StarredUnpackMerge(((1, 2, 3, 4), )
// StarredUnpackMerge(tuple_getitem(x, 0), ...)
FuncGraphPtr StarredUnpack::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
// Check inputs
constexpr size_t starred_unpack_args_size = 1;
constexpr size_t sequence_index = 0;
if (args_abs_list.size() != starred_unpack_args_size) {
MS_LOG(EXCEPTION) << "For 'StarredUnpack', the number of input should be " << starred_unpack_args_size
<< ", but got " << args_abs_list.size();
}
auto &unpack_arg = args_abs_list[sequence_index];
MS_EXCEPTION_IF_NULL(unpack_arg);
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
if (unpack_arg->isa<AbstractScalar>()) {
auto arg_scalar = dyn_cast_ptr<AbstractScalar>(unpack_arg);
const auto &arg_value = arg_scalar->GetValueTrack();
if (arg_value->isa<StringImm>()) {
auto str = arg_value->cast_ptr<StringImm>();
MS_EXCEPTION_IF_NULL(str);
std::string str_value = str->value();
AbstractBasePtrList ptr_list;
for (size_t index = 0; index < str_value.size(); ++index) {
std::stringstream stream;
stream << str_value[index];
string index_str = stream.str();
auto index_abs = std::make_shared<AbstractScalar>(static_cast<std::string>(index_str));
ptr_list.push_back(index_abs);
}
auto tuple_abs = std::make_shared<abstract::AbstractTuple>(ptr_list);
auto unpack_node = ret_graph->add_parameter();
unpack_node->set_abstract(tuple_abs);
ret_graph->set_output(unpack_node);
return ret_graph;
}
} else if (unpack_arg->isa<AbstractSequence>()) {
auto seq = args_abs_list[0]->cast<AbstractSequencePtr>();
const auto &elements = seq->elements();
auto tuple_abs = std::make_shared<abstract::AbstractTuple>(elements);
auto unpack_node = ret_graph->add_parameter();
unpack_node->set_abstract(tuple_abs);
ret_graph->set_output(unpack_node);
return ret_graph;
} else if (unpack_arg->isa<AbstractTensor>()) {
auto input = ret_graph->add_parameter();
auto prim = prim::kPrimUnstack;
auto unstack_node = ret_graph->NewCNodeInOrder({NewValueNode(prim), input});
prim->set_attr(kAttrAxis, MakeValue(static_cast<int64_t>(0)));
ret_graph->set_output(unstack_node);
return ret_graph;
}
MS_LOG(INTERNAL_EXCEPTION) << "The object is not iterable, " << unpack_arg->ToString();
}
std::pair<std::vector<int64_t>, int64_t> StarredUnpackMerge::GetStarredUnpackMergeFlags(
const AbstractBasePtrList &args_abs_list) {
constexpr size_t args_size = 3;
constexpr size_t flags_num = 2;
size_t starred_flags_index = args_abs_list.size() - 2;
size_t is_tuple_index = args_abs_list.size() - 1;
if (args_abs_list.size() < args_size) {
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the number of input should be " << args_size
<< " at least, but got " << args_abs_list.size();
}
if (!args_abs_list[starred_flags_index]->isa<AbstractSequence>()) {
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the starred_flags input should be sequence, but got "
<< args_abs_list[starred_flags_index]->ToString();
}
if (!args_abs_list[is_tuple_index]->isa<AbstractScalar>()) {
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the is_tuple input should be scalar, but got "
<< args_abs_list[is_tuple_index]->ToString();
}
auto abs_seq = args_abs_list[starred_flags_index]->cast<AbstractSequencePtr>();
const auto &elements = abs_seq->elements();
std::vector<int64_t> starred_flags(elements.size(), 0);
for (size_t index = 0; index < elements.size(); ++index) {
auto ele = elements[index];
auto ele_value = ele->GetValueTrack();
auto val = GetValue<int64_t>(ele_value);
starred_flags[index] = val;
}
int64_t is_tuple = GetValue<int64_t>(args_abs_list[is_tuple_index]->GetValueTrack());
size_t sequence_input_size = args_abs_list.size() - flags_num;
if (sequence_input_size != elements.size()) {
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the input is wrong, please check.";
}
return {starred_flags, is_tuple};
}
// a = *[1, 2], (3, 4)
// convert:
// StarredUnpackMerge(assign_node1, assign_node2, starred_flags_node, is_tuple)
// StarredUnpackMerge(StarredUnpack(*[1, 2]), (3, 4), (1, 0), 1) --> (1, 2, (3, 4))
// a: (1, 2, (3, 4))
FuncGraphPtr StarredUnpackMerge::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
// Check inputs, and get flags info.
auto [starred_flags, is_tuple] = GetStarredUnpackMergeFlags(args_abs_list);
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
std::vector<AnfNodePtr> new_inputs;
if (is_tuple == 1) {
new_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
} else if (is_tuple == 0) {
new_inputs.push_back(NewValueNode(prim::kPrimMakeList));
}
constexpr size_t unpack_flags_num = 2;
for (size_t index = 0; index < args_abs_list.size() - unpack_flags_num; ++index) {
auto &unpack_arg = args_abs_list[index];
MS_EXCEPTION_IF_NULL(unpack_arg);
int64_t is_starred = starred_flags[index];
auto input = ret_graph->add_parameter();
if (!is_starred) {
new_inputs.push_back(input);
} else {
// starred must be sequence.
if (!unpack_arg->isa<AbstractSequence>()) {
MS_LOG(EXCEPTION) << "The starred unpack merge input must be sequence, but got " << unpack_arg->ToString();
}
auto unpack_abs_seq = unpack_arg->cast<AbstractSequencePtr>();
const auto &elements = unpack_abs_seq->elements();
size_t unpack_size = elements.size();
for (size_t ele_index = 0; ele_index < unpack_size; ++ele_index) {
std::vector<AnfNodePtr> get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), input};
get_item_inputs.push_back(NewValueNode(static_cast<int64_t>(ele_index)));
auto get_item = ret_graph->NewCNodeInOrder(get_item_inputs);
new_inputs.push_back(get_item);
}
}
}
for (size_t index = 0; index < unpack_flags_num; ++index) {
(void)ret_graph->add_parameter();
}
auto new_node = ret_graph->NewCNodeInOrder(new_inputs);
ret_graph->set_output(new_node);
return ret_graph;
}
} // namespace prim
} // namespace mindspore

View File

@ -0,0 +1,90 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_STARRED_OPERATION_H_
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_STARRED_OPERATION_H_
#include <string>
#include <map>
#include <set>
#include <memory>
#include <utility>
#include <vector>
#include "utils/hash_map.h"
#include "pipeline/jit/ps/static_analysis/static_analysis.h"
#include "utils/misc.h"
#include "utils/any.h"
#include "ir/dtype.h"
#include "ir/meta_func_graph.h"
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
using AbstractBasePtr = abstract::AbstractBasePtr;
using AbstractBasePtrList = abstract::AbstractBasePtrList;
using AbstractTuplePtr = abstract::AbstractTuplePtr;
class StarredGetItem : public MetaFuncGraph {
public:
explicit StarredGetItem(const std::string &name) : MetaFuncGraph(name) {}
~StarredGetItem() override = default;
MS_DECLARE_PARENT(StarredGetItem, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
friend std::ostream &operator<<(std::ostream &os, const StarredGetItem &op) {
os << op.name_;
return os;
}
friend bool operator==(const StarredGetItem &lhs, const StarredGetItem &rhs) { return lhs.name_ == rhs.name_; }
};
using StarredGetItemPtr = std::shared_ptr<StarredGetItem>;
class StarredUnpack : public MetaFuncGraph {
public:
explicit StarredUnpack(const std::string &name) : MetaFuncGraph(name) {}
~StarredUnpack() override = default;
MS_DECLARE_PARENT(StarredUnpack, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
friend std::ostream &operator<<(std::ostream &os, const StarredUnpack &op) {
os << op.name_;
return os;
}
friend bool operator==(const StarredUnpack &lhs, const StarredUnpack &rhs) { return lhs.name_ == rhs.name_; }
};
using StarredUnpackPtr = std::shared_ptr<StarredUnpack>;
class StarredUnpackMerge : public MetaFuncGraph {
public:
explicit StarredUnpackMerge(const std::string &name) : MetaFuncGraph(name) {}
~StarredUnpackMerge() override = default;
MS_DECLARE_PARENT(StarredUnpackMerge, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
friend std::ostream &operator<<(std::ostream &os, const StarredUnpackMerge &op) {
os << op.name_;
return os;
}
friend bool operator==(const StarredUnpackMerge &lhs, const StarredUnpackMerge &rhs) {
return lhs.name_ == rhs.name_;
}
std::pair<std::vector<int64_t>, int64_t> GetStarredUnpackMergeFlags(const AbstractBasePtrList &args_abs_list);
};
using StarredUnpackMergePtr = std::shared_ptr<StarredUnpackMerge>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_STARRED_OPERATION_H_

View File

@ -103,7 +103,10 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
// 'node' is setattr node.
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
if (!allow_fallback_runtime) {
MS_LOG(EXCEPTION) << "Not support setattr during JIT Fallback disabled.";
MS_LOG(EXCEPTION) << "Not support setattr during JIT Fallback disabled. You can use"
" os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' \n"
<< "to enable the JIT lax mode to support the current syntax.\n"
<< trace::GetDebugInfoStr(target_node->debug_info());
}
return parse::ResolveInterpretedObjectOfSetAttr(target_node, attr_node, assigned_node);
}

View File

@ -227,7 +227,8 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
meta_func_graph->isa<prim::VmapMatchOutAxis>() || meta_func_graph->isa<prim::VmapGeneralPreprocess>() ||
meta_func_graph->isa<prim::GradAux>() || meta_func_graph->isa<prim::PyExecuteGradient>() ||
meta_func_graph->isa<prim::MutableGradient>() || meta_func_graph->isa<prim::ZerosLike>() ||
meta_func_graph->isa<prim::ListAdd>();
meta_func_graph->isa<prim::ListAdd>() || meta_func_graph->isa<prim::StarredGetItem>() ||
meta_func_graph->isa<prim::StarredUnpack>() || meta_func_graph->isa<prim::StarredUnpackMerge>();
}
/* inherit relation of MetaFuncGraph

View File

@ -246,6 +246,7 @@ CNodePtr CreatePyInterpretCNodeInOrder(const FuncGraphPtr &fg, const std::string
auto node =
fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
node->set_debug_info(debug_info);
InterpretNodeRecorder::GetInstance().PushPyInterpretNode(node);
return node;
}
@ -311,6 +312,7 @@ AnfNodePtr ConvertPyObjectToPyInterpret(const FuncGraphPtr &fg, const std::strin
auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_tuple, local_value_tuple});
auto prim = NewValueNode(prim::kPrimPyInterpret);
auto interpret_node = fg->NewCNode({prim, script_node, global_dict_node, local_dict_node});
InterpretNodeRecorder::GetInstance().PushPyInterpretNode(interpret_node);
if (replace) {
fg->ReplaceInOrder(node, interpret_node);
}

View File

@ -128,6 +128,7 @@ void Parser::BuildMethodMap() {
expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'.
expr_method_map_["JoinedStr"] = &Parser::ParseJoinedStr;
expr_method_map_["FormattedValue"] = &Parser::ParseFormattedValue;
expr_method_map_["Starred"] = &Parser::ParseStarred;
condition_method_map_["Attribute"] = &Parser::CheckAttributeConstantCond;
condition_method_map_["Name"] = &Parser::CheckNameConstantCond;
condition_method_map_["UnaryOp"] = &Parser::CheckUnaryOpConstantCond;
@ -2315,50 +2316,97 @@ AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &
return const_graph;
}
// Process a tuple
AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Tuple";
MS_EXCEPTION_IF_NULL(block);
// a = *[1, 2], (3, 4)
// StarredUnpackMerge(assign_node1, assign_node2, starred_flags_node, is_tuple)
// StarredUnpackMerge(StarredUnpack(*[1, 2]), (3, 4), (1, 0), 1)
// --> StarredUnpackMerge((1, 2), (3, 4), (1, 0), 1)
// --> (1, 2, (3, 4))
AnfNodePtr Parser::ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple,
const std::vector<AnfNodePtr> &starred_flags) {
auto prim = std::make_shared<prim::StarredUnpackMerge>(NAMED_METAGRAPH_STARRED_UNPACK_MERGE);
std::vector<AnfNodePtr> unpack_merge_inputs{NewValueNode(prim)};
auto starred_flags_node = block->func_graph()->NewCNodeInOrder(starred_flags);
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
if (elts.empty()) {
auto empty_tuple = std::vector<ValuePtr>();
return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
}
std::vector<AnfNodePtr> tuple_vec;
AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
(void)tuple_vec.emplace_back(make_tuple_op);
for (size_t i = 0; i < elts.size(); i++) {
AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
node_ptr = HandleInterpret(block, node_ptr, elts[i]);
(void)tuple_vec.emplace_back(node_ptr);
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
if (elt_type == AST_SUB_TYPE_STARRED) {
auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
CNodePtr unpack_node = block->func_graph()->NewCNodeInOrder({NewValueNode(starred_unpack_prim), node_ptr});
(void)unpack_merge_inputs.emplace_back(unpack_node);
} else {
(void)unpack_merge_inputs.emplace_back(node_ptr);
}
}
MS_EXCEPTION_IF_NULL(block->func_graph());
CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(std::move(tuple_vec));
return tuple_app;
(void)unpack_merge_inputs.emplace_back(starred_flags_node);
if (is_tuple) {
auto is_tuple_node = NewValueNode(static_cast<int64_t>(1));
(void)unpack_merge_inputs.emplace_back(is_tuple_node);
} else {
auto is_tuple_node = NewValueNode(static_cast<int64_t>(0));
(void)unpack_merge_inputs.emplace_back(is_tuple_node);
}
CNodePtr unpack_merge_node = block->func_graph()->NewCNodeInOrder(unpack_merge_inputs);
return unpack_merge_node;
}
AnfNodePtr Parser::ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple) {
MS_EXCEPTION_IF_NULL(block);
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
if (elts.empty()) {
if (is_tuple) {
auto empty_tuple = std::vector<ValuePtr>();
return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
}
auto empty_list = std::vector<ValuePtr>();
return NewValueNode(std::make_shared<ValueList>(empty_list));
}
bool exist_starred_expression = false;
std::vector<AnfNodePtr> starred_flags{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < elts.size(); i++) {
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
if (elt_type == AST_SUB_TYPE_STARRED) {
exist_starred_expression = true;
starred_flags.push_back(NewValueNode(static_cast<int64_t>(1)));
} else {
starred_flags.push_back(NewValueNode(static_cast<int64_t>(0)));
}
}
if (!exist_starred_expression) {
std::vector<AnfNodePtr> sequence_vec;
AnfNodePtr sequence_op = nullptr;
if (is_tuple) {
sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
} else {
sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
}
(void)sequence_vec.emplace_back(sequence_op);
for (size_t i = 0; i < elts.size(); i++) {
AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
node_ptr = HandleInterpret(block, node_ptr, elts[i]);
(void)sequence_vec.emplace_back(node_ptr);
}
MS_EXCEPTION_IF_NULL(block->func_graph());
CNodePtr sequence_app = block->func_graph()->NewCNodeInOrder(std::move(sequence_vec));
return sequence_app;
}
return ParseTupleOrListWithStarred(block, node, is_tuple, starred_flags);
}
// Process a tuple
AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Tuple";
return ParseTupleOrList(block, node, true);
}
// Process a list
AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast List";
MS_EXCEPTION_IF_NULL(block);
py::list elts = python_adapter::GetPyObjAttr(node, "elts");
if (elts.empty()) {
auto empty_list = std::vector<ValuePtr>();
return NewValueNode(std::make_shared<ValueList>(empty_list));
}
std::vector<AnfNodePtr> list_vec;
AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
(void)list_vec.emplace_back(make_list_op);
for (size_t i = 0; i < elts.size(); i++) {
AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
node_ptr = HandleInterpret(block, node_ptr, elts[i]);
(void)list_vec.emplace_back(node_ptr);
}
MS_EXCEPTION_IF_NULL(block->func_graph());
CNodePtr list_app = block->func_graph()->NewCNodeInOrder(std::move(list_vec));
return list_app;
return ParseTupleOrList(block, node, false);
}
// Process a subscript, such as x[y] , node expressed as value[slice]
@ -2471,23 +2519,47 @@ AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const
return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
}
AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Dict";
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> Parser::GetRealKeysValues(const FunctionBlockPtr &block,
const py::object &node) {
py::list keys = node.attr("keys");
py::list values = node.attr("values");
if (keys.size() != values.size()) {
MS_LOG(INTERNAL_EXCEPTION) << "The keys' size is not equal to the values' size.";
}
std::vector<AnfNodePtr> key_nodes;
std::vector<AnfNodePtr> value_nodes;
for (size_t i = 0; i < keys.size(); i++) {
AnfNodePtr key_node = ParseExprNode(block, keys[i]);
key_node = HandleInterpret(block, key_node, keys[i]);
key_nodes.push_back(key_node);
AnfNodePtr value_node = ParseExprNode(block, values[i]);
value_node = HandleInterpret(block, value_node, values[i]);
value_nodes.push_back(value_node);
std::vector<AnfNodePtr> inner_key_nodes;
std::vector<AnfNodePtr> inner_value_nodes;
for (size_t index = 0; index < keys.size(); ++index) {
auto inner_key_node_type = ast_->GetNodeType(keys[index]);
const std::string &inner_key_node_type_name = inner_key_node_type->node_name();
// The key does not exist, mean the value is a dict which need unpack.
if (inner_key_node_type_name == "NoneType") {
auto unpack_dict = values[index];
auto inner_value_node_type =
AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, unpack_dict)));
if (inner_value_node_type != AST_SUB_TYPE_DICT) {
MS_LOG(INTERNAL_EXCEPTION) << "The input of dict which need unpack must be dict, but got "
<< inner_value_node_type;
}
auto [unpack_keys, unpack_values] = GetRealKeysValues(block, unpack_dict);
for (size_t i = 0; i < unpack_keys.size(); ++i) {
inner_key_nodes.push_back(unpack_keys[i]);
inner_value_nodes.push_back(unpack_values[i]);
}
} else {
AnfNodePtr key_node = ParseExprNode(block, keys[index]);
key_node = HandleInterpret(block, key_node, keys[index]);
inner_key_nodes.push_back(key_node);
AnfNodePtr value_node = ParseExprNode(block, values[index]);
value_node = HandleInterpret(block, value_node, values[index]);
inner_value_nodes.push_back(value_node);
}
}
return {inner_key_nodes, inner_value_nodes};
}
AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Dict";
auto [key_nodes, value_nodes] = GetRealKeysValues(block, node);
return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
}
@ -3045,6 +3117,7 @@ CNodePtr GenerateInterpretGetItem(const FuncGraphPtr &fg, const AnfNodePtr &iter
auto prim = NewValueNode(prim::kPrimPyInterpret);
auto interpret_get_item = fg->NewCNodeInOrder({prim, script_node, empty_global_dict_node, local_dict_node});
interpret_get_item->set_debug_info(iter_node->debug_info());
InterpretNodeRecorder::GetInstance().PushPyInterpretNode(interpret_get_item);
return interpret_get_item;
}
@ -3733,6 +3806,34 @@ AnfNodePtr Parser::ParseFormattedValue(const FunctionBlockPtr &block, const py::
return value_node;
}
AnfNodePtr Parser::ParseStarred(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Starred.";
TraceGuard trace_guard(GetLocation(node));
MS_EXCEPTION_IF_NULL(block);
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
AnfNodePtr value_node = ParseExprNode(block, value_object);
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
auto func = block->func_graph();
MS_EXCEPTION_IF_NULL(func);
CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, value_node});
auto prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(prim), iterated_node});
return unpack_node;
}
void Parser::HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target,
const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block);
MS_EXCEPTION_IF_NULL(assigned_node);
py::object value_object = python_adapter::GetPyObjAttr(target, "value");
py::str name = python_adapter::GetPyObjAttr(value_object, "id");
std::string name_id = name;
MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
assigned_node->debug_info()->set_name(name_id);
MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
block->WriteVariable(name_id, assigned_node);
}
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target,
const AnfNodePtr &assigned_node) const {
MS_EXCEPTION_IF_NULL(block);
@ -3755,21 +3856,85 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
block->WriteVariable(name_id, assigned_node);
}
void Parser::HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target,
const AnfNodePtr &assigned_node,
const std::vector<int64_t> &positions) {
// Process assigned_node
auto func = block->func_graph();
MS_EXCEPTION_IF_NULL(func);
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, assigned_node});
auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(starred_unpack_prim), iterated_node});
py::list items = python_adapter::GetPyObjAttr(target, "elts");
for (size_t i = 0; i < items.size(); i++) {
py::object elt = items[i];
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
if (elt_type != AST_SUB_TYPE_STARRED) {
std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
ValuePtr op = prim::GetPythonOps("getitem", module_name);
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(op), unpack_node, NewValueNode(positions[i])};
AnfNodePtr tuple_get_item = func->NewCNodeInOrder(tuple_get_item_inputs);
MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << tuple_get_item->DebugString();
WriteAssignVars(block, elt, tuple_get_item);
} else {
auto starred_get_item_prim = std::make_shared<prim::StarredGetItem>(NAMED_METAGRAPH_STARRED_GET_ITEM);
std::vector<AnfNodePtr> starred_get_item_inputs{NewValueNode(starred_get_item_prim), unpack_node,
NewValueNode(positions[i]),
NewValueNode(SizeToLong(items.size()))};
AnfNodePtr starred_get_item = func->NewCNodeInOrder(starred_get_item_inputs);
MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << starred_get_item->DebugString();
WriteAssignVars(block, elt, starred_get_item);
}
}
}
void Parser::HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target,
const AnfNodePtr &assigned_node) {
MS_EXCEPTION_IF_NULL(block);
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
py::list items = python_adapter::GetPyObjAttr(target, "elts");
for (size_t i = 0; i < items.size(); i++) {
// Use the Primitive replace the operation resolve node (getitem),
// because the getitem will eventually be converted to Primitive node
MS_EXCEPTION_IF_NULL(block->func_graph());
CNodePtr item_apply =
block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
// Record the position with targets.
size_t target_starred_num = 0;
size_t starred_pos = items.size();
std::vector<int64_t> positions(items.size(), 0);
for (size_t i = 0; i < items.size(); i++) {
py::object elt = items[i];
WriteAssignVars(block, elt, item_apply);
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
if (elt_type == AST_SUB_TYPE_STARRED) {
target_starred_num++;
if (target_starred_num > 1) {
MS_LOG(EXCEPTION) << "SyntaxError: " << target_starred_num << " starred expressions in assignment.";
}
starred_pos = i;
positions[i] = i;
} else {
if (i > starred_pos) {
positions[i] = i - items.size();
} else {
positions[i] = i;
}
}
}
auto func = block->func_graph();
MS_EXCEPTION_IF_NULL(func);
if (target_starred_num == 0) {
for (size_t i = 0; i < items.size(); i++) {
// Use the Primitive replace the operation resolve node (getitem),
// because the getitem will eventually be converted to Primitive node
CNodePtr item_apply = func->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
py::object elt = items[i];
WriteAssignVars(block, elt, item_apply);
}
return;
}
// Process AssignTuple with starred expression.
// a, *b, c = x // targets(a, *b, c) = assign(x)
HandleAssignTupleWithStarredExpression(block, target, assigned_node, positions);
}
bool Parser::IsClassParameterMember(const py::object &target_obj, const AnfNodePtr &target_node) const {
@ -3981,6 +4146,8 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
HandleAssignClassMember(block, target_object, value_node);
} else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
HandleAssignClassMember(block, target_object, value_node);
} else if (ast_type == AST_SUB_TYPE_STARRED) {
HandleAssignStarred(block, target_object, value_node);
} else {
TraceGuard trace_guard(GetLocation(target_object));
MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
@ -4196,10 +4363,11 @@ bool Parser::IsPopOperation(const AnfNodePtr &node) const {
FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
MS_LOG(DEBUG) << "Process ast assign";
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
AnfNodePtr value_node = ParseExprNode(block, value_object);
value_node = HandleInterpret(block, value_node, value_object);
py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
size_t count = LongToSize(pcount);
MS_LOG(DEBUG) << "The nodes count is " << count;

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -213,22 +213,31 @@ class Parser {
AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
// Process a tuple.
AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
// Process a tuple.
// Process a list.
AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
// Process a tuple.
// Process a tuple or list.
AnfNodePtr ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple);
// Process a tuple or list with starred expression.
AnfNodePtr ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple,
const std::vector<AnfNodePtr> &starred_flags);
// Process a subscript.
AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
// Process a slice.
AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
// Process a extslice.
AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
// Process a tuple.
// Process a index.
AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
// Process a unaryop.
AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
// Process a dict ast node expression.
AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
const std::vector<AnfNodePtr> &value_nodes);
// Process a dict.
AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> GetRealKeysValues(const FunctionBlockPtr &block,
const py::object &node);
// Process DictComp expression.
AnfNodePtr ParseDictComp(const FunctionBlockPtr &block, const py::object &node);
FunctionBlockPtr ParseDictCompIter(const FunctionBlockPtr &block, const py::object &node,
@ -243,6 +252,7 @@ class Parser {
const py::object &node, const py::object &generator_node);
AnfNodePtr ParseJoinedStr(const FunctionBlockPtr &block, const py::object &node);
AnfNodePtr ParseFormattedValue(const FunctionBlockPtr &block, const py::object &node);
AnfNodePtr ParseStarred(const FunctionBlockPtr &block, const py::object &node);
std::vector<AnfNodePtr> HandleException(const FunctionBlockPtr &block, const py::list &args, const std::string &name);
std::vector<AnfNodePtr> ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node);
void HandleStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes);
@ -327,10 +337,17 @@ class Parser {
// Assign value to single variable name.
void HandleAssignName(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node) const;
// Assign value to starred expression.
void HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node);
// Assign value to tuple.
void HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target,
const AnfNodePtr &assigned_node);
// Assign value to tuple with starred expression.
void HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target,
const AnfNodePtr &assigned_node, const std::vector<int64_t> &positions);
// Assign value to class Parameter member. Return false if not a Parameter member.
bool HandleAssignClassParameterMember(const FunctionBlockPtr &block, const py::object &target,
const AnfNodePtr &value_node);

View File

@ -44,6 +44,7 @@ enum AstSubType : int64_t {
AST_SUB_TYPE_SUBSCRIPT = 8, // ast.Subscript
AST_SUB_TYPE_STARRED = 9, // ast.Starred
AST_SUB_TYPE_ATTRIBUTE = 10, // ast.Attribute
AST_SUB_TYPE_DICT = 11, // ast.Dict
AST_SUB_TYPE_UNKNOWN = 0xFF // Unknown type
};
@ -170,6 +171,9 @@ const char NAMED_PRIMITIVE_MAKELIST[] = "make_list";
const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice";
const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict";
const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call";
const char NAMED_METAGRAPH_STARRED_UNPACK[] = "starred_unpack";
const char NAMED_METAGRAPH_STARRED_GET_ITEM[] = "starred_get_item";
const char NAMED_METAGRAPH_STARRED_UNPACK_MERGE[] = "starred_unpack_merge";
// Define NAMED_PRIMITIVE_GETATTR "getattr".
// Define python inline attr.

View File

@ -1925,7 +1925,10 @@ EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_abs_
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
if (!allow_fallback_runtime) {
MS_EXCEPTION(TypeError) << "Do not support to get attribute from " << data_value->ToString()
<< "\nThe first argument should be a NameSpace, but got " << data->ToString();
<< " in JIT strict mode. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' \"\n"
<< " to enable the JIT lax mode to support the current syntax."
<< "\nThe first argument should be a NameSpace, but got " << data->ToString()
<< trace::GetDebugInfoStr(out_conf->node()->debug_info());
}
auto item_value = item->BuildValue();
@ -2061,7 +2064,11 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
if (!has_default) {
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
if (!allow_fallback_runtime) {
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
MS_EXCEPTION(AttributeError) << "In JIT strict mode, cannot get attributes " << item_name << " or the "
<< data_type->ToString() << " object has no attribute: " << item_name
<< "'. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' "
<< "to enable the JIT lax mode to support the current syntax.\n\n"
<< trace::GetDebugInfoStr(out_conf->node()->debug_info());
}
constexpr auto recursive_level = 3;
@ -2201,7 +2208,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
MS_EXCEPTION(TypeError) << "Do not support to get attribute from " << py::str(type_str) << " object "
<< py::str(obj) << ".\nFor more details, please refer to "
<< "https://mindspore.cn/docs/zh-CN/master/faq/network_compilation.html?highlight=do"
<< "%20support%20get%20attribute%20from";
<< "%20support%20get%20attribute%20from\n"
<< "You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' \n"
<< "to enable the JIT lax mode to support the current syntax.\n"
<< trace::GetDebugInfoStr(out_conf->node()->debug_info());
}
}

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -56,14 +56,28 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
}
AbstractBasePtrList key_list = keys->elements();
for (size_t index = 0; index < keys_size; index++) {
const auto &key = key_list[index];
CheckDictKey(key, op_name);
}
std::unordered_map<std::string, AbstractBasePtr> key_str_value_set;
std::vector<AbstractBasePtr> key_set;
std::vector<AbstractElementPair> key_value;
AbstractBasePtrList value_list = values->elements();
for (size_t index = 0; index < keys_size; index++) {
(void)key_value.emplace_back(key_list[index], value_list[index]);
const auto &key = key_list[index];
CheckDictKey(key, op_name);
auto key_val = key->BuildValue()->ToString();
auto iter = key_str_value_set.find(key_val);
// Remove duplicate keys.
// {Tensor[1]: x, Tensor[1}: y} the length of dict is 2, means the two keys are not duplicate.
if (iter != key_str_value_set.end() && !key->isa<AbstractTensor>()) {
iter->second = value_list[index];
} else {
auto key_str = key->BuildValue()->ToString();
key_str_value_set.insert(std::pair<std::string, AbstractBasePtr>(key_str, value_list[index]));
key_set.push_back(key);
}
}
for (auto &key : key_set) {
auto key_str = key->BuildValue()->ToString();
(void)key_value.emplace_back(key, key_str_value_set[key_str]);
}
return std::make_shared<AbstractDictionary>(key_value);
}

View File

@ -86,6 +86,7 @@ AST_SUB_TYPE_LIST = 7 # ast.List
AST_SUB_TYPE_SUBSCRIPT = 8 # ast.Subscript
AST_SUB_TYPE_STARRED = 9 # ast.Starred
AST_SUB_TYPE_ATTRIBUTE = 10 # ast.Attribute
AST_SUB_TYPE_DICT = 11 # ast.Dict
AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
# Syntax support
@ -692,6 +693,8 @@ def get_ast_type(node):
ast_type = AST_SUB_TYPE_STARRED
elif isinstance(node, ast.Attribute):
ast_type = AST_SUB_TYPE_ATTRIBUTE
elif isinstance(node, ast.Dict):
ast_type = AST_SUB_TYPE_DICT
else:
ast_type = AST_SUB_TYPE_UNKNOWN
return ast_type

View File

@ -30,7 +30,8 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_, StarredGetItem_,\
StarredUnpack_, StarredUnpackMerge_
from mindspore.common import dtype as mstype
from mindspore.common.api import jit, _pynative_executor, _wrap_func
from mindspore.common.api import _add_flags, _core
@ -1195,3 +1196,48 @@ class _ZipOperation(ZipOperation_):
zip_operation = _ZipOperation('zip_operation')
"""`zip_operation` will generate a tuple of zip iterations of inputs."""
class _StarredGetItem(StarredGetItem_):
"""Generates a list of starred get_item for inputs."""
def __init__(self, name):
"""Initialize _StarredGetItem."""
StarredGetItem_.__init__(self, name)
def __call__(self, *args):
pass
starred_get_item = _StarredGetItem('starred_get_item')
"""`starred_get_item` will generate a list of starred get_item for inputs."""
class _StarredUnpack(StarredUnpack_):
"""Generates a tuple of starred unpack for inputs."""
def __init__(self, name):
"""Initialize _StarredUnpack."""
StarredUnpack_.__init__(self, name)
def __call__(self, *args):
pass
starred_unpack = _StarredUnpack('starred_unpack')
"""`starred_unpack` will generate a tuple of starred unpack for inputs."""
class _StarredUnpackMerge(StarredUnpackMerge_):
"""Generates a tuple of starred unpack merge for inputs."""
def __init__(self, name):
"""Initialize _StarredUnpackMerge."""
StarredUnpackMerge_.__init__(self, name)
def __call__(self, *args):
pass
starred_unpack_merge = _StarredUnpackMerge('starred_unpack_merge')
"""`starred_unpack_merge` will generate a tuple of starred unpack merge for inputs."""

View File

@ -239,31 +239,6 @@ def test_compress_with_mutable_input():
assert (foo(Tensor([1])) == [1, 3, 4]).all()
@pytest.mark.skip(reason="not support now")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_star_to_compress_input():
"""
Feature: Support JIT Fallback runtime feature.
Description: use star to compress assigned input.
Expectation: No exception.
"""
@jit
def foo():
x = [1, 2, 3, 4]
a, *b = x
return a, b
ret = foo()
assert len(ret) == 2
assert ret[0] == 1
assert ret[1] == [2, 3, 4]
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@ -312,27 +287,6 @@ def test_unpack_interpret_node():
assert ret == 10
@pytest.mark.skip(reason="not support now")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_starred_to_unpack_input():
"""
Feature: Support JIT Fallback runtime feature.
Description: * operator can not unpack a list.
Expectation: No exception.
"""
@jit
def foo(x):
return f"output is {*a,}"
ret = foo([1, 2, 3, 4])
assert ret == "output is (1, 2, 3, 4)"
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training

View File

@ -0,0 +1,558 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test graph starred expression. """
import pytest
from mindspore import context, jit, Tensor
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_assign_list_input():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = [1, 2, 3, 4]
a = *x, # pylint: disable=R1707
return a
ret = foo()
assert ret == (1, 2, 3, 4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_assign_list_input():
"""
Feature: Support assign list.
Description: Support assign in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = [1, 2, 3, 4]
a = x, # pylint: disable=R1707
return a
ret = foo()
assert ret == ([1, 2, 3, 4],)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_assign_tuple_input():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = (1, 2, 3, 4)
a = *x, # pylint: disable=R1707
return a
ret = foo()
assert ret == (1, 2, 3, 4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_assign_dict_input():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = {"a": 1, "b": 2}
out = *x, # pylint: disable=R1707
return out
ret = foo()
assert ret == ("a", "b")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_assign_string_input():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = "abcde"
out = *x, # pylint: disable=R1707
return out
ret = foo()
assert ret == ('a', 'b', 'c', 'd', 'e')
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_assign_tensor_input():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo(x):
out = *x, # pylint: disable=R1707
return out
ret = foo(Tensor([1, 2, 3, 4]))
assert ret == (Tensor(1), Tensor(2), Tensor(3), Tensor(4))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_in_format_string():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo(x):
return f"output is {*x,}"
ret = foo([1, 2, 3, 4])
assert ret == "output is (1, 2, 3, 4)"
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_tuple():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = (1, 2, 3, 4)
*b, = x
return b
ret = foo()
assert ret == [1, 2, 3, 4]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_list():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = [1, 2, 3, 4]
*b, = x
return b
ret = foo()
assert ret == [1, 2, 3, 4]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_dict():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = {"a": 1, "b": 2}
*b, = x
return b
ret = foo()
assert ret == ['a', 'b']
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_string():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = "abcde"
*b, = x
return b
ret = foo()
assert ret == ['a', 'b', 'c', 'd', 'e']
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_tensor():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo(x):
*b, c = x
return b, c
ret = foo(Tensor([1, 2, 3, 4]))
assert ret[0] == [Tensor(1), Tensor(2), Tensor(3)]
assert ret[1] == Tensor(4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_nested_tuple():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = (1, 2, 3, 4)
y = (5, 6)
a, b, *c = x, y
return a, b, c
ret = foo()
assert ret == ((1, 2, 3, 4), (5, 6), [])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_list_2():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = [1, 2, 3, 4]
a, *b = x
return a, b
ret = foo()
assert len(ret) == 2
assert ret[0] == 1
assert ret[1] == [2, 3, 4]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_tuple_2():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = (1, 2, 3, 4)
a, *b, c = x
return a, b, c
ret = foo()
assert len(ret) == 3
assert ret[0] == 1
assert ret[1] == [2, 3]
assert ret[2] == 4
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_target_list_tuple():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
x = [1, 2, 3, 4]
y = [5, 6]
z = (7, 8)
a, *b = x, y, z
return a, b
ret = foo()
assert len(ret) == 2
assert ret[0] == [1, 2, 3, 4]
assert ret[1] == [[5, 6], (7, 8)]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_with_range():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a, *b, c = range(5)
return a, b, c
ret = foo()
assert ret == (0, [1, 2, 3], 4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_for_in():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
ret = []
for _, *b in [(1, 2, 3), (4, 5, 6, 7)]:
ret.append(b)
return ret
ret = foo()
assert ret == [[2, 3], [5, 6, 7]]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_assign_tuple():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = *[1, 2], *(3, 4), (5, 6)
return a
ret = foo()
assert ret == (1, 2, 3, 4, (5, 6))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_range_tuple():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = *range(4), 4
return a
ret = foo()
assert ret == (0, 1, 2, 3, 4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_range_list():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = [*range(4), 4]
return a
ret = foo()
assert ret == [0, 1, 2, 3, 4]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_return_tuple():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = (*[1], *[2], 3)
return a
ret = foo()
assert ret == (1, 2, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_dict():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = {'x': 1, **{'y': 2}}
return a
ret = foo()
assert len(ret) == 2
assert ret == {'x': 1, 'y': 2}
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_dict_2():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = {'x': 1, **{'y': 2}, "w": 4, **{'z': 3}}
return a
ret = foo()
assert ret == {'x': 1, 'y': 2, 'w': 4, 'z': 3}
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_dict_3():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = {'x': 1, **{'y': 2, 'z': 3, **{'w': 4}}}
return a
ret = foo()
assert ret == {'x': 1, 'y': 2, 'z': 3, 'w': 4}
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_dict_4():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = {'x': 1, 'y': {'z': 3, **{'w': 4}}}
return a
ret = foo()
assert ret == {'x': 1, 'y': {'z': 3, 'w': 4}}
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_starred_expression_dict_key_deduplicate():
"""
Feature: Support starred expression.
Description: Support starred expression in graph mode.
Expectation: No exception.
"""
@jit
def foo():
a = {'x': 1, **{'x': 2}}
return a
ret = foo()
assert ret == {'x': 2}