fix MakeCOOTensor node

This commit is contained in:
yanglf1121 2022-04-11 13:04:55 +08:00
parent 169b412bb7
commit 70b59c1175
13 changed files with 178 additions and 34 deletions

View File

@ -32,7 +32,7 @@ constexpr auto kCSRValueNodeNum = 2;
constexpr auto kSparseAttrIndex = 1;
// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract.
void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
void AbstractSparseToAbstractTuple(const AnfNodePtr &sparse) {
MS_EXCEPTION_IF_NULL(sparse);
if (!(sparse->isa<Parameter>() || sparse->isa<ValueNode>())) {
return;
@ -46,6 +46,12 @@ void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
auto abs_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
abs_tuple->set_type(abs_tuple->BuildType());
sparse->set_abstract(abs_tuple);
} else if (param_abs->isa<abstract::AbstractCOOTensor>()) {
auto abs_sparse = param_abs->cast<abstract::AbstractCOOTensorPtr>();
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
auto abs_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
abs_tuple->set_type(abs_tuple->BuildType());
sparse->set_abstract(abs_tuple);
}
}
@ -183,7 +189,7 @@ CNodePtr ConvertSparseGetAttrToTupleGetItem(int64_t index, const AnfNodePtr &nod
if (inputs.size() <= kSparseAttrIndex) {
MS_LOG(EXCEPTION) << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
}
AbstractCSRToAbstractTuple(inputs[kSparseAttrIndex]);
AbstractSparseToAbstractTuple(inputs[kSparseAttrIndex]);
auto index_node = NewValueNode(index);
AbstractBasePtr index_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
index_node->set_abstract(index_abs);

View File

@ -230,6 +230,11 @@ class COMMON_EXPORT AnfAlgo {
MS_EXCEPTION_IF_NULL(node->abstract());
return node->abstract()->isa<abstract::AbstractCSRTensor>();
}
static bool CheckAbsCOOTensor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->abstract());
return node->abstract()->isa<abstract::AbstractCOOTensor>();
}
};
} // namespace common
} // namespace mindspore

View File

@ -604,6 +604,7 @@ constexpr auto kAnfPrimitiveIndex = 0;
constexpr auto kFirstDataInputIndex = 1;
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
constexpr auto kSparseGetAttrInputSize = 2;
constexpr auto kTupleGetItemInputSize = 3;
// index define of partial
constexpr auto kPartialMinInputSize = 2;

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <unordered_map>
#include "runtime/graph_scheduler/control_node_parser.h"
#include "runtime/graph_scheduler/actor/actor_common.h"
#include "include/common/utils/convert_utils.h"
@ -111,6 +112,25 @@ std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::
default:
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
}
} else if (abstract->isa<abstract::AbstractCOOTensor>()) {
auto coo_abs = abstract->cast<abstract::AbstractCOOTensorPtr>();
MS_EXCEPTION_IF_NULL(coo_abs);
switch (index) {
case kCooTensorIndicesIndex:
dst_abstract = coo_abs->indices();
pre_abstract_num = kCooTensorIndicesIndex;
break;
case kCooTensorValuesIndex:
dst_abstract = coo_abs->values();
pre_abstract_num = kCooTensorValuesIndex;
break;
case kCooTensorDenseShapeIndex:
dst_abstract = coo_abs->values();
pre_abstract_num = kCooTensorDenseShapeIndex;
break;
default:
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
}
} else if (abstract->isa<abstract::AbstractTuple>()) {
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
@ -394,8 +414,11 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
// The node is divided into the following types:
// 1. depend and load.
const auto &node_with_index =
common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
const auto &node_with_index = common::AnfAlgo::VisitKernelWithReturnType(
node, 0, false,
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimCSRTensorGetIndptr, prim::kPrimCSRTensorGetIndices,
prim::kPrimCSRTensorGetValues, prim::kPrimCSRTensorGetDenseShape, prim::kPrimCOOTensorGetIndices,
prim::kPrimCOOTensorGetValues, prim::kPrimCOOTensorGetDenseShape});
auto real_node = node_with_index.first;
size_t real_index = node_with_index.second;
MS_EXCEPTION_IF_NULL(real_node);
@ -435,12 +458,14 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
common::AnfAlgo::CheckPrimitiveType(src_node, prim::kPrimMakeCOOTensor)) {
const auto &make_tensor_cnode = src_node->cast<CNodePtr>();
const auto &make_tensor_inputs = make_tensor_cnode->inputs();
if (make_tensor_inputs.size() <= kMakeCSRTensorInputNum) {
MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString();
if (make_tensor_inputs.size() <= kMakeCSRTensorInputNum && make_tensor_inputs.size() <= kMakeCOOTensorInputNum) {
MS_LOG(EXCEPTION) << "Invalid make sparse tensor node:" << cnode->DebugString();
}
const auto &sub_results =
FetchInputNodeByNode(make_tensor_inputs[LongToSize(iter->second) + kMakeTensorInputStartPos]);
(void)results.insert(results.end(), sub_results.begin(), sub_results.end());
} else if (src_node->isa<Parameter>()) {
results.emplace_back(src_node, iter->second);
} else {
// Csr node from parameter or call node.
auto abstract = src_node->abstract();

View File

@ -66,6 +66,7 @@ constexpr size_t kCooTensorDenseShapeIndex = 2;
constexpr size_t kMakeCSRTensorInputStartPos = 1;
constexpr size_t kMakeTensorInputStartPos = 1;
constexpr size_t kMakeCSRTensorInputNum = 4;
constexpr size_t kMakeCOOTensorInputNum = 3;
using NodeWithContext = std::pair<AnfNodePtr, DeviceContext *>;
struct NodeWithContextCmp {

View File

@ -123,6 +123,7 @@ class IrExportBuilder {
bool SetParamToTensorProto(const ParameterPtr &param, mind_ir::TensorProto *const tensor_proto);
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetAbstractToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
bool SetAbstractToNodeProto(const abstract::AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
@ -549,6 +550,16 @@ bool IrExportBuilder::SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_
return SetAbstractToNodeProto(csr_tensor_abs->dense_shape(), dense_proto);
}
bool IrExportBuilder::SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto) {
abstract::AbstractCOOTensorPtr coo_tensor_abs = abstract->cast<abstract::AbstractCOOTensorPtr>();
MS_EXCEPTION_IF_NULL(coo_tensor_abs);
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_COO_TENSOR);
(void)SetTensorProto(coo_tensor_abs->indices(), attr_proto->add_tensors());
(void)SetTensorProto(coo_tensor_abs->values(), attr_proto->add_tensors());
auto dense_proto = attr_proto->add_values();
return SetAbstractToNodeProto(coo_tensor_abs->dense_shape(), dense_proto);
}
bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) {
auto type = abstract->BuildType();
auto shape = abstract->BuildShape();
@ -714,6 +725,11 @@ bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir
if (!SetCSRTensorToProto(csr_tensor_abs, attr_proto)) {
return false;
}
} else if (type->isa<COOTensorType>()) {
auto coo_tensor_abs = abs->cast<abstract::AbstractCOOTensorPtr>();
if (!SetCOOTensorToProto(coo_tensor_abs, attr_proto)) {
return false;
}
} else {
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
return false;

View File

@ -37,6 +37,12 @@ namespace {
constexpr size_t kNopNodeInputSize = 2;
constexpr size_t kNopNodeRealInputIndex = 1;
const PrimitiveSet expand_prims{
prim::kPrimMakeTuple,
prim::kPrimMakeCSRTensor,
prim::kPrimMakeCOOTensor,
prim::kPrimMakeRowTensor,
};
const std::set<std::string> kNodeTupleOutSet = {prim::kMakeTuple, prim::kGetNext};
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
@ -106,12 +112,6 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, std::
std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node) {
std::vector<KernelWithIndex> ret;
std::vector<KernelWithIndex> ret_empty;
const PrimitiveSet expand_prims{
prim::kPrimMakeTuple,
prim::kPrimMakeCSRTensor,
prim::kPrimMakeCOOTensor,
prim::kPrimMakeRowTensor,
};
// The MakeTuple/MakeSparse node need expand and recurse.
if (IsOneOfPrimitiveCNode(node, expand_prims)) {
auto make_tuple = node->cast<CNodePtr>();
@ -157,7 +157,7 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
// If the node is a call, the outputs num should get from the abstract.
if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
AnfAlgo::CheckAbsCSRTensor(node)) {
AnfAlgo::CheckAbsCSRTensor(node) || AnfAlgo::CheckAbsCOOTensor(node)) {
outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());
}
@ -211,16 +211,28 @@ bool IsNodeDynamicShape(const AnfNodePtr &node) {
AnfNodePtr AnfAlgo::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
if (CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
} else if (tuple_get_item->size() != kSparseGetAttrInputSize) {
MS_LOG(EXCEPTION) << "The node sparse_get_attribute must have 1 input!";
}
return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
}
size_t AnfAlgo::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
MS_EXCEPTION_IF_NULL(tuple_get_item);
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
if (CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
if (tuple_get_item->size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
} else if (tuple_get_item->size() != kSparseGetAttrInputSize) {
MS_LOG(EXCEPTION) << "The node sparse_get_attribute must have 1 input!";
}
std::string prim_name = GetCNodeFuncName(tuple_get_item);
if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
return sparse_attr_map.at(prim_name);
}
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(output_index_value_node);
@ -248,11 +260,13 @@ KernelWithIndex AnfAlgo::VisitKernelWithReturnType(const AnfNodePtr &anf_node, s
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
std::string prim_name = GetCNodeFuncName(cnode);
// TupleGetItem and SparseGetAttr needs to find real input
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
abstract::AbstractBasePtr abs = nullptr;
auto item_with_index_tmp = VisitKernelWithReturnType(
GetTupleGetItemRealInput(cnode), GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types, &abs);
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
if (IsOneOfPrimitiveCNode(item_with_index_tmp.first, expand_prims)) {
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);

View File

@ -73,6 +73,12 @@ inline void CheckSparseIndicesDtype(const mindspore::TypePtr data_type, const st
<< data_type->ToString() << ".";
}
}
inline void CheckSparseIndicesDtypeInt32(const mindspore::TypePtr data_type, const std::string &arg_name) {
if (!data_type->equal(mindspore::kInt32)) {
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
<< data_type->ToString() << ".";
}
}
} // namespace
namespace mindspore {
@ -485,6 +491,11 @@ AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const Primiti
MS_EXCEPTION_IF_NULL(sparse->indices());
MS_EXCEPTION_IF_NULL(dense);
auto indptr = sparse->indptr();
auto indices = sparse->indices();
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
auto sparse_shape = sparse->shape()->shape();
auto dense_shape = dense->shape()->shape();
CheckSparseShape(sparse_shape, dense_shape);
@ -514,6 +525,11 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
MS_EXCEPTION_IF_NULL(sparse->indices());
MS_EXCEPTION_IF_NULL(dense);
auto indptr = sparse->indptr();
auto indices = sparse->indices();
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
auto sparse_shape = sparse->shape()->shape();
auto dense_shape = dense->shape()->shape();
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
@ -553,8 +569,14 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
MS_EXCEPTION_IF_NULL(sparse->shape());
MS_EXCEPTION_IF_NULL(sparse->values());
MS_EXCEPTION_IF_NULL(sparse->indices());
MS_EXCEPTION_IF_NULL(sparse->indptr());
MS_EXCEPTION_IF_NULL(axis);
auto indptr = sparse->indptr();
auto indices = sparse->indices();
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
auto sparse_shape = sparse->shape()->shape();
if (sparse_shape.size() != kCSRReduceSumShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRReduceSumShapeSize << "-D inputs!"
@ -566,7 +588,7 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
int64_t dim = static_cast<int64_t>(sparse_shape.size());
if (axis_value < -dim || axis_value >= dim || (axis_value != 1 && axis_value != -1)) {
MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be -1 or 1. But got `axis`: " << axis_value;
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be -1 or 1. But got `axis`: " << axis_value;
}
if (axis_value < 0) {
axis_value += dim;
@ -605,6 +627,9 @@ AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr
MS_EXCEPTION_IF_NULL(dense);
MS_EXCEPTION_IF_NULL(sparse_shape);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
if (sparse_shape->size() != kCSRShapeSize) {
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRShapeSize << "-D inputs!"
<< "But sparse tensor has " << sparse_shape->size() << " dimensions.";
@ -631,6 +656,8 @@ AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(indptr);
MS_EXCEPTION_IF_NULL(nnz);
@ -665,7 +692,7 @@ AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
MS_EXCEPTION_IF_NULL(row_indices);
MS_EXCEPTION_IF_NULL(height);
CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices");
MS_EXCEPTION_IF_NULL(height->BuildValue());
ShapeVector out_shape;
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {

View File

@ -126,7 +126,8 @@ TypePtr TypeIdToType(TypeId id) {
{kObjectTypeIOMonad, kIOMonadType},
{kTypeUnknown, kTypeNone},
{kMetaTypeProblem, kTypeNone},
{kObjectTypeCSRTensorType, kCSRTensorType}};
{kObjectTypeCSRTensorType, kCSRTensorType},
{kObjectTypeCOOTensorType, kCOOTensorType}};
const auto &it = type_id_to_type.find(id);
if (it == type_id_to_type.end()) {
MS_LOG(EXCEPTION) << "Not support the type: " << GetExcptionTypeString(id);

View File

@ -1930,8 +1930,7 @@ def csr_abs(x):
def csr_mv(x, dense_vector):
"""Implementation of `abs` for CSRTensor."""
check_value_type('dense_vector', dense_vector, (Tensor,), 'CSRTensor.mv')
"""Implementation of `mv` for CSRTensor."""
return F.csr_mv(x, dense_vector)

View File

@ -2413,18 +2413,13 @@ class SparseTensor(COOTensor_):
(3, 4)
"""
def __init__(self, indices=None, values=None, shape=None, coo_tensor=None):
def __init__(self, indices, values, shape):
"Init COOTensor"
print("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +
"Please use 'COOTensor' instead.")
if indices is None and values is None and shape is None and coo_tensor is not None:
if not isinstance(coo_tensor, (COOTensor, COOTensor_)):
raise TypeError("If only one input provided, it must be a COOTensor.")
COOTensor_.__init__(self, coo_tensor)
else:
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
COOTensor_.__init__(self, indices, values, shape)
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
COOTensor_.__init__(self, indices, values, shape)
@property
def indices(self):

View File

@ -1023,7 +1023,7 @@ def print_info(info):
def make_sparse_tensor(indices, values, dense_shape):
"""Call make_coo_tensor in this function."""
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " + \
"Please use 'COOTensor' instead.")
return make_coo_tensor(indices, values, dense_shape)

View File

@ -59,6 +59,60 @@ def test_make_coo():
compare_coo(coo3, coo2)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_coo_tensor_with_control_if():
"""
Feature: Test COOTensor in if.
Description: Test COOTensor computation in while loop.
Expectation: Success.
"""
class COOTensorValuesDouble(nn.Cell):
def construct(self, x):
indices = x.indices
values = x.values * 2
shape = x.shape
return COOTensor(indices, values, shape)
class COOTensorValuesAdd2(nn.Cell):
def construct(self, x):
indices = x.indices
values = x.values + 2
shape = x.shape
return COOTensor(indices, values, shape)
class COOTensorWithControlIf(nn.Cell):
def __init__(self, shape):
super(COOTensorWithControlIf, self).__init__()
self.op1 = COOTensorValuesDouble()
self.op2 = COOTensorValuesAdd2()
self.shape = shape
def construct(self, a, b, indices, values):
x = COOTensor(indices, values, self.shape)
if a > b:
x = self.op1(x)
else:
x = self.op2(x)
return x.indices, x.values, x.shape
a = Tensor(0, mstype.int32)
b = Tensor(2, mstype.int32)
indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=mstype.float32)
shape = (3, 4)
net = COOTensorWithControlIf(shape)
out = net(a, b, indices, values)
assert np.allclose(out[0].asnumpy(), indices.asnumpy(), .0, .0)
assert np.allclose(out[1].asnumpy(), values.asnumpy() + 2, .0, .0)
assert out[2] == shape
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training