!3114 add coo_tensor

Merge pull request !3114 from riemann_penn/coo_tensor
This commit is contained in:
mindspore-ci-bot 2020-07-18 10:24:58 +08:00 committed by Gitee
commit 4a19e6b8cb
36 changed files with 652 additions and 76 deletions

View File

@ -17,7 +17,7 @@
"""Resources for ast tree parse.""" """Resources for ast tree parse."""
import ast import ast
import math import math
from mindspore import IndexedSlices from mindspore import IndexedSlices, SparseTensor
from mindspore.ops.composite import multitype_ops from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C from mindspore.ops import functional as F, composite as C
from . import standard_method as M from . import standard_method as M
@ -140,4 +140,5 @@ convert_object_map = {
# user defined # user defined
IndexedSlices: F.make_indexed_slices, IndexedSlices: F.make_indexed_slices,
SparseTensor: F.make_sparse_tensor,
} }

View File

@ -124,6 +124,8 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
// Do Nothing // Do Nothing
} else if (type->isa<UndeterminedType>()) { } else if (type->isa<UndeterminedType>()) {
// Do Nothing // Do Nothing
} else if (type->isa<SparseTensorType>()) {
// Do Nothing
} else if (type->isa<Tuple>()) { } else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type); TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE); type_proto->set_data_type(irpb::DT_TUPLE);

View File

@ -803,6 +803,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a); abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b); abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
if (a_tuple == nullptr || b_tuple == nullptr) { if (a_tuple == nullptr || b_tuple == nullptr) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType();
});
auto stub = GenerateStubFunc(types);
if (stub != nullptr) {
MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
<< ", function: " << stub->ToString();
return stub;
}
MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
<< args_spec_list[1]->ToString(); << args_spec_list[1]->ToString();
} }

View File

@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
return py::none(); return py::none();
} }
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
return nullptr;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
auto py_fn = SignMatch(types); auto py_fn = SignMatch(types);
std::ostringstream buffer; std::ostringstream buffer;

View File

@ -283,6 +283,11 @@ const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeInd
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
// SparseTensor
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

View File

@ -292,7 +292,12 @@ extern const PrimitivePtr kPrimMakeIndexedSlices;
extern const PrimitivePtr kPrimIndexedSlicesGetValues; extern const PrimitivePtr kPrimIndexedSlicesGetValues;
extern const PrimitivePtr kPrimIndexedSlicesGetIndices; extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
extern const PrimitivePtr kPrimIsIndexedSlices;
// SparseTensor
extern const PrimitivePtr kPrimMakeSparseTensor;
extern const PrimitivePtr kPrimSparseTensorGetValues;
extern const PrimitivePtr kPrimSparseTensorGetIndices;
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; const char SWITCH_UNROLL_FLAG[] = "unroll_flag";

View File

@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2); auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>(); auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value); MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value(); auto shp = dense_shape_value->value();
@ -358,6 +378,12 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto elem = GetValue<int>(e); auto elem = GetValue<int>(e);
return elem; return elem;
}); });
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
}
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec); auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices); ret->set_indices(indices);
ret->set_values(values); ret->set_values(values);
@ -395,16 +421,89 @@ AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, c
return indexed_slices->dense_shape(); return indexed_slices->dense_shape();
} }
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 2) {
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (values_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
<< " dimension tensor";
}
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
std::vector<int> dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int {
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
}
auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
return ret;
}
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1); CheckArgsSize(op_name, args_spec_list, 1);
bool ret = false; auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) { MS_EXCEPTION_IF_NULL(sparse_tensor->values());
ret = true; return sparse_tensor->values();
} }
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
return std::make_shared<AbstractScalar>(ret); AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
return sparse_tensor->indices();
}
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
return sparse_tensor->dense_shape();
} }
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -264,7 +264,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
return IsPrimitiveCNode(user.first, prim); return IsPrimitiveCNode(user.first, prim);
}); });
if (cnode == users.end()) { if (cnode == users.end()) {
MS_LOG(EXCEPTION) << "Fail to find cnode."; MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
} }
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1; auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;

View File

@ -43,6 +43,7 @@
#include "frontend/optimizer/irpass/transpose_eliminate.h" #include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/opt.h" #include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h" #include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -159,6 +160,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
indexed_slices_eliminate_ = MakeSubstitution( indexed_slices_eliminate_ = MakeSubstitution(
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate", std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
// SparseTensor Eliminate
sparse_tensor_eliminate_ = MakeSubstitution(
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {

View File

@ -107,6 +107,9 @@ class OptimizeIRPassLib {
// IndexedSlices Eliminate // IndexedSlices Eliminate
SubstitutionPtr indexed_slices_eliminate_; SubstitutionPtr indexed_slices_eliminate_;
// SparseTensor Eliminate
SubstitutionPtr sparse_tensor_eliminate_;
}; };
// the collection of irpass for resolve action // the collection of irpass for resolve action

View File

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
class SparseTensorEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);
if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);
if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);
if (is_match_) {
return tuple_->input(3);
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}
void Reset() {
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_

View File

@ -157,6 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_, irpass.get_ref_param_eliminate_,
irpass.indexed_slices_eliminate_, irpass.indexed_slices_eliminate_,
irpass.sparse_tensor_eliminate_,
}); });
OptPassGroupMap map({ OptPassGroupMap map({
{"b_1", b_1}, {"b_1", b_1},

View File

@ -179,6 +179,12 @@ MethodMap &GetMethodMap() {
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
}}, }},
{kObjectTypeSparseTensorType,
{
{"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
}},
{kObjectTypeJTagged, {}}, {kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}}, {kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}}; {kObjectTypeEnvType, {}}};

View File

@ -138,7 +138,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, // SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
}; };
return prim_eval_implement_map; return prim_eval_implement_map;
} }

View File

@ -358,7 +358,13 @@ AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, cons
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -36,6 +36,7 @@ using mindspore::abstract::AbstractIndexedSlices;
using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor;
using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractType; using mindspore::abstract::AbstractType;
@ -95,7 +96,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() || ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
ptrBase->isa<abstract::AbstractRefKey>()) { ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
return; return;
} }

View File

@ -17,10 +17,10 @@ from . import dtype
from .api import ms_function from .api import ms_function
from .dtype import * from .dtype import *
from .parameter import Parameter, ParameterTuple from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, IndexedSlices from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor
__all__ = [ __all__ = [
"MetaTensor", "Tensor", "IndexedSlices", # tensor "MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor
'ms_function', # api 'ms_function', # api
'Parameter', 'ParameterTuple', # parameter 'Parameter', 'ParameterTuple', # parameter
"dtype" "dtype"

View File

@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices'] __all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64, np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_) np.float32, np.float64, np.bool_)
@ -211,3 +211,7 @@ class Tensor(Tensor_):
class IndexedSlices: class IndexedSlices:
def __init__(self, indices, values, dense_shape): def __init__(self, indices, values, dense_shape):
raise NotImplementedError raise NotImplementedError
class SparseTensor:
def __init__(self, indices, values, dense_shape):
raise NotImplementedError

View File

@ -1093,5 +1093,64 @@ std::string AbstractIndexedSlices::ToString() const {
<< ", dense_shape: " << dense_shape_->ToString(); << ", dense_shape: " << dense_shape_->ToString();
return buffer.str(); return buffer.str();
} }
// SparseTensor
TypePtr AbstractSparseTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element());
TypePtr element_type = element()->BuildType();
return std::make_shared<SparseTensorType>(element_type);
}
AbstractBasePtr AbstractSparseTensor::Clone() const {
MS_EXCEPTION_IF_NULL(element());
auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return clone;
}
AbstractBasePtr AbstractSparseTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape()->Clone();
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
std::string AbstractSparseTensor::ToString() const {
std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack();
MS_EXCEPTION_IF_NULL(shape_track);
MS_EXCEPTION_IF_NULL(element());
auto value_track = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
<< ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

View File

@ -604,10 +604,39 @@ class AbstractIndexedSlices : public AbstractUndetermined {
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; } const AbstractTensorPtr indices() const { return indices_; }
const AbstractTensorPtr values() const { return values_; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; } void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
private:
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
// SparseTensor
class AbstractSparseTensor : public AbstractUndetermined {
public:
explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractSparseTensor(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractUndetermined(element_type, shape) {}
~AbstractSparseTensor() override = default;
MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;

View File

@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class) ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
ABSTRACT_REPORT_NAME_TRAITS(Sequeue) ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
template <typename T> template <typename T>

View File

@ -221,6 +221,48 @@ bool IndexedSlicesType::operator==(const Type &other) const {
return *element_type_ == *other_elem_type; return *element_type_ == *other_elem_type;
} }
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
Function::Function() : Object(kObjectTypeFunction) { Function::Function() : Object(kObjectTypeFunction) {
args_ = std::vector<TypePtr>(); args_ = std::vector<TypePtr>();
retval_ = nullptr; retval_ = nullptr;

View File

@ -177,6 +177,29 @@ class IndexedSlicesType : public Object {
}; };
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>; using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
class Function : public Object { class Function : public Object {
public: public:
Function(); Function();

View File

@ -117,6 +117,8 @@ const char *ObjectIdLabel(const TypeId &v) {
return "kObjectTypeTensorType"; return "kObjectTypeTensorType";
case kObjectTypeIndexedSlicesType: case kObjectTypeIndexedSlicesType:
return "kObjectTypeIndexedSlicesType"; return "kObjectTypeIndexedSlicesType";
case kObjectTypeSparseTensorType:
return "kObjectTypeSparseTensorType";
case kObjectTypeUndeterminedType: case kObjectTypeUndeterminedType:
return "kObjectTypeUndeterminedType"; return "kObjectTypeUndeterminedType";
case kObjectTypeDictionary: case kObjectTypeDictionary:

View File

@ -51,6 +51,7 @@ enum TypeId : int {
kObjectTypeKeyword, kObjectTypeKeyword,
kObjectTypeTensorType, kObjectTypeTensorType,
kObjectTypeIndexedSlicesType, kObjectTypeIndexedSlicesType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType, kObjectTypeUndeterminedType,
kObjectTypeClass, kObjectTypeClass,
kObjectTypeDictionary, kObjectTypeDictionary,

View File

@ -207,6 +207,23 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) {
return std::make_shared<IndexedSlicesType>(element_type); return std::make_shared<IndexedSlicesType>(element_type);
} }
TypePtr SparseTensorStrToType(const std::string &type_name) {
if (type_name == "SparseTensor") {
return std::make_shared<SparseTensorType>();
}
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
return std::make_shared<SparseTensorType>(element_type);
}
TypePtr UndeterminedStrToType(const std::string &type_name) { TypePtr UndeterminedStrToType(const std::string &type_name) {
if (type_name == "Undetermined") { if (type_name == "Undetermined") {
return std::make_shared<UndeterminedType>(); return std::make_shared<UndeterminedType>();
@ -349,6 +366,8 @@ TypePtr StringToType(const std::string &type_name) {
type = UndeterminedStrToType(type_name); type = UndeterminedStrToType(type_name);
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
type = IndexedSlicesStrToType(type_name); type = IndexedSlicesStrToType(type_name);
} else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
type = SparseTensorStrToType(type_name);
} else if (type_name.compare(0, strlen("List"), "List") == 0) { } else if (type_name.compare(0, strlen("List"), "List") == 0) {
type = ListStrToType(type_name); type = ListStrToType(type_name);
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
@ -428,6 +447,7 @@ const TypePtr kTypeEnv = std::make_shared<EnvType>();
const TypePtr kTypeType = std::make_shared<TypeType>(); const TypePtr kTypeType = std::make_shared<TypeType>();
const TypePtr kTensorType = std::make_shared<TensorType>(); const TypePtr kTensorType = std::make_shared<TensorType>();
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>(); const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>();
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>(); const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
const TypePtr kString = std::make_shared<String>(); const TypePtr kString = std::make_shared<String>();
const TypePtr kList = std::make_shared<List>(); const TypePtr kList = std::make_shared<List>();

View File

@ -139,6 +139,8 @@ REGISTER_PYBIND_DEFINE(
})); }));
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType") (void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
.def(py::init()); .def(py::init());
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
.def(py::init());
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType") (void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
.def(py::init()); .def(py::init());
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function") (void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")

View File

@ -17,9 +17,49 @@
*/ */
#include "ir/meta_func_graph.h" #include "ir/meta_func_graph.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "utils/context/ms_context.h"
#include "frontend/operator/ops.h"
// namespace to support intermediate representation definition // namespace to support intermediate representation definition
namespace mindspore { namespace mindspore {
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
return nullptr;
}
FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
TypePtrList types; TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),

View File

@ -79,6 +79,7 @@ class MetaFuncGraph : public FuncGraphBase {
std::shared_ptr<Derived> shared_from_base() { std::shared_ptr<Derived> shared_from_base() {
return std::static_pointer_cast<Derived>(shared_from_this()); return std::static_pointer_cast<Derived>(shared_from_this());
} }
FuncGraphPtr GenerateStubFunc(const TypePtrList &types);
std::string name_; std::string name_;
std::vector<Signature> signatures_; std::vector<Signature> signatures_;
std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_; std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_;

View File

@ -40,18 +40,12 @@ class ParamValue {
const std::string &name() const { return name_; } const std::string &name() const { return name_; }
void set_name(const std::string &name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
const std::string &sparse_grad() const { return sparse_grad_; }
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
bool requires_grad() const { return requires_grad_; } bool requires_grad() const { return requires_grad_; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
bool layerwise_parallel() const { return layerwise_parallel_; } bool layerwise_parallel() const { return layerwise_parallel_; }
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; } void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; }
// Whether the parameter clone from other parameter. // Whether the parameter clone from other parameter.
bool cloned() const { return cloned_; } bool cloned() const { return cloned_; }
@ -81,10 +75,8 @@ class ParamValue {
private: private:
tensor::MetaTensorPtr value_; tensor::MetaTensorPtr value_;
std::string name_{"Parameter"}; std::string name_{"Parameter"};
std::string sparse_grad_;
bool requires_grad_{true}; bool requires_grad_{true};
bool layerwise_parallel_{false}; bool layerwise_parallel_{false};
bool has_indexed_slices_grad_{false};
bool be_cloned_{false}; bool be_cloned_{false};
bool cloned_{false}; bool cloned_{false};
std::vector<int32_t> be_cloned_index_; std::vector<int32_t> be_cloned_index_;

View File

@ -29,14 +29,10 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad) .def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel, .def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
&ParamValue::set_layerwise_parallel) &ParamValue::set_layerwise_parallel)
.def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad,
&ParamValue::set_has_indexed_slices_grad)
.def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad)
.def(py::pickle( .def(py::pickle(
[](const ParamValue &p) { // __getstate__ [](const ParamValue &p) { // __getstate__
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(), return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
p.layerwise_parallel(), p.has_indexed_slices_grad(), p.layerwise_parallel());
p.sparse_grad());
}, },
[](const py::tuple &t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 6) { if (t.size() != 6) {
@ -47,8 +43,6 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
p->set_name(t[1].cast<std::string>()); p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>()); p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>()); p->set_layerwise_parallel(t[3].cast<bool>());
p->set_has_indexed_slices_grad(t[4].cast<bool>());
p->set_sparse_grad(t[5].cast<std::string>());
return p; return p;
})); }));
})); }));

View File

@ -159,6 +159,10 @@ indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
make_sparse_tensor = Primitive('MakeSparseTensor')
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__add__', tensor_add)
tensor_operator_registry.register('__sub__', tensor_sub) tensor_operator_registry.register('__sub__', tensor_sub)

View File

@ -616,5 +616,18 @@ TEST_F(TestOptLib, test_indexed_slices) {
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
} }
TEST_F(TestOptLib, test_sparse_tensor) {
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices");
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices");
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values");
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values");
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape");
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape");
auto patterns = std::vector<SubstitutionPtr>({irpass.sparse_tensor_eliminate_});
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -1163,3 +1163,38 @@ def test_indexed_slices(tag):
return z return z
return fns[tag] return fns[tag]
def test_sparse_tensor(tag):
""" test_add_zero """
fns = FnDict()
make_sparse_tensor = Primitive('MakeSparseTensor')
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
@fns
def before_get_indices(x, y, z):
return sparse_tensor_get_indices(make_sparse_tensor(x, y, z))
@fns
def after_get_indices(x, y, z):
return x
@fns
def before_get_values(x, y, z):
return sparse_tensor_get_values(make_sparse_tensor(x, y, z))
@fns
def after_get_values(x, y, z):
return y
@fns
def before_get_dense_shape(x, y, z):
return sparse_tensor_get_dense_shape(make_sparse_tensor(x, y, z))
@fns
def after_get_dense_shape(x, y, z):
return z
return fns[tag]

View File

@ -35,6 +35,9 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from mindspore.nn import Optimizer from mindspore.nn import Optimizer
from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.train import Model
from ....dataset_mock import MindData
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
@ -47,6 +50,40 @@ size_op = P.Size()
invert_permutation = P.InvertPermutation() invert_permutation = P.InvertPermutation()
logical_and = P.LogicalAnd() logical_and = P.LogicalAnd()
def get_axis(x):
shape = shape_op(x)
length = F.tuple_len(shape)
perm = F.make_range(0, length)
return perm
class MSELoss(nn.Cell):
def __init__(self):
super(MSELoss, self).__init__()
self.reduce_sum = P.ReduceSum()
self.square = P.Square()
self.reduce_mean = P.ReduceMean()
def construct(self, data, label):
diff = data - label
return self.reduce_mean(self.square(diff), get_axis(diff))
class MindDataSet(MindData):
def __init__(self, dataset_types, dataset_shapes):
super(MindDataSet, self).__init__(size=2, batch_size=32,
np_types=dataset_types,
output_shapes=dataset_shapes,
input_indexs=(0, 1))
def __next__(self):
if self._size < self._iter_num:
raise StopIteration
self._iter_num += 1
lst = []
for shape_, type_ in zip(self._output_shapes, self._np_types):
lst.append(Tensor(np.ones(shape_).astype(type_)))
return tuple(lst)
@constexpr @constexpr
def _generate_shape_index(out_shape, indices_shape, axis): def _generate_shape_index(out_shape, indices_shape, axis):
out_rank = len(out_shape) out_rank = len(out_shape)
@ -189,8 +226,8 @@ def test_indexed_slices_make_indexed_slices():
def construct(self, indices, values): def construct(self, indices, values):
ret = (IndexedSlices(indices, values, self.dense_shape),) ret = (IndexedSlices(indices, values, self.dense_shape),)
return ret[0] return ret[0]
indices = Tensor([[0, 0], [1, 2]]) indices = Tensor([1, 2])
values = Tensor([1, 2], dtype=ms.float32) values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
MakeIndexedSlices()(indices, values) MakeIndexedSlices()(indices, values)
@ -202,8 +239,8 @@ def test_indexed_slices_attr():
def construct(self, indices, values): def construct(self, indices, values):
x = IndexedSlices(indices, values, self.dense_shape) x = IndexedSlices(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape() return x.values(), x.indices(), x.dense_shape()
indices = Tensor([[0, 0], [1, 2]]) indices = Tensor([0])
values = Tensor([1, 2], dtype=ms.float32) values = Tensor([[1, 2]], dtype=ms.float32)
IndexedSlicesGetAttr()(indices, values) IndexedSlicesGetAttr()(indices, values)
@ -279,3 +316,29 @@ def test_indexed_slices_env_get():
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network(inputs, label) train_network(inputs, label)
def test_indexed_slices_model_train():
class Net(nn.Cell):
def __init__(self, in_features, out_features):
super(Net, self).__init__()
self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
self.add = P.TensorAdd()
self.cast = P.Cast()
self.flag = True
def construct(self, inputs, label):
x = self.add(inputs, self.weight)
if self.flag:
x = self.cast(x, mstype.float32)
return x
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))
dataset = MindDataSet(dataset_types, dataset_shapes)
net = Net(16, 16)
net.set_train()
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)

View File

@ -0,0 +1,61 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
@File : test_sparse_tensor.py
@Author:
@Date : 2020-07-16
@Desc : test mindspore sparse_tensor's operation
"""
import mindspore as ms
import mindspore.nn as nn
from mindspore.ops import composite as C
from mindspore import Tensor, SparseTensor, context
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
def test_sparse_tensor_make_sparse_tensor():
class MakeSparseTensor(nn.Cell):
def __init__(self):
super(MakeSparseTensor, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
ret = (SparseTensor(indices, values, self.dense_shape),)
return ret[0]
indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
MakeSparseTensor()(indices, values)
def test_sparse_tensor_attr():
grad_op = C.GradOperation('get_all', get_all=True)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, input1, input2):
gout = grad_op(self.network)(input1, input2)
return gout
class SparseTensorGetAttr(nn.Cell):
def __init__(self):
super(SparseTensorGetAttr, self).__init__()
self.dense_shape = (3, 4)
def construct(self, indices, values):
x = SparseTensor(indices, values, self.dense_shape)
return x.values(), x.indices(), x.dense_shape()
indices = Tensor([[0, 1], [1, 2]])
values = Tensor([1, 2], dtype=ms.float32)
SparseTensorGetAttr()(indices, values)