forked from mindspore-Ecosystem/mindspore
!536 support ellipsis and bool for tensor slice
Merge pull request !536 from zhangbuxue/support_elipis_for_tensor_slice
This commit is contained in:
commit
24a9f4974c
|
@ -495,6 +495,8 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
TypePtr type = nullptr;
|
||||
if (type_name.compare("None") == 0) {
|
||||
type = std::make_shared<TypeNone>();
|
||||
} else if (type_name.compare("Ellipsis") == 0) {
|
||||
type = std::make_shared<Ellipsis>();
|
||||
} else if (type_name.compare("TypeType") == 0) {
|
||||
type = std::make_shared<TypeType>();
|
||||
} else if (type_name.compare("SymbolicKeyType") == 0) {
|
||||
|
|
|
@ -18,6 +18,5 @@
|
|||
|
||||
namespace mindspore {
|
||||
const TypePtr kTypeNone = std::make_shared<TypeNone>();
|
||||
const TypePtr kTypeAnything = std::make_shared<TypeAnything>();
|
||||
const TypePtr kAnyType = std::make_shared<TypeAnything>();
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -71,8 +71,20 @@ class TypeNull : public Type {
|
|||
};
|
||||
using TypeNullPtr = std::shared_ptr<TypeNull>;
|
||||
|
||||
class Ellipsis : public Type {
|
||||
public:
|
||||
Ellipsis() : Type(kMetaTypeEllipsis) {}
|
||||
~Ellipsis() override {}
|
||||
MS_DECLARE_PARENT(Ellipsis, Type)
|
||||
|
||||
TypeId generic_type_id() const override { return kMetaTypeEllipsis; }
|
||||
TypePtr DeepCopy() const override { return std::make_shared<Ellipsis>(); }
|
||||
std::string ToReprString() const override { return "Ellipsis"; }
|
||||
std::string DumpText() const override { return "Ellipsis"; }
|
||||
};
|
||||
using EllipsisPtr = std::shared_ptr<Ellipsis>;
|
||||
|
||||
extern const TypePtr kTypeNone;
|
||||
extern const TypePtr kTypeAnything;
|
||||
extern const TypePtr kAnyType;
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -49,6 +49,7 @@ enum TypeId : int {
|
|||
kMetaTypeExternal,
|
||||
kMetaTypeNone,
|
||||
kMetaTypeNull,
|
||||
kMetaTypeEllipsis,
|
||||
kMetaTypeEnd,
|
||||
//
|
||||
// Object types
|
||||
|
|
|
@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared<abstract:
|
|||
const NamedPtr kNone = std::make_shared<None>();
|
||||
|
||||
abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared<abstract::AbstractNull>(); }
|
||||
const NamedPtr kNullObj = std::make_shared<NullObj>();
|
||||
const NamedPtr kNull = std::make_shared<NullObj>();
|
||||
|
||||
abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared<abstract::AbstractEllipsis>(); }
|
||||
const NamedPtr kEllipsis = std::make_shared<EllipsisObj>();
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,7 +61,6 @@ class Named : public Value {
|
|||
std::string name_;
|
||||
std::size_t hash_id_;
|
||||
};
|
||||
|
||||
using NamedPtr = std::shared_ptr<Named>;
|
||||
|
||||
class None : public Named {
|
||||
|
@ -71,7 +70,6 @@ class None : public Named {
|
|||
MS_DECLARE_PARENT(None, Named);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
};
|
||||
|
||||
extern const NamedPtr kNone;
|
||||
|
||||
class NullObj : public Named {
|
||||
|
@ -81,7 +79,15 @@ class NullObj : public Named {
|
|||
MS_DECLARE_PARENT(NullObj, Named);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
};
|
||||
extern const NamedPtr kNull;
|
||||
|
||||
extern const NamedPtr kNullObj;
|
||||
class EllipsisObj : public Named {
|
||||
public:
|
||||
EllipsisObj() : Named("Ellipsis") {}
|
||||
~EllipsisObj() override = default;
|
||||
MS_DECLARE_PARENT(EllipsisObj, Named);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
};
|
||||
extern const NamedPtr kEllipsis;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_NAMED_H_
|
||||
|
|
|
@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) {
|
|||
if (std::is_integral<T>::value) {
|
||||
return static_cast<int>(x) % static_cast<int>(y);
|
||||
}
|
||||
float x_int = std::floor(x);
|
||||
float y_int = std::ceil(y);
|
||||
float max = x_int / y_int;
|
||||
int x_int = std::floor(x);
|
||||
int y_int = std::ceil(y);
|
||||
int max = x_int / y_int;
|
||||
float ret = x - y * max;
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase;
|
|||
using mindspore::abstract::AbstractClass;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractEllipsis;
|
||||
using mindspore::abstract::AbstractEllipsisPtr;
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractFunctionPtr;
|
||||
using mindspore::abstract::AbstractList;
|
||||
|
@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
|||
|
||||
std::vector<unsigned int> shrink;
|
||||
auto slice_tuple_eles = slice_tuple->elements();
|
||||
size_t ellipsis_num = 0;
|
||||
for (size_t index = 0; index < slice_tuple_size; index++) {
|
||||
if (slice_tuple_eles[index]->isa<AbstractSlice>()) {
|
||||
AbstractSlicePtr slice = dyn_cast<AbstractSlice>(slice_tuple_eles[index]);
|
||||
|
@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple,
|
|||
continue;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got "
|
||||
if (slice_tuple_eles[index]->isa<AbstractEllipsis>()) {
|
||||
ellipsis_num++;
|
||||
if (ellipsis_num > 1) {
|
||||
MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis";
|
||||
}
|
||||
size_t ellipsis_len = shape_size - (slice_tuple_size - 1);
|
||||
begin->insert(begin->end(), ellipsis_len, 0);
|
||||
end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len);
|
||||
strides->insert(strides->end(), ellipsis_len, 1);
|
||||
shrink.insert(shrink.end(), ellipsis_len, 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got "
|
||||
<< slice_tuple_eles[index]->ToString();
|
||||
}
|
||||
|
||||
|
@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|||
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTensorPtr tensorPtr = abstract::CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
||||
(void)ret_graph->add_parameter();
|
||||
|
||||
auto shape = tensorPtr->shape()->shape();
|
||||
std::vector<int> begin;
|
||||
std::vector<int> end;
|
||||
|
@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|||
shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides);
|
||||
} else if (args_spec_list[1]->isa<AbstractScalar>()) {
|
||||
AbstractScalarPtr scalar_ptr = dyn_cast<AbstractScalar>(args_spec_list[1]);
|
||||
if (scalar_ptr->BuildValue()->isa<BoolImm>()) {
|
||||
if (scalar_ptr->BuildValue()->cast<BoolImmPtr>()->value()) {
|
||||
return ExpandADim(ret_graph, tensor_node);
|
||||
}
|
||||
}
|
||||
shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides);
|
||||
} else if (args_spec_list[1]->isa<AbstractEllipsis>()) {
|
||||
ret_graph->set_output(tensor_node);
|
||||
return ret_graph;
|
||||
} else if (args_spec_list[1]->isa<AbstractNone>()) {
|
||||
return ExpandADim(ret_graph, tensor_node);
|
||||
} else {
|
||||
std::ostringstream args_info;
|
||||
for (const auto &arg : args_spec_list) {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
args_info << arg->ToString() << "\n";
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got "
|
||||
<< args_info.str();
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got "
|
||||
<< args_info.str();
|
||||
}
|
||||
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
AnfNodePtr tensor_node = ret_graph->add_parameter();
|
||||
(void)ret_graph->add_parameter();
|
||||
|
||||
auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations");
|
||||
auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0),
|
||||
NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)});
|
||||
|
@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec
|
|||
return ret_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const {
|
||||
auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional");
|
||||
ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
||||
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
|
||||
.def(py::init<std::string &>());
|
||||
|
|
|
@ -206,6 +206,8 @@ class TensorSlice : public MetaFuncGraph {
|
|||
MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||
|
||||
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
|
||||
};
|
||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||
|
||||
|
|
|
@ -109,6 +109,7 @@ void Parser::BuildMethodMap() {
|
|||
expr_method_map_["Index"] = &Parser::ParseIndex;
|
||||
expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
|
||||
expr_method_map_["Dict"] = &Parser::ParseDict;
|
||||
expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
|
||||
}
|
||||
|
||||
void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
|
||||
|
@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block,
|
|||
|
||||
namelist_for_default_value.push_back(arg_name);
|
||||
if (py::isinstance<py::none>(defaults[i])) {
|
||||
default_values.push_back(NewValueNode(kNullObj));
|
||||
default_values.push_back(NewValueNode(kNull));
|
||||
} else {
|
||||
default_values.push_back(ParseExprNode(block, defaults[i]));
|
||||
}
|
||||
|
@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
|
|||
return NewValueNode(kNone);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
|
||||
MS_LOG(DEBUG) << "Process ast Ellipsis";
|
||||
return NewValueNode(kEllipsis);
|
||||
}
|
||||
|
||||
AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
|
||||
MS_LOG(DEBUG) << "Process ast Num";
|
||||
py::object obj = python_adapter::GetPyObjAttr(node, "n");
|
||||
|
|
|
@ -92,6 +92,8 @@ class Parser {
|
|||
AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process NoneType
|
||||
AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process Ellipsis
|
||||
AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a integer or float number
|
||||
AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
|
||||
// process a string variable
|
||||
|
|
|
@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const {
|
|||
|
||||
std::string AbstractNull::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << type_name() << "("
|
||||
<< "Value: "
|
||||
<< "Null"
|
||||
<< ")";
|
||||
buffer << type_name() << "(Value: Null)";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
|
||||
|
||||
bool AbstractEllipsis::operator==(const AbstractBase &other) const {
|
||||
if (&other == this) {
|
||||
return true;
|
||||
}
|
||||
if (other.isa<AbstractEllipsis>()) {
|
||||
auto other_none = static_cast<const AbstractEllipsis *>(&other);
|
||||
return *this == *other_none;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
std::string AbstractEllipsis::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << type_name() << "(Value: Ellipsis)";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr<AbstractNone>;
|
|||
// the un assigned state value for variable, which means the variable is not assigned
|
||||
class AbstractNull : public AbstractBase {
|
||||
public:
|
||||
AbstractNull() : AbstractBase(kNullObj) { set_type(std::make_shared<TypeNull>()); }
|
||||
AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared<TypeNull>()); }
|
||||
~AbstractNull() override = default;
|
||||
MS_DECLARE_PARENT(AbstractNull, AbstractBase)
|
||||
|
||||
|
@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase {
|
|||
};
|
||||
using AbstractNullPtr = std::shared_ptr<AbstractNull>;
|
||||
|
||||
class AbstractEllipsis : public AbstractBase {
|
||||
public:
|
||||
AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared<Ellipsis>()); }
|
||||
~AbstractEllipsis() override = default;
|
||||
MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase)
|
||||
|
||||
TypePtr BuildType() const override { return std::make_shared<Ellipsis>(); }
|
||||
bool operator==(const AbstractEllipsis &other) const;
|
||||
bool operator==(const AbstractBase &other) const override;
|
||||
AbstractBasePtr Clone() const override { return std::make_shared<AbstractEllipsis>(); }
|
||||
std::string ToString() const override;
|
||||
};
|
||||
using AbstractEllipsisPtr = std::shared_ptr<AbstractEllipsis>;
|
||||
|
||||
class AbstractRefKey : public AbstractBase {
|
||||
public:
|
||||
AbstractRefKey() : AbstractBase() { set_type(std::make_shared<RefKeyType>()); }
|
||||
|
|
|
@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index):
|
|||
@getitem.register("Tensor", "Slice")
|
||||
def _tensor_getitem_by_slice(data, slice_index):
|
||||
"""
|
||||
Getting item of tensor by slice index.
|
||||
Getting item of tensor by slice.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
|
@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index):
|
|||
@getitem.register("Tensor", "Tuple")
|
||||
def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
|
||||
"""
|
||||
Getting item of tensor by slice tuple index.
|
||||
Getting item of tensor by slice tuple.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
|
@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index):
|
|||
Tensor, element type is same as the element type of data.
|
||||
"""
|
||||
return _tensor_slice(data, slice_tuple_index)
|
||||
|
||||
|
||||
@getitem.register("Tensor", "Ellipsis")
|
||||
def _tensor_getitem_by_ellipsis(data, ellipsis_index):
|
||||
"""
|
||||
Getting item of tensor by Ellipsis.
|
||||
|
||||
Inputs:
|
||||
data (Tensor): A tensor.
|
||||
ellipsis (Ellipsis): A Ellipsis object.
|
||||
|
||||
Outputs:
|
||||
Tensor, same as data.
|
||||
"""
|
||||
return _tensor_slice(data, ellipsis_index)
|
||||
|
|
|
@ -67,6 +67,7 @@ scalar_to_tensor = P.ScalarToTensor()
|
|||
tuple_to_array = P.TupleToArray()
|
||||
scalar_cast = P.ScalarCast()
|
||||
print_ = P.Print()
|
||||
expand_dims = P.ExpandDims()
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
|
|
|
@ -42,6 +42,20 @@ class NetWorkSlicePositive(Cell):
|
|||
return ret0, ret1, ret2, ret3
|
||||
|
||||
|
||||
class NetWorkSliceEllipsis(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkSliceEllipsis, self).__init__()
|
||||
self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32))
|
||||
self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32))
|
||||
self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32))
|
||||
|
||||
def construct(self, tensor):
|
||||
ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0
|
||||
ret1 = tensor[...] + self.tensor_ret1
|
||||
ret2 = tensor[True] + self.tensor_ret2
|
||||
return ret0, ret1, ret2
|
||||
|
||||
|
||||
class NetWorkReduceDimension(Cell):
|
||||
def __init__(self):
|
||||
super(NetWorkReduceDimension, self).__init__()
|
||||
|
@ -83,7 +97,7 @@ class NetWorkReduceToScalar(Cell):
|
|||
class TensorAssignWithBoolTensorIndex(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex, self).__init__()
|
||||
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
|
||||
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
|
||||
|
||||
def construct(self, a, b, c, u_tensor, _scalar):
|
||||
a[c] = u_scalar
|
||||
|
@ -104,14 +118,14 @@ class TensorAssignWithBoolTensorIndexError(Cell):
|
|||
class TensorAssignWithBoolTensorIndex2(Cell):
|
||||
def __init__(self):
|
||||
super(TensorAssignWithBoolTensorIndex2, self).__init__()
|
||||
self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64)
|
||||
self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64)
|
||||
|
||||
def construct(self, a, u_tensor, _scalar):
|
||||
a[a>8] = u_tensor
|
||||
a[a>=6] = u_scalar
|
||||
a[a<3] = u_scalar
|
||||
a[a<=5] = u_tensor
|
||||
a[a==5] = u_scalar
|
||||
a[a > 8] = u_tensor
|
||||
a[a >= 6] = u_scalar
|
||||
a[a < 3] = u_scalar
|
||||
a[a <= 5] = u_tensor
|
||||
a[a == 5] = u_scalar
|
||||
z = a + self.t
|
||||
return z
|
||||
|
||||
|
@ -121,11 +135,11 @@ class TensorAssignWithBoolTensorIndex2Error(Cell):
|
|||
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
|
||||
|
||||
def construct(self, a, u_tensor):
|
||||
a[a>8][a>5] = u_tensor
|
||||
a[a > 8][a > 5] = u_tensor
|
||||
return a
|
||||
|
||||
|
||||
a = np.random.uniform(1,10,[2,3])
|
||||
a = np.random.uniform(1, 10, [2, 3])
|
||||
b = a > 5
|
||||
c = a < 3
|
||||
Ta = Tensor(a)
|
||||
|
@ -152,7 +166,7 @@ def test_tensor_assign_bool_index():
|
|||
net1(Ta, Tb, Ta, u_tensor, u_scalar)
|
||||
with pytest.raises(ValueError):
|
||||
net1(Ta, Tb, Tc, u_tensor_error, u_scalar)
|
||||
#net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
|
||||
# net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar)
|
||||
with pytest.raises(ValueError):
|
||||
net2(Ta, u_tensor_error, u_scalar)
|
||||
net3 = TensorAssignWithBoolTensorIndexError()
|
||||
|
@ -192,7 +206,10 @@ test_cases = [
|
|||
'block': NetWorkReduceToScalar(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))],
|
||||
}),
|
||||
|
||||
('NetWorkSliceEllipsis', {
|
||||
'block': NetWorkSliceEllipsis(),
|
||||
'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))],
|
||||
}),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -162,14 +162,15 @@ def test_ops():
|
|||
if self.int > self.float:
|
||||
if [1, 2, 3] != None:
|
||||
if self.str_a + self.str_b == "helloworld":
|
||||
print("hello world")
|
||||
return ret
|
||||
if q == 86:
|
||||
print("hello world")
|
||||
return ret
|
||||
return x
|
||||
|
||||
net = OpsNet(9, 2)
|
||||
x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32))
|
||||
y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32))
|
||||
context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net(x, y)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue