forked from mindspore-Ecosystem/mindspore
fix MakeCOOTensor node
This commit is contained in:
parent
169b412bb7
commit
70b59c1175
|
@ -32,7 +32,7 @@ constexpr auto kCSRValueNodeNum = 2;
|
|||
constexpr auto kSparseAttrIndex = 1;
|
||||
|
||||
// Convert CSRTensor Parameter or ValueNode to Tuple by setting its abstract.
|
||||
void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
|
||||
void AbstractSparseToAbstractTuple(const AnfNodePtr &sparse) {
|
||||
MS_EXCEPTION_IF_NULL(sparse);
|
||||
if (!(sparse->isa<Parameter>() || sparse->isa<ValueNode>())) {
|
||||
return;
|
||||
|
@ -46,6 +46,12 @@ void AbstractCSRToAbstractTuple(const AnfNodePtr &sparse) {
|
|||
auto abs_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
abs_tuple->set_type(abs_tuple->BuildType());
|
||||
sparse->set_abstract(abs_tuple);
|
||||
} else if (param_abs->isa<abstract::AbstractCOOTensor>()) {
|
||||
auto abs_sparse = param_abs->cast<abstract::AbstractCOOTensorPtr>();
|
||||
std::vector<AbstractBasePtr> abstract_list{abs_sparse->indices(), abs_sparse->values(), abs_sparse->dense_shape()};
|
||||
auto abs_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
abs_tuple->set_type(abs_tuple->BuildType());
|
||||
sparse->set_abstract(abs_tuple);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -183,7 +189,7 @@ CNodePtr ConvertSparseGetAttrToTupleGetItem(int64_t index, const AnfNodePtr &nod
|
|||
if (inputs.size() <= kSparseAttrIndex) {
|
||||
MS_LOG(EXCEPTION) << "For SparseGetAttr, CNode must have 2 inputs (Prim, Sparse)";
|
||||
}
|
||||
AbstractCSRToAbstractTuple(inputs[kSparseAttrIndex]);
|
||||
AbstractSparseToAbstractTuple(inputs[kSparseAttrIndex]);
|
||||
auto index_node = NewValueNode(index);
|
||||
AbstractBasePtr index_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
|
||||
index_node->set_abstract(index_abs);
|
||||
|
|
|
@ -230,6 +230,11 @@ class COMMON_EXPORT AnfAlgo {
|
|||
MS_EXCEPTION_IF_NULL(node->abstract());
|
||||
return node->abstract()->isa<abstract::AbstractCSRTensor>();
|
||||
}
|
||||
static bool CheckAbsCOOTensor(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->abstract());
|
||||
return node->abstract()->isa<abstract::AbstractCOOTensor>();
|
||||
}
|
||||
};
|
||||
} // namespace common
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -604,6 +604,7 @@ constexpr auto kAnfPrimitiveIndex = 0;
|
|||
constexpr auto kFirstDataInputIndex = 1;
|
||||
constexpr auto kRealInputNodeIndexInTupleGetItem = 1;
|
||||
constexpr auto kInputNodeOutputIndexInTupleGetItem = 2;
|
||||
constexpr auto kSparseGetAttrInputSize = 2;
|
||||
constexpr auto kTupleGetItemInputSize = 3;
|
||||
// index define of partial
|
||||
constexpr auto kPartialMinInputSize = 2;
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <unordered_map>
|
||||
#include "runtime/graph_scheduler/control_node_parser.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
|
@ -111,6 +112,25 @@ std::set<size_t> FetchRealIndexByAbstract(const AbstractBasePtr &abstract, std::
|
|||
default:
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
} else if (abstract->isa<abstract::AbstractCOOTensor>()) {
|
||||
auto coo_abs = abstract->cast<abstract::AbstractCOOTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(coo_abs);
|
||||
switch (index) {
|
||||
case kCooTensorIndicesIndex:
|
||||
dst_abstract = coo_abs->indices();
|
||||
pre_abstract_num = kCooTensorIndicesIndex;
|
||||
break;
|
||||
case kCooTensorValuesIndex:
|
||||
dst_abstract = coo_abs->values();
|
||||
pre_abstract_num = kCooTensorValuesIndex;
|
||||
break;
|
||||
case kCooTensorDenseShapeIndex:
|
||||
dst_abstract = coo_abs->values();
|
||||
pre_abstract_num = kCooTensorDenseShapeIndex;
|
||||
break;
|
||||
default:
|
||||
MS_LOG(EXCEPTION) << "Invalid index:" << index << " for abstract:" << abstract->ToString();
|
||||
}
|
||||
} else if (abstract->isa<abstract::AbstractTuple>()) {
|
||||
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
|
@ -394,8 +414,11 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
|
||||
// The node is divided into the following types:
|
||||
// 1. depend and load.
|
||||
const auto &node_with_index =
|
||||
common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem, prim::kPrimMakeTuple});
|
||||
const auto &node_with_index = common::AnfAlgo::VisitKernelWithReturnType(
|
||||
node, 0, false,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimMakeTuple, prim::kPrimCSRTensorGetIndptr, prim::kPrimCSRTensorGetIndices,
|
||||
prim::kPrimCSRTensorGetValues, prim::kPrimCSRTensorGetDenseShape, prim::kPrimCOOTensorGetIndices,
|
||||
prim::kPrimCOOTensorGetValues, prim::kPrimCOOTensorGetDenseShape});
|
||||
auto real_node = node_with_index.first;
|
||||
size_t real_index = node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_node);
|
||||
|
@ -435,12 +458,14 @@ std::vector<KernelWithIndex> FetchInputNodeByNode(const AnfNodePtr &node) {
|
|||
common::AnfAlgo::CheckPrimitiveType(src_node, prim::kPrimMakeCOOTensor)) {
|
||||
const auto &make_tensor_cnode = src_node->cast<CNodePtr>();
|
||||
const auto &make_tensor_inputs = make_tensor_cnode->inputs();
|
||||
if (make_tensor_inputs.size() <= kMakeCSRTensorInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid make csr tensor node:" << cnode->DebugString();
|
||||
if (make_tensor_inputs.size() <= kMakeCSRTensorInputNum && make_tensor_inputs.size() <= kMakeCOOTensorInputNum) {
|
||||
MS_LOG(EXCEPTION) << "Invalid make sparse tensor node:" << cnode->DebugString();
|
||||
}
|
||||
const auto &sub_results =
|
||||
FetchInputNodeByNode(make_tensor_inputs[LongToSize(iter->second) + kMakeTensorInputStartPos]);
|
||||
(void)results.insert(results.end(), sub_results.begin(), sub_results.end());
|
||||
} else if (src_node->isa<Parameter>()) {
|
||||
results.emplace_back(src_node, iter->second);
|
||||
} else {
|
||||
// Csr node from parameter or call node.
|
||||
auto abstract = src_node->abstract();
|
||||
|
|
|
@ -66,6 +66,7 @@ constexpr size_t kCooTensorDenseShapeIndex = 2;
|
|||
constexpr size_t kMakeCSRTensorInputStartPos = 1;
|
||||
constexpr size_t kMakeTensorInputStartPos = 1;
|
||||
constexpr size_t kMakeCSRTensorInputNum = 4;
|
||||
constexpr size_t kMakeCOOTensorInputNum = 3;
|
||||
|
||||
using NodeWithContext = std::pair<AnfNodePtr, DeviceContext *>;
|
||||
struct NodeWithContextCmp {
|
||||
|
|
|
@ -123,6 +123,7 @@ class IrExportBuilder {
|
|||
bool SetParamToTensorProto(const ParameterPtr ¶m, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto);
|
||||
bool SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||
bool SetAttributeProto(const AnfNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
bool SetAbstractToNodeProto(const CNodePtr &node, mind_ir::NodeProto *const node_proto);
|
||||
bool SetAbstractToNodeProto(const abstract::AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto);
|
||||
|
@ -549,6 +550,16 @@ bool IrExportBuilder::SetCSRTensorToProto(const AbstractBasePtr &abstract, mind_
|
|||
return SetAbstractToNodeProto(csr_tensor_abs->dense_shape(), dense_proto);
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetCOOTensorToProto(const AbstractBasePtr &abstract, mind_ir::AttributeProto *const attr_proto) {
|
||||
abstract::AbstractCOOTensorPtr coo_tensor_abs = abstract->cast<abstract::AbstractCOOTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(coo_tensor_abs);
|
||||
attr_proto->set_type(mind_ir::AttributeProto_AttributeType_COO_TENSOR);
|
||||
(void)SetTensorProto(coo_tensor_abs->indices(), attr_proto->add_tensors());
|
||||
(void)SetTensorProto(coo_tensor_abs->values(), attr_proto->add_tensors());
|
||||
auto dense_proto = attr_proto->add_values();
|
||||
return SetAbstractToNodeProto(coo_tensor_abs->dense_shape(), dense_proto);
|
||||
}
|
||||
|
||||
bool IrExportBuilder::SetTensorProto(const AbstractBasePtr &abstract, mind_ir::TensorProto *const tensor_proto) {
|
||||
auto type = abstract->BuildType();
|
||||
auto shape = abstract->BuildShape();
|
||||
|
@ -714,6 +725,11 @@ bool IrExportBuilder::SetAbstractToNodeProto(const AbstractBasePtr &abs, mind_ir
|
|||
if (!SetCSRTensorToProto(csr_tensor_abs, attr_proto)) {
|
||||
return false;
|
||||
}
|
||||
} else if (type->isa<COOTensorType>()) {
|
||||
auto coo_tensor_abs = abs->cast<abstract::AbstractCOOTensorPtr>();
|
||||
if (!SetCOOTensorToProto(coo_tensor_abs, attr_proto)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Type of cnode need to be supported: " << type->type_name();
|
||||
return false;
|
||||
|
|
|
@ -37,6 +37,12 @@ namespace {
|
|||
constexpr size_t kNopNodeInputSize = 2;
|
||||
constexpr size_t kNopNodeRealInputIndex = 1;
|
||||
|
||||
const PrimitiveSet expand_prims{
|
||||
prim::kPrimMakeTuple,
|
||||
prim::kPrimMakeCSRTensor,
|
||||
prim::kPrimMakeCOOTensor,
|
||||
prim::kPrimMakeRowTensor,
|
||||
};
|
||||
const std::set<std::string> kNodeTupleOutSet = {prim::kMakeTuple, prim::kGetNext};
|
||||
|
||||
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
||||
|
@ -106,12 +112,6 @@ void GetRealOutputRecursively(const AnfNodePtr &node, size_t output_index, std::
|
|||
std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node) {
|
||||
std::vector<KernelWithIndex> ret;
|
||||
std::vector<KernelWithIndex> ret_empty;
|
||||
const PrimitiveSet expand_prims{
|
||||
prim::kPrimMakeTuple,
|
||||
prim::kPrimMakeCSRTensor,
|
||||
prim::kPrimMakeCOOTensor,
|
||||
prim::kPrimMakeRowTensor,
|
||||
};
|
||||
// The MakeTuple/MakeSparse node need expand and recurse.
|
||||
if (IsOneOfPrimitiveCNode(node, expand_prims)) {
|
||||
auto make_tuple = node->cast<CNodePtr>();
|
||||
|
@ -157,7 +157,7 @@ std::vector<KernelWithIndex> GetAllOutputWithIndexInner(const AnfNodePtr &node)
|
|||
|
||||
// If the node is a call, the outputs num should get from the abstract.
|
||||
if (AnfAlgo::IsCallNode(node) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem) ||
|
||||
AnfAlgo::CheckAbsCSRTensor(node)) {
|
||||
AnfAlgo::CheckAbsCSRTensor(node) || AnfAlgo::CheckAbsCOOTensor(node)) {
|
||||
outputs_num = AnfAlgo::GetOutputNumByAbstract(node->abstract());
|
||||
}
|
||||
|
||||
|
@ -211,16 +211,28 @@ bool IsNodeDynamicShape(const AnfNodePtr &node) {
|
|||
|
||||
AnfNodePtr AnfAlgo::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
if (CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
|
||||
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
} else if (tuple_get_item->size() != kSparseGetAttrInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node sparse_get_attribute must have 1 input!";
|
||||
}
|
||||
return tuple_get_item->input(kRealInputNodeIndexInTupleGetItem);
|
||||
}
|
||||
|
||||
size_t AnfAlgo::GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_get_item);
|
||||
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
if (CheckPrimitiveType(tuple_get_item, prim::kPrimTupleGetItem)) {
|
||||
if (tuple_get_item->size() != kTupleGetItemInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
|
||||
}
|
||||
} else if (tuple_get_item->size() != kSparseGetAttrInputSize) {
|
||||
MS_LOG(EXCEPTION) << "The node sparse_get_attribute must have 1 input!";
|
||||
}
|
||||
std::string prim_name = GetCNodeFuncName(tuple_get_item);
|
||||
if (sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
|
||||
return sparse_attr_map.at(prim_name);
|
||||
}
|
||||
auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(output_index_value_node);
|
||||
|
@ -248,11 +260,13 @@ KernelWithIndex AnfAlgo::VisitKernelWithReturnType(const AnfNodePtr &anf_node, s
|
|||
}
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem)) {
|
||||
std::string prim_name = GetCNodeFuncName(cnode);
|
||||
// TupleGetItem and SparseGetAttr needs to find real input
|
||||
if (CheckPrimitiveType(cnode, prim::kPrimTupleGetItem) || sparse_attr_map.find(prim_name) != sparse_attr_map.end()) {
|
||||
abstract::AbstractBasePtr abs = nullptr;
|
||||
auto item_with_index_tmp = VisitKernelWithReturnType(
|
||||
GetTupleGetItemRealInput(cnode), GetTupleGetItemOutIndex(cnode), skip_nop_node, return_types, &abs);
|
||||
if (CheckPrimitiveType(item_with_index_tmp.first, prim::kPrimMakeTuple)) {
|
||||
if (IsOneOfPrimitiveCNode(item_with_index_tmp.first, expand_prims)) {
|
||||
MS_EXCEPTION_IF_NULL(item_with_index_tmp.first);
|
||||
auto make_tuple = item_with_index_tmp.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(make_tuple);
|
||||
|
|
|
@ -73,6 +73,12 @@ inline void CheckSparseIndicesDtype(const mindspore::TypePtr data_type, const st
|
|||
<< data_type->ToString() << ".";
|
||||
}
|
||||
}
|
||||
inline void CheckSparseIndicesDtypeInt32(const mindspore::TypePtr data_type, const std::string &arg_name) {
|
||||
if (!data_type->equal(mindspore::kInt32)) {
|
||||
MS_EXCEPTION(mindspore::TypeError) << "The dtype of " << arg_name << " only support Int32 for now, but got "
|
||||
<< data_type->ToString() << ".";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -485,6 +491,11 @@ AbstractBasePtr InferImplCSRElementWise(const AnalysisEnginePtr &, const Primiti
|
|||
MS_EXCEPTION_IF_NULL(sparse->indices());
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
auto indptr = sparse->indptr();
|
||||
auto indices = sparse->indices();
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
auto sparse_shape = sparse->shape()->shape();
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
CheckSparseShape(sparse_shape, dense_shape);
|
||||
|
@ -514,6 +525,11 @@ AbstractBasePtr InferImplCSRMV(const AnalysisEnginePtr &, const PrimitivePtr &pr
|
|||
MS_EXCEPTION_IF_NULL(sparse->indices());
|
||||
MS_EXCEPTION_IF_NULL(dense);
|
||||
|
||||
auto indptr = sparse->indptr();
|
||||
auto indices = sparse->indices();
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
auto sparse_shape = sparse->shape()->shape();
|
||||
auto dense_shape = dense->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRMVShapeSize || dense_shape.size() != kCSRMVShapeSize) {
|
||||
|
@ -553,8 +569,14 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(sparse->shape());
|
||||
MS_EXCEPTION_IF_NULL(sparse->values());
|
||||
MS_EXCEPTION_IF_NULL(sparse->indices());
|
||||
MS_EXCEPTION_IF_NULL(sparse->indptr());
|
||||
MS_EXCEPTION_IF_NULL(axis);
|
||||
|
||||
auto indptr = sparse->indptr();
|
||||
auto indices = sparse->indices();
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
auto sparse_shape = sparse->shape()->shape();
|
||||
if (sparse_shape.size() != kCSRReduceSumShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRReduceSumShapeSize << "-D inputs!"
|
||||
|
@ -566,7 +588,7 @@ AbstractBasePtr InferImplCSRReduceSum(const AnalysisEnginePtr &, const Primitive
|
|||
int64_t axis_value = GetValue<int64_t>(axis->BuildValue());
|
||||
int64_t dim = static_cast<int64_t>(sparse_shape.size());
|
||||
if (axis_value < -dim || axis_value >= dim || (axis_value != 1 && axis_value != -1)) {
|
||||
MS_EXCEPTION(ValueError) << "For CSRReduceSum, `axis` should be -1 or 1. But got `axis`: " << axis_value;
|
||||
MS_EXCEPTION(TypeError) << "For CSRReduceSum, `axis` should be -1 or 1. But got `axis`: " << axis_value;
|
||||
}
|
||||
if (axis_value < 0) {
|
||||
axis_value += dim;
|
||||
|
@ -605,6 +627,9 @@ AbstractBasePtr InferImplCSRGather(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
MS_EXCEPTION_IF_NULL(dense);
|
||||
MS_EXCEPTION_IF_NULL(sparse_shape);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
CheckSparseIndicesDtypeInt32(indices->element()->BuildType(), "Indices");
|
||||
|
||||
if (sparse_shape->size() != kCSRShapeSize) {
|
||||
MS_EXCEPTION(ValueError) << "Currently, only support " << kCSRShapeSize << "-D inputs!"
|
||||
<< "But sparse tensor has " << sparse_shape->size() << " dimensions.";
|
||||
|
@ -631,6 +656,8 @@ AbstractBasePtr InferImplCSR2COO(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, kCSRArgsSize);
|
||||
auto indptr = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
CheckSparseIndicesDtypeInt32(indptr->element()->BuildType(), "Indptr");
|
||||
|
||||
auto nnz = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(indptr);
|
||||
MS_EXCEPTION_IF_NULL(nnz);
|
||||
|
@ -665,7 +692,7 @@ AbstractBasePtr InferImplCOO2CSR(const AnalysisEnginePtr &, const PrimitivePtr &
|
|||
auto height = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(row_indices);
|
||||
MS_EXCEPTION_IF_NULL(height);
|
||||
|
||||
CheckSparseIndicesDtypeInt32(row_indices->element()->BuildType(), "row_indices");
|
||||
MS_EXCEPTION_IF_NULL(height->BuildValue());
|
||||
ShapeVector out_shape;
|
||||
if (height->BuildValue()->isa<Int32Imm>() || height->BuildValue()->isa<Int64Imm>()) {
|
||||
|
|
|
@ -126,7 +126,8 @@ TypePtr TypeIdToType(TypeId id) {
|
|||
{kObjectTypeIOMonad, kIOMonadType},
|
||||
{kTypeUnknown, kTypeNone},
|
||||
{kMetaTypeProblem, kTypeNone},
|
||||
{kObjectTypeCSRTensorType, kCSRTensorType}};
|
||||
{kObjectTypeCSRTensorType, kCSRTensorType},
|
||||
{kObjectTypeCOOTensorType, kCOOTensorType}};
|
||||
const auto &it = type_id_to_type.find(id);
|
||||
if (it == type_id_to_type.end()) {
|
||||
MS_LOG(EXCEPTION) << "Not support the type: " << GetExcptionTypeString(id);
|
||||
|
|
|
@ -1930,8 +1930,7 @@ def csr_abs(x):
|
|||
|
||||
|
||||
def csr_mv(x, dense_vector):
|
||||
"""Implementation of `abs` for CSRTensor."""
|
||||
check_value_type('dense_vector', dense_vector, (Tensor,), 'CSRTensor.mv')
|
||||
"""Implementation of `mv` for CSRTensor."""
|
||||
return F.csr_mv(x, dense_vector)
|
||||
|
||||
|
||||
|
|
|
@ -2413,18 +2413,13 @@ class SparseTensor(COOTensor_):
|
|||
(3, 4)
|
||||
"""
|
||||
|
||||
def __init__(self, indices=None, values=None, shape=None, coo_tensor=None):
|
||||
def __init__(self, indices, values, shape):
|
||||
"Init COOTensor"
|
||||
print("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +
|
||||
"Please use 'COOTensor' instead.")
|
||||
if indices is None and values is None and shape is None and coo_tensor is not None:
|
||||
if not isinstance(coo_tensor, (COOTensor, COOTensor_)):
|
||||
raise TypeError("If only one input provided, it must be a COOTensor.")
|
||||
COOTensor_.__init__(self, coo_tensor)
|
||||
else:
|
||||
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
|
||||
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
|
||||
COOTensor_.__init__(self, indices, values, shape)
|
||||
if not (isinstance(indices, Tensor) and isinstance(values, Tensor) and isinstance(shape, tuple)):
|
||||
raise TypeError("Inputs must follow: COOTensor(indices, values, shape).")
|
||||
COOTensor_.__init__(self, indices, values, shape)
|
||||
|
||||
@property
|
||||
def indices(self):
|
||||
|
|
|
@ -1023,7 +1023,7 @@ def print_info(info):
|
|||
|
||||
def make_sparse_tensor(indices, values, dense_shape):
|
||||
"""Call make_coo_tensor in this function."""
|
||||
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +
|
||||
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " + \
|
||||
"Please use 'COOTensor' instead.")
|
||||
return make_coo_tensor(indices, values, dense_shape)
|
||||
|
||||
|
|
|
@ -59,6 +59,60 @@ def test_make_coo():
|
|||
compare_coo(coo3, coo2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_coo_tensor_with_control_if():
|
||||
"""
|
||||
Feature: Test COOTensor in if.
|
||||
Description: Test COOTensor computation in while loop.
|
||||
Expectation: Success.
|
||||
"""
|
||||
class COOTensorValuesDouble(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values * 2
|
||||
shape = x.shape
|
||||
return COOTensor(indices, values, shape)
|
||||
|
||||
class COOTensorValuesAdd2(nn.Cell):
|
||||
|
||||
def construct(self, x):
|
||||
indices = x.indices
|
||||
values = x.values + 2
|
||||
shape = x.shape
|
||||
return COOTensor(indices, values, shape)
|
||||
|
||||
class COOTensorWithControlIf(nn.Cell):
|
||||
def __init__(self, shape):
|
||||
super(COOTensorWithControlIf, self).__init__()
|
||||
self.op1 = COOTensorValuesDouble()
|
||||
self.op2 = COOTensorValuesAdd2()
|
||||
self.shape = shape
|
||||
|
||||
def construct(self, a, b, indices, values):
|
||||
x = COOTensor(indices, values, self.shape)
|
||||
if a > b:
|
||||
x = self.op1(x)
|
||||
else:
|
||||
x = self.op2(x)
|
||||
return x.indices, x.values, x.shape
|
||||
|
||||
a = Tensor(0, mstype.int32)
|
||||
b = Tensor(2, mstype.int32)
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=mstype.float32)
|
||||
shape = (3, 4)
|
||||
net = COOTensorWithControlIf(shape)
|
||||
out = net(a, b, indices, values)
|
||||
assert np.allclose(out[0].asnumpy(), indices.asnumpy(), .0, .0)
|
||||
assert np.allclose(out[1].asnumpy(), values.asnumpy() + 2, .0, .0)
|
||||
assert out[2] == shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
|
|
Loading…
Reference in New Issue