support ellipsis and bool for tensor slice

This commit is contained in:
buxue 2020-04-21 15:39:59 +08:00
parent 53b3d187b9
commit 437bb8c27c
17 changed files with 170 additions and 40 deletions

View File

@ -495,6 +495,8 @@ TypePtr StringToType(const std::string &type_name) {
TypePtr type = nullptr; TypePtr type = nullptr;
if (type_name.compare("None") == 0) { if (type_name.compare("None") == 0) {
type = std::make_shared<TypeNone>(); type = std::make_shared<TypeNone>();
} else if (type_name.compare("Ellipsis") == 0) {
type = std::make_shared<Ellipsis>();
} else if (type_name.compare("TypeType") == 0) { } else if (type_name.compare("TypeType") == 0) {
type = std::make_shared<TypeType>(); type = std::make_shared<TypeType>();
} else if (type_name.compare("SymbolicKeyType") == 0) { } else if (type_name.compare("SymbolicKeyType") == 0) {

View File

@ -18,6 +18,5 @@
namespace mindspore { namespace mindspore {
const TypePtr kTypeNone = std::make_shared<TypeNone>(); const TypePtr kTypeNone = std::make_shared<TypeNone>();
const TypePtr kTypeAnything = std::make_shared<TypeAnything>();
const TypePtr kAnyType = std::make_shared<TypeAnything>(); const TypePtr kAnyType = std::make_shared<TypeAnything>();
} // namespace mindspore } // namespace mindspore

View File

@ -71,8 +71,20 @@ class TypeNull : public Type {
}; };
using TypeNullPtr = std::shared_ptr<TypeNull>; using TypeNullPtr = std::shared_ptr<TypeNull>;
class Ellipsis : public Type {
public:
Ellipsis() : Type(kMetaTypeEllipsis) {}
~Ellipsis() override {}
MS_DECLARE_PARENT(Ellipsis, Type)
TypeId generic_type_id() const override { return kMetaTypeEllipsis; }
TypePtr DeepCopy() const override { return std::make_shared<Ellipsis>(); }
std::string ToReprString() const override { return "Ellipsis"; }
std::string DumpText() const override { return "Ellipsis"; }
};
using EllipsisPtr = std::shared_ptr<Ellipsis>;
extern const TypePtr kTypeNone; extern const TypePtr kTypeNone;
extern const TypePtr kTypeAnything;
extern const TypePtr kAnyType; extern const TypePtr kAnyType;
} // namespace mindspore } // namespace mindspore

View File

@ -49,6 +49,7 @@ enum TypeId : int {
kMetaTypeExternal, kMetaTypeExternal,
kMetaTypeNone, kMetaTypeNone,
kMetaTypeNull, kMetaTypeNull,
kMetaTypeEllipsis,
kMetaTypeEnd, kMetaTypeEnd,
// //
// Object types // Object types

View File

@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
const NamedPtr kNone = std::make_shared<None>(); const NamedPtr kNone = std::make_shared<None>();
abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); } abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
const NamedPtr kNullObj = std::make_shared<NullObj>(); const NamedPtr kNull = std::make_shared<NullObj>();
abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
const NamedPtr kEllipsis = std::make_shared<EllipsisObj>();
} // namespace mindspore } // namespace mindspore

View File

@ -61,7 +61,6 @@ class Named : public Value {
std::string name_; std::string name_;
std::size_t hash_id_; std::size_t hash_id_;
}; };
using NamedPtr = std::shared_ptr<Named>; using NamedPtr = std::shared_ptr<Named>;
class None : public Named { class None : public Named {
@ -71,7 +70,6 @@ class None : public Named {
MS_DECLARE_PARENT(None, Named); MS_DECLARE_PARENT(None, Named);
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
}; };
extern const NamedPtr kNone; extern const NamedPtr kNone;
class NullObj : public Named { class NullObj : public Named {
@ -81,7 +79,15 @@ class NullObj : public Named {
MS_DECLARE_PARENT(NullObj, Named); MS_DECLARE_PARENT(NullObj, Named);
abstract::AbstractBasePtr ToAbstract() override; abstract::AbstractBasePtr ToAbstract() override;
}; };
extern const NamedPtr kNull;
extern const NamedPtr kNullObj; class EllipsisObj : public Named {
public:
EllipsisObj() : Named("Ellipsis") {}
~EllipsisObj() override = default;
MS_DECLARE_PARENT(EllipsisObj, Named);
abstract::AbstractBasePtr ToAbstract() override;
};
extern const NamedPtr kEllipsis;
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_NAMED_H_ #endif // MINDSPORE_CCSRC_IR_NAMED_H_

View File

@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) {
if (std::is_integral<T>::value) { if (std::is_integral<T>::value) {
return static_cast<int>(x) % static_cast<int>(y); return static_cast<int>(x) % static_cast<int>(y);
} }
float x_int = std::floor(x); int x_int = std::floor(x);
float y_int = std::ceil(y); int y_int = std::ceil(y);
float max = x_int / y_int; int max = x_int / y_int;
float ret = x - y * max; float ret = x - y * max;
return ret; return ret;
} }

View File

@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractClass;
using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr; using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractEllipsis;
using mindspore::abstract::AbstractEllipsisPtr;
using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractFunctionPtr; using mindspore::abstract::AbstractFunctionPtr;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
std::vector<unsigned int> shrink; std::vector<unsigned int> shrink;
auto slice_tuple_eles = slice_tuple->elements(); auto slice_tuple_eles = slice_tuple->elements();
size_t ellipsis_num = 0;
for (size_t index = 0; index < slice_tuple_size; index++) { for (size_t index = 0; index < slice_tuple_size; index++) {
if (slice_tuple_eles[index]->isa<AbstractSlice>()) { if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]); AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
continue; continue;
} }
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got " if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
ellipsis_num++;
if (ellipsis_num > 1) {
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
}
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
begin->insert(begin->end(), ellipsis_len, 0);
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
strides->insert(strides->end(), ellipsis_len, 1);
shrink.insert(shrink.end(), ellipsis_len, 0);
continue;
}
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
<< slice_tuple_eles[index]->ToString(); << slice_tuple_eles[index]->ToString();
} }
@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
abstract::CheckArgsSize(op_name, args_spec_list, 2); abstract::CheckArgsSize(op_name, args_spec_list, 2);
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0); AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr tensor_node = ret_graph->add_parameter();
(void)ret_graph->add_parameter();
auto shape = tensorPtr->shape()->shape(); auto shape = tensorPtr->shape()->shape();
std::vector<int> begin; std::vector<int> begin;
std::vector<int> end; std::vector<int> end;
@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides); shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractScalar>()) { } else if (args_spec_list[1]->isa<AbstractScalar>()) {
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]); AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
return ExpandADim(ret_graph, tensor_node);
}
}
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
ret_graph->set_output(tensor_node);
return ret_graph;
} else if (args_spec_list[1]->isa<AbstractNone>()) {
return ExpandADim(ret_graph, tensor_node);
} else { } else {
std::ostringstream args_info; std::ostringstream args_info;
for (const auto &arg : args_spec_list) { for (const auto &arg : args_spec_list) {
MS_EXCEPTION_IF_NULL(arg); MS_EXCEPTION_IF_NULL(arg);
args_info << arg->ToString() << "\n"; args_info << arg->ToString() << "\n";
} }
MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got " MS_LOG(EXCEPTION)
<< args_info.str(); << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
<< args_info.str();
} }
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr tensor_node = ret_graph->add_parameter();
(void)ret_graph->add_parameter();
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations"); auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0), auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)}); NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
return ret_graph; return ret_graph;
} }
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
return ret_graph;
}
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
.def(py::init<std::string &>()); .def(py::init<std::string &>());

View File

@ -206,6 +206,8 @@ class TensorSlice : public MetaFuncGraph {
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
}; };
using TensorSlicePtr = std::shared_ptr<TensorSlice>; using TensorSlicePtr = std::shared_ptr<TensorSlice>;

View File

@ -109,6 +109,7 @@ void Parser::BuildMethodMap() {
expr_method_map_["Index"] = &Parser::ParseIndex; expr_method_map_["Index"] = &Parser::ParseIndex;
expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
expr_method_map_["Dict"] = &Parser::ParseDict; expr_method_map_["Dict"] = &Parser::ParseDict;
expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
} }
void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block,
namelist_for_default_value.push_back(arg_name); namelist_for_default_value.push_back(arg_name);
if (py::isinstance<py::none>(defaults[i])) { if (py::isinstance<py::none>(defaults[i])) {
default_values.push_back(NewValueNode(kNullObj)); default_values.push_back(NewValueNode(kNull));
} else { } else {
default_values.push_back(ParseExprNode(block, defaults[i])); default_values.push_back(ParseExprNode(block, defaults[i]));
} }
@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
return NewValueNode(kNone); return NewValueNode(kNone);
} }
AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
MS_LOG(DEBUG) << "Process ast Ellipsis";
return NewValueNode(kEllipsis);
}
AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
MS_LOG(DEBUG) << "Process ast Num"; MS_LOG(DEBUG) << "Process ast Num";
py::object obj = python_adapter::GetPyObjAttr(node, "n"); py::object obj = python_adapter::GetPyObjAttr(node, "n");

View File

@ -92,6 +92,8 @@ class Parser {
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
// process NoneType // process NoneType
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
// process Ellipsis
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
// process a integer or float number // process a integer or float number
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
// process a string variable // process a string variable

View File

@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const {
std::string AbstractNull::ToString() const { std::string AbstractNull::ToString() const {
std::ostringstream buffer; std::ostringstream buffer;
buffer << type_name() << "(" buffer << type_name() << "(Value: Null)";
<< "Value: " return buffer.str();
<< "Null" }
<< ")";
bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
bool AbstractEllipsis::operator==(const AbstractBase &other) const {
if (&other == this) {
return true;
}
if (other.isa<AbstractEllipsis>()) {
auto other_none = static_cast<const AbstractEllipsis *>(&other);
return *this == *other_none;
} else {
return false;
}
}
std::string AbstractEllipsis::ToString() const {
std::ostringstream buffer;
buffer << type_name() << "(Value: Ellipsis)";
return buffer.str(); return buffer.str();
} }

View File

@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr<AbstractNone>;
// the un assigned state value for variable, which means the variable is not assigned // the un assigned state value for variable, which means the variable is not assigned
class AbstractNull : public AbstractBase { class AbstractNull : public AbstractBase {
public: public:
AbstractNull() : AbstractBase(kNullObj) { set_type(std::make_shared<TypeNull>()); } AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
~AbstractNull() override = default; ~AbstractNull() override = default;
MS_DECLARE_PARENT(AbstractNull, AbstractBase) MS_DECLARE_PARENT(AbstractNull, AbstractBase)
@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase {
}; };
using AbstractNullPtr = std::shared_ptr<AbstractNull>; using AbstractNullPtr = std::shared_ptr<AbstractNull>;
class AbstractEllipsis : public AbstractBase {
public:
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<Ellipsis>()); }
~AbstractEllipsis() override = default;
MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<Ellipsis>(); }
bool operator==(const AbstractEllipsis &other) const;
bool operator==(const AbstractBase &other) const override;
AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); }
std::string ToString() const override;
};
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
class AbstractRefKey : public AbstractBase { class AbstractRefKey : public AbstractBase {
public: public:
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); } AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); }

View File

@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index):
@getitem.register("Tensor", "Slice") @getitem.register("Tensor", "Slice")
def _tensor_getitem_by_slice(data, slice_index): def _tensor_getitem_by_slice(data, slice_index):
""" """
Getting item of tensor by slice index. Getting item of tensor by slice.
Inputs: Inputs:
data (Tensor): A tensor. data (Tensor): A tensor.
@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index):
@getitem.register("Tensor", "Tuple") @getitem.register("Tensor", "Tuple")
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
""" """
Getting item of tensor by slice tuple index. Getting item of tensor by slice tuple.
Inputs: Inputs:
data (Tensor): A tensor. data (Tensor): A tensor.
@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
Tensor, element type is same as the element type of data. Tensor, element type is same as the element type of data.
""" """
return _tensor_slice(data, slice_tuple_index) return _tensor_slice(data, slice_tuple_index)
@getitem.register("Tensor", "Ellipsis")
def _tensor_getitem_by_ellipsis(data, ellipsis_index):
"""
Getting item of tensor by Ellipsis.
Inputs:
data (Tensor): A tensor.
ellipsis (Ellipsis): A Ellipsis object.
Outputs:
Tensor, same as data.
"""
return _tensor_slice(data, ellipsis_index)

View File

@ -67,6 +67,7 @@ scalar_to_tensor = P.ScalarToTensor()
tuple_to_array = P.TupleToArray() tuple_to_array = P.TupleToArray()
scalar_cast = P.ScalarCast() scalar_cast = P.ScalarCast()
print_ = P.Print() print_ = P.Print()
expand_dims = P.ExpandDims()
tuple_setitem = Primitive('tuple_setitem') tuple_setitem = Primitive('tuple_setitem')
tuple_getitem = Primitive('tuple_getitem') tuple_getitem = Primitive('tuple_getitem')

View File

@ -42,6 +42,20 @@ class NetWorkSlicePositive(Cell):
return ret0, ret1, ret2, ret3 return ret0, ret1, ret2, ret3
class NetWorkSliceEllipsis(Cell):
def __init__(self):
super(NetWorkSliceEllipsis, self).__init__()
self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
def construct(self, tensor):
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
ret1 = tensor[...] + self.tensor_ret1
ret2 = tensor[True] + self.tensor_ret2
return ret0, ret1, ret2
class NetWorkReduceDimension(Cell): class NetWorkReduceDimension(Cell):
def __init__(self): def __init__(self):
super(NetWorkReduceDimension, self).__init__() super(NetWorkReduceDimension, self).__init__()
@ -83,7 +97,7 @@ class NetWorkReduceToScalar(Cell):
class TensorAssignWithBoolTensorIndex(Cell): class TensorAssignWithBoolTensorIndex(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithBoolTensorIndex, self).__init__() super(TensorAssignWithBoolTensorIndex, self).__init__()
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
def construct(self, a, b, c, u_tensor, _scalar): def construct(self, a, b, c, u_tensor, _scalar):
a[c] = u_scalar a[c] = u_scalar
@ -104,14 +118,14 @@ class TensorAssignWithBoolTensorIndexError(Cell):
class TensorAssignWithBoolTensorIndex2(Cell): class TensorAssignWithBoolTensorIndex2(Cell):
def __init__(self): def __init__(self):
super(TensorAssignWithBoolTensorIndex2, self).__init__() super(TensorAssignWithBoolTensorIndex2, self).__init__()
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
def construct(self, a, u_tensor, _scalar): def construct(self, a, u_tensor, _scalar):
a[a>8] = u_tensor a[a > 8] = u_tensor
a[a>=6] = u_scalar a[a >= 6] = u_scalar
a[a<3] = u_scalar a[a < 3] = u_scalar
a[a<=5] = u_tensor a[a <= 5] = u_tensor
a[a==5] = u_scalar a[a == 5] = u_scalar
z = a + self.t z = a + self.t
return z return z
@ -121,11 +135,11 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
super(TensorAssignWithBoolTensorIndex2Error, self).__init__() super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
def construct(self, a, u_tensor): def construct(self, a, u_tensor):
a[a>8][a>5] = u_tensor a[a > 8][a > 5] = u_tensor
return a return a
a = np.random.uniform(1,10,[2,3]) a = np.random.uniform(1, 10, [2, 3])
b = a > 5 b = a > 5
c = a < 3 c = a < 3
Ta = Tensor(a) Ta = Tensor(a)
@ -152,7 +166,7 @@ def test_tensor_assign_bool_index():
net1(Ta, Tb, Ta, u_tensor, u_scalar) net1(Ta, Tb, Ta, u_tensor, u_scalar)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net1(Ta, Tb, Tc, u_tensor_error, u_scalar) net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
with pytest.raises(ValueError): with pytest.raises(ValueError):
net2(Ta, u_tensor_error, u_scalar) net2(Ta, u_tensor_error, u_scalar)
net3 = TensorAssignWithBoolTensorIndexError() net3 = TensorAssignWithBoolTensorIndexError()
@ -192,7 +206,10 @@ test_cases = [
'block': NetWorkReduceToScalar(), 'block': NetWorkReduceToScalar(),
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
}), }),
('NetWorkSliceEllipsis', {
'block': NetWorkSliceEllipsis(),
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
}),
] ]

View File

@ -162,14 +162,15 @@ def test_ops():
if self.int > self.float: if self.int > self.float:
if [1, 2, 3] != None: if [1, 2, 3] != None:
if self.str_a + self.str_b == "helloworld": if self.str_a + self.str_b == "helloworld":
print("hello world") if q == 86:
return ret print("hello world")
return ret
return x return x
net = OpsNet(9, 2) net = OpsNet(9, 2)
x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
context.set_context(mode=context.GRAPH_MODE, save_graphs=True) context.set_context(mode=context.GRAPH_MODE)
net(x, y) net(x, y)