forked from mindspore-Ecosystem/mindspore
support list slice assign
This commit is contained in:
parent
a7a9486d32
commit
11b4836040
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include "frontend/operator/composite/composite.h"
|
||||
#include <algorithm>
|
||||
|
||||
#include <tuple>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
|
@ -37,6 +37,7 @@
|
|||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
constexpr auto kStepDefault = 1;
|
||||
using AbstractTensor = mindspore::abstract::AbstractTensor;
|
||||
using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
|
||||
|
||||
|
@ -54,14 +55,6 @@ using mindspore::abstract::AbstractScalar;
|
|||
using mindspore::abstract::AbstractSlice;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
|
||||
ElemwiseMap kElemwiseMap = {{"__add__", kPrimScalarAdd}, {"__sub__", kPrimScalarSub}, {"__mul__", kPrimScalarMul},
|
||||
{"__truediv__", nullptr}, {"__floordiv__", nullptr}, {"__mod__", kPrimScalarMod},
|
||||
{"__pow__", kPrimScalarPow}, {"__eq__", kPrimScalarEq}, {"__lt__", kPrimScalarLt},
|
||||
{"__gt__", kPrimScalarGt}, {"__ne__", kPrimScalarNe}, {"__le__", kPrimScalarLe},
|
||||
{"__ge__", kPrimScalarGe}};
|
||||
|
||||
ValuePtr kCompositeHyperMap = std::make_shared<HyperMap>();
|
||||
|
||||
void HyperMap::Init() {
|
||||
if (fn_leaf_) {
|
||||
name_ = "hyper_map[" + fn_leaf_->name() + "]";
|
||||
|
@ -1246,78 +1239,81 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
|
|||
<< member->BuildType()->ToString();
|
||||
}
|
||||
|
||||
void GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &tuple, const AbstractSlicePtr &slice,
|
||||
int64_t *start_index, int64_t *stop_index, int64_t *step_value) {
|
||||
MS_EXCEPTION_IF_NULL(tuple);
|
||||
std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &sequence,
|
||||
const AbstractSlicePtr &slice) {
|
||||
MS_EXCEPTION_IF_NULL(sequence);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
MS_EXCEPTION_IF_NULL(start_index);
|
||||
MS_EXCEPTION_IF_NULL(stop_index);
|
||||
MS_EXCEPTION_IF_NULL(step_value);
|
||||
int64_t start_index;
|
||||
int64_t stop_index;
|
||||
int64_t step_value;
|
||||
|
||||
const std::string start_name("Slice start index");
|
||||
const std::string stop_name("Slice stop index");
|
||||
const std::string step_name("Slice step value");
|
||||
|
||||
int64_t tuple_size = SizeToLong(tuple->size());
|
||||
int64_t tuple_size = SizeToLong(sequence->size());
|
||||
int64_t start_default = 0;
|
||||
int64_t stop_default = tuple_size;
|
||||
int64_t step_default = 1;
|
||||
int64_t step_default = kStepDefault;
|
||||
|
||||
*step_value = CheckSliceMember(slice->step(), step_default, step_name);
|
||||
if (*step_value == 0) {
|
||||
step_value = CheckSliceMember(slice->step(), step_default, step_name);
|
||||
if (step_value == 0) {
|
||||
MS_EXCEPTION(ValueError) << "Slice step cannot be zero.";
|
||||
}
|
||||
|
||||
if (*step_value < 0) {
|
||||
if (step_value < 0) {
|
||||
start_default = tuple_size - 1;
|
||||
stop_default = -1;
|
||||
stop_default = ((-tuple_size) - 1);
|
||||
}
|
||||
|
||||
*start_index = CheckSliceMember(slice->start(), start_default, start_name);
|
||||
*stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
|
||||
start_index = CheckSliceMember(slice->start(), start_default, start_name);
|
||||
stop_index = CheckSliceMember(slice->stop(), stop_default, stop_name);
|
||||
|
||||
if (*start_index < -tuple_size) {
|
||||
*start_index = 0;
|
||||
}
|
||||
if (*stop_index > tuple_size) {
|
||||
*stop_index = tuple_size;
|
||||
}
|
||||
if (*start_index > tuple_size || *stop_index < -tuple_size) {
|
||||
*start_index = 0;
|
||||
*stop_index = 0;
|
||||
if (start_index < -tuple_size) {
|
||||
start_index = 0;
|
||||
}
|
||||
|
||||
*start_index = GetPositiveIndex(*start_index, tuple_size);
|
||||
if (!slice->stop()->isa<AbstractNone>()) {
|
||||
*stop_index = GetPositiveIndex(*stop_index, tuple_size);
|
||||
if (stop_index > tuple_size) {
|
||||
stop_index = tuple_size;
|
||||
}
|
||||
|
||||
if (start_index > tuple_size) {
|
||||
start_index = tuple_size;
|
||||
}
|
||||
|
||||
if (stop_index < ((-tuple_size) - 1)) {
|
||||
stop_index = 0;
|
||||
}
|
||||
|
||||
start_index = GetPositiveIndex(start_index, tuple_size);
|
||||
|
||||
stop_index = GetPositiveIndex(stop_index, tuple_size);
|
||||
|
||||
return std::make_tuple(start_index, stop_index, step_value);
|
||||
}
|
||||
|
||||
FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// slice a tuple
|
||||
// args: tuple, start index, end index, step
|
||||
auto seq_pair = CheckArgs(args_spec_list);
|
||||
auto sequence = seq_pair.first;
|
||||
auto slice = seq_pair.second;
|
||||
int64_t start_index;
|
||||
int64_t stop_index;
|
||||
int64_t step_value;
|
||||
GenerateTupleSliceParameter(sequence, slice, &start_index, &stop_index, &step_value);
|
||||
void SequenceSliceGetItem::CheckArgs(const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr size_t arg_size = 2;
|
||||
abstract::CheckArgsSize(this->name(), args_spec_list, arg_size);
|
||||
sequence_ = abstract::CheckArg<abstract::AbstractSequence>(this->name(), args_spec_list, 0);
|
||||
slice_ = abstract::CheckArg<AbstractSlice>(this->name(), args_spec_list, 1);
|
||||
}
|
||||
|
||||
FuncGraphPtr SequenceSliceGetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr p_tuple = ret->add_parameter();
|
||||
AnfNodePtr p_seq = ret->add_parameter();
|
||||
(void)ret->add_parameter();
|
||||
|
||||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(NewValueNode(prim_));
|
||||
if (step_value > 0) {
|
||||
for (int64_t index = start_index; index < stop_index; index = index + step_value) {
|
||||
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
|
||||
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
|
||||
}
|
||||
} else {
|
||||
for (int64_t index = start_index; index > stop_index; index = index + step_value) {
|
||||
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
|
||||
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_seq, NewValueNode(index)}));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1325,24 +1321,6 @@ FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
return ret;
|
||||
}
|
||||
|
||||
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> TupleSlice::CheckArgs(
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr size_t arg_size = 2;
|
||||
abstract::CheckArgsSize("TupleSlice", args_spec_list, arg_size);
|
||||
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("TupleSlice", args_spec_list, 0);
|
||||
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("TupleSlice", args_spec_list, 1);
|
||||
return std::make_pair(sequence, slice);
|
||||
}
|
||||
|
||||
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> ListSlice::CheckArgs(
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr size_t arg_size = 2;
|
||||
abstract::CheckArgsSize("ListSlice", args_spec_list, arg_size);
|
||||
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("ListSlice", args_spec_list, 0);
|
||||
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("ListSlice", args_spec_list, 1);
|
||||
return std::make_pair(sequence, slice);
|
||||
}
|
||||
|
||||
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// select indexed item
|
||||
// args: tuple of items, index
|
||||
|
@ -1363,20 +1341,16 @@ REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
|||
.def(py::init<std::string &>());
|
||||
}));
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) {
|
||||
(void)py::class_<TupleSlice, MetaFuncGraph, std::shared_ptr<TupleSlice>>(*m, "TupleSlice_")
|
||||
.def(py::init<std::string &>());
|
||||
}));
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
|
||||
(void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
|
||||
*m, "TupleGetItemTensor_")
|
||||
.def(py::init<std::string &>());
|
||||
}));
|
||||
|
||||
REGISTER_PYBIND_DEFINE(ListSlice_, ([](const py::module *m) {
|
||||
(void)py::class_<ListSlice, MetaFuncGraph, std::shared_ptr<ListSlice>>(*m, "ListSlice_")
|
||||
.def(py::init<std::string &>());
|
||||
REGISTER_PYBIND_DEFINE(SequenceSliceGetItem, ([](const py::module *m) {
|
||||
(void)py::class_<SequenceSliceGetItem, MetaFuncGraph, std::shared_ptr<SequenceSliceGetItem>>(
|
||||
*m, "SequenceSliceGetItem_")
|
||||
.def(py::init<std::string &, std::string &, std::string &>());
|
||||
}));
|
||||
|
||||
namespace {
|
||||
|
@ -1449,5 +1423,114 @@ REGISTER_PYBIND_DEFINE(Shard_, ([](const py::module *m) {
|
|||
(void)py::class_<Shard, MetaFuncGraph, std::shared_ptr<Shard>>(*m, "Shard_")
|
||||
.def(py::init<const std::string &>(), py::arg("fn"));
|
||||
}));
|
||||
|
||||
void ListSliceSetItem::CheckArgs(const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr size_t kSliceSetItemArgsSizeargs_size = 3;
|
||||
constexpr size_t kSliceSetItemListIndex = 0;
|
||||
constexpr size_t kSliceSetItemSliceIndex = 1;
|
||||
constexpr size_t kSliceSetItemValueIndex = 2;
|
||||
abstract::CheckArgsSize("list_slice_set_item", args_spec_list, kSliceSetItemArgsSizeargs_size);
|
||||
this->sequence_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_spec_list, kSliceSetItemListIndex);
|
||||
this->slice_ = abstract::CheckArg<AbstractSlice>("list_slice_set_item", args_spec_list, kSliceSetItemSliceIndex);
|
||||
this->value_list_ = abstract::CheckArg<AbstractList>("list_slice_set_item", args_spec_list, kSliceSetItemValueIndex);
|
||||
}
|
||||
|
||||
FuncGraphPtr ListSliceSetItem::BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) {
|
||||
// Init graph with the input list_node slice assign_node
|
||||
CheckAssignRange(start_index, stop_index, step_value);
|
||||
auto graph = std::make_shared<FuncGraph>();
|
||||
graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
auto list_node = graph->add_parameter();
|
||||
(void)graph->add_parameter();
|
||||
auto assign_parameter = graph->add_parameter();
|
||||
auto assign_node = GetAssignNode(graph, assign_parameter, step_value);
|
||||
std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
|
||||
int64_t list_index = 0;
|
||||
// check the index is in the slice range
|
||||
auto check_in_range = [start_index, stop_index, step_value](int64_t index) -> bool {
|
||||
if (step_value > 0) {
|
||||
return (index >= start_index && index < stop_index);
|
||||
}
|
||||
return (index <= start_index && index > stop_index);
|
||||
};
|
||||
int64_t list_size = SizeToLong(sequence_->size());
|
||||
int64_t assign_index = 0;
|
||||
int64_t value_size = SizeToLong(value_list_->size());
|
||||
while (list_index < list_size || assign_index < value_size) {
|
||||
if (!check_in_range(list_index)) {
|
||||
// list start <= stop && step = 1 insert the assign node to target node
|
||||
while (assign_index < value_size && list_index == start_index) {
|
||||
(void)elems.emplace_back(
|
||||
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
|
||||
}
|
||||
if (list_index < list_size) {
|
||||
(void)elems.emplace_back(
|
||||
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
|
||||
}
|
||||
} else {
|
||||
if (((list_index - start_index) % step_value) == 0) {
|
||||
++list_index;
|
||||
if (assign_index >= value_size) {
|
||||
continue;
|
||||
}
|
||||
(void)elems.emplace_back(
|
||||
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
|
||||
} else {
|
||||
(void)elems.emplace_back(
|
||||
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), list_node, NewValueNode(list_index++)}));
|
||||
}
|
||||
// the assign node's len is larger than the range
|
||||
while (!check_in_range(list_index) && assign_index < value_size) {
|
||||
(void)elems.emplace_back(
|
||||
graph->NewCNodeInOrder({NewValueNode(kPrimListGetItem), assign_node, NewValueNode(assign_index++)}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph->set_output(graph->NewCNodeInOrder(elems));
|
||||
return graph;
|
||||
}
|
||||
|
||||
void ListSliceSetItem::CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value) {
|
||||
if (step_value != kStepDefault) {
|
||||
int64_t start_include = 0;
|
||||
if (start_index < SizeToLong(sequence_->size()) && start_index >= -SizeToLong(sequence_->size())) {
|
||||
start_include = 1;
|
||||
}
|
||||
auto assign_size = ((stop_index - start_index - 1) / step_value) + start_include;
|
||||
if (step_value < 0) {
|
||||
assign_size = ((start_index - stop_index) / -step_value) + start_include;
|
||||
}
|
||||
if (assign_size != SizeToLong(value_list_->size())) {
|
||||
MS_EXCEPTION(ValueError) << "attempt to assign sequence of size " << value_list_->size()
|
||||
<< " to extended slice of size " << assign_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AnfNodePtr ListSliceSetItem::GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node,
|
||||
int64_t step_value) {
|
||||
if (step_value > 0) {
|
||||
return assign_node;
|
||||
}
|
||||
std::vector<AnfNodePtr> elems = {NewValueNode(prim::kPrimMakeList)};
|
||||
for (int64_t i = SizeToInt(value_list_->size()) - 1; i >= 0; --i) {
|
||||
elems.emplace_back(
|
||||
func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), assign_node, NewValueNode(i)}));
|
||||
}
|
||||
return func_graph->NewCNodeInOrder(elems);
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(ListSliceSetItem_, ([](const py::module *m) {
|
||||
(void)py::class_<ListSliceSetItem, MetaFuncGraph, std::shared_ptr<ListSliceSetItem>>(
|
||||
*m, "ListSliceSetItem_")
|
||||
.def(py::init<const std::string &>());
|
||||
}));
|
||||
|
||||
FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
this->CheckArgs(args_spec_list);
|
||||
auto [start, stop, step] = GenerateTupleSliceParameter(sequence_, slice_);
|
||||
return this->BuildFuncGraph(start, stop, step);
|
||||
}
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,6 +47,7 @@ using AbstractScalarPtr = abstract::AbstractScalarPtr;
|
|||
using AbstractTensorPtr = abstract::AbstractTensorPtr;
|
||||
using ElemwiseMap = mindspore::HashMap<std::string, PrimitivePtr>;
|
||||
using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
|
||||
using AbstractListPtr = abstract::AbstractListPtr;
|
||||
|
||||
class HyperMap : public MetaFuncGraph {
|
||||
public:
|
||||
|
@ -202,37 +203,55 @@ using TupleAddPtr = std::shared_ptr<TupleAdd>;
|
|||
|
||||
class SequenceSlice : public MetaFuncGraph {
|
||||
public:
|
||||
explicit SequenceSlice(const std::string &name, const PrimitivePtr &prim, const PrimitivePtr &get_item)
|
||||
: MetaFuncGraph(name), prim_(prim), get_item_(get_item) {}
|
||||
explicit SequenceSlice(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~SequenceSlice() override = default;
|
||||
MS_DECLARE_PARENT(SequenceSlice, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) final;
|
||||
friend bool operator==(const SequenceSlice &lhs, const SequenceSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||
virtual std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
|
||||
const AbstractBasePtrList &args_spec_list) = 0;
|
||||
|
||||
protected:
|
||||
virtual void CheckArgs(const AbstractBasePtrList &args_spec_list) = 0;
|
||||
virtual FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) = 0;
|
||||
abstract::AbstractSequencePtr sequence_ = nullptr;
|
||||
AbstractSlicePtr slice_ = nullptr;
|
||||
};
|
||||
|
||||
class SequenceSliceGetItem : public SequenceSlice {
|
||||
public:
|
||||
explicit SequenceSliceGetItem(const std::string &name, const std::string &prim_name, const std::string &get_item_name)
|
||||
: SequenceSlice(name),
|
||||
prim_(std::make_shared<Primitive>(prim_name)),
|
||||
get_item_(std::make_shared<Primitive>(get_item_name)) {}
|
||||
~SequenceSliceGetItem() override = default;
|
||||
MS_DECLARE_PARENT(SequenceSliceGetItem, MetaFuncGraph)
|
||||
friend bool operator==(const SequenceSliceGetItem &lhs, const SequenceSliceGetItem &rhs) {
|
||||
return lhs.name_ == rhs.name_;
|
||||
}
|
||||
|
||||
protected:
|
||||
void CheckArgs(const AbstractBasePtrList &args_spec_list) override;
|
||||
FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override;
|
||||
|
||||
private:
|
||||
PrimitivePtr prim_;
|
||||
PrimitivePtr get_item_;
|
||||
};
|
||||
|
||||
class TupleSlice : public SequenceSlice {
|
||||
class ListSliceSetItem : public SequenceSlice {
|
||||
public:
|
||||
explicit TupleSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeTuple, prim::kPrimTupleGetItem) {}
|
||||
~TupleSlice() override = default;
|
||||
MS_DECLARE_PARENT(TupleSlice, SequenceSlice)
|
||||
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
|
||||
const AbstractBasePtrList &args_spec_list) override;
|
||||
};
|
||||
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
|
||||
explicit ListSliceSetItem(const std::string &name) : SequenceSlice(name) {}
|
||||
~ListSliceSetItem() override = default;
|
||||
MS_DECLARE_PARENT(ListSliceSetItem, MetaFuncGraph)
|
||||
friend bool operator==(const ListSliceSetItem &lhs, const ListSliceSetItem &rhs) { return lhs.name_ == rhs.name_; }
|
||||
|
||||
class ListSlice : public SequenceSlice {
|
||||
public:
|
||||
explicit ListSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeList, prim::kPrimListGetItem) {}
|
||||
~ListSlice() override = default;
|
||||
MS_DECLARE_PARENT(ListSlice, SequenceSlice)
|
||||
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
|
||||
const AbstractBasePtrList &args_spec_list) override;
|
||||
protected:
|
||||
void CheckArgs(const AbstractBasePtrList &args_spec_list) override;
|
||||
FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override;
|
||||
|
||||
private:
|
||||
void CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value);
|
||||
AnfNodePtr GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node, int64_t step_value);
|
||||
AbstractListPtr value_list_ = nullptr;
|
||||
};
|
||||
|
||||
class TupleGetItemTensor : public MetaFuncGraph {
|
||||
|
|
|
@ -168,10 +168,11 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
|||
inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
||||
return meta_func_graph->isa<prim::Tail>() || meta_func_graph->isa<prim::MakeTupleGradient>() ||
|
||||
meta_func_graph->isa<prim::MakeListGradient>() || meta_func_graph->isa<prim::TupleAdd>() ||
|
||||
meta_func_graph->isa<prim::SequenceSlice>() || meta_func_graph->isa<prim::UnpackCall>() ||
|
||||
meta_func_graph->isa<prim::ZipOperation>() || meta_func_graph->isa<prim::ListAppend>() ||
|
||||
meta_func_graph->isa<prim::ListInsert>() || meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>() ||
|
||||
meta_func_graph->isa<prim::VmapMatchOutAxis>() || meta_func_graph->isa<prim::VmapGeneralPreprocess>();
|
||||
meta_func_graph->isa<prim::SequenceSliceGetItem>() || meta_func_graph->isa<prim::ListSliceSetItem>() ||
|
||||
meta_func_graph->isa<prim::UnpackCall>() || meta_func_graph->isa<prim::ZipOperation>() ||
|
||||
meta_func_graph->isa<prim::ListAppend>() || meta_func_graph->isa<prim::ListInsert>() ||
|
||||
meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>() || meta_func_graph->isa<prim::VmapMatchOutAxis>() ||
|
||||
meta_func_graph->isa<prim::VmapGeneralPreprocess>();
|
||||
}
|
||||
|
||||
/* inherit relation of MetaFuncGraph
|
||||
|
@ -186,7 +187,10 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
|||
* ├── MakeTupleGradient
|
||||
* ├── MakeListGradient
|
||||
* ├── GradOperation
|
||||
* └── TupleAdd
|
||||
* ├── TupleAdd
|
||||
* └── SequenceSlice
|
||||
* ├── SequenceSliceGetItem
|
||||
* └── ListSliceSetItem
|
||||
*/
|
||||
std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) {
|
||||
if (meta_func_graph == nullptr) {
|
||||
|
|
|
@ -2314,14 +2314,14 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje
|
|||
}
|
||||
if (AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
|
||||
AST_SUB_TYPE_SUBSCRIPT) {
|
||||
HandleAssignSubscript(block, value_obj, setitem_app);
|
||||
return;
|
||||
if (IsSubscriptReferenceType(value_obj)) {
|
||||
HandleAssignSubscript(block, value_obj, setitem_app);
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (!py::hasattr(value_obj, "id")) {
|
||||
MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast<std::string>() << "\n\n"
|
||||
<< trace::GetDebugInfo(value_node->debug_info());
|
||||
if (py::hasattr(value_obj, "id")) {
|
||||
var_name = value_obj.attr("id").cast<std::string>();
|
||||
}
|
||||
var_name = value_obj.attr("id").cast<std::string>();
|
||||
block->WriteVariable(var_name, setitem_app);
|
||||
}
|
||||
|
||||
|
@ -2910,5 +2910,12 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
|
|||
}
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
bool Parser::IsSubscriptReferenceType(const py::object &obj) {
|
||||
py::object slice_node = python_adapter::GetPyObjAttr(obj, "slice");
|
||||
auto node_type = ast_->GetNodeType(slice_node);
|
||||
auto node_name = node_type->node_name();
|
||||
return node_name != "Slice";
|
||||
}
|
||||
} // namespace parse
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -291,6 +291,8 @@ class Parser {
|
|||
const std::vector<AnfNodePtr> &packed_arguments,
|
||||
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const;
|
||||
ScopePtr GetScopeForParseFunction();
|
||||
// Check the value is subscript is reference type
|
||||
bool IsSubscriptReferenceType(const py::object &obj);
|
||||
void BuildMethodMap();
|
||||
FunctionBlockPtr MakeFunctionBlock(const Parser &parse) {
|
||||
FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);
|
||||
|
|
|
@ -778,7 +778,10 @@ class MS_CORE_API AbstractSequence : public AbstractBase {
|
|||
///
|
||||
/// \return A size_t.
|
||||
std::size_t size() const { return elements_.size(); }
|
||||
|
||||
/// \brief Get the size of the stored elements.
|
||||
///
|
||||
/// \return A size_t.
|
||||
bool empty() const { return elements_.empty(); }
|
||||
/// \brief Get the stored elements.
|
||||
///
|
||||
/// \return A vector of elements.
|
||||
|
|
|
@ -21,8 +21,8 @@ from types import FunctionType
|
|||
import mindspore as ms
|
||||
from mindspore import context
|
||||
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
||||
ListSlice_, VmapOperation_, TaylorOperation_
|
||||
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
||||
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_executor, _wrap_func
|
||||
from ..primitive import Primitive
|
||||
|
@ -30,7 +30,7 @@ from ..operations import _grad_ops
|
|||
from .. import operations as P
|
||||
from .. import signature as sig
|
||||
|
||||
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_, ListSlice_]
|
||||
__all__ = [TupleAdd_, UnpackCall_, TupleGetItemTensor_, SequenceSliceGetItem_, ListSliceSetItem_]
|
||||
|
||||
|
||||
def add_flags(fn=None, **flags):
|
||||
|
|
|
@ -26,7 +26,7 @@ using ".register" decorator.
|
|||
"""
|
||||
|
||||
|
||||
class _TupleSlice(base.TupleSlice_):
|
||||
class _TupleSlice(base.SequenceSliceGetItem_):
|
||||
"""
|
||||
Slices a tuple.
|
||||
|
||||
|
@ -40,7 +40,7 @@ class _TupleSlice(base.TupleSlice_):
|
|||
|
||||
def __init__(self, name):
|
||||
"""Initialize _TupleSlice."""
|
||||
base.TupleSlice_.__init__(self, name)
|
||||
base.SequenceSliceGetItem_.__init__(self, name, "MakeTuple", "TupleGetItem")
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
@ -50,7 +50,7 @@ _tuple_slice = _TupleSlice('tuple_slice')
|
|||
"""_tuple_slice is a metafuncgraph object which will slice a tuple."""
|
||||
|
||||
|
||||
class _ListSlice(base.ListSlice_):
|
||||
class _ListSlice(base.SequenceSliceGetItem_):
|
||||
"""
|
||||
Slices a List.
|
||||
|
||||
|
@ -64,7 +64,7 @@ class _ListSlice(base.ListSlice_):
|
|||
|
||||
def __init__(self, name):
|
||||
"""Initialize _TupleSlice."""
|
||||
base.ListSlice_.__init__(self, name)
|
||||
base.SequenceSliceGetItem_.__init__(self, name, "make_list", "list_getitem")
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
|
|
@ -17,11 +17,38 @@
|
|||
|
||||
from . import _compile_utils as compile_utils
|
||||
from ... import functional as F
|
||||
from ...operations._inner_ops import SliceGetItem
|
||||
from ...composite import base
|
||||
from ....common import Tensor
|
||||
|
||||
setitem = base.MultitypeFuncGraph('setitem')
|
||||
|
||||
slice_get_item = SliceGetItem()
|
||||
|
||||
|
||||
class _ListSliceSetItem(base.ListSliceSetItem_):
|
||||
"""
|
||||
List slice assign.
|
||||
|
||||
Inputs:
|
||||
data (List): A List to be sliced.
|
||||
s (slice): The index to slice list data.
|
||||
value : The value to be assign
|
||||
|
||||
Outputs:
|
||||
List, consists of some elements of data.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _TupleSlice."""
|
||||
base.ListSliceSetItem_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
_list_slice_set_item = _ListSliceSetItem('list_slice_set_item')
|
||||
"""_list_slice_set_item is a MetaFuncGraph object which assign a list will slice."""
|
||||
|
||||
|
||||
@setitem.register("List", "Number", "String")
|
||||
def _list_setitem_with_string(data, number_index, value):
|
||||
|
@ -102,6 +129,75 @@ def _list_setitem_with_tuple(data, number_index, value):
|
|||
return F.list_setitem(data, number_index, value)
|
||||
|
||||
|
||||
@setitem.register("List", "Slice", "Tuple")
|
||||
def _list_slice_setitem_with_tuple(data, slice_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
Inputs:
|
||||
data (list): Data of type list.
|
||||
slice_index (slice): Index of data.
|
||||
value (tuple): Value given.
|
||||
|
||||
Outputs:
|
||||
list, type is the same as the element type of data.
|
||||
"""
|
||||
list_value = list(value)
|
||||
return _list_slice_set_item(data, slice_index, list_value)
|
||||
|
||||
|
||||
@setitem.register("List", "Slice", "List")
|
||||
def _list_slice_setitem_with_list(data, slice_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
Inputs:
|
||||
data (list): Data of type list.
|
||||
slice_index (slice): Index of data.
|
||||
value (list): Value given.
|
||||
|
||||
Outputs:
|
||||
list, type is the same as the element type of data.
|
||||
"""
|
||||
return _list_slice_set_item(data, slice_index, value)
|
||||
|
||||
|
||||
@setitem.register("List", "Slice", "Tensor")
|
||||
def _list_slice_setitem_with_tensor(data, slice_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
Inputs:
|
||||
data (list): Data of type list.
|
||||
slice_index (slice): Index of data.
|
||||
value (Tensor): Value given.
|
||||
|
||||
Outputs:
|
||||
list, type is the same as the element type of data.
|
||||
"""
|
||||
value_list = list(value)
|
||||
return _list_slice_set_item(data, slice_index, value_list)
|
||||
|
||||
|
||||
@setitem.register("List", "Slice", "Number")
|
||||
def _list_slice_setitem_with_number(data, slice_index, value):
|
||||
"""
|
||||
Assigns value to list.
|
||||
|
||||
Inputs:
|
||||
data (list): Data of type list.
|
||||
slice_index (slice): Index of data.
|
||||
value (number): Value given.
|
||||
|
||||
Outputs:
|
||||
lis/t, type is the same as the element type of data.
|
||||
"""
|
||||
step = slice_get_item(slice_index, "step")
|
||||
if step == 1 or step is None:
|
||||
raise TypeError("can only assign an iterable")
|
||||
raise TypeError("must assign iterable to extended slice")
|
||||
|
||||
|
||||
@setitem.register("Dictionary", "String", "Tensor")
|
||||
def _dict_setitem_with_tensor(data, key, value):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import pytest
|
||||
from mindspore.nn import Cell
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_list_slice_tensor_no_step():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice assign with tensor
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class NetInner(Cell):
|
||||
def construct(self, start=None, stop=None, step=None):
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = Tensor([11, 22, 33])
|
||||
a[start:stop:step] = b
|
||||
return tuple(a)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = NetInner()
|
||||
python_out = (Tensor(11), Tensor(22), Tensor(33), 4, 5, 6, 7, 8, 9)
|
||||
pynative_out = net(0, 3, None)
|
||||
assert pynative_out == python_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(0, 3, None)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_pynative_list_slice_tensor_with_step():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice assign with tensor
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class NetInner(Cell):
|
||||
def construct(self, start=None, stop=None, step=None):
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = Tensor([11, 22, 33])
|
||||
a[start:stop:step] = b
|
||||
return tuple(a)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = NetInner()
|
||||
python_out = (Tensor(11), 2, 3, Tensor(22), 5, 6, Tensor(33), 8, 9)
|
||||
pynative_out = net(0, None, 3)
|
||||
assert python_out == pynative_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(0, None, 3)
|
||||
assert python_out == graph_out
|
|
@ -88,7 +88,8 @@ class UTCompositeUtils {
|
|||
};
|
||||
|
||||
TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
|
||||
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
|
||||
MetaFuncGraphPtr tupleSlicePtr =
|
||||
std::make_shared<prim::SequenceSliceGetItem>("TupleSlice", "MakeTuple", "TupleGetItem");
|
||||
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 3);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -114,7 +115,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_two_numbers) {
|
|||
}
|
||||
|
||||
TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
|
||||
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
|
||||
MetaFuncGraphPtr tupleSlicePtr =
|
||||
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
|
||||
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -146,7 +148,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) {
|
|||
|
||||
TEST_F(TestComposite, test_TupleSlice_arg_slice) {
|
||||
std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
|
||||
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
|
||||
MetaFuncGraphPtr tupleSlicePtr =
|
||||
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
|
||||
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -173,7 +176,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice) {
|
|||
}
|
||||
|
||||
TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
|
||||
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
|
||||
MetaFuncGraphPtr tupleSlicePtr =
|
||||
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
|
||||
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -200,7 +204,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_none) {
|
|||
}
|
||||
|
||||
TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
|
||||
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
|
||||
MetaFuncGraphPtr tupleSlicePtr =
|
||||
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
|
||||
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -227,7 +232,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_negative) {
|
|||
}
|
||||
|
||||
TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
|
||||
MetaFuncGraphPtr tupleSlicePtr = std::make_shared<prim::TupleSlice>("tuple_slice");
|
||||
MetaFuncGraphPtr tupleSlicePtr =
|
||||
std::make_shared<prim::SequenceSliceGetItem>("tuple_slice", "MakeTuple", "TupleGetItem");
|
||||
FuncGraphPtr tupleSliceGraphPtr = UTCompositeUtils::MakeFuncGraph(tupleSlicePtr, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -257,7 +263,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
|
|||
/// Description: The second input is a scalar
|
||||
/// Expectation: Throw type error
|
||||
TEST_F(TestComposite, test_ListSlice_arg_one_number) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -292,7 +298,7 @@ TEST_F(TestComposite, test_ListSlice_arg_one_number) {
|
|||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice) {
|
||||
std::shared_ptr<py::scoped_interpreter> env = python_adapter::set_python_scoped();
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -321,7 +327,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice) {
|
|||
/// Description: Test List slice the step is none
|
||||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -350,7 +356,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
|
|||
/// Description: Test List slice the step is negative
|
||||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice", "make_list", "list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
@ -379,7 +385,7 @@ TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
|
|||
/// Description: Test List slice the step is positive
|
||||
/// Expectation: No Expectation
|
||||
TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||
MetaFuncGraphPtr list_slice = std::make_shared<prim::SequenceSliceGetItem>("list_slice","make_list","list_getitem");
|
||||
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||
|
||||
AbstractBasePtrList eles;
|
||||
|
|
|
@ -13,19 +13,23 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test enumerate"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor, ms_function
|
||||
from mindspore import context
|
||||
|
||||
|
||||
def test_list_index_1D():
|
||||
def test_list_index_1d():
|
||||
"""
|
||||
Feature: List index assign
|
||||
Description: Test list assign in pynative mode
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
|
@ -38,8 +42,22 @@ def test_list_index_1D():
|
|||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
out = net()
|
||||
assert list(out[0]) == [100]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
def test_list_neg_index_1D():
|
||||
|
||||
|
||||
def test_list_neg_index_1d():
|
||||
"""
|
||||
Feature: List index assign
|
||||
Description: Test list assign in pynative mode
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
|
@ -52,8 +70,20 @@ def test_list_neg_index_1D():
|
|||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
out = net()
|
||||
assert list(out[0]) == [100]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
def test_list_index_2D():
|
||||
|
||||
def test_list_index_2d():
|
||||
"""
|
||||
Feature: List index assign
|
||||
Description: Test list assign in pynative mode
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
|
@ -67,23 +97,48 @@ def test_list_index_2D():
|
|||
assert list(out[1]) == [200, 201]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_neg_index_2D():
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[1][-2] = 200
|
||||
list_[1][-1] = 201
|
||||
return list_
|
||||
|
||||
net = Net()
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
out = net()
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [200, 201]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_index_3D():
|
||||
def test_list_neg_index_2d():
|
||||
"""
|
||||
Feature: List index assign
|
||||
Description: Test list assign in pynative mode
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [3, 3, 3]]
|
||||
list_[1][-2] = 20
|
||||
list_[1][-1] = 21
|
||||
return list_
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [20, 21]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
out = net()
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [20, 21]
|
||||
assert list(out[2]) == [3, 3, 3]
|
||||
|
||||
|
||||
def test_list_index_3d():
|
||||
"""
|
||||
Feature: List index assign
|
||||
Description: Test list assign in pynative mode
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [[3, 3, 3]]]
|
||||
|
@ -92,30 +147,53 @@ def test_list_index_3D():
|
|||
list_[2][0][2] = 302
|
||||
return list_
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = Net()
|
||||
out = net()
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2][0]) == [300, 301, 302]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
out = net()
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2][0]) == [300, 301, 302]
|
||||
|
||||
|
||||
def test_list_neg_index_3d():
|
||||
"""
|
||||
Feature: List index assign
|
||||
Description: Test list assign in pynative mode
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
|
||||
def test_list_neg_index_3D():
|
||||
class Net(nn.Cell):
|
||||
def construct(self):
|
||||
list_ = [[1], [2, 2], [[3, 3, 3]]]
|
||||
list_[2][0][-3] = 300
|
||||
list_[2][0][-2] = 301
|
||||
list_[2][0][-1] = 302
|
||||
list_[2][0][-3] = 30
|
||||
list_[2][0][-2] = 31
|
||||
list_[2][0][-1] = 32
|
||||
return list_
|
||||
|
||||
net = Net()
|
||||
out = net()
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2][0]) == [300, 301, 302]
|
||||
assert list(out[2][0]) == [30, 31, 32]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
out = net()
|
||||
assert list(out[0]) == [1]
|
||||
assert list(out[1]) == [2, 2]
|
||||
assert list(out[2][0]) == [30, 31, 32]
|
||||
|
||||
|
||||
|
||||
|
||||
def test_list_index_1D_parameter():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
list_ = [x]
|
||||
|
@ -127,6 +205,7 @@ def test_list_index_1D_parameter():
|
|||
|
||||
|
||||
def test_list_index_2D_parameter():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
list_ = [[x, x]]
|
||||
|
@ -138,6 +217,7 @@ def test_list_index_2D_parameter():
|
|||
|
||||
|
||||
def test_list_index_3D_parameter():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Net(nn.Cell):
|
||||
def construct(self, x):
|
||||
list_ = [[[x, x]]]
|
||||
|
@ -149,6 +229,7 @@ def test_list_index_3D_parameter():
|
|||
|
||||
|
||||
def test_const_list_index_3D_bprop():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -177,6 +258,7 @@ def test_const_list_index_3D_bprop():
|
|||
|
||||
|
||||
def test_parameter_list_index_3D_bprop():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
|
@ -203,3 +285,427 @@ def test_parameter_list_index_3D_bprop():
|
|||
value = Tensor(np.ones((2, 3), np.int64))
|
||||
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
|
||||
grad_net(x, value, sens)
|
||||
|
||||
|
||||
|
||||
class Net1(Cell):
|
||||
def construct(self, a, b, start=None, stop=None, step=None):
|
||||
a[start:stop:step] = b[start:stop:step]
|
||||
return tuple(a)
|
||||
|
||||
|
||||
def compare_func1(a, b, start=None, stop=None, step=None):
|
||||
a[start:stop:step] = b[start:stop:step]
|
||||
return tuple(a)
|
||||
|
||||
|
||||
|
||||
def test_list_slice_length_equal():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list assign the size is equal
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4]
|
||||
b = [5, 6, 7, 8]
|
||||
python_out = compare_func1(a, b, 0, None, 2)
|
||||
|
||||
a = [1, 2, 3, 4]
|
||||
b = [5, 6, 7, 8]
|
||||
net = Net1()
|
||||
pynative_mode_out = net(a, b, 0, None, 2)
|
||||
assert pynative_mode_out == python_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 0, None, 2)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_length_error():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list assign the size is not equal
|
||||
Expectation: ValueError.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
a = [1, 2, 3, 4, 5]
|
||||
b = [5, 6, 7, 8]
|
||||
net = Net1()
|
||||
with pytest.raises(ValueError) as err:
|
||||
net(a, b, 0, None, 2)
|
||||
assert "attempt to assign sequence of size 2 to extended slice of size 3" in str(err.value)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
with pytest.raises(ValueError) as err:
|
||||
net(a, b, 0, None, 2)
|
||||
assert "attempt to assign sequence of size 2 to extended slice of size 3" in str(err.value)
|
||||
|
||||
|
||||
def compare_func2(a, b, start=None, stop=None, step=None):
|
||||
a[start:stop:step] = b
|
||||
return tuple(a)
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
def construct(self, a, b, start=None, stop=None, step=None):
|
||||
a[start:stop:step] = b
|
||||
return tuple(a)
|
||||
|
||||
|
||||
def test_list_slice_shrink():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice shrink assign
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33]
|
||||
python_out = compare_func2(a, b, 0, 5)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33]
|
||||
net = Net2()
|
||||
pynative_out = net(a, b, 0, 5)
|
||||
assert pynative_out == python_out
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33]
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 0, 5)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_insert():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice insert assign
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
python_out = compare_func2(a, b, 0, 1)
|
||||
net = Net2()
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
pynative_out = net(a, b, 0, 1)
|
||||
assert pynative_out == python_out
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 0, 1)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_assign():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice start and stop is larger than size
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
python_out = compare_func2(a, b, -12, 456)
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
net = Net2()
|
||||
pynative_out = net(a, b, -12, 456)
|
||||
assert pynative_out == python_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, -12, 456)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_extend():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice extend
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
net = Net2()
|
||||
python_out = compare_func2(a, b, 1234, 0)
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
pynative_out = net(a, b, 1234, 0)
|
||||
assert pynative_out == python_out
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 1234, 0)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_extend_front():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice extend
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
python_out = compare_func2(a, b, 0, 0)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = Net2()
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
pynative_out = net(a, b, 0, 0)
|
||||
assert pynative_out == python_out
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 0, 0)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_extend_inner():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice extend
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
python_out = compare_func2(a, b, 5, 5)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
net = Net2()
|
||||
pynative_out = net(a, b, 5, 5)
|
||||
assert pynative_out == python_out
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 5, 5)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_erase():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice erase
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6, 7]
|
||||
python_out = compare_func2(a, [], 1, 3)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7]
|
||||
net = Net2()
|
||||
pynative_out = net(a, [], 1, 3)
|
||||
assert pynative_out == python_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7]
|
||||
graph_out = net(a, [], 1, 3)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_tuple_without_step():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice assign with tuple
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = (11, 22, 33)
|
||||
python_out = compare_func2(a, b, 0, 4, None)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = (11, 22, 33)
|
||||
net = Net2()
|
||||
pynative_out = net(a, b, 0, 4, None)
|
||||
assert pynative_out == python_out
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = (11, 22, 33)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 0, 4, None)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_tuple_with_step():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list slice assign with tuple
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = (11, 22, 33)
|
||||
python_out = compare_func2(a, b, 1, None, 3)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = (11, 22, 33)
|
||||
net = Net2()
|
||||
pynative_out = net(a, b, 1, None, 3)
|
||||
assert pynative_out == python_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = net(a, b, 1, None, 3)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_double_slice():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test list double slice assign
|
||||
Expectation: ValueError
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
@ms_function
|
||||
def foo(a, b, start1, stop1, step1, start2, stop2, step2):
|
||||
a[start1:stop1:step1][start2: stop2: step2] = b
|
||||
return a
|
||||
|
||||
class NetInner(Cell):
|
||||
def construct(self, a, b, start1, stop1, step1, start2, stop2, step2):
|
||||
a[start1:stop1:step1][start2: stop2: step2] = b
|
||||
return tuple(a)
|
||||
|
||||
net = NetInner()
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [11, 22, 33]
|
||||
assert foo(a, b, 0, None, 1, 0, None, 3) == net(a, b, 0, None, 1, 0, None, 3)
|
||||
|
||||
|
||||
def convert_tuple(a):
|
||||
result = tuple()
|
||||
for i in a:
|
||||
if isinstance(i, list):
|
||||
result += (tuple(i),)
|
||||
continue
|
||||
result += (i,)
|
||||
return result
|
||||
|
||||
|
||||
def test_list_in_list_slice():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test high dimension list slice assign
|
||||
Expectation: No exception.
|
||||
"""
|
||||
|
||||
class TestNet(Cell):
|
||||
def construct(self, a, b, index, start=None, stop=None, step=None):
|
||||
a[index][start:stop:step] = b
|
||||
return tuple(a)
|
||||
|
||||
def com_func3(a, b, index, start=None, stop=None, step=None):
|
||||
a[index][start:stop:step] = b
|
||||
return convert_tuple(a)
|
||||
|
||||
a = [1, 2, [1, 2, 3, 4, 5, 6, 7], 8, 9]
|
||||
b = [1111, 2222]
|
||||
python_out = com_func3(a, b, 2, 1, None, 3)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = TestNet()
|
||||
a = [1, 2, [1, 2, 3, 4, 5, 6, 7], 8, 9]
|
||||
b = [1111, 2222]
|
||||
pynative_out = convert_tuple(net(a, b, 2, 1, None, 3))
|
||||
assert pynative_out == python_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
graph_out = convert_tuple(net(a, b, 2, 1, None, 3))
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_list_slice_negative_step():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test negative step list slice assign
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [33, 44, 55]
|
||||
python_out = compare_func2(a, b, -1, -9, -3)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
net = Net2()
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [33, 44, 55]
|
||||
pynative_out = net(a, b, -1, -9, -3)
|
||||
assert pynative_out == python_out
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [33, 44, 55]
|
||||
graph_out = net(a, b, -1, -9, -3)
|
||||
assert graph_out == python_out
|
||||
|
||||
|
||||
def test_graph_list_slice_assign_extended_number():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test negative step list slice assign
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6]
|
||||
b = 1
|
||||
|
||||
net = Net2()
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
with pytest.raises(TypeError) as err:
|
||||
net(a, b, 0, None, 2)
|
||||
assert "must assign iterable to extended slice" in str(err.value)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
with pytest.raises(TypeError) as err:
|
||||
net(a, b, 0, None, 2)
|
||||
assert "must assign iterable to extended slice" in str(err.value)
|
||||
|
||||
|
||||
def test_graph_list_slice_assign_number():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test negative step list slice assign
|
||||
Expectation: No exception.
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6]
|
||||
b = 1
|
||||
net = Net2()
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
with pytest.raises(TypeError) as err:
|
||||
net(a, b, 0, None, 1)
|
||||
assert "can only assign an iterable" in str(err.value)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
with pytest.raises(TypeError) as err:
|
||||
net(a, b, 0, None, 1)
|
||||
assert "can only assign an iterable" in str(err.value)
|
||||
|
||||
|
||||
def test_list_slice_negetive_error():
|
||||
"""
|
||||
Feature: List assign
|
||||
Description: Test negative step list slice assign
|
||||
Expectation: ValueError
|
||||
"""
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
b = [33, 44, 55]
|
||||
net = Net2()
|
||||
context.set_context(mode=context.PYNATIVE_MODE)
|
||||
with pytest.raises(ValueError) as err:
|
||||
net(a, b, -1, -3, -3)
|
||||
assert "attempt to assign sequence of size 3 to extended slice of size 1" in str(err.value)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
with pytest.raises(ValueError) as err:
|
||||
net(a, b, -1, -3, -3)
|
||||
assert "attempt to assign sequence of size 3 to extended slice of size 1" in str(err.value)
|
||||
|
|
Loading…
Reference in New Issue