!61455 Support starred_expression in graph mode.
Merge pull request !61455 from Margaret_wangrui/starred_expression_unpack_r2.3
This commit is contained in:
commit
1da59bf782
|
@ -31,6 +31,8 @@ mindspore/model_zoo/official/recommend/wide_and_deep_multitable/src/wide_and_dee
|
|||
mindspore/mindspore/ccsrc/pipeline/jit/ps/resource.cc:mindspore::pipeline::GetMethodMap
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/ps/fallback.cc:mindspore::fallback::GetJitAnnotationTypeFromComment
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/ps/pipeline.cc:mindspore::pipeline::GraphExecutorPy::RunInner
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/ps/debug/anf_ir_utils.cc:mindspore::Skip
|
||||
mindspore/mindspore/ccsrc/common/debug/anf_ir_dump.cc:mindspore::Skip
|
||||
mindspore/mindspore/python/mindspore/ops/operations/array_ops.py:_compute_slicing_shape
|
||||
mindspore/mindspore/python/mindspore/ops/function/array_func.py:scatter_nd
|
||||
mindspore/mindspore/python/mindspore/ops/function/nn_func.py:interpolate
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -81,7 +81,8 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
|||
meta_func_graph->isa<prim::VmapMatchOutAxis>() || meta_func_graph->isa<prim::VmapGeneralPreprocess>() ||
|
||||
meta_func_graph->isa<prim::GradAux>() || meta_func_graph->isa<prim::PyExecuteGradient>() ||
|
||||
meta_func_graph->isa<prim::MutableGradient>() || meta_func_graph->isa<prim::ZerosLike>() ||
|
||||
meta_func_graph->isa<prim::ListAdd>();
|
||||
meta_func_graph->isa<prim::ListAdd>() || meta_func_graph->isa<prim::StarredGetItem>() ||
|
||||
meta_func_graph->isa<prim::StarredUnpack>() || meta_func_graph->isa<prim::StarredUnpackMerge>();
|
||||
}
|
||||
|
||||
std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
|
||||
|
|
|
@ -593,6 +593,7 @@ FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
|
|||
|
||||
// Make fprop first result, PyExecute's forward result.
|
||||
AnfNodePtr out = fg->NewCNodeInOrder(params);
|
||||
InterpretNodeRecorder::GetInstance().PushPyExecuteNode(out);
|
||||
|
||||
// Make fprop second result, PyExecute's backward function.
|
||||
FuncGraphPtr bprop = std::make_shared<FuncGraph>();
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -31,6 +31,7 @@
|
|||
#include "frontend/operator/composite/do_signature.h"
|
||||
#include "frontend/operator/composite/unpack_call.h"
|
||||
#include "frontend/operator/composite/multitype_funcgraph.h"
|
||||
#include "frontend/operator/composite/starred_operation.h"
|
||||
#include "pipeline/jit/ps/static_analysis/static_analysis.h"
|
||||
#include "utils/misc.h"
|
||||
#include "utils/any.h"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -25,6 +25,7 @@
|
|||
#include "frontend/operator/composite/multitype_funcgraph.h"
|
||||
#include "frontend/operator/composite/zip_operation.h"
|
||||
#include "frontend/operator/composite/tensor_index.h"
|
||||
#include "frontend/operator/composite/starred_operation.h"
|
||||
namespace mindspore {
|
||||
namespace prim {
|
||||
void RegCompositeOpsGroup(const py::module *m) {
|
||||
|
@ -148,6 +149,18 @@ void RegCompositeOpsGroup(const py::module *m) {
|
|||
(void)py::class_<ZipOperation, MetaFuncGraph, std::shared_ptr<ZipOperation>>(*m, "ZipOperation_")
|
||||
.def(py::init<std::string &>());
|
||||
|
||||
// Reg StarredUnpack
|
||||
(void)py::class_<StarredUnpack, MetaFuncGraph, std::shared_ptr<StarredUnpack>>(*m, "StarredUnpack_")
|
||||
.def(py::init<std::string &>());
|
||||
|
||||
// Reg StarredGetItem
|
||||
(void)py::class_<StarredGetItem, MetaFuncGraph, std::shared_ptr<StarredGetItem>>(*m, "StarredGetItem_")
|
||||
.def(py::init<std::string &>());
|
||||
|
||||
// Reg StarredUnpackMerge
|
||||
(void)py::class_<StarredUnpackMerge, MetaFuncGraph, std::shared_ptr<StarredUnpackMerge>>(*m, "StarredUnpackMerge_")
|
||||
.def(py::init<std::string &>());
|
||||
|
||||
// Reg VmapGeneralPreprocess
|
||||
(void)py::class_<VmapGeneralPreprocess, MetaFuncGraph, std::shared_ptr<VmapGeneralPreprocess>>(
|
||||
*m, "VmapGeneralPreprocess_")
|
||||
|
|
|
@ -0,0 +1,245 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "frontend/operator/composite/starred_operation.h"
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
#include "mindspore/core/ops/sequence_ops.h"
|
||||
#include "mindspore/core/ops/array_ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractSequence;
|
||||
using mindspore::abstract::AbstractSequencePtr;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
|
||||
// x = (1, 2, 3, 4)
|
||||
// a, *b, c = x // targets(a, *b, c) = assign(x)
|
||||
// a = 1, *b = [2, 3], c = 4
|
||||
// convert:
|
||||
// StarredGetItem(sequence, position_in_target, targets_num)
|
||||
// *b: StarredGetItem(x, 1, 3)
|
||||
// output: *b = makelist(getitem(x, 1), getitem(x, 2))
|
||||
FuncGraphPtr StarredGetItem::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
|
||||
// Check inputs
|
||||
constexpr size_t starred_getitem_args_size = 3;
|
||||
constexpr size_t sequence_index = 0;
|
||||
constexpr size_t position_in_target_index = 1;
|
||||
constexpr size_t targets_num_index = 2;
|
||||
|
||||
if (args_abs_list.size() != starred_getitem_args_size) {
|
||||
MS_LOG(EXCEPTION) << "For 'StarredGetItem', the number of input should be " << starred_getitem_args_size
|
||||
<< ", but got " << args_abs_list.size();
|
||||
}
|
||||
|
||||
auto first_input__abs = args_abs_list[sequence_index];
|
||||
MS_EXCEPTION_IF_NULL(first_input__abs);
|
||||
if (!first_input__abs->isa<AbstractSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "The first input of StarredGetItem operation must be sequence, but got "
|
||||
<< first_input__abs->ToString();
|
||||
}
|
||||
auto seq_abs = first_input__abs->cast<AbstractSequencePtr>();
|
||||
auto elements = seq_abs->elements();
|
||||
size_t elements_size = elements.size();
|
||||
|
||||
auto pos_abs = args_abs_list[position_in_target_index];
|
||||
MS_EXCEPTION_IF_NULL(pos_abs);
|
||||
int64_t position_in_target = GetValue<int64_t>(pos_abs->GetValueTrack());
|
||||
|
||||
auto targets_num_abs = args_abs_list[targets_num_index];
|
||||
MS_EXCEPTION_IF_NULL(targets_num_abs);
|
||||
int64_t targets_num = GetValue<int64_t>(targets_num_abs->GetValueTrack());
|
||||
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
std::vector<AnfNodePtr> make_list_inputs;
|
||||
make_list_inputs.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
int64_t list_input_num = elements_size - (targets_num - 1);
|
||||
auto assign_node = ret_graph->add_parameter();
|
||||
|
||||
for (int64_t index = 0; index < list_input_num; ++index) {
|
||||
auto get_item_prim = NewValueNode(prim::kPrimTupleGetItem);
|
||||
std::vector<AnfNodePtr> get_item_inputs{get_item_prim, assign_node};
|
||||
auto index_value = NewValueNode(static_cast<int64_t>(position_in_target + index));
|
||||
get_item_inputs.push_back(index_value);
|
||||
auto get_item = ret_graph->NewCNodeInOrder(get_item_inputs);
|
||||
make_list_inputs.push_back(get_item);
|
||||
}
|
||||
|
||||
for (size_t idx = 0; idx < args_abs_list.size() - 1; idx++) {
|
||||
(void)ret_graph->add_parameter();
|
||||
}
|
||||
|
||||
auto list_out = ret_graph->NewCNodeInOrder(make_list_inputs);
|
||||
ret_graph->set_output(list_out);
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
// x = [1, 2, 3, 4]
|
||||
// a = *x, // targets(a) = assign(*x,)
|
||||
// a = (1, 2, 3, 4)
|
||||
// convert:
|
||||
// StarredUnpackMerge(StarredUnpack(sequence))
|
||||
// StarredUnpackMerge(((1, 2, 3, 4), )
|
||||
// StarredUnpackMerge(tuple_getitem(x, 0), ...)
|
||||
FuncGraphPtr StarredUnpack::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
|
||||
// Check inputs
|
||||
constexpr size_t starred_unpack_args_size = 1;
|
||||
constexpr size_t sequence_index = 0;
|
||||
if (args_abs_list.size() != starred_unpack_args_size) {
|
||||
MS_LOG(EXCEPTION) << "For 'StarredUnpack', the number of input should be " << starred_unpack_args_size
|
||||
<< ", but got " << args_abs_list.size();
|
||||
}
|
||||
auto &unpack_arg = args_abs_list[sequence_index];
|
||||
MS_EXCEPTION_IF_NULL(unpack_arg);
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
if (unpack_arg->isa<AbstractScalar>()) {
|
||||
auto arg_scalar = dyn_cast_ptr<AbstractScalar>(unpack_arg);
|
||||
const auto &arg_value = arg_scalar->GetValueTrack();
|
||||
if (arg_value->isa<StringImm>()) {
|
||||
auto str = arg_value->cast_ptr<StringImm>();
|
||||
MS_EXCEPTION_IF_NULL(str);
|
||||
std::string str_value = str->value();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t index = 0; index < str_value.size(); ++index) {
|
||||
std::stringstream stream;
|
||||
stream << str_value[index];
|
||||
string index_str = stream.str();
|
||||
auto index_abs = std::make_shared<AbstractScalar>(static_cast<std::string>(index_str));
|
||||
ptr_list.push_back(index_abs);
|
||||
}
|
||||
auto tuple_abs = std::make_shared<abstract::AbstractTuple>(ptr_list);
|
||||
auto unpack_node = ret_graph->add_parameter();
|
||||
unpack_node->set_abstract(tuple_abs);
|
||||
ret_graph->set_output(unpack_node);
|
||||
return ret_graph;
|
||||
}
|
||||
} else if (unpack_arg->isa<AbstractSequence>()) {
|
||||
auto seq = args_abs_list[0]->cast<AbstractSequencePtr>();
|
||||
const auto &elements = seq->elements();
|
||||
auto tuple_abs = std::make_shared<abstract::AbstractTuple>(elements);
|
||||
auto unpack_node = ret_graph->add_parameter();
|
||||
unpack_node->set_abstract(tuple_abs);
|
||||
ret_graph->set_output(unpack_node);
|
||||
return ret_graph;
|
||||
} else if (unpack_arg->isa<AbstractTensor>()) {
|
||||
auto input = ret_graph->add_parameter();
|
||||
auto prim = prim::kPrimUnstack;
|
||||
auto unstack_node = ret_graph->NewCNodeInOrder({NewValueNode(prim), input});
|
||||
prim->set_attr(kAttrAxis, MakeValue(static_cast<int64_t>(0)));
|
||||
ret_graph->set_output(unstack_node);
|
||||
return ret_graph;
|
||||
}
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The object is not iterable, " << unpack_arg->ToString();
|
||||
}
|
||||
|
||||
std::pair<std::vector<int64_t>, int64_t> StarredUnpackMerge::GetStarredUnpackMergeFlags(
|
||||
const AbstractBasePtrList &args_abs_list) {
|
||||
constexpr size_t args_size = 3;
|
||||
constexpr size_t flags_num = 2;
|
||||
size_t starred_flags_index = args_abs_list.size() - 2;
|
||||
size_t is_tuple_index = args_abs_list.size() - 1;
|
||||
|
||||
if (args_abs_list.size() < args_size) {
|
||||
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the number of input should be " << args_size
|
||||
<< " at least, but got " << args_abs_list.size();
|
||||
}
|
||||
if (!args_abs_list[starred_flags_index]->isa<AbstractSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the starred_flags input should be sequence, but got "
|
||||
<< args_abs_list[starred_flags_index]->ToString();
|
||||
}
|
||||
if (!args_abs_list[is_tuple_index]->isa<AbstractScalar>()) {
|
||||
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the is_tuple input should be scalar, but got "
|
||||
<< args_abs_list[is_tuple_index]->ToString();
|
||||
}
|
||||
|
||||
auto abs_seq = args_abs_list[starred_flags_index]->cast<AbstractSequencePtr>();
|
||||
const auto &elements = abs_seq->elements();
|
||||
std::vector<int64_t> starred_flags(elements.size(), 0);
|
||||
for (size_t index = 0; index < elements.size(); ++index) {
|
||||
auto ele = elements[index];
|
||||
auto ele_value = ele->GetValueTrack();
|
||||
auto val = GetValue<int64_t>(ele_value);
|
||||
starred_flags[index] = val;
|
||||
}
|
||||
int64_t is_tuple = GetValue<int64_t>(args_abs_list[is_tuple_index]->GetValueTrack());
|
||||
size_t sequence_input_size = args_abs_list.size() - flags_num;
|
||||
if (sequence_input_size != elements.size()) {
|
||||
MS_LOG(EXCEPTION) << "For 'StarredUnpackMerge', the input is wrong, please check.";
|
||||
}
|
||||
return {starred_flags, is_tuple};
|
||||
}
|
||||
|
||||
// a = *[1, 2], (3, 4)
|
||||
// convert:
|
||||
// StarredUnpackMerge(assign_node1, assign_node2, starred_flags_node, is_tuple)
|
||||
// StarredUnpackMerge(StarredUnpack(*[1, 2]), (3, 4), (1, 0), 1) --> (1, 2, (3, 4))
|
||||
// a: (1, 2, (3, 4))
|
||||
FuncGraphPtr StarredUnpackMerge::GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) {
|
||||
// Check inputs, and get flags info.
|
||||
auto [starred_flags, is_tuple] = GetStarredUnpackMergeFlags(args_abs_list);
|
||||
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
std::vector<AnfNodePtr> new_inputs;
|
||||
if (is_tuple == 1) {
|
||||
new_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
} else if (is_tuple == 0) {
|
||||
new_inputs.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
}
|
||||
constexpr size_t unpack_flags_num = 2;
|
||||
for (size_t index = 0; index < args_abs_list.size() - unpack_flags_num; ++index) {
|
||||
auto &unpack_arg = args_abs_list[index];
|
||||
MS_EXCEPTION_IF_NULL(unpack_arg);
|
||||
int64_t is_starred = starred_flags[index];
|
||||
auto input = ret_graph->add_parameter();
|
||||
if (!is_starred) {
|
||||
new_inputs.push_back(input);
|
||||
} else {
|
||||
// starred must be sequence.
|
||||
if (!unpack_arg->isa<AbstractSequence>()) {
|
||||
MS_LOG(EXCEPTION) << "The starred unpack merge input must be sequence, but got " << unpack_arg->ToString();
|
||||
}
|
||||
auto unpack_abs_seq = unpack_arg->cast<AbstractSequencePtr>();
|
||||
const auto &elements = unpack_abs_seq->elements();
|
||||
size_t unpack_size = elements.size();
|
||||
for (size_t ele_index = 0; ele_index < unpack_size; ++ele_index) {
|
||||
std::vector<AnfNodePtr> get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), input};
|
||||
get_item_inputs.push_back(NewValueNode(static_cast<int64_t>(ele_index)));
|
||||
auto get_item = ret_graph->NewCNodeInOrder(get_item_inputs);
|
||||
new_inputs.push_back(get_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (size_t index = 0; index < unpack_flags_num; ++index) {
|
||||
(void)ret_graph->add_parameter();
|
||||
}
|
||||
auto new_node = ret_graph->NewCNodeInOrder(new_inputs);
|
||||
ret_graph->set_output(new_node);
|
||||
return ret_graph;
|
||||
}
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,90 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_STARRED_OPERATION_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_STARRED_OPERATION_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "pipeline/jit/ps/static_analysis/static_analysis.h"
|
||||
#include "utils/misc.h"
|
||||
#include "utils/any.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/meta_func_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
using AbstractBasePtr = abstract::AbstractBasePtr;
|
||||
using AbstractBasePtrList = abstract::AbstractBasePtrList;
|
||||
using AbstractTuplePtr = abstract::AbstractTuplePtr;
|
||||
|
||||
class StarredGetItem : public MetaFuncGraph {
|
||||
public:
|
||||
explicit StarredGetItem(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~StarredGetItem() override = default;
|
||||
MS_DECLARE_PARENT(StarredGetItem, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const StarredGetItem &op) {
|
||||
os << op.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const StarredGetItem &lhs, const StarredGetItem &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using StarredGetItemPtr = std::shared_ptr<StarredGetItem>;
|
||||
|
||||
class StarredUnpack : public MetaFuncGraph {
|
||||
public:
|
||||
explicit StarredUnpack(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~StarredUnpack() override = default;
|
||||
MS_DECLARE_PARENT(StarredUnpack, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const StarredUnpack &op) {
|
||||
os << op.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const StarredUnpack &lhs, const StarredUnpack &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using StarredUnpackPtr = std::shared_ptr<StarredUnpack>;
|
||||
|
||||
class StarredUnpackMerge : public MetaFuncGraph {
|
||||
public:
|
||||
explicit StarredUnpackMerge(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~StarredUnpackMerge() override = default;
|
||||
MS_DECLARE_PARENT(StarredUnpackMerge, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const StarredUnpackMerge &op) {
|
||||
os << op.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const StarredUnpackMerge &lhs, const StarredUnpackMerge &rhs) {
|
||||
return lhs.name_ == rhs.name_;
|
||||
}
|
||||
std::pair<std::vector<int64_t>, int64_t> GetStarredUnpackMergeFlags(const AbstractBasePtrList &args_abs_list);
|
||||
};
|
||||
using StarredUnpackMergePtr = std::shared_ptr<StarredUnpackMerge>;
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_STARRED_OPERATION_H_
|
|
@ -103,7 +103,10 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr
|
|||
// 'node' is setattr node.
|
||||
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
|
||||
if (!allow_fallback_runtime) {
|
||||
MS_LOG(EXCEPTION) << "Not support setattr during JIT Fallback disabled.";
|
||||
MS_LOG(EXCEPTION) << "Not support setattr during JIT Fallback disabled. You can use"
|
||||
" os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' \n"
|
||||
<< "to enable the JIT lax mode to support the current syntax.\n"
|
||||
<< trace::GetDebugInfoStr(target_node->debug_info());
|
||||
}
|
||||
return parse::ResolveInterpretedObjectOfSetAttr(target_node, attr_node, assigned_node);
|
||||
}
|
||||
|
|
|
@ -227,7 +227,8 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
|||
meta_func_graph->isa<prim::VmapMatchOutAxis>() || meta_func_graph->isa<prim::VmapGeneralPreprocess>() ||
|
||||
meta_func_graph->isa<prim::GradAux>() || meta_func_graph->isa<prim::PyExecuteGradient>() ||
|
||||
meta_func_graph->isa<prim::MutableGradient>() || meta_func_graph->isa<prim::ZerosLike>() ||
|
||||
meta_func_graph->isa<prim::ListAdd>();
|
||||
meta_func_graph->isa<prim::ListAdd>() || meta_func_graph->isa<prim::StarredGetItem>() ||
|
||||
meta_func_graph->isa<prim::StarredUnpack>() || meta_func_graph->isa<prim::StarredUnpackMerge>();
|
||||
}
|
||||
|
||||
/* inherit relation of MetaFuncGraph
|
||||
|
|
|
@ -246,6 +246,7 @@ CNodePtr CreatePyInterpretCNodeInOrder(const FuncGraphPtr &fg, const std::string
|
|||
auto node =
|
||||
fg->NewCNodeInOrder({NewValueNode(prim::kPrimPyInterpret), script_node, global_dict_node, local_dict_node});
|
||||
node->set_debug_info(debug_info);
|
||||
InterpretNodeRecorder::GetInstance().PushPyInterpretNode(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
|
@ -311,6 +312,7 @@ AnfNodePtr ConvertPyObjectToPyInterpret(const FuncGraphPtr &fg, const std::strin
|
|||
auto local_dict_node = fg->NewCNode({NewValueNode(prim::kPrimMakeDict), local_key_tuple, local_value_tuple});
|
||||
auto prim = NewValueNode(prim::kPrimPyInterpret);
|
||||
auto interpret_node = fg->NewCNode({prim, script_node, global_dict_node, local_dict_node});
|
||||
InterpretNodeRecorder::GetInstance().PushPyInterpretNode(interpret_node);
|
||||
if (replace) {
|
||||
fg->ReplaceInOrder(node, interpret_node);
|
||||
}
|
||||
|
|
|
@ -128,6 +128,7 @@ void Parser::BuildMethodMap() {
|
|||
expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'.
|
||||
expr_method_map_["JoinedStr"] = &Parser::ParseJoinedStr;
|
||||
expr_method_map_["FormattedValue"] = &Parser::ParseFormattedValue;
|
||||
expr_method_map_["Starred"] = &Parser::ParseStarred;
|
||||
condition_method_map_["Attribute"] = &Parser::CheckAttributeConstantCond;
|
||||
condition_method_map_["Name"] = &Parser::CheckNameConstantCond;
|
||||
condition_method_map_["UnaryOp"] = &Parser::CheckUnaryOpConstantCond;
|
||||
|
@ -2315,50 +2316,97 @@ AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &
|
|||
return const_graph;
|
||||
}
|
||||
|
||||
// Process a tuple
|
||||
AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Tuple";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
// a = *[1, 2], (3, 4)
|
||||
// StarredUnpackMerge(assign_node1, assign_node2, starred_flags_node, is_tuple)
|
||||
// StarredUnpackMerge(StarredUnpack(*[1, 2]), (3, 4), (1, 0), 1)
|
||||
// --> StarredUnpackMerge((1, 2), (3, 4), (1, 0), 1)
|
||||
// --> (1, 2, (3, 4))
|
||||
AnfNodePtr Parser::ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple,
|
||||
const std::vector<AnfNodePtr> &starred_flags) {
|
||||
auto prim = std::make_shared<prim::StarredUnpackMerge>(NAMED_METAGRAPH_STARRED_UNPACK_MERGE);
|
||||
std::vector<AnfNodePtr> unpack_merge_inputs{NewValueNode(prim)};
|
||||
auto starred_flags_node = block->func_graph()->NewCNodeInOrder(starred_flags);
|
||||
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
|
||||
if (elts.empty()) {
|
||||
auto empty_tuple = std::vector<ValuePtr>();
|
||||
return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> tuple_vec;
|
||||
AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
|
||||
(void)tuple_vec.emplace_back(make_tuple_op);
|
||||
for (size_t i = 0; i < elts.size(); i++) {
|
||||
AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
|
||||
node_ptr = HandleInterpret(block, node_ptr, elts[i]);
|
||||
(void)tuple_vec.emplace_back(node_ptr);
|
||||
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
|
||||
if (elt_type == AST_SUB_TYPE_STARRED) {
|
||||
auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
|
||||
CNodePtr unpack_node = block->func_graph()->NewCNodeInOrder({NewValueNode(starred_unpack_prim), node_ptr});
|
||||
(void)unpack_merge_inputs.emplace_back(unpack_node);
|
||||
} else {
|
||||
(void)unpack_merge_inputs.emplace_back(node_ptr);
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(std::move(tuple_vec));
|
||||
return tuple_app;
|
||||
(void)unpack_merge_inputs.emplace_back(starred_flags_node);
|
||||
if (is_tuple) {
|
||||
auto is_tuple_node = NewValueNode(static_cast<int64_t>(1));
|
||||
(void)unpack_merge_inputs.emplace_back(is_tuple_node);
|
||||
} else {
|
||||
auto is_tuple_node = NewValueNode(static_cast<int64_t>(0));
|
||||
(void)unpack_merge_inputs.emplace_back(is_tuple_node);
|
||||
}
|
||||
|
||||
CNodePtr unpack_merge_node = block->func_graph()->NewCNodeInOrder(unpack_merge_inputs);
|
||||
return unpack_merge_node;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
|
||||
if (elts.empty()) {
|
||||
if (is_tuple) {
|
||||
auto empty_tuple = std::vector<ValuePtr>();
|
||||
return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
|
||||
}
|
||||
auto empty_list = std::vector<ValuePtr>();
|
||||
return NewValueNode(std::make_shared<ValueList>(empty_list));
|
||||
}
|
||||
|
||||
bool exist_starred_expression = false;
|
||||
std::vector<AnfNodePtr> starred_flags{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t i = 0; i < elts.size(); i++) {
|
||||
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
|
||||
if (elt_type == AST_SUB_TYPE_STARRED) {
|
||||
exist_starred_expression = true;
|
||||
starred_flags.push_back(NewValueNode(static_cast<int64_t>(1)));
|
||||
} else {
|
||||
starred_flags.push_back(NewValueNode(static_cast<int64_t>(0)));
|
||||
}
|
||||
}
|
||||
|
||||
if (!exist_starred_expression) {
|
||||
std::vector<AnfNodePtr> sequence_vec;
|
||||
AnfNodePtr sequence_op = nullptr;
|
||||
if (is_tuple) {
|
||||
sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
|
||||
} else {
|
||||
sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
|
||||
}
|
||||
(void)sequence_vec.emplace_back(sequence_op);
|
||||
for (size_t i = 0; i < elts.size(); i++) {
|
||||
AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
|
||||
node_ptr = HandleInterpret(block, node_ptr, elts[i]);
|
||||
(void)sequence_vec.emplace_back(node_ptr);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
CNodePtr sequence_app = block->func_graph()->NewCNodeInOrder(std::move(sequence_vec));
|
||||
return sequence_app;
|
||||
}
|
||||
return ParseTupleOrListWithStarred(block, node, is_tuple, starred_flags);
|
||||
}
|
||||
|
||||
// Process a tuple
|
||||
AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Tuple";
|
||||
return ParseTupleOrList(block, node, true);
|
||||
}
|
||||
|
||||
// Process a list
|
||||
AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast List";
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::list elts = python_adapter::GetPyObjAttr(node, "elts");
|
||||
if (elts.empty()) {
|
||||
auto empty_list = std::vector<ValuePtr>();
|
||||
return NewValueNode(std::make_shared<ValueList>(empty_list));
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> list_vec;
|
||||
AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
|
||||
(void)list_vec.emplace_back(make_list_op);
|
||||
for (size_t i = 0; i < elts.size(); i++) {
|
||||
AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
|
||||
node_ptr = HandleInterpret(block, node_ptr, elts[i]);
|
||||
(void)list_vec.emplace_back(node_ptr);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
CNodePtr list_app = block->func_graph()->NewCNodeInOrder(std::move(list_vec));
|
||||
return list_app;
|
||||
return ParseTupleOrList(block, node, false);
|
||||
}
|
||||
|
||||
// Process a subscript, such as x[y] , node expressed as value[slice]
|
||||
|
@ -2471,23 +2519,47 @@ AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const
|
|||
return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Dict";
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> Parser::GetRealKeysValues(const FunctionBlockPtr &block,
|
||||
const py::object &node) {
|
||||
py::list keys = node.attr("keys");
|
||||
py::list values = node.attr("values");
|
||||
if (keys.size() != values.size()) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The keys' size is not equal to the values' size.";
|
||||
}
|
||||
std::vector<AnfNodePtr> key_nodes;
|
||||
std::vector<AnfNodePtr> value_nodes;
|
||||
for (size_t i = 0; i < keys.size(); i++) {
|
||||
AnfNodePtr key_node = ParseExprNode(block, keys[i]);
|
||||
key_node = HandleInterpret(block, key_node, keys[i]);
|
||||
key_nodes.push_back(key_node);
|
||||
AnfNodePtr value_node = ParseExprNode(block, values[i]);
|
||||
value_node = HandleInterpret(block, value_node, values[i]);
|
||||
value_nodes.push_back(value_node);
|
||||
std::vector<AnfNodePtr> inner_key_nodes;
|
||||
std::vector<AnfNodePtr> inner_value_nodes;
|
||||
for (size_t index = 0; index < keys.size(); ++index) {
|
||||
auto inner_key_node_type = ast_->GetNodeType(keys[index]);
|
||||
const std::string &inner_key_node_type_name = inner_key_node_type->node_name();
|
||||
// The key does not exist, mean the value is a dict which need unpack.
|
||||
if (inner_key_node_type_name == "NoneType") {
|
||||
auto unpack_dict = values[index];
|
||||
auto inner_value_node_type =
|
||||
AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, unpack_dict)));
|
||||
if (inner_value_node_type != AST_SUB_TYPE_DICT) {
|
||||
MS_LOG(INTERNAL_EXCEPTION) << "The input of dict which need unpack must be dict, but got "
|
||||
<< inner_value_node_type;
|
||||
}
|
||||
auto [unpack_keys, unpack_values] = GetRealKeysValues(block, unpack_dict);
|
||||
for (size_t i = 0; i < unpack_keys.size(); ++i) {
|
||||
inner_key_nodes.push_back(unpack_keys[i]);
|
||||
inner_value_nodes.push_back(unpack_values[i]);
|
||||
}
|
||||
} else {
|
||||
AnfNodePtr key_node = ParseExprNode(block, keys[index]);
|
||||
key_node = HandleInterpret(block, key_node, keys[index]);
|
||||
inner_key_nodes.push_back(key_node);
|
||||
AnfNodePtr value_node = ParseExprNode(block, values[index]);
|
||||
value_node = HandleInterpret(block, value_node, values[index]);
|
||||
inner_value_nodes.push_back(value_node);
|
||||
}
|
||||
}
|
||||
return {inner_key_nodes, inner_value_nodes};
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Dict";
|
||||
auto [key_nodes, value_nodes] = GetRealKeysValues(block, node);
|
||||
return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
|
||||
}
|
||||
|
||||
|
@ -3045,6 +3117,7 @@ CNodePtr GenerateInterpretGetItem(const FuncGraphPtr &fg, const AnfNodePtr &iter
|
|||
auto prim = NewValueNode(prim::kPrimPyInterpret);
|
||||
auto interpret_get_item = fg->NewCNodeInOrder({prim, script_node, empty_global_dict_node, local_dict_node});
|
||||
interpret_get_item->set_debug_info(iter_node->debug_info());
|
||||
InterpretNodeRecorder::GetInstance().PushPyInterpretNode(interpret_get_item);
|
||||
return interpret_get_item;
|
||||
}
|
||||
|
||||
|
@ -3733,6 +3806,34 @@ AnfNodePtr Parser::ParseFormattedValue(const FunctionBlockPtr &block, const py::
|
|||
return value_node;
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseStarred(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Starred.";
|
||||
TraceGuard trace_guard(GetLocation(node));
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
|
||||
AnfNodePtr value_node = ParseExprNode(block, value_object);
|
||||
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
|
||||
auto func = block->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, value_node});
|
||||
auto prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
|
||||
CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(prim), iterated_node});
|
||||
return unpack_node;
|
||||
}
|
||||
|
||||
void Parser::HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
MS_EXCEPTION_IF_NULL(assigned_node);
|
||||
py::object value_object = python_adapter::GetPyObjAttr(target, "value");
|
||||
py::str name = python_adapter::GetPyObjAttr(value_object, "id");
|
||||
std::string name_id = name;
|
||||
MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
|
||||
assigned_node->debug_info()->set_name(name_id);
|
||||
MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
|
||||
block->WriteVariable(name_id, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target,
|
||||
const AnfNodePtr &assigned_node) const {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
|
@ -3755,21 +3856,85 @@ void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &t
|
|||
block->WriteVariable(name_id, assigned_node);
|
||||
}
|
||||
|
||||
void Parser::HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target,
|
||||
const AnfNodePtr &assigned_node,
|
||||
const std::vector<int64_t> &positions) {
|
||||
// Process assigned_node
|
||||
auto func = block->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
|
||||
CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, assigned_node});
|
||||
auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
|
||||
CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(starred_unpack_prim), iterated_node});
|
||||
|
||||
py::list items = python_adapter::GetPyObjAttr(target, "elts");
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
py::object elt = items[i];
|
||||
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
|
||||
if (elt_type != AST_SUB_TYPE_STARRED) {
|
||||
std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
|
||||
ValuePtr op = prim::GetPythonOps("getitem", module_name);
|
||||
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(op), unpack_node, NewValueNode(positions[i])};
|
||||
AnfNodePtr tuple_get_item = func->NewCNodeInOrder(tuple_get_item_inputs);
|
||||
MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << tuple_get_item->DebugString();
|
||||
WriteAssignVars(block, elt, tuple_get_item);
|
||||
} else {
|
||||
auto starred_get_item_prim = std::make_shared<prim::StarredGetItem>(NAMED_METAGRAPH_STARRED_GET_ITEM);
|
||||
std::vector<AnfNodePtr> starred_get_item_inputs{NewValueNode(starred_get_item_prim), unpack_node,
|
||||
NewValueNode(positions[i]),
|
||||
NewValueNode(SizeToLong(items.size()))};
|
||||
AnfNodePtr starred_get_item = func->NewCNodeInOrder(starred_get_item_inputs);
|
||||
MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << starred_get_item->DebugString();
|
||||
WriteAssignVars(block, elt, starred_get_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Parser::HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target,
|
||||
const AnfNodePtr &assigned_node) {
|
||||
MS_EXCEPTION_IF_NULL(block);
|
||||
AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
|
||||
py::list items = python_adapter::GetPyObjAttr(target, "elts");
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
// Use the Primitive replace the operation resolve node (getitem),
|
||||
// because the getitem will eventually be converted to Primitive node
|
||||
MS_EXCEPTION_IF_NULL(block->func_graph());
|
||||
CNodePtr item_apply =
|
||||
block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
|
||||
|
||||
// Record the position with targets.
|
||||
size_t target_starred_num = 0;
|
||||
size_t starred_pos = items.size();
|
||||
std::vector<int64_t> positions(items.size(), 0);
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
py::object elt = items[i];
|
||||
WriteAssignVars(block, elt, item_apply);
|
||||
auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
|
||||
if (elt_type == AST_SUB_TYPE_STARRED) {
|
||||
target_starred_num++;
|
||||
if (target_starred_num > 1) {
|
||||
MS_LOG(EXCEPTION) << "SyntaxError: " << target_starred_num << " starred expressions in assignment.";
|
||||
}
|
||||
starred_pos = i;
|
||||
positions[i] = i;
|
||||
} else {
|
||||
if (i > starred_pos) {
|
||||
positions[i] = i - items.size();
|
||||
} else {
|
||||
positions[i] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
auto func = block->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(func);
|
||||
|
||||
if (target_starred_num == 0) {
|
||||
for (size_t i = 0; i < items.size(); i++) {
|
||||
// Use the Primitive replace the operation resolve node (getitem),
|
||||
// because the getitem will eventually be converted to Primitive node
|
||||
CNodePtr item_apply = func->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
|
||||
py::object elt = items[i];
|
||||
WriteAssignVars(block, elt, item_apply);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Process AssignTuple with starred expression.
|
||||
// a, *b, c = x // targets(a, *b, c) = assign(x)
|
||||
HandleAssignTupleWithStarredExpression(block, target, assigned_node, positions);
|
||||
}
|
||||
|
||||
bool Parser::IsClassParameterMember(const py::object &target_obj, const AnfNodePtr &target_node) const {
|
||||
|
@ -3981,6 +4146,8 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta
|
|||
HandleAssignClassMember(block, target_object, value_node);
|
||||
} else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
|
||||
HandleAssignClassMember(block, target_object, value_node);
|
||||
} else if (ast_type == AST_SUB_TYPE_STARRED) {
|
||||
HandleAssignStarred(block, target_object, value_node);
|
||||
} else {
|
||||
TraceGuard trace_guard(GetLocation(target_object));
|
||||
MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
|
||||
|
@ -4196,10 +4363,11 @@ bool Parser::IsPopOperation(const AnfNodePtr &node) const {
|
|||
FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast assign";
|
||||
py::object value_object = python_adapter::GetPyObjAttr(node, "value");
|
||||
py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
|
||||
|
||||
AnfNodePtr value_node = ParseExprNode(block, value_object);
|
||||
value_node = HandleInterpret(block, value_node, value_object);
|
||||
|
||||
py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
|
||||
py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
|
||||
size_t count = LongToSize(pcount);
|
||||
MS_LOG(DEBUG) << "The nodes count is " << count;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -213,22 +213,31 @@ class Parser {
|
|||
AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a tuple.
|
||||
AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a tuple.
|
||||
// Process a list.
|
||||
AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a tuple.
|
||||
// Process a tuple or list.
|
||||
AnfNodePtr ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple);
|
||||
// Process a tuple or list with starred expression.
|
||||
AnfNodePtr ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple,
|
||||
const std::vector<AnfNodePtr> &starred_flags);
|
||||
// Process a subscript.
|
||||
AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a slice.
|
||||
AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a extslice.
|
||||
AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a tuple.
|
||||
// Process a index.
|
||||
AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a unaryop.
|
||||
AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
|
||||
// Process a dict ast node expression.
|
||||
AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
|
||||
const std::vector<AnfNodePtr> &value_nodes);
|
||||
// Process a dict.
|
||||
AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
|
||||
|
||||
std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> GetRealKeysValues(const FunctionBlockPtr &block,
|
||||
const py::object &node);
|
||||
// Process DictComp expression.
|
||||
AnfNodePtr ParseDictComp(const FunctionBlockPtr &block, const py::object &node);
|
||||
FunctionBlockPtr ParseDictCompIter(const FunctionBlockPtr &block, const py::object &node,
|
||||
|
@ -243,6 +252,7 @@ class Parser {
|
|||
const py::object &node, const py::object &generator_node);
|
||||
AnfNodePtr ParseJoinedStr(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseFormattedValue(const FunctionBlockPtr &block, const py::object &node);
|
||||
AnfNodePtr ParseStarred(const FunctionBlockPtr &block, const py::object &node);
|
||||
std::vector<AnfNodePtr> HandleException(const FunctionBlockPtr &block, const py::list &args, const std::string &name);
|
||||
std::vector<AnfNodePtr> ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node);
|
||||
void HandleStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes);
|
||||
|
@ -327,10 +337,17 @@ class Parser {
|
|||
// Assign value to single variable name.
|
||||
void HandleAssignName(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node) const;
|
||||
|
||||
// Assign value to starred expression.
|
||||
void HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node);
|
||||
|
||||
// Assign value to tuple.
|
||||
void HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target,
|
||||
const AnfNodePtr &assigned_node);
|
||||
|
||||
// Assign value to tuple with starred expression.
|
||||
void HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target,
|
||||
const AnfNodePtr &assigned_node, const std::vector<int64_t> &positions);
|
||||
|
||||
// Assign value to class Parameter member. Return false if not a Parameter member.
|
||||
bool HandleAssignClassParameterMember(const FunctionBlockPtr &block, const py::object &target,
|
||||
const AnfNodePtr &value_node);
|
||||
|
|
|
@ -44,6 +44,7 @@ enum AstSubType : int64_t {
|
|||
AST_SUB_TYPE_SUBSCRIPT = 8, // ast.Subscript
|
||||
AST_SUB_TYPE_STARRED = 9, // ast.Starred
|
||||
AST_SUB_TYPE_ATTRIBUTE = 10, // ast.Attribute
|
||||
AST_SUB_TYPE_DICT = 11, // ast.Dict
|
||||
AST_SUB_TYPE_UNKNOWN = 0xFF // Unknown type
|
||||
};
|
||||
|
||||
|
@ -170,6 +171,9 @@ const char NAMED_PRIMITIVE_MAKELIST[] = "make_list";
|
|||
const char NAMED_PRIMITIVE_MAKESLICE[] = "make_slice";
|
||||
const char NAMED_PRIMITIVE_MAKEDICT[] = "make_dict";
|
||||
const char NAMED_METAGRAPH_UNPACKCALL[] = "unpack_call";
|
||||
const char NAMED_METAGRAPH_STARRED_UNPACK[] = "starred_unpack";
|
||||
const char NAMED_METAGRAPH_STARRED_GET_ITEM[] = "starred_get_item";
|
||||
const char NAMED_METAGRAPH_STARRED_UNPACK_MERGE[] = "starred_unpack_merge";
|
||||
|
||||
// Define NAMED_PRIMITIVE_GETATTR "getattr".
|
||||
// Define python inline attr.
|
||||
|
|
|
@ -1925,7 +1925,10 @@ EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_abs_
|
|||
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
|
||||
if (!allow_fallback_runtime) {
|
||||
MS_EXCEPTION(TypeError) << "Do not support to get attribute from " << data_value->ToString()
|
||||
<< "\nThe first argument should be a NameSpace, but got " << data->ToString();
|
||||
<< " in JIT strict mode. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' \"\n"
|
||||
<< " to enable the JIT lax mode to support the current syntax."
|
||||
<< "\nThe first argument should be a NameSpace, but got " << data->ToString()
|
||||
<< trace::GetDebugInfoStr(out_conf->node()->debug_info());
|
||||
}
|
||||
|
||||
auto item_value = item->BuildValue();
|
||||
|
@ -2061,7 +2064,11 @@ EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePt
|
|||
if (!has_default) {
|
||||
const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
|
||||
if (!allow_fallback_runtime) {
|
||||
MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
|
||||
MS_EXCEPTION(AttributeError) << "In JIT strict mode, cannot get attributes " << item_name << " or the "
|
||||
<< data_type->ToString() << " object has no attribute: " << item_name
|
||||
<< "'. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' "
|
||||
<< "to enable the JIT lax mode to support the current syntax.\n\n"
|
||||
<< trace::GetDebugInfoStr(out_conf->node()->debug_info());
|
||||
}
|
||||
|
||||
constexpr auto recursive_level = 3;
|
||||
|
@ -2201,7 +2208,10 @@ EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePt
|
|||
MS_EXCEPTION(TypeError) << "Do not support to get attribute from " << py::str(type_str) << " object "
|
||||
<< py::str(obj) << ".\nFor more details, please refer to "
|
||||
<< "https://mindspore.cn/docs/zh-CN/master/faq/network_compilation.html?highlight=do"
|
||||
<< "%20support%20get%20attribute%20from";
|
||||
<< "%20support%20get%20attribute%20from\n"
|
||||
<< "You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' \n"
|
||||
<< "to enable the JIT lax mode to support the current syntax.\n"
|
||||
<< trace::GetDebugInfoStr(out_conf->node()->debug_info());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -56,14 +56,28 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
}
|
||||
|
||||
AbstractBasePtrList key_list = keys->elements();
|
||||
for (size_t index = 0; index < keys_size; index++) {
|
||||
const auto &key = key_list[index];
|
||||
CheckDictKey(key, op_name);
|
||||
}
|
||||
std::unordered_map<std::string, AbstractBasePtr> key_str_value_set;
|
||||
std::vector<AbstractBasePtr> key_set;
|
||||
std::vector<AbstractElementPair> key_value;
|
||||
AbstractBasePtrList value_list = values->elements();
|
||||
for (size_t index = 0; index < keys_size; index++) {
|
||||
(void)key_value.emplace_back(key_list[index], value_list[index]);
|
||||
const auto &key = key_list[index];
|
||||
CheckDictKey(key, op_name);
|
||||
auto key_val = key->BuildValue()->ToString();
|
||||
auto iter = key_str_value_set.find(key_val);
|
||||
// Remove duplicate keys.
|
||||
// {Tensor[1]: x, Tensor[1}: y} the length of dict is 2, means the two keys are not duplicate.
|
||||
if (iter != key_str_value_set.end() && !key->isa<AbstractTensor>()) {
|
||||
iter->second = value_list[index];
|
||||
} else {
|
||||
auto key_str = key->BuildValue()->ToString();
|
||||
key_str_value_set.insert(std::pair<std::string, AbstractBasePtr>(key_str, value_list[index]));
|
||||
key_set.push_back(key);
|
||||
}
|
||||
}
|
||||
for (auto &key : key_set) {
|
||||
auto key_str = key->BuildValue()->ToString();
|
||||
(void)key_value.emplace_back(key, key_str_value_set[key_str]);
|
||||
}
|
||||
return std::make_shared<AbstractDictionary>(key_value);
|
||||
}
|
||||
|
|
|
@ -86,6 +86,7 @@ AST_SUB_TYPE_LIST = 7 # ast.List
|
|||
AST_SUB_TYPE_SUBSCRIPT = 8 # ast.Subscript
|
||||
AST_SUB_TYPE_STARRED = 9 # ast.Starred
|
||||
AST_SUB_TYPE_ATTRIBUTE = 10 # ast.Attribute
|
||||
AST_SUB_TYPE_DICT = 11 # ast.Dict
|
||||
AST_SUB_TYPE_UNKNOWN = 0xFF # unknown
|
||||
|
||||
# Syntax support
|
||||
|
@ -692,6 +693,8 @@ def get_ast_type(node):
|
|||
ast_type = AST_SUB_TYPE_STARRED
|
||||
elif isinstance(node, ast.Attribute):
|
||||
ast_type = AST_SUB_TYPE_ATTRIBUTE
|
||||
elif isinstance(node, ast.Dict):
|
||||
ast_type = AST_SUB_TYPE_DICT
|
||||
else:
|
||||
ast_type = AST_SUB_TYPE_UNKNOWN
|
||||
return ast_type
|
||||
|
|
|
@ -30,7 +30,8 @@ from mindspore._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFu
|
|||
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
||||
ListClear_, ListReverse_, ListExtend_, DictClear_, DictHasKey_, DictUpdate_, DictFromKeys_, \
|
||||
ZerosLike_, TensorIndexGetitem_, TensorIndexSetitem_, ListAdd_, DictSetItem_, \
|
||||
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_
|
||||
HandleBoolTensor_, HandleEmptySlice_, PreSetitemByTuple_, HandleScalarTensorIndex_, StarredGetItem_,\
|
||||
StarredUnpack_, StarredUnpackMerge_
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.api import jit, _pynative_executor, _wrap_func
|
||||
from mindspore.common.api import _add_flags, _core
|
||||
|
@ -1195,3 +1196,48 @@ class _ZipOperation(ZipOperation_):
|
|||
|
||||
zip_operation = _ZipOperation('zip_operation')
|
||||
"""`zip_operation` will generate a tuple of zip iterations of inputs."""
|
||||
|
||||
|
||||
class _StarredGetItem(StarredGetItem_):
|
||||
"""Generates a list of starred get_item for inputs."""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _StarredGetItem."""
|
||||
StarredGetItem_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
starred_get_item = _StarredGetItem('starred_get_item')
|
||||
"""`starred_get_item` will generate a list of starred get_item for inputs."""
|
||||
|
||||
|
||||
class _StarredUnpack(StarredUnpack_):
|
||||
"""Generates a tuple of starred unpack for inputs."""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _StarredUnpack."""
|
||||
StarredUnpack_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
starred_unpack = _StarredUnpack('starred_unpack')
|
||||
"""`starred_unpack` will generate a tuple of starred unpack for inputs."""
|
||||
|
||||
|
||||
class _StarredUnpackMerge(StarredUnpackMerge_):
|
||||
"""Generates a tuple of starred unpack merge for inputs."""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _StarredUnpackMerge."""
|
||||
StarredUnpackMerge_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
starred_unpack_merge = _StarredUnpackMerge('starred_unpack_merge')
|
||||
"""`starred_unpack_merge` will generate a tuple of starred unpack merge for inputs."""
|
||||
|
|
|
@ -239,31 +239,6 @@ def test_compress_with_mutable_input():
|
|||
assert (foo(Tensor([1])) == [1, 3, 4]).all()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not support now")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_star_to_compress_input():
|
||||
"""
|
||||
Feature: Support JIT Fallback runtime feature.
|
||||
Description: use star to compress assigned input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
a, *b = x
|
||||
return a, b
|
||||
|
||||
ret = foo()
|
||||
assert len(ret) == 2
|
||||
assert ret[0] == 1
|
||||
assert ret[1] == [2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
@ -312,27 +287,6 @@ def test_unpack_interpret_node():
|
|||
assert ret == 10
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not support now")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_to_unpack_input():
|
||||
"""
|
||||
Feature: Support JIT Fallback runtime feature.
|
||||
Description: * operator can not unpack a list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
return f"output is {*a,}"
|
||||
|
||||
ret = foo([1, 2, 3, 4])
|
||||
assert ret == "output is (1, 2, 3, 4)"
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
|
@ -0,0 +1,558 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test graph starred expression. """
|
||||
import pytest
|
||||
from mindspore import context, jit, Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_assign_list_input():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
a = *x, # pylint: disable=R1707
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == (1, 2, 3, 4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_assign_list_input():
|
||||
"""
|
||||
Feature: Support assign list.
|
||||
Description: Support assign in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
a = x, # pylint: disable=R1707
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == ([1, 2, 3, 4],)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_assign_tuple_input():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
a = *x, # pylint: disable=R1707
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == (1, 2, 3, 4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_assign_dict_input():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = {"a": 1, "b": 2}
|
||||
out = *x, # pylint: disable=R1707
|
||||
return out
|
||||
|
||||
ret = foo()
|
||||
assert ret == ("a", "b")
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_assign_string_input():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = "abcde"
|
||||
out = *x, # pylint: disable=R1707
|
||||
return out
|
||||
|
||||
ret = foo()
|
||||
assert ret == ('a', 'b', 'c', 'd', 'e')
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_assign_tensor_input():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
out = *x, # pylint: disable=R1707
|
||||
return out
|
||||
|
||||
ret = foo(Tensor([1, 2, 3, 4]))
|
||||
assert ret == (Tensor(1), Tensor(2), Tensor(3), Tensor(4))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_in_format_string():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
return f"output is {*x,}"
|
||||
|
||||
ret = foo([1, 2, 3, 4])
|
||||
assert ret == "output is (1, 2, 3, 4)"
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_tuple():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
*b, = x
|
||||
return b
|
||||
|
||||
ret = foo()
|
||||
assert ret == [1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_list():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
*b, = x
|
||||
return b
|
||||
|
||||
ret = foo()
|
||||
assert ret == [1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_dict():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = {"a": 1, "b": 2}
|
||||
*b, = x
|
||||
return b
|
||||
|
||||
ret = foo()
|
||||
assert ret == ['a', 'b']
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_string():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = "abcde"
|
||||
*b, = x
|
||||
return b
|
||||
|
||||
ret = foo()
|
||||
assert ret == ['a', 'b', 'c', 'd', 'e']
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_tensor():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo(x):
|
||||
*b, c = x
|
||||
return b, c
|
||||
|
||||
ret = foo(Tensor([1, 2, 3, 4]))
|
||||
assert ret[0] == [Tensor(1), Tensor(2), Tensor(3)]
|
||||
assert ret[1] == Tensor(4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_nested_tuple():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
y = (5, 6)
|
||||
a, b, *c = x, y
|
||||
return a, b, c
|
||||
|
||||
ret = foo()
|
||||
assert ret == ((1, 2, 3, 4), (5, 6), [])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_list_2():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
a, *b = x
|
||||
return a, b
|
||||
|
||||
ret = foo()
|
||||
assert len(ret) == 2
|
||||
assert ret[0] == 1
|
||||
assert ret[1] == [2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_tuple_2():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = (1, 2, 3, 4)
|
||||
a, *b, c = x
|
||||
return a, b, c
|
||||
|
||||
ret = foo()
|
||||
assert len(ret) == 3
|
||||
assert ret[0] == 1
|
||||
assert ret[1] == [2, 3]
|
||||
assert ret[2] == 4
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_target_list_tuple():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
x = [1, 2, 3, 4]
|
||||
y = [5, 6]
|
||||
z = (7, 8)
|
||||
a, *b = x, y, z
|
||||
return a, b
|
||||
|
||||
ret = foo()
|
||||
assert len(ret) == 2
|
||||
assert ret[0] == [1, 2, 3, 4]
|
||||
assert ret[1] == [[5, 6], (7, 8)]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_with_range():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a, *b, c = range(5)
|
||||
return a, b, c
|
||||
|
||||
ret = foo()
|
||||
assert ret == (0, [1, 2, 3], 4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_for_in():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
ret = []
|
||||
for _, *b in [(1, 2, 3), (4, 5, 6, 7)]:
|
||||
ret.append(b)
|
||||
return ret
|
||||
|
||||
ret = foo()
|
||||
assert ret == [[2, 3], [5, 6, 7]]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_assign_tuple():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = *[1, 2], *(3, 4), (5, 6)
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == (1, 2, 3, 4, (5, 6))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_range_tuple():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = *range(4), 4
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == (0, 1, 2, 3, 4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_range_list():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = [*range(4), 4]
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == [0, 1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_return_tuple():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = (*[1], *[2], 3)
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == (1, 2, 3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_dict():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = {'x': 1, **{'y': 2}}
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert len(ret) == 2
|
||||
assert ret == {'x': 1, 'y': 2}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_dict_2():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = {'x': 1, **{'y': 2}, "w": 4, **{'z': 3}}
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == {'x': 1, 'y': 2, 'w': 4, 'z': 3}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_dict_3():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = {'x': 1, **{'y': 2, 'z': 3, **{'w': 4}}}
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == {'x': 1, 'y': 2, 'z': 3, 'w': 4}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_dict_4():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = {'x': 1, 'y': {'z': 3, **{'w': 4}}}
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == {'x': 1, 'y': {'z': 3, 'w': 4}}
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_starred_expression_dict_key_deduplicate():
|
||||
"""
|
||||
Feature: Support starred expression.
|
||||
Description: Support starred expression in graph mode.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
@jit
|
||||
def foo():
|
||||
a = {'x': 1, **{'x': 2}}
|
||||
return a
|
||||
|
||||
ret = foo()
|
||||
assert ret == {'x': 2}
|
Loading…
Reference in New Issue