forked from mindspore-Ecosystem/mindspore
support ellipsis and bool for tensor slice
This commit is contained in:
parent
53b3d187b9
commit
437bb8c27c
|
@ -495,6 +495,8 @@ TypePtr StringToType(const std::string &type_name) {
|
||||||
TypePtr type = nullptr;
|
TypePtr type = nullptr;
|
||||||
if (type_name.compare("None") == 0) {
|
if (type_name.compare("None") == 0) {
|
||||||
type = std::make_shared<TypeNone>();
|
type = std::make_shared<TypeNone>();
|
||||||
|
} else if (type_name.compare("Ellipsis") == 0) {
|
||||||
|
type = std::make_shared<Ellipsis>();
|
||||||
} else if (type_name.compare("TypeType") == 0) {
|
} else if (type_name.compare("TypeType") == 0) {
|
||||||
type = std::make_shared<TypeType>();
|
type = std::make_shared<TypeType>();
|
||||||
} else if (type_name.compare("SymbolicKeyType") == 0) {
|
} else if (type_name.compare("SymbolicKeyType") == 0) {
|
||||||
|
|
|
@ -18,6 +18,5 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
const TypePtr kTypeNone = std::make_shared<TypeNone>();
|
const TypePtr kTypeNone = std::make_shared<TypeNone>();
|
||||||
const TypePtr kTypeAnything = std::make_shared<TypeAnything>();
|
|
||||||
const TypePtr kAnyType = std::make_shared<TypeAnything>();
|
const TypePtr kAnyType = std::make_shared<TypeAnything>();
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -71,8 +71,20 @@ class TypeNull : public Type {
|
||||||
};
|
};
|
||||||
using TypeNullPtr = std::shared_ptr<TypeNull>;
|
using TypeNullPtr = std::shared_ptr<TypeNull>;
|
||||||
|
|
||||||
|
class Ellipsis : public Type {
|
||||||
|
public:
|
||||||
|
Ellipsis() : Type(kMetaTypeEllipsis) {}
|
||||||
|
~Ellipsis() override {}
|
||||||
|
MS_DECLARE_PARENT(Ellipsis, Type)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kMetaTypeEllipsis; }
|
||||||
|
TypePtr DeepCopy() const override { return std::make_shared<Ellipsis>(); }
|
||||||
|
std::string ToReprString() const override { return "Ellipsis"; }
|
||||||
|
std::string DumpText() const override { return "Ellipsis"; }
|
||||||
|
};
|
||||||
|
using EllipsisPtr = std::shared_ptr<Ellipsis>;
|
||||||
|
|
||||||
extern const TypePtr kTypeNone;
|
extern const TypePtr kTypeNone;
|
||||||
extern const TypePtr kTypeAnything;
|
|
||||||
extern const TypePtr kAnyType;
|
extern const TypePtr kAnyType;
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,7 @@ enum TypeId : int {
|
||||||
kMetaTypeExternal,
|
kMetaTypeExternal,
|
||||||
kMetaTypeNone,
|
kMetaTypeNone,
|
||||||
kMetaTypeNull,
|
kMetaTypeNull,
|
||||||
|
kMetaTypeEllipsis,
|
||||||
kMetaTypeEnd,
|
kMetaTypeEnd,
|
||||||
//
|
//
|
||||||
// Object types
|
// Object types
|
||||||
|
|
|
@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
|
||||||
const NamedPtr kNone = std::make_shared<None>();
|
const NamedPtr kNone = std::make_shared<None>();
|
||||||
|
|
||||||
abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
|
abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
|
||||||
const NamedPtr kNullObj = std::make_shared<NullObj>();
|
const NamedPtr kNull = std::make_shared<NullObj>();
|
||||||
|
|
||||||
|
abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
|
||||||
|
const NamedPtr kEllipsis = std::make_shared<EllipsisObj>();
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -61,7 +61,6 @@ class Named : public Value {
|
||||||
std::string name_;
|
std::string name_;
|
||||||
std::size_t hash_id_;
|
std::size_t hash_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
using NamedPtr = std::shared_ptr<Named>;
|
using NamedPtr = std::shared_ptr<Named>;
|
||||||
|
|
||||||
class None : public Named {
|
class None : public Named {
|
||||||
|
@ -71,7 +70,6 @@ class None : public Named {
|
||||||
MS_DECLARE_PARENT(None, Named);
|
MS_DECLARE_PARENT(None, Named);
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
extern const NamedPtr kNone;
|
extern const NamedPtr kNone;
|
||||||
|
|
||||||
class NullObj : public Named {
|
class NullObj : public Named {
|
||||||
|
@ -81,7 +79,15 @@ class NullObj : public Named {
|
||||||
MS_DECLARE_PARENT(NullObj, Named);
|
MS_DECLARE_PARENT(NullObj, Named);
|
||||||
abstract::AbstractBasePtr ToAbstract() override;
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
};
|
};
|
||||||
|
extern const NamedPtr kNull;
|
||||||
|
|
||||||
extern const NamedPtr kNullObj;
|
class EllipsisObj : public Named {
|
||||||
|
public:
|
||||||
|
EllipsisObj() : Named("Ellipsis") {}
|
||||||
|
~EllipsisObj() override = default;
|
||||||
|
MS_DECLARE_PARENT(EllipsisObj, Named);
|
||||||
|
abstract::AbstractBasePtr ToAbstract() override;
|
||||||
|
};
|
||||||
|
extern const NamedPtr kEllipsis;
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_IR_NAMED_H_
|
#endif // MINDSPORE_CCSRC_IR_NAMED_H_
|
||||||
|
|
|
@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) {
|
||||||
if (std::is_integral<T>::value) {
|
if (std::is_integral<T>::value) {
|
||||||
return static_cast<int>(x) % static_cast<int>(y);
|
return static_cast<int>(x) % static_cast<int>(y);
|
||||||
}
|
}
|
||||||
float x_int = std::floor(x);
|
int x_int = std::floor(x);
|
||||||
float y_int = std::ceil(y);
|
int y_int = std::ceil(y);
|
||||||
float max = x_int / y_int;
|
int max = x_int / y_int;
|
||||||
float ret = x - y * max;
|
float ret = x - y * max;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase;
|
||||||
using mindspore::abstract::AbstractClass;
|
using mindspore::abstract::AbstractClass;
|
||||||
using mindspore::abstract::AbstractDictionary;
|
using mindspore::abstract::AbstractDictionary;
|
||||||
using mindspore::abstract::AbstractDictionaryPtr;
|
using mindspore::abstract::AbstractDictionaryPtr;
|
||||||
|
using mindspore::abstract::AbstractEllipsis;
|
||||||
|
using mindspore::abstract::AbstractEllipsisPtr;
|
||||||
using mindspore::abstract::AbstractFunction;
|
using mindspore::abstract::AbstractFunction;
|
||||||
using mindspore::abstract::AbstractFunctionPtr;
|
using mindspore::abstract::AbstractFunctionPtr;
|
||||||
using mindspore::abstract::AbstractList;
|
using mindspore::abstract::AbstractList;
|
||||||
|
@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
||||||
|
|
||||||
std::vector<unsigned int> shrink;
|
std::vector<unsigned int> shrink;
|
||||||
auto slice_tuple_eles = slice_tuple->elements();
|
auto slice_tuple_eles = slice_tuple->elements();
|
||||||
|
size_t ellipsis_num = 0;
|
||||||
for (size_t index = 0; index < slice_tuple_size; index++) {
|
for (size_t index = 0; index < slice_tuple_size; index++) {
|
||||||
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
||||||
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
||||||
|
@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got "
|
if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
|
||||||
|
ellipsis_num++;
|
||||||
|
if (ellipsis_num > 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
|
||||||
|
}
|
||||||
|
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
|
||||||
|
begin->insert(begin->end(), ellipsis_len, 0);
|
||||||
|
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
|
||||||
|
strides->insert(strides->end(), ellipsis_len, 1);
|
||||||
|
shrink.insert(shrink.end(), ellipsis_len, 0);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
|
||||||
<< slice_tuple_eles[index]->ToString();
|
<< slice_tuple_eles[index]->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
||||||
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
|
||||||
|
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||||
|
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
||||||
|
(void)ret_graph->add_parameter();
|
||||||
|
|
||||||
auto shape = tensorPtr->shape()->shape();
|
auto shape = tensorPtr->shape()->shape();
|
||||||
std::vector<int> begin;
|
std::vector<int> begin;
|
||||||
std::vector<int> end;
|
std::vector<int> end;
|
||||||
|
@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
||||||
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
|
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
|
||||||
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
|
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
|
||||||
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
|
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
|
||||||
|
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
|
||||||
|
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
||||||
|
return ExpandADim(ret_graph, tensor_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
||||||
|
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
||||||
|
ret_graph->set_output(tensor_node);
|
||||||
|
return ret_graph;
|
||||||
|
} else if (args_spec_list[1]->isa<AbstractNone>()) {
|
||||||
|
return ExpandADim(ret_graph, tensor_node);
|
||||||
} else {
|
} else {
|
||||||
std::ostringstream args_info;
|
std::ostringstream args_info;
|
||||||
for (const auto &arg : args_spec_list) {
|
for (const auto &arg : args_spec_list) {
|
||||||
MS_EXCEPTION_IF_NULL(arg);
|
MS_EXCEPTION_IF_NULL(arg);
|
||||||
args_info << arg->ToString() << "\n";
|
args_info << arg->ToString() << "\n";
|
||||||
}
|
}
|
||||||
MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got "
|
MS_LOG(EXCEPTION)
|
||||||
|
<< "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
|
||||||
<< args_info.str();
|
<< args_info.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
|
||||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
|
||||||
|
|
||||||
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
|
||||||
(void)ret_graph->add_parameter();
|
|
||||||
|
|
||||||
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
|
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
|
||||||
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
|
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
|
||||||
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
|
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
|
||||||
|
@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
||||||
return ret_graph;
|
return ret_graph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
|
||||||
|
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
||||||
|
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
||||||
|
return ret_graph;
|
||||||
|
}
|
||||||
|
|
||||||
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
||||||
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
|
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
|
||||||
.def(py::init<std::string &>());
|
.def(py::init<std::string &>());
|
||||||
|
|
|
@ -206,6 +206,8 @@ class TensorSlice : public MetaFuncGraph {
|
||||||
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
|
|
||||||
|
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
|
||||||
};
|
};
|
||||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||||
|
|
||||||
|
|
|
@ -109,6 +109,7 @@ void Parser::BuildMethodMap() {
|
||||||
expr_method_map_["Index"] = &Parser::ParseIndex;
|
expr_method_map_["Index"] = &Parser::ParseIndex;
|
||||||
expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
|
expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
|
||||||
expr_method_map_["Dict"] = &Parser::ParseDict;
|
expr_method_map_["Dict"] = &Parser::ParseDict;
|
||||||
|
expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
|
void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
|
||||||
|
@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block,
|
||||||
|
|
||||||
namelist_for_default_value.push_back(arg_name);
|
namelist_for_default_value.push_back(arg_name);
|
||||||
if (py::isinstance<py::none>(defaults[i])) {
|
if (py::isinstance<py::none>(defaults[i])) {
|
||||||
default_values.push_back(NewValueNode(kNullObj));
|
default_values.push_back(NewValueNode(kNull));
|
||||||
} else {
|
} else {
|
||||||
default_values.push_back(ParseExprNode(block, defaults[i]));
|
default_values.push_back(ParseExprNode(block, defaults[i]));
|
||||||
}
|
}
|
||||||
|
@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
|
||||||
return NewValueNode(kNone);
|
return NewValueNode(kNone);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
|
||||||
|
MS_LOG(DEBUG) << "Process ast Ellipsis";
|
||||||
|
return NewValueNode(kEllipsis);
|
||||||
|
}
|
||||||
|
|
||||||
AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
|
AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
|
||||||
MS_LOG(DEBUG) << "Process ast Num";
|
MS_LOG(DEBUG) << "Process ast Num";
|
||||||
py::object obj = python_adapter::GetPyObjAttr(node, "n");
|
py::object obj = python_adapter::GetPyObjAttr(node, "n");
|
||||||
|
|
|
@ -92,6 +92,8 @@ class Parser {
|
||||||
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
|
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
|
||||||
// process NoneType
|
// process NoneType
|
||||||
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
|
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
|
||||||
|
// process Ellipsis
|
||||||
|
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
|
||||||
// process a integer or float number
|
// process a integer or float number
|
||||||
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
|
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
|
||||||
// process a string variable
|
// process a string variable
|
||||||
|
|
|
@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const {
|
||||||
|
|
||||||
std::string AbstractNull::ToString() const {
|
std::string AbstractNull::ToString() const {
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
buffer << type_name() << "("
|
buffer << type_name() << "(Value: Null)";
|
||||||
<< "Value: "
|
return buffer.str();
|
||||||
<< "Null"
|
}
|
||||||
<< ")";
|
|
||||||
|
bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
|
||||||
|
|
||||||
|
bool AbstractEllipsis::operator==(const AbstractBase &other) const {
|
||||||
|
if (&other == this) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (other.isa<AbstractEllipsis>()) {
|
||||||
|
auto other_none = static_cast<const AbstractEllipsis *>(&other);
|
||||||
|
return *this == *other_none;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AbstractEllipsis::ToString() const {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << type_name() << "(Value: Ellipsis)";
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr<AbstractNone>;
|
||||||
// the un assigned state value for variable, which means the variable is not assigned
|
// the un assigned state value for variable, which means the variable is not assigned
|
||||||
class AbstractNull : public AbstractBase {
|
class AbstractNull : public AbstractBase {
|
||||||
public:
|
public:
|
||||||
AbstractNull() : AbstractBase(kNullObj) { set_type(std::make_shared<TypeNull>()); }
|
AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
|
||||||
~AbstractNull() override = default;
|
~AbstractNull() override = default;
|
||||||
MS_DECLARE_PARENT(AbstractNull, AbstractBase)
|
MS_DECLARE_PARENT(AbstractNull, AbstractBase)
|
||||||
|
|
||||||
|
@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase {
|
||||||
};
|
};
|
||||||
using AbstractNullPtr = std::shared_ptr<AbstractNull>;
|
using AbstractNullPtr = std::shared_ptr<AbstractNull>;
|
||||||
|
|
||||||
|
class AbstractEllipsis : public AbstractBase {
|
||||||
|
public:
|
||||||
|
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<Ellipsis>()); }
|
||||||
|
~AbstractEllipsis() override = default;
|
||||||
|
MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase)
|
||||||
|
|
||||||
|
TypePtr BuildType() const override { return std::make_shared<Ellipsis>(); }
|
||||||
|
bool operator==(const AbstractEllipsis &other) const;
|
||||||
|
bool operator==(const AbstractBase &other) const override;
|
||||||
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); }
|
||||||
|
std::string ToString() const override;
|
||||||
|
};
|
||||||
|
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
|
||||||
|
|
||||||
class AbstractRefKey : public AbstractBase {
|
class AbstractRefKey : public AbstractBase {
|
||||||
public:
|
public:
|
||||||
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); }
|
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); }
|
||||||
|
|
|
@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index):
|
||||||
@getitem.register("Tensor", "Slice")
|
@getitem.register("Tensor", "Slice")
|
||||||
def _tensor_getitem_by_slice(data, slice_index):
|
def _tensor_getitem_by_slice(data, slice_index):
|
||||||
"""
|
"""
|
||||||
Getting item of tensor by slice index.
|
Getting item of tensor by slice.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
data (Tensor): A tensor.
|
data (Tensor): A tensor.
|
||||||
|
@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index):
|
||||||
@getitem.register("Tensor", "Tuple")
|
@getitem.register("Tensor", "Tuple")
|
||||||
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
|
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
|
||||||
"""
|
"""
|
||||||
Getting item of tensor by slice tuple index.
|
Getting item of tensor by slice tuple.
|
||||||
|
|
||||||
Inputs:
|
Inputs:
|
||||||
data (Tensor): A tensor.
|
data (Tensor): A tensor.
|
||||||
|
@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
|
||||||
Tensor, element type is same as the element type of data.
|
Tensor, element type is same as the element type of data.
|
||||||
"""
|
"""
|
||||||
return _tensor_slice(data, slice_tuple_index)
|
return _tensor_slice(data, slice_tuple_index)
|
||||||
|
|
||||||
|
|
||||||
|
@getitem.register("Tensor", "Ellipsis")
|
||||||
|
def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
||||||
|
"""
|
||||||
|
Getting item of tensor by Ellipsis.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (Tensor): A tensor.
|
||||||
|
ellipsis (Ellipsis): A Ellipsis object.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, same as data.
|
||||||
|
"""
|
||||||
|
return _tensor_slice(data, ellipsis_index)
|
||||||
|
|
|
@ -67,6 +67,7 @@ scalar_to_tensor = P.ScalarToTensor()
|
||||||
tuple_to_array = P.TupleToArray()
|
tuple_to_array = P.TupleToArray()
|
||||||
scalar_cast = P.ScalarCast()
|
scalar_cast = P.ScalarCast()
|
||||||
print_ = P.Print()
|
print_ = P.Print()
|
||||||
|
expand_dims = P.ExpandDims()
|
||||||
|
|
||||||
tuple_setitem = Primitive('tuple_setitem')
|
tuple_setitem = Primitive('tuple_setitem')
|
||||||
tuple_getitem = Primitive('tuple_getitem')
|
tuple_getitem = Primitive('tuple_getitem')
|
||||||
|
|
|
@ -42,6 +42,20 @@ class NetWorkSlicePositive(Cell):
|
||||||
return ret0, ret1, ret2, ret3
|
return ret0, ret1, ret2, ret3
|
||||||
|
|
||||||
|
|
||||||
|
class NetWorkSliceEllipsis(Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(NetWorkSliceEllipsis, self).__init__()
|
||||||
|
self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
|
||||||
|
self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
|
||||||
|
self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
|
||||||
|
|
||||||
|
def construct(self, tensor):
|
||||||
|
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
|
||||||
|
ret1 = tensor[...] + self.tensor_ret1
|
||||||
|
ret2 = tensor[True] + self.tensor_ret2
|
||||||
|
return ret0, ret1, ret2
|
||||||
|
|
||||||
|
|
||||||
class NetWorkReduceDimension(Cell):
|
class NetWorkReduceDimension(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(NetWorkReduceDimension, self).__init__()
|
super(NetWorkReduceDimension, self).__init__()
|
||||||
|
@ -83,7 +97,7 @@ class NetWorkReduceToScalar(Cell):
|
||||||
class TensorAssignWithBoolTensorIndex(Cell):
|
class TensorAssignWithBoolTensorIndex(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(TensorAssignWithBoolTensorIndex, self).__init__()
|
super(TensorAssignWithBoolTensorIndex, self).__init__()
|
||||||
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
|
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
|
||||||
|
|
||||||
def construct(self, a, b, c, u_tensor, _scalar):
|
def construct(self, a, b, c, u_tensor, _scalar):
|
||||||
a[c] = u_scalar
|
a[c] = u_scalar
|
||||||
|
@ -104,14 +118,14 @@ class TensorAssignWithBoolTensorIndexError(Cell):
|
||||||
class TensorAssignWithBoolTensorIndex2(Cell):
|
class TensorAssignWithBoolTensorIndex2(Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(TensorAssignWithBoolTensorIndex2, self).__init__()
|
super(TensorAssignWithBoolTensorIndex2, self).__init__()
|
||||||
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
|
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
|
||||||
|
|
||||||
def construct(self, a, u_tensor, _scalar):
|
def construct(self, a, u_tensor, _scalar):
|
||||||
a[a>8] = u_tensor
|
a[a > 8] = u_tensor
|
||||||
a[a>=6] = u_scalar
|
a[a >= 6] = u_scalar
|
||||||
a[a<3] = u_scalar
|
a[a < 3] = u_scalar
|
||||||
a[a<=5] = u_tensor
|
a[a <= 5] = u_tensor
|
||||||
a[a==5] = u_scalar
|
a[a == 5] = u_scalar
|
||||||
z = a + self.t
|
z = a + self.t
|
||||||
return z
|
return z
|
||||||
|
|
||||||
|
@ -121,11 +135,11 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
|
||||||
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
|
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
|
||||||
|
|
||||||
def construct(self, a, u_tensor):
|
def construct(self, a, u_tensor):
|
||||||
a[a>8][a>5] = u_tensor
|
a[a > 8][a > 5] = u_tensor
|
||||||
return a
|
return a
|
||||||
|
|
||||||
|
|
||||||
a = np.random.uniform(1,10,[2,3])
|
a = np.random.uniform(1, 10, [2, 3])
|
||||||
b = a > 5
|
b = a > 5
|
||||||
c = a < 3
|
c = a < 3
|
||||||
Ta = Tensor(a)
|
Ta = Tensor(a)
|
||||||
|
@ -152,7 +166,7 @@ def test_tensor_assign_bool_index():
|
||||||
net1(Ta, Tb, Ta, u_tensor, u_scalar)
|
net1(Ta, Tb, Ta, u_tensor, u_scalar)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
|
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
|
||||||
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
|
# net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
net2(Ta, u_tensor_error, u_scalar)
|
net2(Ta, u_tensor_error, u_scalar)
|
||||||
net3 = TensorAssignWithBoolTensorIndexError()
|
net3 = TensorAssignWithBoolTensorIndexError()
|
||||||
|
@ -192,7 +206,10 @@ test_cases = [
|
||||||
'block': NetWorkReduceToScalar(),
|
'block': NetWorkReduceToScalar(),
|
||||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
||||||
}),
|
}),
|
||||||
|
('NetWorkSliceEllipsis', {
|
||||||
|
'block': NetWorkSliceEllipsis(),
|
||||||
|
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
||||||
|
}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -162,6 +162,7 @@ def test_ops():
|
||||||
if self.int > self.float:
|
if self.int > self.float:
|
||||||
if [1, 2, 3] != None:
|
if [1, 2, 3] != None:
|
||||||
if self.str_a + self.str_b == "helloworld":
|
if self.str_a + self.str_b == "helloworld":
|
||||||
|
if q == 86:
|
||||||
print("hello world")
|
print("hello world")
|
||||||
return ret
|
return ret
|
||||||
return x
|
return x
|
||||||
|
@ -169,7 +170,7 @@ def test_ops():
|
||||||
net = OpsNet(9, 2)
|
net = OpsNet(9, 2)
|
||||||
x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
|
x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
|
||||||
y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
|
y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
|
||||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
context.set_context(mode=context.GRAPH_MODE)
|
||||||
net(x, y)
|
net(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue