support list slice assign

This commit is contained in:
lianliguang 2022-04-14 14:34:52 +08:00
parent a7a9486d32
commit 11b4836040
12 changed files with 936 additions and 149 deletions

View File

@ -18,7 +18,7 @@
#include "frontend/operator/composite/composite.h"
#include <algorithm>
#include <tuple>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "abstract/abstract_value.h"
@ -37,6 +37,7 @@
namespace mindspore {
// namespace to support composite operators definition
namespace prim {
constexpr auto kStepDefault = 1;
using AbstractTensor = mindspore::abstract::AbstractTensor;
using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
@ -54,14 +55,6 @@ using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSlice;
using mindspore::abstract::AbstractTuple;
ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
{"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
{"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
{"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
{"__ge__", kPrimScalarGe}};
ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
void HyperMap::Init() {
if (fn_leaf_) {
name_ = "hyper_map[" + fn_leaf_->name() + "]";
@ -1246,78 +1239,81 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
<< member->BuildType()->ToString();
}
void GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &tuple, const AbstractSlicePtr &slice,
int64_t *start_index, int64_t *stop_index, int64_t *step_value) {
MS_EXCEPTION_IF_NULL(tuple);
std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &sequence,
const AbstractSlicePtr &slice) {
MS_EXCEPTION_IF_NULL(sequence);
MS_EXCEPTION_IF_NULL(slice);
MS_EXCEPTION_IF_NULL(start_index);
MS_EXCEPTION_IF_NULL(stop_index);
MS_EXCEPTION_IF_NULL(step_value);
int64_t start_index;
int64_t stop_index;
int64_t step_value;
const std::string start_name("Slice start index");
const std::string stop_name("Slice stop index");
const std::string step_name("Slice step value");
int64_t tuple_size = SizeToLong(tuple->size());
int64_t tuple_size = SizeToLong(sequence->size());
int64_t start_default = 0;
int64_t stop_default = tuple_size;
int64_t step_default = 1;
int64_t step_default = kStepDefault;
*step_value = CheckSliceMember(slice->step(), step_default, step_name);
if (*step_value == 0) {
step_value = CheckSliceMember(slice->step(), step_default, step_name);
if (step_value == 0) {
MS_EXCEPTION(ValueError) << "Slice step cannot be zero.";
}
if (*step_value < 0) {
if (step_value < 0) {
start_default = tuple_size - 1;
stop_default = -1;
stop_default = ((-tuple_size) - 1);
}
*start_index = CheckSliceMember(slice->start(), start_default, start_name);
*stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
start_index = CheckSliceMember(slice->start(), start_default, start_name);
stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
if (*start_index < -tuple_size) {
*start_index = 0;
}
if (*stop_index > tuple_size) {
*stop_index = tuple_size;
}
if (*start_index > tuple_size || *stop_index < -tuple_size) {
*start_index = 0;
*stop_index = 0;
if (start_index < -tuple_size) {
start_index = 0;
}
*start_index = GetPositiveIndex(*start_index, tuple_size);
if (!slice->stop()->isa<AbstractNone>()) {
*stop_index = GetPositiveIndex(*stop_index, tuple_size);
if (stop_index > tuple_size) {
stop_index = tuple_size;
}
if (start_index > tuple_size) {
start_index = tuple_size;
}
if (stop_index < ((-tuple_size) - 1)) {
stop_index = 0;
}
start_index = GetPositiveIndex(start_index, tuple_size);
stop_index = GetPositiveIndex(stop_index, tuple_size);
return std::make_tuple(start_index, stop_index, step_value);
}
FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tuple
// args: tuple, start index, end index, step
auto seq_pair = CheckArgs(args_spec_list);
auto sequence = seq_pair.first;
auto slice = seq_pair.second;
int64_t start_index;
int64_t stop_index;
int64_t step_value;
GenerateTupleSliceParameter(sequence, slice, &start_index, &stop_index, &step_value);
void SequenceSliceGetItem::CheckArgs(const AbstractBasePtrList &args_spec_list) {
constexpr size_t arg_size = 2;
abstract::CheckArgsSize(this->name(), args_spec_list, arg_size);
sequence_ = abstract::CheckArg<abstract::AbstractSequence>(this->name(), args_spec_list, 0);
slice_ = abstract::CheckArg<AbstractSlice>(this->name(), args_spec_list, 1);
}
FuncGraphPtr SequenceSliceGetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr p_tuple = ret->add_parameter();
AnfNodePtr p_seq = ret->add_parameter();
(void)ret->add_parameter();
std::vector<AnfNodePtr> elems;
elems.push_back(NewValueNode(prim_));
if (step_value > 0) {
for (int64_t index = start_index; index < stop_index; index = index + step_value) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
}
} else {
for (int64_t index = start_index; index > stop_index; index = index + step_value) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
}
}
@ -1325,24 +1321,6 @@ FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_sp
return ret;
}
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> TupleSlice::CheckArgs(
const AbstractBasePtrList &args_spec_list) {
constexpr size_t arg_size = 2;
abstract::CheckArgsSize("TupleSlice", args_spec_list, arg_size);
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("TupleSlice", args_spec_list, 0);
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("TupleSlice", args_spec_list, 1);
return std::make_pair(sequence, slice);
}
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> ListSlice::CheckArgs(
const AbstractBasePtrList &args_spec_list) {
constexpr size_t arg_size = 2;
abstract::CheckArgsSize("ListSlice", args_spec_list, arg_size);
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("ListSlice", args_spec_list, 0);
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("ListSlice", args_spec_list, 1);
return std::make_pair(sequence, slice);
}
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// select indexed item
// args: tuple of items, index
@ -1363,20 +1341,16 @@ REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
(void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
*m, "TupleGetItemTensor_")
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(ListSlice_, ([](const py::module *m) {
(void)py::class_<ListSlice, MetaFuncGraph, std::shared_ptr<ListSlice>>(*m, "ListSlice_")
.def(py::init<std::string &>());
REGISTER_PYBIND_DEFINE(SequenceSliceGetItem, ([](const py::module *m) {
(void)py::class_<SequenceSliceGetItem, MetaFuncGraph, std::shared_ptr<SequenceSliceGetItem>>(
*m, "SequenceSliceGetItem_")
.def(py::init<std::string &, std::string &, std::string &>());
}));
namespace {
@ -1449,5 +1423,114 @@ REGISTER_PYBIND_DEFINE(Shard_, ([](const py::module *m) {
(void)py::class_<Shard, MetaFuncGraph, std::shared_ptr<Shard>>(*m, "Shard_")
.def(py::init<const std::string &>(), py::arg("fn"));
}));
void ListSliceSetItem::CheckArgs(const AbstractBasePtrList &args_spec_list) {
constexpr size_t kSliceSetItemArgsSizeargs_size = 3;
constexpr size_t kSliceSetItemListIndex = 0;
constexpr size_t kSliceSetItemSliceIndex = 1;
constexpr size_t kSliceSetItemValueIndex = 2;
abstract::CheckArgsSize("list_slice_set_item", args_spec_list, kSliceSetItemArgsSizeargs_size);
this->sequence_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_spec_list, kSliceSetItemListIndex);
this->slice_ = abstract::CheckArg<AbstractSlice>("list_slice_set_item", args_spec_list, kSliceSetItemSliceIndex);
this->value_list_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_spec_list, kSliceSetItemValueIndex);
}
FuncGraphPtr ListSliceSetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
// Init graph with the input list_node slice assign_node
CheckAssignRange(start_index, stop_index, step_value);
auto graph = std::make_shared<FuncGraph>();
graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
auto list_node = graph->add_parameter();
(void)graph->add_parameter();
auto assign_parameter = graph->add_parameter();
auto assign_node = GetAssignNode(graph, assign_parameter, step_value);
std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
int64_t list_index = 0;
// check the index is in the slice range
auto check_in_range = [start_index, stop_index, step_value](int64_t index) -> bool {
if (step_value > 0) {
return (index >= start_index && index < stop_index);
}
return (index <= start_index && index > stop_index);
};
int64_t list_size = SizeToLong(sequence_->size());
int64_t assign_index = 0;
int64_t value_size = SizeToLong(value_list_->size());
while (list_index < list_size || assign_index < value_size) {
if (!check_in_range(list_index)) {
// list start <= stop && step = 1 insert the assign node to target node
while (assign_index < value_size && list_index == start_index) {
(void)elems.emplace_back(
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
}
if (list_index < list_size) {
(void)elems.emplace_back(
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
}
} else {
if (((list_index - start_index) % step_value) == 0) {
++list_index;
if (assign_index >= value_size) {
continue;
}
(void)elems.emplace_back(
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
} else {
(void)elems.emplace_back(
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
}
// the assign node's len is larger than the range
while (!check_in_range(list_index) && assign_index < value_size) {
(void)elems.emplace_back(
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
}
}
}
graph->set_output(graph->NewCNodeInOrder(elems));
return graph;
}
void ListSliceSetItem::CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value) {
if (step_value != kStepDefault) {
int64_t start_include = 0;
if (start_index < SizeToLong(sequence_->size()) && start_index >= -SizeToLong(sequence_->size())) {
start_include = 1;
}
auto assign_size = ((stop_index - start_index - 1) / step_value) + start_include;
if (step_value < 0) {
assign_size = ((start_index - stop_index) / -step_value) + start_include;
}
if (assign_size != SizeToLong(value_list_->size())) {
MS_EXCEPTION(ValueError) << "attempt to assign sequence of size " << value_list_->size()
<< " to extended slice of size " << assign_size;
}
}
}
AnfNodePtr ListSliceSetItem::GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node,
int64_t step_value) {
if (step_value > 0) {
return assign_node;
}
std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
for (int64_t i = SizeToInt(value_list_->size()) - 1; i >= 0; --i) {
elems.emplace_back(
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), assign_node, NewValueNode(i)}));
}
return func_graph->NewCNodeInOrder(elems);
}
REGISTER_PYBIND_DEFINE(ListSliceSetItem_, ([](const py::module *m) {
(void)py::class_<ListSliceSetItem, MetaFuncGraph, std::shared_ptr<ListSliceSetItem>>(
*m, "ListSliceSetItem_")
.def(py::init<const std::string &>());
}));
FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
this->CheckArgs(args_spec_list);
auto [start, stop, step] = GenerateTupleSliceParameter(sequence_, slice_);
return this->BuildFuncGraph(start, stop, step);
}
} // namespace prim
} // namespace mindspore

View File

@ -47,6 +47,7 @@ using AbstractScalarPtr = abstract::AbstractScalarPtr;
using AbstractTensorPtr = abstract::AbstractTensorPtr;
using ElemwiseMap = mindspore::HashMap<std::string, PrimitivePtr>;
using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
using AbstractListPtr = abstract::AbstractListPtr;
class HyperMap : public MetaFuncGraph {
public:
@ -202,37 +203,55 @@ using TupleAddPtr = std::shared_ptr<TupleAdd>;
class SequenceSlice : public MetaFuncGraph {
public:
explicit SequenceSlice(const std::string &name, const PrimitivePtr &prim, const PrimitivePtr &get_item)
: MetaFuncGraph(name), prim_(prim), get_item_(get_item) {}
explicit SequenceSlice(const std::string &name) : MetaFuncGraph(name) {}
~SequenceSlice() override = default;
MS_DECLARE_PARENT(SequenceSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) final;
friend bool operator==(const SequenceSlice &lhs, const SequenceSlice &rhs) { return lhs.name_ == rhs.name_; }
virtual std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
const AbstractBasePtrList &args_spec_list) = 0;
protected:
virtual void CheckArgs(const AbstractBasePtrList &args_spec_list) = 0;
virtual FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) = 0;
abstract::AbstractSequencePtr sequence_ = nullptr;
AbstractSlicePtr slice_ = nullptr;
};
class SequenceSliceGetItem : public SequenceSlice {
public:
explicit SequenceSliceGetItem(const std::string &name, const std::string &prim_name, const std::string &get_item_name)
: SequenceSlice(name),
prim_(std::make_shared<Primitive>(prim_name)),
get_item_(std::make_shared<Primitive>(get_item_name)) {}
~SequenceSliceGetItem() override = default;
MS_DECLARE_PARENT(SequenceSliceGetItem, MetaFuncGraph)
friend bool operator==(const SequenceSliceGetItem &lhs, const SequenceSliceGetItem &rhs) {
return lhs.name_ == rhs.name_;
}
protected:
void CheckArgs(const AbstractBasePtrList &args_spec_list) override;
FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override;
private:
PrimitivePtr prim_;
PrimitivePtr get_item_;
};
class TupleSlice : public SequenceSlice {
class ListSliceSetItem : public SequenceSlice {
public:
explicit TupleSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeTuple, prim::kPrimTupleGetItem) {}
~TupleSlice() override = default;
MS_DECLARE_PARENT(TupleSlice, SequenceSlice)
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
const AbstractBasePtrList &args_spec_list) override;
};
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
explicit ListSliceSetItem(const std::string &name) : SequenceSlice(name) {}
~ListSliceSetItem() override = default;
MS_DECLARE_PARENT(ListSliceSetItem, MetaFuncGraph)
friend bool operator==(const ListSliceSetItem &lhs, const ListSliceSetItem &rhs) { return lhs.name_ == rhs.name_; }
class ListSlice : public SequenceSlice {
public:
explicit ListSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeList, prim::kPrimListGetItem) {}
~ListSlice() override = default;
MS_DECLARE_PARENT(ListSlice, SequenceSlice)
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
const AbstractBasePtrList &args_spec_list) override;
protected:
void CheckArgs(const AbstractBasePtrList &args_spec_list) override;
FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override;
private:
void CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value);
AnfNodePtr GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node, int64_t step_value);
AbstractListPtr value_list_ = nullptr;
};
class TupleGetItemTensor : public MetaFuncGraph {

View File

@ -168,10 +168,11 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
return meta_func_graph->isa<prim::Tail>() || meta_func_graph->isa<prim::MakeTupleGradient>() ||
meta_func_graph->isa<prim::MakeListGradient>() || meta_func_graph->isa<prim::TupleAdd>() ||
meta_func_graph->isa<prim::SequenceSlice>() || meta_func_graph->isa<prim::UnpackCall>() ||
meta_func_graph->isa<prim::ZipOperation>() || meta_func_graph->isa<prim::ListAppend>() ||
meta_func_graph->isa<prim::ListInsert>() || meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>() ||
meta_func_graph->isa<prim::VmapMatchOutAxis>() || meta_func_graph->isa<prim::VmapGeneralPreprocess>();
meta_func_graph->isa<prim::SequenceSliceGetItem>() || meta_func_graph->isa<prim::ListSliceSetItem>() ||
meta_func_graph->isa<prim::UnpackCall>() || meta_func_graph->isa<prim::ZipOperation>() ||
meta_func_graph->isa<prim::ListAppend>() || meta_func_graph->isa<prim::ListInsert>() ||
meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>() || meta_func_graph->isa<prim::VmapMatchOutAxis>() ||
meta_func_graph->isa<prim::VmapGeneralPreprocess>();
}
/* inherit relation of MetaFuncGraph
@ -186,7 +187,10 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
* MakeTupleGradient
* MakeListGradient
* GradOperation
* TupleAdd
* TupleAdd
* SequenceSlice
* SequenceSliceGetItem
* ListSliceSetItem
*/
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
if (meta_func_graph == nullptr) {

View File

@ -2314,14 +2314,14 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
}
if (AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
AST_SUB_TYPE_SUBSCRIPT) {
HandleAssignSubscript(block, value_obj, setitem_app);
return;
if (IsSubscriptReferenceType(value_obj)) {
HandleAssignSubscript(block, value_obj, setitem_app);
return;
}
}
if (!py::hasattr(value_obj, "id")) {
MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast<std::string>() << "\n\n"
<< trace::GetDebugInfo(value_node->debug_info());
if (py::hasattr(value_obj, "id")) {
var_name = value_obj.attr("id").cast<std::string>();
}
var_name = value_obj.attr("id").cast<std::string>();
block->WriteVariable(var_name, setitem_app);
}
@ -2910,5 +2910,12 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
}
return func_graph;
}
bool Parser::IsSubscriptReferenceType(const py::object &obj) {
py::object slice_node = python_adapter::GetPyObjAttr(obj, "slice");
auto node_type = ast_->GetNodeType(slice_node);
auto node_name = node_type->node_name();
return node_name != "Slice";
}
} // namespace parse
} // namespace mindspore

View File

@ -291,6 +291,8 @@ class Parser {
const std::vector<AnfNodePtr> &packed_arguments,
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const;
ScopePtr GetScopeForParseFunction();
// Check the value is subscript is reference type
bool IsSubscriptReferenceType(const py::object &obj);
void BuildMethodMap();
FunctionBlockPtr MakeFunctionBlock(const Parser &parse) {
FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);

View File

@ -778,7 +778,10 @@ class MS_CORE_API AbstractSequence : public AbstractBase {
///
/// \return A size_t.
std::size_t size() const { return elements_.size(); }
/// \brief Get the size of the stored elements.
///
/// \return A size_t.
bool empty() const { return elements_.empty(); }
/// \brief Get the stored elements.
///
/// \return A vector of elements.

View File

@ -21,8 +21,8 @@ from types import FunctionType
import mindspore as ms
from mindspore import context
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
ListSlice_, VmapOperation_, TaylorOperation_
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_executor, _wrap_func
from ..primitive import Primitive
@ -30,7 +30,7 @@ from ..operations import _grad_ops
from .. import operations as P
from .. import signature as sig
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_, ListSlice_]
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_]
def add_flags(fn=None, **flags):

View File

@ -26,7 +26,7 @@ using ".register" decorator.
"""
class _TupleSlice(base.TupleSlice_):
class _TupleSlice(base.SequenceSliceGetItem_):
"""
Slices a tuple.
@ -40,7 +40,7 @@ class _TupleSlice(base.TupleSlice_):
def __init__(self, name):
"""Initialize _TupleSlice."""
base.TupleSlice_.__init__(self, name)
base.SequenceSliceGetItem_.__init__(self, name, "MakeTuple", "TupleGetItem")
def __call__(self, *args):
pass
@ -50,7 +50,7 @@ _tuple_slice = _TupleSlice('tuple_slice')
"""_tuple_slice is a metafuncgraph object which will slice a tuple."""
class _ListSlice(base.ListSlice_):
class _ListSlice(base.SequenceSliceGetItem_):
"""
Slices a List.
@ -64,7 +64,7 @@ class _ListSlice(base.ListSlice_):
def __init__(self, name):
"""Initialize _TupleSlice."""
base.ListSlice_.__init__(self, name)
base.SequenceSliceGetItem_.__init__(self, name, "make_list", "list_getitem")
def __call__(self, *args):
pass

View File

@ -17,11 +17,38 @@
from . import _compile_utils as compile_utils
from ... import functional as F
from ...operations._inner_ops import SliceGetItem
from ...composite import base
from ....common import Tensor
setitem = base.MultitypeFuncGraph('setitem')
slice_get_item = SliceGetItem()
class _ListSliceSetItem(base.ListSliceSetItem_):
"""
List slice assign.
Inputs:
data (List): A List to be sliced.
s (slice): The index to slice list data.
value : The value to be assign
Outputs:
List, consists of some elements of data.
"""
def __init__(self, name):
"""Initialize _TupleSlice."""
base.ListSliceSetItem_.__init__(self, name)
def __call__(self, *args):
pass
_list_slice_set_item = _ListSliceSetItem('list_slice_set_item')
"""_list_slice_set_item is a MetaFuncGraph object which assign a list will slice."""
@setitem.register("List", "Number", "String")
def _list_setitem_with_string(data, number_index, value):
@ -102,6 +129,75 @@ def _list_setitem_with_tuple(data, number_index, value):
return F.list_setitem(data, number_index, value)
@setitem.register("List", "Slice", "Tuple")
def _list_slice_setitem_with_tuple(data, slice_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type list.
slice_index (slice): Index of data.
value (tuple): Value given.
Outputs:
list, type is the same as the element type of data.
"""
list_value = list(value)
return _list_slice_set_item(data, slice_index, list_value)
@setitem.register("List", "Slice", "List")
def _list_slice_setitem_with_list(data, slice_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type list.
slice_index (slice): Index of data.
value (list): Value given.
Outputs:
list, type is the same as the element type of data.
"""
return _list_slice_set_item(data, slice_index, value)
@setitem.register("List", "Slice", "Tensor")
def _list_slice_setitem_with_tensor(data, slice_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type list.
slice_index (slice): Index of data.
value (Tensor): Value given.
Outputs:
list, type is the same as the element type of data.
"""
value_list = list(value)
return _list_slice_set_item(data, slice_index, value_list)
@setitem.register("List", "Slice", "Number")
def _list_slice_setitem_with_number(data, slice_index, value):
"""
Assigns value to list.
Inputs:
data (list): Data of type list.
slice_index (slice): Index of data.
value (number): Value given.
Outputs:
lis/t, type is the same as the element type of data.
"""
step = slice_get_item(slice_index, "step")
if step == 1 or step is None:
raise TypeError("can only assign an iterable")
raise TypeError("must assign iterable to extended slice")
@setitem.register("Dictionary", "String", "Tensor")
def _dict_setitem_with_tensor(data, key, value):
"""

View File

@ -0,0 +1,61 @@
import pytest
from mindspore.nn import Cell
from mindspore import Tensor
from mindspore import context
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pynative_list_slice_tensor_no_step():
"""
Feature: List assign
Description: Test list slice assign with tensor
Expectation: No exception.
"""
class NetInner(Cell):
def construct(self, start=None, stop=None, step=None):
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = Tensor([11, 22, 33])
a[start:stop:step] = b
return tuple(a)
context.set_context(mode=context.PYNATIVE_MODE)
net = NetInner()
python_out = (Tensor(11), Tensor(22), Tensor(33), 4, 5, 6, 7, 8, 9)
pynative_out = net(0, 3, None)
assert pynative_out == python_out
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(0, 3, None)
assert graph_out == python_out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_pynative_list_slice_tensor_with_step():
"""
Feature: List assign
Description: Test list slice assign with tensor
Expectation: No exception.
"""
class NetInner(Cell):
def construct(self, start=None, stop=None, step=None):
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = Tensor([11, 22, 33])
a[start:stop:step] = b
return tuple(a)
context.set_context(mode=context.PYNATIVE_MODE)
net = NetInner()
python_out = (Tensor(11), 2, 3, Tensor(22), 5, 6, Tensor(33), 8, 9)
pynative_out = net(0, None, 3)
assert python_out == pynative_out
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(0, None, 3)
assert python_out == graph_out

View File

@ -88,7 +88,8 @@ class UTCompositeUtils {
};
TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
MetaFuncGraphPtr tupleSlicePtr =
std::make_shared<prim::SequenceSliceGetItem>("TupleSlice", "MakeTuple", "TupleGetItem");
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
AbstractBasePtrList eles;
@ -114,7 +115,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
}
TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
MetaFuncGraphPtr tupleSlicePtr =
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
AbstractBasePtrList eles;
@ -146,7 +148,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
TEST_F(TestComposite, test_TupleSlice_arg_slice) {
std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
MetaFuncGraphPtr tupleSlicePtr =
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
AbstractBasePtrList eles;
@ -173,7 +176,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
}
TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
MetaFuncGraphPtr tupleSlicePtr =
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
AbstractBasePtrList eles;
@ -200,7 +204,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
}
TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
MetaFuncGraphPtr tupleSlicePtr =
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
AbstractBasePtrList eles;
@ -227,7 +232,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
}
TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
MetaFuncGraphPtr tupleSlicePtr =
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
AbstractBasePtrList eles;
@ -257,7 +263,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
/// Description: The second input is a scalar
/// Expectation: Throw type error
TEST_F(TestComposite, test_ListSlice_arg_one_number) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
AbstractBasePtrList eles;
@ -292,7 +298,7 @@ TEST_F(TestComposite, test_ListSlice_arg_one_number) {
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice) {
std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
@ -321,7 +327,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice) {
/// Description: Test List slice the step is none
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
@ -350,7 +356,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
/// Description: Test List slice the step is negative
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
@ -379,7 +385,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
/// Description: Test List slice the step is positive
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice","make_list","list_getitem");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;

View File

@ -13,19 +13,23 @@
# limitations under the License.
# ============================================================================
""" test enumerate"""
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
from mindspore.nn import Cell
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE)
from mindspore.ops import operations as P
from mindspore import Tensor, ms_function
from mindspore import context
def test_list_index_1D():
def test_list_index_1d():
"""
Feature: List index assign
Description: Test list assign in pynative mode
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
@ -38,8 +42,22 @@ def test_list_index_1D():
assert list(out[1]) == [2, 2]
assert list(out[2]) == [3, 3, 3]
context.set_context(mode=context.GRAPH_MODE)
net = Net()
out = net()
assert list(out[0]) == [100]
assert list(out[1]) == [2, 2]
assert list(out[2]) == [3, 3, 3]
def test_list_neg_index_1D():
def test_list_neg_index_1d():
"""
Feature: List index assign
Description: Test list assign in pynative mode
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
@ -52,8 +70,20 @@ def test_list_neg_index_1D():
assert list(out[1]) == [2, 2]
assert list(out[2]) == [3, 3, 3]
context.set_context(mode=context.GRAPH_MODE)
out = net()
assert list(out[0]) == [100]
assert list(out[1]) == [2, 2]
assert list(out[2]) == [3, 3, 3]
def test_list_index_2D():
def test_list_index_2d():
"""
Feature: List index assign
Description: Test list assign in pynative mode
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
@ -67,23 +97,48 @@ def test_list_index_2D():
assert list(out[1]) == [200, 201]
assert list(out[2]) == [3, 3, 3]
def test_list_neg_index_2D():
class Net(nn.Cell):
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[1][-2] = 200
list_[1][-1] = 201
return list_
net = Net()
context.set_context(mode=context.GRAPH_MODE)
out = net()
assert list(out[0]) == [1]
assert list(out[1]) == [200, 201]
assert list(out[2]) == [3, 3, 3]
def test_list_index_3D():
def test_list_neg_index_2d():
"""
Feature: List index assign
Description: Test list assign in pynative mode
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
class Net(nn.Cell):
def construct(self):
list_ = [[1], [2, 2], [3, 3, 3]]
list_[1][-2] = 20
list_[1][-1] = 21
return list_
net = Net()
out = net()
assert list(out[0]) == [1]
assert list(out[1]) == [20, 21]
assert list(out[2]) == [3, 3, 3]
context.set_context(mode=context.GRAPH_MODE)
out = net()
assert list(out[0]) == [1]
assert list(out[1]) == [20, 21]
assert list(out[2]) == [3, 3, 3]
def test_list_index_3d():
"""
Feature: List index assign
Description: Test list assign in pynative mode
Expectation: No exception.
"""
class Net(nn.Cell):
def construct(self):
list_ = [[1], [2, 2], [[3, 3, 3]]]
@ -92,30 +147,53 @@ def test_list_index_3D():
list_[2][0][2] = 302
return list_
context.set_context(mode=context.PYNATIVE_MODE)
net = Net()
out = net()
assert list(out[0]) == [1]
assert list(out[1]) == [2, 2]
assert list(out[2][0]) == [300, 301, 302]
context.set_context(mode=context.GRAPH_MODE)
out = net()
assert list(out[0]) == [1]
assert list(out[1]) == [2, 2]
assert list(out[2][0]) == [300, 301, 302]
def test_list_neg_index_3d():
"""
Feature: List index assign
Description: Test list assign in pynative mode
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
def test_list_neg_index_3D():
class Net(nn.Cell):
def construct(self):
list_ = [[1], [2, 2], [[3, 3, 3]]]
list_[2][0][-3] = 300
list_[2][0][-2] = 301
list_[2][0][-1] = 302
list_[2][0][-3] = 30
list_[2][0][-2] = 31
list_[2][0][-1] = 32
return list_
net = Net()
out = net()
assert list(out[0]) == [1]
assert list(out[1]) == [2, 2]
assert list(out[2][0]) == [300, 301, 302]
assert list(out[2][0]) == [30, 31, 32]
context.set_context(mode=context.GRAPH_MODE)
out = net()
assert list(out[0]) == [1]
assert list(out[1]) == [2, 2]
assert list(out[2][0]) == [30, 31, 32]
def test_list_index_1D_parameter():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def construct(self, x):
list_ = [x]
@ -127,6 +205,7 @@ def test_list_index_1D_parameter():
def test_list_index_2D_parameter():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def construct(self, x):
list_ = [[x, x]]
@ -138,6 +217,7 @@ def test_list_index_2D_parameter():
def test_list_index_3D_parameter():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def construct(self, x):
list_ = [[[x, x]]]
@ -149,6 +229,7 @@ def test_list_index_3D_parameter():
def test_const_list_index_3D_bprop():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -177,6 +258,7 @@ def test_const_list_index_3D_bprop():
def test_parameter_list_index_3D_bprop():
context.set_context(mode=context.GRAPH_MODE)
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
@ -203,3 +285,427 @@ def test_parameter_list_index_3D_bprop():
value = Tensor(np.ones((2, 3), np.int64))
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
grad_net(x, value, sens)
class Net1(Cell):
def construct(self, a, b, start=None, stop=None, step=None):
a[start:stop:step] = b[start:stop:step]
return tuple(a)
def compare_func1(a, b, start=None, stop=None, step=None):
a[start:stop:step] = b[start:stop:step]
return tuple(a)
def test_list_slice_length_equal():
"""
Feature: List assign
Description: Test list assign the size is equal
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4]
b = [5, 6, 7, 8]
python_out = compare_func1(a, b, 0, None, 2)
a = [1, 2, 3, 4]
b = [5, 6, 7, 8]
net = Net1()
pynative_mode_out = net(a, b, 0, None, 2)
assert pynative_mode_out == python_out
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 0, None, 2)
assert graph_out == python_out
def test_list_slice_length_error():
"""
Feature: List assign
Description: Test list assign the size is not equal
Expectation: ValueError.
"""
context.set_context(mode=context.GRAPH_MODE)
a = [1, 2, 3, 4, 5]
b = [5, 6, 7, 8]
net = Net1()
with pytest.raises(ValueError) as err:
net(a, b, 0, None, 2)
assert "attempt to assign sequence of size 2 to extended slice of size 3" in str(err.value)
context.set_context(mode=context.PYNATIVE_MODE)
with pytest.raises(ValueError) as err:
net(a, b, 0, None, 2)
assert "attempt to assign sequence of size 2 to extended slice of size 3" in str(err.value)
def compare_func2(a, b, start=None, stop=None, step=None):
a[start:stop:step] = b
return tuple(a)
class Net2(Cell):
def construct(self, a, b, start=None, stop=None, step=None):
a[start:stop:step] = b
return tuple(a)
def test_list_slice_shrink():
"""
Feature: List assign
Description: Test list slice shrink assign
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33]
python_out = compare_func2(a, b, 0, 5)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33]
net = Net2()
pynative_out = net(a, b, 0, 5)
assert pynative_out == python_out
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33]
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 0, 5)
assert graph_out == python_out
def test_list_slice_insert():
"""
Feature: List assign
Description: Test list slice insert assign
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
python_out = compare_func2(a, b, 0, 1)
net = Net2()
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
pynative_out = net(a, b, 0, 1)
assert pynative_out == python_out
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 0, 1)
assert graph_out == python_out
def test_list_slice_assign():
"""
Feature: List assign
Description: Test list slice start and stop is larger than size
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
python_out = compare_func2(a, b, -12, 456)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
net = Net2()
pynative_out = net(a, b, -12, 456)
assert pynative_out == python_out
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, -12, 456)
assert graph_out == python_out
def test_list_slice_extend():
"""
Feature: List assign
Description: Test list slice extend
Expectation: No exception.
"""
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
net = Net2()
python_out = compare_func2(a, b, 1234, 0)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
pynative_out = net(a, b, 1234, 0)
assert pynative_out == python_out
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 1234, 0)
assert graph_out == python_out
def test_list_slice_extend_front():
"""
Feature: List assign
Description: Test list slice extend
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
python_out = compare_func2(a, b, 0, 0)
context.set_context(mode=context.PYNATIVE_MODE)
net = Net2()
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
pynative_out = net(a, b, 0, 0)
assert pynative_out == python_out
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 0, 0)
assert graph_out == python_out
def test_list_slice_extend_inner():
"""
Feature: List assign
Description: Test list slice extend
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
python_out = compare_func2(a, b, 5, 5)
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
net = Net2()
pynative_out = net(a, b, 5, 5)
assert pynative_out == python_out
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33, 44, 55]
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 5, 5)
assert graph_out == python_out
def test_list_slice_erase():
"""
Feature: List assign
Description: Test list slice erase
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6, 7]
python_out = compare_func2(a, [], 1, 3)
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7]
net = Net2()
pynative_out = net(a, [], 1, 3)
assert pynative_out == python_out
context.set_context(mode=context.GRAPH_MODE)
a = [1, 2, 3, 4, 5, 6, 7]
graph_out = net(a, [], 1, 3)
assert graph_out == python_out
def test_list_slice_tuple_without_step():
"""
Feature: List assign
Description: Test list slice assign with tuple
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = (11, 22, 33)
python_out = compare_func2(a, b, 0, 4, None)
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = (11, 22, 33)
net = Net2()
pynative_out = net(a, b, 0, 4, None)
assert pynative_out == python_out
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = (11, 22, 33)
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 0, 4, None)
assert graph_out == python_out
def test_list_slice_tuple_with_step():
"""
Feature: List assign
Description: Test list slice assign with tuple
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = (11, 22, 33)
python_out = compare_func2(a, b, 1, None, 3)
context.set_context(mode=context.PYNATIVE_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = (11, 22, 33)
net = Net2()
pynative_out = net(a, b, 1, None, 3)
assert pynative_out == python_out
context.set_context(mode=context.GRAPH_MODE)
graph_out = net(a, b, 1, None, 3)
assert graph_out == python_out
def test_list_double_slice():
"""
Feature: List assign
Description: Test list double slice assign
Expectation: ValueError
"""
context.set_context(mode=context.PYNATIVE_MODE)
@ms_function
def foo(a, b, start1, stop1, step1, start2, stop2, step2):
a[start1:stop1:step1][start2: stop2: step2] = b
return a
class NetInner(Cell):
def construct(self, a, b, start1, stop1, step1, start2, stop2, step2):
a[start1:stop1:step1][start2: stop2: step2] = b
return tuple(a)
net = NetInner()
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [11, 22, 33]
assert foo(a, b, 0, None, 1, 0, None, 3) == net(a, b, 0, None, 1, 0, None, 3)
def convert_tuple(a):
result = tuple()
for i in a:
if isinstance(i, list):
result += (tuple(i),)
continue
result += (i,)
return result
def test_list_in_list_slice():
"""
Feature: List assign
Description: Test high dimension list slice assign
Expectation: No exception.
"""
class TestNet(Cell):
def construct(self, a, b, index, start=None, stop=None, step=None):
a[index][start:stop:step] = b
return tuple(a)
def com_func3(a, b, index, start=None, stop=None, step=None):
a[index][start:stop:step] = b
return convert_tuple(a)
a = [1, 2, [1, 2, 3, 4, 5, 6, 7], 8, 9]
b = [1111, 2222]
python_out = com_func3(a, b, 2, 1, None, 3)
context.set_context(mode=context.PYNATIVE_MODE)
net = TestNet()
a = [1, 2, [1, 2, 3, 4, 5, 6, 7], 8, 9]
b = [1111, 2222]
pynative_out = convert_tuple(net(a, b, 2, 1, None, 3))
assert pynative_out == python_out
context.set_context(mode=context.GRAPH_MODE)
graph_out = convert_tuple(net(a, b, 2, 1, None, 3))
assert graph_out == python_out
def test_list_slice_negative_step():
"""
Feature: List assign
Description: Test negative step list slice assign
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [33, 44, 55]
python_out = compare_func2(a, b, -1, -9, -3)
context.set_context(mode=context.PYNATIVE_MODE)
net = Net2()
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [33, 44, 55]
pynative_out = net(a, b, -1, -9, -3)
assert pynative_out == python_out
context.set_context(mode=context.GRAPH_MODE)
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [33, 44, 55]
graph_out = net(a, b, -1, -9, -3)
assert graph_out == python_out
def test_graph_list_slice_assign_extended_number():
"""
Feature: List assign
Description: Test negative step list slice assign
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6]
b = 1
net = Net2()
context.set_context(mode=context.PYNATIVE_MODE)
with pytest.raises(TypeError) as err:
net(a, b, 0, None, 2)
assert "must assign iterable to extended slice" in str(err.value)
context.set_context(mode=context.GRAPH_MODE)
with pytest.raises(TypeError) as err:
net(a, b, 0, None, 2)
assert "must assign iterable to extended slice" in str(err.value)
def test_graph_list_slice_assign_number():
"""
Feature: List assign
Description: Test negative step list slice assign
Expectation: No exception.
"""
a = [1, 2, 3, 4, 5, 6]
b = 1
net = Net2()
context.set_context(mode=context.PYNATIVE_MODE)
with pytest.raises(TypeError) as err:
net(a, b, 0, None, 1)
assert "can only assign an iterable" in str(err.value)
context.set_context(mode=context.GRAPH_MODE)
with pytest.raises(TypeError) as err:
net(a, b, 0, None, 1)
assert "can only assign an iterable" in str(err.value)
def test_list_slice_negetive_error():
"""
Feature: List assign
Description: Test negative step list slice assign
Expectation: ValueError
"""
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
b = [33, 44, 55]
net = Net2()
context.set_context(mode=context.PYNATIVE_MODE)
with pytest.raises(ValueError) as err:
net(a, b, -1, -3, -3)
assert "attempt to assign sequence of size 3 to extended slice of size 1" in str(err.value)
context.set_context(mode=context.GRAPH_MODE)
with pytest.raises(ValueError) as err:
net(a, b, -1, -3, -3)
assert "attempt to assign sequence of size 3 to extended slice of size 1" in str(err.value)