diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index e60b70eface..e2b83331f52 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -17,7 +17,7 @@ """Resources for ast tree parse.""" import ast import math -from mindspore import IndexedSlices +from mindspore import IndexedSlices, SparseTensor from mindspore.ops.composite import multitype_ops from mindspore.ops import functional as F, composite as C from . import standard_method as M @@ -140,4 +140,5 @@ convert_object_map = { # user defined IndexedSlices: F.make_indexed_slices, + SparseTensor: F.make_sparse_tensor, } diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index 35cdfafe26e..9172d11471a 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -124,6 +124,8 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s // Do Nothing } else if (type->isa()) { // Do Nothing + } else if (type->isa()) { + // Do Nothing } else if (type->isa()) { TuplePtr tuple_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_TUPLE); diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 7d2573e50ab..0586572dd1f 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -803,6 +803,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li abstract::AbstractTuplePtr a_tuple = dyn_cast(abs_a); abstract::AbstractTuplePtr b_tuple = dyn_cast(abs_b); if (a_tuple == nullptr || b_tuple == nullptr) { + TypePtrList types; + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), + [](const AbstractBasePtr &arg) -> TypePtr { + MS_EXCEPTION_IF_NULL(arg); + return arg->BuildType(); + }); + auto stub = GenerateStubFunc(types); + if (stub != nullptr) { + MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd " + << ", function: " << stub->ToString(); + return stub; + } MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " << args_spec_list[1]->ToString(); } diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc index ba0d3d9ebb8..16aa6f654bd 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { return py::none(); } -FuncGraphPtr GenerateStubFunc(const TypePtrList &types) { - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool enable_sparse = context->enable_sparse(); - if (!enable_sparse) { - return nullptr; - } - - std::vector parameters; - ParameterPtr undetermined_param = nullptr; - auto stub = std::make_shared(); - for (size_t i = 0; i < types.size(); ++i) { - auto param = stub->add_parameter(); - parameters.push_back(param); - if (types[i]->type_id() == kObjectTypeUndeterminedType) { - undetermined_param = param; - } - } - if (undetermined_param != nullptr) { - std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; - for (size_t i = 0; i < types.size(); ++i) { - if (types[i]->type_id() == kObjectTypeFunction) { - std::vector call_prim{parameters[i], undetermined_param}; - inputs.push_back(stub->NewCNode(call_prim)); - } else { - inputs.push_back(parameters[i]); - } - } - auto stub_output = stub->NewCNode(inputs); - stub->set_output(stub_output); - stub->set_stub(true); - return stub; - } - return nullptr; -} - FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { auto py_fn = SignMatch(types); std::ostringstream buffer; diff --git a/mindspore/ccsrc/frontend/operator/ops.cc b/mindspore/ccsrc/frontend/operator/ops.cc index 5c7672ee3c6..bf3d55678e1 100755 --- a/mindspore/ccsrc/frontend/operator/ops.cc +++ b/mindspore/ccsrc/frontend/operator/ops.cc @@ -283,6 +283,11 @@ const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeInd const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); -const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); + +// SparseTensor +const PrimitivePtr kPrimMakeSparseTensor = std::make_shared("MakeSparseTensor"); +const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared("SparseTensorGetValues"); +const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared("SparseTensorGetIndices"); +const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared("SparseTensorGetDenseShape"); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 0dea045a6ea..d57b681ff26 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -292,7 +292,12 @@ extern const PrimitivePtr kPrimMakeIndexedSlices; extern const PrimitivePtr kPrimIndexedSlicesGetValues; extern const PrimitivePtr kPrimIndexedSlicesGetIndices; extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; -extern const PrimitivePtr kPrimIsIndexedSlices; + +// SparseTensor +extern const PrimitivePtr kPrimMakeSparseTensor; +extern const PrimitivePtr kPrimSparseTensorGetValues; +extern const PrimitivePtr kPrimSparseTensorGetIndices; +extern const PrimitivePtr kPrimSparseTensorGetDenseShape; // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/ccsrc/frontend/operator/prim_others.cc index 530ad6a10c9..25f41860f68 100644 --- a/mindspore/ccsrc/frontend/operator/prim_others.cc +++ b/mindspore/ccsrc/frontend/operator/prim_others.cc @@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim auto values = CheckArg(op_name, args_spec_list, 1); auto dense_shape = CheckArg(op_name, args_spec_list, 2); + auto indices_dtype = indices->element()->BuildType(); + if (!indices_dtype->isa()) { + MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString(); + } + auto indices_shp = indices->shape()->shape(); + if (indices_shp.size() != 1) { + MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size() + << " dimension tensor"; + } + auto values_shp = values->shape()->shape(); + if (indices_shp[0] != values_shp[0]) { + MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values " + << values_shp[0] << ", but got " << indices_shp[0]; + } + + for (auto elem_type : dense_shape->ElementsType()) { + if (!elem_type->isa()) { + MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString(); + } + } auto dense_shape_value = dense_shape->BuildValue()->cast(); MS_EXCEPTION_IF_NULL(dense_shape_value); auto shp = dense_shape_value->value(); @@ -358,6 +378,12 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim auto elem = GetValue(e); return elem; }); + for (auto dense_shape_elem : dense_shape_vec) { + if (dense_shape_elem < 0) { + MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got " + << dense_shape_value->ToString(); + } + } auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); ret->set_indices(indices); ret->set_values(values); @@ -395,16 +421,89 @@ AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, c return indexed_slices->dense_shape(); } -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list) { +AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto indices = CheckArg(op_name, args_spec_list, 0); + auto values = CheckArg(op_name, args_spec_list, 1); + auto dense_shape = CheckArg(op_name, args_spec_list, 2); + + auto indices_dtype = indices->element()->BuildType(); + if (!indices_dtype->isa()) { + MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString(); + } + auto indices_shp = indices->shape()->shape(); + if (indices_shp.size() != 2) { + MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size() + << " dimension tensor"; + } + auto values_shp = values->shape()->shape(); + if (values_shp.size() != 1) { + MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size() + << " dimension tensor"; + } + if (indices_shp[0] != values_shp[0]) { + MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values " + << values_shp[0] << ", but got " << indices_shp[0]; + } + + for (auto elem_type : dense_shape->ElementsType()) { + if (!elem_type->isa()) { + MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString(); + } + } + auto dense_shape_value = dense_shape->BuildValue()->cast(); + MS_EXCEPTION_IF_NULL(dense_shape_value); + auto shp = dense_shape_value->value(); + std::vector dense_shape_vec; + (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), + [](const ValuePtr &e) -> int { + auto elem = GetValue(e); + return elem; + }); + for (auto dense_shape_elem : dense_shape_vec) { + if (dense_shape_elem < 0) { + MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got " + << dense_shape_value->ToString(); + } + } + auto ret = std::make_shared(values->element()->BuildType(), dense_shape_vec); + ret->set_indices(indices); + ret->set_values(values); + ret->set_dense_shape(dense_shape); + return ret; +} + +AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. const std::string op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 1); - bool ret = false; - if (args_spec_list[0]->isa()) { - ret = true; - } - MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); - return std::make_shared(ret); + auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(sparse_tensor->values()); + return sparse_tensor->values(); +} + +AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(sparse_tensor->indices()); + return sparse_tensor->indices(); +} + +AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors and a tuple. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto sparse_tensor = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape()); + return sparse_tensor->dense_shape(); } } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 5ca2ca6c43d..aa76d279d53 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -264,7 +264,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re return IsPrimitiveCNode(user.first, prim); }); if (cnode == users.end()) { - MS_LOG(EXCEPTION) << "Fail to find cnode."; + MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString(); } auto inputs_num = cnode->first->cast()->inputs().size() - 1; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index efc3795a4cc..23321074f7f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -43,6 +43,7 @@ #include "frontend/optimizer/irpass/transpose_eliminate.h" #include "frontend/optimizer/opt.h" #include "frontend/optimizer/irpass/indexed_slices_eliminate.h" +#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" namespace mindspore { namespace opt { @@ -159,6 +160,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() { indexed_slices_eliminate_ = MakeSubstitution( std::make_shared(), "indexed_slices_eliminate", {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); + + // SparseTensor Eliminate + sparse_tensor_eliminate_ = MakeSubstitution( + std::make_shared(), "sparse_tensor_eliminate", + {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape}); } ResolveIRPassLib::ResolveIRPassLib() { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 4af8c0789dc..718302a1e01 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -107,6 +107,9 @@ class OptimizeIRPassLib { // IndexedSlices Eliminate SubstitutionPtr indexed_slices_eliminate_; + + // SparseTensor Eliminate + SubstitutionPtr sparse_tensor_eliminate_; }; // the collection of irpass for resolve action diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h new file mode 100644 index 00000000000..ac8f2449f3c --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/sparse_tensor_eliminate.h @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_ +#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_ + +#include +#include + +#include "frontend/optimizer/irpass.h" +#include "frontend/optimizer/optimizer.h" +#include "ir/visitor.h" +#include "frontend/operator/ops.h" + +namespace mindspore { +namespace opt { +namespace irpass { +// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}} +// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}} +// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}} +class SparseTensorEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + Reset(); + AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(1); + } + AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(2); + } + AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node); + + if (is_match_) { + return tuple_->input(3); + } + return nullptr; + } + + void Visit(const CNodePtr &cnode) override { + if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) { + tuple_ = cnode; + is_match_ = true; + } + } + + void Reset() { + tuple_ = nullptr; + is_match_ = false; + } + + private: + bool is_match_{false}; + CNodePtr tuple_{nullptr}; +}; +} // namespace irpass +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_ diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index bb9a517556e..f3a03658a2e 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -157,6 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.make_ref_eliminate_, irpass.get_ref_param_eliminate_, irpass.indexed_slices_eliminate_, + irpass.sparse_tensor_eliminate_, }); OptPassGroupMap map({ {"b_1", b_1}, diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index ece128b77b7..16d4a00346e 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -179,6 +179,12 @@ MethodMap &GetMethodMap() { {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape }}, + {kObjectTypeSparseTensorType, + { + {"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values + {"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices + {"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape + }}, {kObjectTypeJTagged, {}}, {kObjectTypeSymbolicKeyType, {}}, {kObjectTypeEnvType, {}}}; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 9f3011d1187..90d4aaa125f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -138,7 +138,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, - {prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, + // SparseTensor + {prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}}, + {prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}}, + {prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}}, + {prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h index 692fbe66e88..b931bf6b7e8 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.h @@ -358,8 +358,14 @@ AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, cons const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 04aa6efd05b..53164f8ac0e 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -36,6 +36,7 @@ using mindspore::abstract::AbstractIndexedSlices; using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractScalar; +using mindspore::abstract::AbstractSparseTensor; using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractType; @@ -95,7 +96,7 @@ void ValidateAbstract(const AnfNodePtr &node) { if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa()) { + ptrBase->isa() || ptrBase->isa()) { return; } diff --git a/mindspore/common/__init__.py b/mindspore/common/__init__.py index c896805d75a..570e0368c53 100644 --- a/mindspore/common/__init__.py +++ b/mindspore/common/__init__.py @@ -17,10 +17,10 @@ from . import dtype from .api import ms_function from .dtype import * from .parameter import Parameter, ParameterTuple -from .tensor import MetaTensor, Tensor, IndexedSlices +from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor __all__ = [ - "MetaTensor", "Tensor", "IndexedSlices", # tensor + "MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor 'ms_function', # api 'Parameter', 'ParameterTuple', # parameter "dtype" diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 64a8eb46373..dde82186809 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename from . import dtype as mstype from ._register_for_tensor import tensor_operator_registry -__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices'] +__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor'] np_types = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.float32, np.float64, np.bool_) @@ -211,3 +211,7 @@ class Tensor(Tensor_): class IndexedSlices: def __init__(self, indices, values, dense_shape): raise NotImplementedError + +class SparseTensor: + def __init__(self, indices, values, dense_shape): + raise NotImplementedError diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 7bef3829a61..fb16cf0161c 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -1093,5 +1093,64 @@ std::string AbstractIndexedSlices::ToString() const { << ", dense_shape: " << dense_shape_->ToString(); return buffer.str(); } + +// SparseTensor +TypePtr AbstractSparseTensor::BuildType() const { + MS_EXCEPTION_IF_NULL(element()); + TypePtr element_type = element()->BuildType(); + return std::make_shared(element_type); +} + +AbstractBasePtr AbstractSparseTensor::Clone() const { + MS_EXCEPTION_IF_NULL(element()); + auto clone = std::make_shared(element()->Clone()); + ShapePtr shp = shape(); + clone->set_shape(shp->Clone()); + clone->set_value(GetValueTrack()); + clone->set_indices(indices_->Clone()->cast()); + clone->set_values(values_->Clone()->cast()); + clone->set_dense_shape(dense_shape_->Clone()->cast()); + return clone; +} + +AbstractBasePtr AbstractSparseTensor::Broaden() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape(); + broaden->set_shape(shp->Clone()); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const { + MS_EXCEPTION_IF_NULL(element()); + auto broaden = std::make_shared(element()->Broaden()); + auto shp = shape()->Clone(); + shp->Broaden(); + broaden->set_shape(shp); + broaden->set_value(kAnyValue); + broaden->set_indices(indices_->Clone()->cast()); + broaden->set_values(values_->Clone()->cast()); + broaden->set_dense_shape(dense_shape_->Clone()->cast()); + return broaden; +} + +std::string AbstractSparseTensor::ToString() const { + std::ostringstream buffer; + BaseShapePtr shape_track = GetShapeTrack(); + MS_EXCEPTION_IF_NULL(shape_track); + MS_EXCEPTION_IF_NULL(element()); + auto value_track = GetValueTrack(); + MS_EXCEPTION_IF_NULL(value_track); + buffer << type_name() << "(" + << "shape: " << shape_track->ToString() << ", element: " << element()->ToString() + << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")" + << ", indices: " << indices_->ToString() << ", values" << values_->ToString() + << ", dense_shape: " << dense_shape_->ToString(); + return buffer.str(); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index d922f93e70b..5f2ca8f3f32 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -604,10 +604,39 @@ class AbstractIndexedSlices : public AbstractUndetermined { MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) const AbstractTensorPtr indices() const { return indices_; } - const AbstractTensorPtr values() const { return values_; } - const AbstractTuplePtr dense_shape() const { return dense_shape_; } void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } + const AbstractTensorPtr values() const { return values_; } void set_values(const AbstractTensorPtr &values) { values_ = values; } + const AbstractTuplePtr dense_shape() const { return dense_shape_; } + void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } + TypePtr BuildType() const override; + AbstractBasePtr Clone() const override; + AbstractBasePtr Broaden() const override; + AbstractBasePtr BroadenWithShape() const; + + std::string ToString() const override; + + private: + AbstractTensorPtr indices_; + AbstractTensorPtr values_; + AbstractTuplePtr dense_shape_; +}; + +// SparseTensor +class AbstractSparseTensor : public AbstractUndetermined { + public: + explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) + : AbstractUndetermined(element, shape) {} + AbstractSparseTensor(const TypePtr &element_type, const std::vector &shape) + : AbstractUndetermined(element_type, shape) {} + ~AbstractSparseTensor() override = default; + MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined) + + const AbstractTensorPtr indices() const { return indices_; } + void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } + const AbstractTensorPtr values() const { return values_; } + void set_values(const AbstractTensorPtr &values) { values_ = values; } + const AbstractTuplePtr dense_shape() const { return dense_shape_; } void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } TypePtr BuildType() const override; AbstractBasePtr Clone() const override; diff --git a/mindspore/core/abstract/param_validator.h b/mindspore/core/abstract/param_validator.h index 434235abda3..e08d4fc8e85 100644 --- a/mindspore/core/abstract/param_validator.h +++ b/mindspore/core/abstract/param_validator.h @@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type) ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) ABSTRACT_REPORT_NAME_TRAITS(Class) ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) +ABSTRACT_REPORT_NAME_TRAITS(SparseTensor) ABSTRACT_REPORT_NAME_TRAITS(Sequeue) template diff --git a/mindspore/core/ir/dtype.cc b/mindspore/core/ir/dtype.cc index 71a78bdcf67..89ab2ac0fa4 100644 --- a/mindspore/core/ir/dtype.cc +++ b/mindspore/core/ir/dtype.cc @@ -221,6 +221,48 @@ bool IndexedSlicesType::operator==(const Type &other) const { return *element_type_ == *other_elem_type; } +TypePtr SparseTensorType::DeepCopy() const { + MS_EXCEPTION_IF_NULL(element_type_); + if (IsGeneric()) { + return std::make_shared(); + } + return std::make_shared(element_type_->DeepCopy()); +} + +std::string SparseTensorType::ToReprString() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->ToReprString() + "]"; +} + +std::string SparseTensorType::ToString() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->ToString() + "]"; +} + +std::string SparseTensorType::DumpText() const { + if (element_type_ == nullptr) { + return "SparseTensor"; + } + return "SparseTensor[" + element_type_->DumpText() + "]"; +} + +bool SparseTensorType::operator==(const Type &other) const { + if (!IsSameObjectType(*this, other)) { + return false; + } + auto other_elem_type = static_cast(other).element_type_; + if (element_type_ == nullptr && other_elem_type == nullptr) { + return true; + } else if (element_type_ == nullptr || other_elem_type == nullptr) { + return false; + } + return *element_type_ == *other_elem_type; +} + Function::Function() : Object(kObjectTypeFunction) { args_ = std::vector(); retval_ = nullptr; diff --git a/mindspore/core/ir/dtype.h b/mindspore/core/ir/dtype.h index dc277c031c6..0ff152a4f46 100644 --- a/mindspore/core/ir/dtype.h +++ b/mindspore/core/ir/dtype.h @@ -177,6 +177,29 @@ class IndexedSlicesType : public Object { }; using IndexedSlicesTypePtr = std::shared_ptr; +class SparseTensorType : public Object { + public: + SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {} + explicit SparseTensorType(const TypePtr &ele) + : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {} + ~SparseTensorType() override = default; + MS_DECLARE_PARENT(SparseTensorType, Object) + + TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; } + const TypePtr element() const { return element_type_; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } + + TypePtr DeepCopy() const override; + std::string ToString() const override; + std::string ToReprString() const override; + std::string DumpText() const override; + bool operator==(const Type &other) const override; + + private: + TypePtr element_type_; +}; +using SparseTensorTypePtr = std::shared_ptr; + class Function : public Object { public: Function(); diff --git a/mindspore/core/ir/dtype/type.cc b/mindspore/core/ir/dtype/type.cc index 754876a366a..39586602e75 100644 --- a/mindspore/core/ir/dtype/type.cc +++ b/mindspore/core/ir/dtype/type.cc @@ -117,6 +117,8 @@ const char *ObjectIdLabel(const TypeId &v) { return "kObjectTypeTensorType"; case kObjectTypeIndexedSlicesType: return "kObjectTypeIndexedSlicesType"; + case kObjectTypeSparseTensorType: + return "kObjectTypeSparseTensorType"; case kObjectTypeUndeterminedType: return "kObjectTypeUndeterminedType"; case kObjectTypeDictionary: diff --git a/mindspore/core/ir/dtype/type_id.h b/mindspore/core/ir/dtype/type_id.h index 6fb2a354c17..960c2f320d2 100644 --- a/mindspore/core/ir/dtype/type_id.h +++ b/mindspore/core/ir/dtype/type_id.h @@ -51,6 +51,7 @@ enum TypeId : int { kObjectTypeKeyword, kObjectTypeTensorType, kObjectTypeIndexedSlicesType, + kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, kObjectTypeClass, kObjectTypeDictionary, diff --git a/mindspore/core/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc index 099748217ed..9038646ceb7 100644 --- a/mindspore/core/ir/dtype_extends.cc +++ b/mindspore/core/ir/dtype_extends.cc @@ -207,6 +207,23 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) { return std::make_shared(element_type); } +TypePtr SparseTensorStrToType(const std::string &type_name) { + if (type_name == "SparseTensor") { + return std::make_shared(); + } + auto start = type_name.find_first_of('[') + 1; + auto end = type_name.find_last_of(']'); + if (start >= type_name.size()) { + return nullptr; + } + auto element_str = type_name.substr(start, end - start); + auto element_type = StringToType(element_str); + if (element_type == nullptr) { + return nullptr; + } + return std::make_shared(element_type); +} + TypePtr UndeterminedStrToType(const std::string &type_name) { if (type_name == "Undetermined") { return std::make_shared(); @@ -349,6 +366,8 @@ TypePtr StringToType(const std::string &type_name) { type = UndeterminedStrToType(type_name); } else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) { type = IndexedSlicesStrToType(type_name); + } else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) { + type = SparseTensorStrToType(type_name); } else if (type_name.compare(0, strlen("List"), "List") == 0) { type = ListStrToType(type_name); } else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) { @@ -428,6 +447,7 @@ const TypePtr kTypeEnv = std::make_shared(); const TypePtr kTypeType = std::make_shared(); const TypePtr kTensorType = std::make_shared(); const TypePtr kIndexedSlicesType = std::make_shared(); +const TypePtr kSparseTensorType = std::make_shared(); const TypePtr kUndeterminedType = std::make_shared(); const TypePtr kString = std::make_shared(); const TypePtr kList = std::make_shared(); diff --git a/mindspore/core/ir/dtype_py.cc b/mindspore/core/ir/dtype_py.cc index b1e2151b6dd..7577a39f7a6 100644 --- a/mindspore/core/ir/dtype_py.cc +++ b/mindspore/core/ir/dtype_py.cc @@ -139,6 +139,8 @@ REGISTER_PYBIND_DEFINE( })); (void)py::class_>(m_sub, "IndexedSlicesType") .def(py::init()); + (void)py::class_>(m_sub, "SparseTensorType") + .def(py::init()); (void)py::class_>(m_sub, "UndeterminedType") .def(py::init()); (void)py::class_>(m_sub, "Function") diff --git a/mindspore/core/ir/meta_func_graph.cc b/mindspore/core/ir/meta_func_graph.cc index c0cf9d4d2f2..7953931e8f4 100644 --- a/mindspore/core/ir/meta_func_graph.cc +++ b/mindspore/core/ir/meta_func_graph.cc @@ -17,9 +17,49 @@ */ #include "ir/meta_func_graph.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "pipeline/jit/static_analysis/abstract_function.h" +#include "utils/context/ms_context.h" +#include "frontend/operator/ops.h" // namespace to support intermediate representation definition namespace mindspore { +FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) { + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + if (!enable_sparse) { + return nullptr; + } + + std::vector parameters; + ParameterPtr undetermined_param = nullptr; + auto stub = std::make_shared(); + for (size_t i = 0; i < types.size(); ++i) { + auto param = stub->add_parameter(); + parameters.push_back(param); + if (types[i]->type_id() == kObjectTypeUndeterminedType) { + undetermined_param = param; + } + } + if (undetermined_param != nullptr) { + std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; + for (size_t i = 0; i < types.size(); ++i) { + if (types[i]->type_id() == kObjectTypeFunction) { + std::vector call_prim{parameters[i], undetermined_param}; + inputs.push_back(stub->NewCNode(call_prim)); + } else { + inputs.push_back(parameters[i]); + } + } + auto stub_output = stub->NewCNode(inputs); + stub->set_output(stub_output); + stub->set_stub(true); + return stub; + } + return nullptr; +} + FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { TypePtrList types; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), diff --git a/mindspore/core/ir/meta_func_graph.h b/mindspore/core/ir/meta_func_graph.h index 933c3f700d8..df1edc312de 100644 --- a/mindspore/core/ir/meta_func_graph.h +++ b/mindspore/core/ir/meta_func_graph.h @@ -79,6 +79,7 @@ class MetaFuncGraph : public FuncGraphBase { std::shared_ptr shared_from_base() { return std::static_pointer_cast(shared_from_this()); } + FuncGraphPtr GenerateStubFunc(const TypePtrList &types); std::string name_; std::vector signatures_; std::unordered_map cache_; diff --git a/mindspore/core/ir/param_value.h b/mindspore/core/ir/param_value.h index 00b79ae91ca..36026ce97fc 100644 --- a/mindspore/core/ir/param_value.h +++ b/mindspore/core/ir/param_value.h @@ -40,18 +40,12 @@ class ParamValue { const std::string &name() const { return name_; } void set_name(const std::string &name) { name_ = name; } - const std::string &sparse_grad() const { return sparse_grad_; } - void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; } - bool requires_grad() const { return requires_grad_; } void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; } bool layerwise_parallel() const { return layerwise_parallel_; } void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; } - bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; } - void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; } - // Whether the parameter clone from other parameter. bool cloned() const { return cloned_; } @@ -81,10 +75,8 @@ class ParamValue { private: tensor::MetaTensorPtr value_; std::string name_{"Parameter"}; - std::string sparse_grad_; bool requires_grad_{true}; bool layerwise_parallel_{false}; - bool has_indexed_slices_grad_{false}; bool be_cloned_{false}; bool cloned_{false}; std::vector be_cloned_index_; diff --git a/mindspore/core/ir/param_value_py.cc b/mindspore/core/ir/param_value_py.cc index fb4b313c228..c976d41cd21 100644 --- a/mindspore/core/ir/param_value_py.cc +++ b/mindspore/core/ir/param_value_py.cc @@ -29,14 +29,10 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) { .def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad) .def_property("layerwise_parallel", &ParamValue::layerwise_parallel, &ParamValue::set_layerwise_parallel) - .def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad, - &ParamValue::set_has_indexed_slices_grad) - .def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad) .def(py::pickle( [](const ParamValue &p) { // __getstate__ return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(), - p.layerwise_parallel(), p.has_indexed_slices_grad(), - p.sparse_grad()); + p.layerwise_parallel()); }, [](const py::tuple &t) { // __setstate__ if (t.size() != 6) { @@ -47,8 +43,6 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) { p->set_name(t[1].cast()); p->set_requires_grad(t[2].cast()); p->set_layerwise_parallel(t[3].cast()); - p->set_has_indexed_slices_grad(t[4].cast()); - p->set_sparse_grad(t[5].cast()); return p; })); })); diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 2be011cb773..36294fa4cdf 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -159,6 +159,10 @@ indexed_slices_get_values = Primitive('IndexedSlicesGetValues') indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices') indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape') +make_sparse_tensor = Primitive('MakeSparseTensor') +sparse_tensor_get_values = Primitive('SparseTensorGetValues') +sparse_tensor_get_indices = Primitive('SparseTensorGetIndices') +sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape') tensor_operator_registry.register('__add__', tensor_add) tensor_operator_registry.register('__sub__', tensor_sub) diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 751b301283c..c0d5523edc2 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -616,5 +616,18 @@ TEST_F(TestOptLib, test_indexed_slices) { ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); } + +TEST_F(TestOptLib, test_sparse_tensor) { + FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices"); + FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices"); + FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values"); + FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values"); + FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape"); + FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape"); + auto patterns = std::vector({irpass.sparse_tensor_eliminate_}); + ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns)); + ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns)); + ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns)); +} } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 16c557adbe2..369dfb3316d 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -1163,3 +1163,38 @@ def test_indexed_slices(tag): return z return fns[tag] + + +def test_sparse_tensor(tag): + """ test_add_zero """ + fns = FnDict() + make_sparse_tensor = Primitive('MakeSparseTensor') + sparse_tensor_get_values = Primitive('SparseTensorGetValues') + sparse_tensor_get_indices = Primitive('SparseTensorGetIndices') + sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape') + + @fns + def before_get_indices(x, y, z): + return sparse_tensor_get_indices(make_sparse_tensor(x, y, z)) + + @fns + def after_get_indices(x, y, z): + return x + + @fns + def before_get_values(x, y, z): + return sparse_tensor_get_values(make_sparse_tensor(x, y, z)) + + @fns + def after_get_values(x, y, z): + return y + + @fns + def before_get_dense_shape(x, y, z): + return sparse_tensor_get_dense_shape(make_sparse_tensor(x, y, z)) + + @fns + def after_get_dense_shape(x, y, z): + return z + + return fns[tag] diff --git a/tests/ut/python/ir/test_indexed_slices.py b/tests/ut/python/ir/test_indexed_slices.py index 36dfe464cb4..ff0cfa1da5f 100644 --- a/tests/ut/python/ir/test_indexed_slices.py +++ b/tests/ut/python/ir/test_indexed_slices.py @@ -35,6 +35,9 @@ from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.nn import Optimizer from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.train import Model +from ....dataset_mock import MindData context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) @@ -47,6 +50,40 @@ size_op = P.Size() invert_permutation = P.InvertPermutation() logical_and = P.LogicalAnd() +def get_axis(x): + shape = shape_op(x) + length = F.tuple_len(shape) + perm = F.make_range(0, length) + return perm + +class MSELoss(nn.Cell): + def __init__(self): + super(MSELoss, self).__init__() + self.reduce_sum = P.ReduceSum() + self.square = P.Square() + self.reduce_mean = P.ReduceMean() + + def construct(self, data, label): + diff = data - label + return self.reduce_mean(self.square(diff), get_axis(diff)) + + +class MindDataSet(MindData): + def __init__(self, dataset_types, dataset_shapes): + super(MindDataSet, self).__init__(size=2, batch_size=32, + np_types=dataset_types, + output_shapes=dataset_shapes, + input_indexs=(0, 1)) + def __next__(self): + if self._size < self._iter_num: + raise StopIteration + self._iter_num += 1 + lst = [] + for shape_, type_ in zip(self._output_shapes, self._np_types): + lst.append(Tensor(np.ones(shape_).astype(type_))) + return tuple(lst) + + @constexpr def _generate_shape_index(out_shape, indices_shape, axis): out_rank = len(out_shape) @@ -189,8 +226,8 @@ def test_indexed_slices_make_indexed_slices(): def construct(self, indices, values): ret = (IndexedSlices(indices, values, self.dense_shape),) return ret[0] - indices = Tensor([[0, 0], [1, 2]]) - values = Tensor([1, 2], dtype=ms.float32) + indices = Tensor([1, 2]) + values = Tensor([[0, 0], [1, 2]], dtype=ms.float32) MakeIndexedSlices()(indices, values) @@ -202,8 +239,8 @@ def test_indexed_slices_attr(): def construct(self, indices, values): x = IndexedSlices(indices, values, self.dense_shape) return x.values(), x.indices(), x.dense_shape() - indices = Tensor([[0, 0], [1, 2]]) - values = Tensor([1, 2], dtype=ms.float32) + indices = Tensor([0]) + values = Tensor([[1, 2]], dtype=ms.float32) IndexedSlicesGetAttr()(indices, values) @@ -279,3 +316,29 @@ def test_indexed_slices_env_get(): net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) train_network(inputs, label) + + +def test_indexed_slices_model_train(): + class Net(nn.Cell): + def __init__(self, in_features, out_features): + super(Net, self).__init__() + self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") + self.add = P.TensorAdd() + self.cast = P.Cast() + self.flag = True + + def construct(self, inputs, label): + x = self.add(inputs, self.weight) + if self.flag: + x = self.cast(x, mstype.float32) + return x + + dataset_types = (np.float32, np.float32) + dataset_shapes = ((16, 16), (16, 16)) + dataset = MindDataSet(dataset_types, dataset_shapes) + net = Net(16, 16) + net.set_train() + + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False) diff --git a/tests/ut/python/ir/test_sparse_tensor.py b/tests/ut/python/ir/test_sparse_tensor.py new file mode 100644 index 00000000000..3f8ca8b184c --- /dev/null +++ b/tests/ut/python/ir/test_sparse_tensor.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +@File : test_sparse_tensor.py +@Author: +@Date : 2020-07-16 +@Desc : test mindspore sparse_tensor's operation +""" +import mindspore as ms +import mindspore.nn as nn +from mindspore.ops import composite as C +from mindspore import Tensor, SparseTensor, context + +context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + +def test_sparse_tensor_make_sparse_tensor(): + class MakeSparseTensor(nn.Cell): + def __init__(self): + super(MakeSparseTensor, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + ret = (SparseTensor(indices, values, self.dense_shape),) + return ret[0] + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + MakeSparseTensor()(indices, values) + + +def test_sparse_tensor_attr(): + grad_op = C.GradOperation('get_all', get_all=True) + class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + def construct(self, input1, input2): + gout = grad_op(self.network)(input1, input2) + return gout + + class SparseTensorGetAttr(nn.Cell): + def __init__(self): + super(SparseTensorGetAttr, self).__init__() + self.dense_shape = (3, 4) + def construct(self, indices, values): + x = SparseTensor(indices, values, self.dense_shape) + return x.values(), x.indices(), x.dense_shape() + + indices = Tensor([[0, 1], [1, 2]]) + values = Tensor([1, 2], dtype=ms.float32) + SparseTensorGetAttr()(indices, values)