forked from mindspore-Ecosystem/mindspore
!3114 add coo_tensor
Merge pull request !3114 from riemann_penn/coo_tensor
This commit is contained in:
commit
4a19e6b8cb
|
@ -17,7 +17,7 @@
|
|||
"""Resources for ast tree parse."""
|
||||
import ast
|
||||
import math
|
||||
from mindspore import IndexedSlices
|
||||
from mindspore import IndexedSlices, SparseTensor
|
||||
from mindspore.ops.composite import multitype_ops
|
||||
from mindspore.ops import functional as F, composite as C
|
||||
from . import standard_method as M
|
||||
|
@ -140,4 +140,5 @@ convert_object_map = {
|
|||
|
||||
# user defined
|
||||
IndexedSlices: F.make_indexed_slices,
|
||||
SparseTensor: F.make_sparse_tensor,
|
||||
}
|
||||
|
|
|
@ -124,6 +124,8 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
|
|||
// Do Nothing
|
||||
} else if (type->isa<UndeterminedType>()) {
|
||||
// Do Nothing
|
||||
} else if (type->isa<SparseTensorType>()) {
|
||||
// Do Nothing
|
||||
} else if (type->isa<Tuple>()) {
|
||||
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
||||
type_proto->set_data_type(irpb::DT_TUPLE);
|
||||
|
|
|
@ -803,6 +803,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
|
|||
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
|
||||
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
|
||||
if (a_tuple == nullptr || b_tuple == nullptr) {
|
||||
TypePtrList types;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
||||
[](const AbstractBasePtr &arg) -> TypePtr {
|
||||
MS_EXCEPTION_IF_NULL(arg);
|
||||
return arg->BuildType();
|
||||
});
|
||||
auto stub = GenerateStubFunc(types);
|
||||
if (stub != nullptr) {
|
||||
MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
|
||||
<< ", function: " << stub->ToString();
|
||||
return stub;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
|
||||
<< args_spec_list[1]->ToString();
|
||||
}
|
||||
|
|
|
@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
|
|||
return py::none();
|
||||
}
|
||||
|
||||
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
ParameterPtr undetermined_param = nullptr;
|
||||
auto stub = std::make_shared<FuncGraph>();
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
auto param = stub->add_parameter();
|
||||
parameters.push_back(param);
|
||||
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
|
||||
undetermined_param = param;
|
||||
}
|
||||
}
|
||||
if (undetermined_param != nullptr) {
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
if (types[i]->type_id() == kObjectTypeFunction) {
|
||||
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
|
||||
inputs.push_back(stub->NewCNode(call_prim));
|
||||
} else {
|
||||
inputs.push_back(parameters[i]);
|
||||
}
|
||||
}
|
||||
auto stub_output = stub->NewCNode(inputs);
|
||||
stub->set_output(stub_output);
|
||||
stub->set_stub(true);
|
||||
return stub;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||
auto py_fn = SignMatch(types);
|
||||
std::ostringstream buffer;
|
||||
|
|
|
@ -283,6 +283,11 @@ const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeInd
|
|||
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
||||
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
||||
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
|
||||
|
||||
// SparseTensor
|
||||
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
|
||||
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
|
||||
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -292,7 +292,12 @@ extern const PrimitivePtr kPrimMakeIndexedSlices;
|
|||
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
|
||||
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
|
||||
extern const PrimitivePtr kPrimIsIndexedSlices;
|
||||
|
||||
// SparseTensor
|
||||
extern const PrimitivePtr kPrimMakeSparseTensor;
|
||||
extern const PrimitivePtr kPrimSparseTensorGetValues;
|
||||
extern const PrimitivePtr kPrimSparseTensorGetIndices;
|
||||
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
|
||||
|
||||
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
|
||||
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
|
||||
|
|
|
@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
|
|||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
|
||||
|
||||
auto indices_dtype = indices->element()->BuildType();
|
||||
if (!indices_dtype->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
|
||||
}
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
if (indices_shp.size() != 1) {
|
||||
MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (indices_shp[0] != values_shp[0]) {
|
||||
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
|
||||
<< values_shp[0] << ", but got " << indices_shp[0];
|
||||
}
|
||||
|
||||
for (auto elem_type : dense_shape->ElementsType()) {
|
||||
if (!elem_type->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
|
||||
}
|
||||
}
|
||||
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||
auto shp = dense_shape_value->value();
|
||||
|
@ -358,6 +378,12 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
|
|||
auto elem = GetValue<int>(e);
|
||||
return elem;
|
||||
});
|
||||
for (auto dense_shape_elem : dense_shape_vec) {
|
||||
if (dense_shape_elem < 0) {
|
||||
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
|
||||
<< dense_shape_value->ToString();
|
||||
}
|
||||
}
|
||||
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
||||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
|
@ -395,16 +421,89 @@ AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, c
|
|||
return indexed_slices->dense_shape();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
|
||||
|
||||
auto indices_dtype = indices->element()->BuildType();
|
||||
if (!indices_dtype->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
|
||||
}
|
||||
auto indices_shp = indices->shape()->shape();
|
||||
if (indices_shp.size() != 2) {
|
||||
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
auto values_shp = values->shape()->shape();
|
||||
if (values_shp.size() != 1) {
|
||||
MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
|
||||
<< " dimension tensor";
|
||||
}
|
||||
if (indices_shp[0] != values_shp[0]) {
|
||||
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
|
||||
<< values_shp[0] << ", but got " << indices_shp[0];
|
||||
}
|
||||
|
||||
for (auto elem_type : dense_shape->ElementsType()) {
|
||||
if (!elem_type->isa<Int>()) {
|
||||
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
|
||||
}
|
||||
}
|
||||
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||
auto shp = dense_shape_value->value();
|
||||
std::vector<int> dense_shape_vec;
|
||||
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
|
||||
[](const ValuePtr &e) -> int {
|
||||
auto elem = GetValue<int>(e);
|
||||
return elem;
|
||||
});
|
||||
for (auto dense_shape_elem : dense_shape_vec) {
|
||||
if (dense_shape_elem < 0) {
|
||||
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
|
||||
<< dense_shape_value->ToString();
|
||||
}
|
||||
}
|
||||
auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
|
||||
ret->set_indices(indices);
|
||||
ret->set_values(values);
|
||||
ret->set_dense_shape(dense_shape);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
bool ret = false;
|
||||
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) {
|
||||
ret = true;
|
||||
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(sparse_tensor->values());
|
||||
return sparse_tensor->values();
|
||||
}
|
||||
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
|
||||
return std::make_shared<AbstractScalar>(ret);
|
||||
|
||||
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
|
||||
return sparse_tensor->indices();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: two tensors and a tuple.
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 1);
|
||||
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
|
||||
return sparse_tensor->dense_shape();
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -264,7 +264,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
|
|||
return IsPrimitiveCNode(user.first, prim);
|
||||
});
|
||||
if (cnode == users.end()) {
|
||||
MS_LOG(EXCEPTION) << "Fail to find cnode.";
|
||||
MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
|
||||
}
|
||||
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
|
||||
|
||||
|
|
|
@ -43,6 +43,7 @@
|
|||
#include "frontend/optimizer/irpass/transpose_eliminate.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -159,6 +160,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
indexed_slices_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
|
||||
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
|
||||
|
||||
// SparseTensor Eliminate
|
||||
sparse_tensor_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
|
||||
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
|
||||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
|
|
|
@ -107,6 +107,9 @@ class OptimizeIRPassLib {
|
|||
|
||||
// IndexedSlices Eliminate
|
||||
SubstitutionPtr indexed_slices_eliminate_;
|
||||
|
||||
// SparseTensor Eliminate
|
||||
SubstitutionPtr sparse_tensor_eliminate_;
|
||||
};
|
||||
|
||||
// the collection of irpass for resolve action
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
|
||||
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
|
||||
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
|
||||
class SparseTensorEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(1);
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(2);
|
||||
}
|
||||
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);
|
||||
|
||||
if (is_match_) {
|
||||
return tuple_->input(3);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
|
||||
tuple_ = cnode;
|
||||
is_match_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
tuple_ = nullptr;
|
||||
is_match_ = false;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_match_{false};
|
||||
CNodePtr tuple_{nullptr};
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
|
|
@ -157,6 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.make_ref_eliminate_,
|
||||
irpass.get_ref_param_eliminate_,
|
||||
irpass.indexed_slices_eliminate_,
|
||||
irpass.sparse_tensor_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"b_1", b_1},
|
||||
|
|
|
@ -179,6 +179,12 @@ MethodMap &GetMethodMap() {
|
|||
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
|
||||
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
|
||||
}},
|
||||
{kObjectTypeSparseTensorType,
|
||||
{
|
||||
{"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
|
||||
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
|
||||
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
|
||||
}},
|
||||
{kObjectTypeJTagged, {}},
|
||||
{kObjectTypeSymbolicKeyType, {}},
|
||||
{kObjectTypeEnvType, {}}};
|
||||
|
|
|
@ -138,7 +138,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
|
||||
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
|
||||
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
|
||||
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}},
|
||||
// SparseTensor
|
||||
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
||||
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
||||
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
|
||||
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
|
||||
};
|
||||
return prim_eval_implement_map;
|
||||
}
|
||||
|
|
|
@ -358,7 +358,13 @@ AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, cons
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -36,6 +36,7 @@ using mindspore::abstract::AbstractIndexedSlices;
|
|||
using mindspore::abstract::AbstractJTagged;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractSparseTensor;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractType;
|
||||
|
@ -95,7 +96,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
|||
|
||||
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
|
||||
ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,10 +17,10 @@ from . import dtype
|
|||
from .api import ms_function
|
||||
from .dtype import *
|
||||
from .parameter import Parameter, ParameterTuple
|
||||
from .tensor import MetaTensor, Tensor, IndexedSlices
|
||||
from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor
|
||||
|
||||
__all__ = [
|
||||
"MetaTensor", "Tensor", "IndexedSlices", # tensor
|
||||
"MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor
|
||||
'ms_function', # api
|
||||
'Parameter', 'ParameterTuple', # parameter
|
||||
"dtype"
|
||||
|
|
|
@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
|
|||
from . import dtype as mstype
|
||||
from ._register_for_tensor import tensor_operator_registry
|
||||
|
||||
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices']
|
||||
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor']
|
||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||
np.float32, np.float64, np.bool_)
|
||||
|
@ -211,3 +211,7 @@ class Tensor(Tensor_):
|
|||
class IndexedSlices:
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
|
||||
class SparseTensor:
|
||||
def __init__(self, indices, values, dense_shape):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1093,5 +1093,64 @@ std::string AbstractIndexedSlices::ToString() const {
|
|||
<< ", dense_shape: " << dense_shape_->ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
// SparseTensor
|
||||
TypePtr AbstractSparseTensor::BuildType() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
TypePtr element_type = element()->BuildType();
|
||||
return std::make_shared<SparseTensorType>(element_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractSparseTensor::Clone() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
|
||||
ShapePtr shp = shape();
|
||||
clone->set_shape(shp->Clone());
|
||||
clone->set_value(GetValueTrack());
|
||||
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||
return clone;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractSparseTensor::Broaden() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
|
||||
auto shp = shape();
|
||||
broaden->set_shape(shp->Clone());
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
|
||||
auto shp = shape()->Clone();
|
||||
shp->Broaden();
|
||||
broaden->set_shape(shp);
|
||||
broaden->set_value(kAnyValue);
|
||||
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||
return broaden;
|
||||
}
|
||||
|
||||
std::string AbstractSparseTensor::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
BaseShapePtr shape_track = GetShapeTrack();
|
||||
MS_EXCEPTION_IF_NULL(shape_track);
|
||||
MS_EXCEPTION_IF_NULL(element());
|
||||
auto value_track = GetValueTrack();
|
||||
MS_EXCEPTION_IF_NULL(value_track);
|
||||
buffer << type_name() << "("
|
||||
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
|
||||
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
|
||||
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
|
||||
<< ", dense_shape: " << dense_shape_->ToString();
|
||||
return buffer.str();
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -604,10 +604,39 @@ class AbstractIndexedSlices : public AbstractUndetermined {
|
|||
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
|
||||
|
||||
const AbstractTensorPtr indices() const { return indices_; }
|
||||
const AbstractTensorPtr values() const { return values_; }
|
||||
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||
const AbstractTensorPtr values() const { return values_; }
|
||||
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
||||
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
AbstractBasePtr Broaden() const override;
|
||||
AbstractBasePtr BroadenWithShape() const;
|
||||
|
||||
std::string ToString() const override;
|
||||
|
||||
private:
|
||||
AbstractTensorPtr indices_;
|
||||
AbstractTensorPtr values_;
|
||||
AbstractTuplePtr dense_shape_;
|
||||
};
|
||||
|
||||
// SparseTensor
|
||||
class AbstractSparseTensor : public AbstractUndetermined {
|
||||
public:
|
||||
explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||
: AbstractUndetermined(element, shape) {}
|
||||
AbstractSparseTensor(const TypePtr &element_type, const std::vector<int> &shape)
|
||||
: AbstractUndetermined(element_type, shape) {}
|
||||
~AbstractSparseTensor() override = default;
|
||||
MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined)
|
||||
|
||||
const AbstractTensorPtr indices() const { return indices_; }
|
||||
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||
const AbstractTensorPtr values() const { return values_; }
|
||||
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
||||
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||
TypePtr BuildType() const override;
|
||||
AbstractBasePtr Clone() const override;
|
||||
|
|
|
@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
|
|||
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
|
||||
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
|
||||
|
||||
template <typename T>
|
||||
|
|
|
@ -221,6 +221,48 @@ bool IndexedSlicesType::operator==(const Type &other) const {
|
|||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
TypePtr SparseTensorType::DeepCopy() const {
|
||||
MS_EXCEPTION_IF_NULL(element_type_);
|
||||
if (IsGeneric()) {
|
||||
return std::make_shared<SparseTensorType>();
|
||||
}
|
||||
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
|
||||
}
|
||||
|
||||
std::string SparseTensorType::ToReprString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "SparseTensor";
|
||||
}
|
||||
return "SparseTensor[" + element_type_->ToReprString() + "]";
|
||||
}
|
||||
|
||||
std::string SparseTensorType::ToString() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "SparseTensor";
|
||||
}
|
||||
return "SparseTensor[" + element_type_->ToString() + "]";
|
||||
}
|
||||
|
||||
std::string SparseTensorType::DumpText() const {
|
||||
if (element_type_ == nullptr) {
|
||||
return "SparseTensor";
|
||||
}
|
||||
return "SparseTensor[" + element_type_->DumpText() + "]";
|
||||
}
|
||||
|
||||
bool SparseTensorType::operator==(const Type &other) const {
|
||||
if (!IsSameObjectType(*this, other)) {
|
||||
return false;
|
||||
}
|
||||
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
|
||||
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||
return true;
|
||||
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return *element_type_ == *other_elem_type;
|
||||
}
|
||||
|
||||
Function::Function() : Object(kObjectTypeFunction) {
|
||||
args_ = std::vector<TypePtr>();
|
||||
retval_ = nullptr;
|
||||
|
|
|
@ -177,6 +177,29 @@ class IndexedSlicesType : public Object {
|
|||
};
|
||||
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
|
||||
|
||||
class SparseTensorType : public Object {
|
||||
public:
|
||||
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
|
||||
explicit SparseTensorType(const TypePtr &ele)
|
||||
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||
~SparseTensorType() override = default;
|
||||
MS_DECLARE_PARENT(SparseTensorType, Object)
|
||||
|
||||
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
|
||||
const TypePtr element() const { return element_type_; }
|
||||
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||
|
||||
TypePtr DeepCopy() const override;
|
||||
std::string ToString() const override;
|
||||
std::string ToReprString() const override;
|
||||
std::string DumpText() const override;
|
||||
bool operator==(const Type &other) const override;
|
||||
|
||||
private:
|
||||
TypePtr element_type_;
|
||||
};
|
||||
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
|
||||
|
||||
class Function : public Object {
|
||||
public:
|
||||
Function();
|
||||
|
|
|
@ -117,6 +117,8 @@ const char *ObjectIdLabel(const TypeId &v) {
|
|||
return "kObjectTypeTensorType";
|
||||
case kObjectTypeIndexedSlicesType:
|
||||
return "kObjectTypeIndexedSlicesType";
|
||||
case kObjectTypeSparseTensorType:
|
||||
return "kObjectTypeSparseTensorType";
|
||||
case kObjectTypeUndeterminedType:
|
||||
return "kObjectTypeUndeterminedType";
|
||||
case kObjectTypeDictionary:
|
||||
|
|
|
@ -51,6 +51,7 @@ enum TypeId : int {
|
|||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeIndexedSlicesType,
|
||||
kObjectTypeSparseTensorType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
kObjectTypeDictionary,
|
||||
|
|
|
@ -207,6 +207,23 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) {
|
|||
return std::make_shared<IndexedSlicesType>(element_type);
|
||||
}
|
||||
|
||||
TypePtr SparseTensorStrToType(const std::string &type_name) {
|
||||
if (type_name == "SparseTensor") {
|
||||
return std::make_shared<SparseTensorType>();
|
||||
}
|
||||
auto start = type_name.find_first_of('[') + 1;
|
||||
auto end = type_name.find_last_of(']');
|
||||
if (start >= type_name.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
auto element_str = type_name.substr(start, end - start);
|
||||
auto element_type = StringToType(element_str);
|
||||
if (element_type == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return std::make_shared<SparseTensorType>(element_type);
|
||||
}
|
||||
|
||||
TypePtr UndeterminedStrToType(const std::string &type_name) {
|
||||
if (type_name == "Undetermined") {
|
||||
return std::make_shared<UndeterminedType>();
|
||||
|
@ -349,6 +366,8 @@ TypePtr StringToType(const std::string &type_name) {
|
|||
type = UndeterminedStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
|
||||
type = IndexedSlicesStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
|
||||
type = SparseTensorStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||
type = ListStrToType(type_name);
|
||||
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||
|
@ -428,6 +447,7 @@ const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
|||
const TypePtr kTypeType = std::make_shared<TypeType>();
|
||||
const TypePtr kTensorType = std::make_shared<TensorType>();
|
||||
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
|
||||
const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>();
|
||||
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
|
||||
const TypePtr kString = std::make_shared<String>();
|
||||
const TypePtr kList = std::make_shared<List>();
|
||||
|
|
|
@ -139,6 +139,8 @@ REGISTER_PYBIND_DEFINE(
|
|||
}));
|
||||
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
|
||||
.def(py::init());
|
||||
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
|
||||
.def(py::init());
|
||||
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
||||
.def(py::init());
|
||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||
|
|
|
@ -17,9 +17,49 @@
|
|||
*/
|
||||
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
// namespace to support intermediate representation definition
|
||||
namespace mindspore {
|
||||
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
if (!enable_sparse) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> parameters;
|
||||
ParameterPtr undetermined_param = nullptr;
|
||||
auto stub = std::make_shared<FuncGraph>();
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
auto param = stub->add_parameter();
|
||||
parameters.push_back(param);
|
||||
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
|
||||
undetermined_param = param;
|
||||
}
|
||||
}
|
||||
if (undetermined_param != nullptr) {
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
for (size_t i = 0; i < types.size(); ++i) {
|
||||
if (types[i]->type_id() == kObjectTypeFunction) {
|
||||
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
|
||||
inputs.push_back(stub->NewCNode(call_prim));
|
||||
} else {
|
||||
inputs.push_back(parameters[i]);
|
||||
}
|
||||
}
|
||||
auto stub_output = stub->NewCNode(inputs);
|
||||
stub->set_output(stub_output);
|
||||
stub->set_stub(true);
|
||||
return stub;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
|
||||
TypePtrList types;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
||||
|
|
|
@ -79,6 +79,7 @@ class MetaFuncGraph : public FuncGraphBase {
|
|||
std::shared_ptr<Derived> shared_from_base() {
|
||||
return std::static_pointer_cast<Derived>(shared_from_this());
|
||||
}
|
||||
FuncGraphPtr GenerateStubFunc(const TypePtrList &types);
|
||||
std::string name_;
|
||||
std::vector<Signature> signatures_;
|
||||
std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_;
|
||||
|
|
|
@ -40,18 +40,12 @@ class ParamValue {
|
|||
const std::string &name() const { return name_; }
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
|
||||
const std::string &sparse_grad() const { return sparse_grad_; }
|
||||
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
||||
|
||||
bool requires_grad() const { return requires_grad_; }
|
||||
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
|
||||
|
||||
bool layerwise_parallel() const { return layerwise_parallel_; }
|
||||
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
|
||||
|
||||
bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
|
||||
void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; }
|
||||
|
||||
// Whether the parameter clone from other parameter.
|
||||
bool cloned() const { return cloned_; }
|
||||
|
||||
|
@ -81,10 +75,8 @@ class ParamValue {
|
|||
private:
|
||||
tensor::MetaTensorPtr value_;
|
||||
std::string name_{"Parameter"};
|
||||
std::string sparse_grad_;
|
||||
bool requires_grad_{true};
|
||||
bool layerwise_parallel_{false};
|
||||
bool has_indexed_slices_grad_{false};
|
||||
bool be_cloned_{false};
|
||||
bool cloned_{false};
|
||||
std::vector<int32_t> be_cloned_index_;
|
||||
|
|
|
@ -29,14 +29,10 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
|
|||
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
|
||||
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
|
||||
&ParamValue::set_layerwise_parallel)
|
||||
.def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad,
|
||||
&ParamValue::set_has_indexed_slices_grad)
|
||||
.def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad)
|
||||
.def(py::pickle(
|
||||
[](const ParamValue &p) { // __getstate__
|
||||
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
|
||||
p.layerwise_parallel(), p.has_indexed_slices_grad(),
|
||||
p.sparse_grad());
|
||||
p.layerwise_parallel());
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 6) {
|
||||
|
@ -47,8 +43,6 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
|
|||
p->set_name(t[1].cast<std::string>());
|
||||
p->set_requires_grad(t[2].cast<bool>());
|
||||
p->set_layerwise_parallel(t[3].cast<bool>());
|
||||
p->set_has_indexed_slices_grad(t[4].cast<bool>());
|
||||
p->set_sparse_grad(t[5].cast<std::string>());
|
||||
return p;
|
||||
}));
|
||||
}));
|
||||
|
|
|
@ -159,6 +159,10 @@ indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
|||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||
|
||||
make_sparse_tensor = Primitive('MakeSparseTensor')
|
||||
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
|
||||
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
|
||||
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
|
||||
|
||||
tensor_operator_registry.register('__add__', tensor_add)
|
||||
tensor_operator_registry.register('__sub__', tensor_sub)
|
||||
|
|
|
@ -616,5 +616,18 @@ TEST_F(TestOptLib, test_indexed_slices) {
|
|||
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
||||
}
|
||||
|
||||
TEST_F(TestOptLib, test_sparse_tensor) {
|
||||
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices");
|
||||
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices");
|
||||
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values");
|
||||
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values");
|
||||
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape");
|
||||
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape");
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.sparse_tensor_eliminate_});
|
||||
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1163,3 +1163,38 @@ def test_indexed_slices(tag):
|
|||
return z
|
||||
|
||||
return fns[tag]
|
||||
|
||||
|
||||
def test_sparse_tensor(tag):
|
||||
""" test_add_zero """
|
||||
fns = FnDict()
|
||||
make_sparse_tensor = Primitive('MakeSparseTensor')
|
||||
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
|
||||
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
|
||||
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
|
||||
|
||||
@fns
|
||||
def before_get_indices(x, y, z):
|
||||
return sparse_tensor_get_indices(make_sparse_tensor(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_indices(x, y, z):
|
||||
return x
|
||||
|
||||
@fns
|
||||
def before_get_values(x, y, z):
|
||||
return sparse_tensor_get_values(make_sparse_tensor(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_values(x, y, z):
|
||||
return y
|
||||
|
||||
@fns
|
||||
def before_get_dense_shape(x, y, z):
|
||||
return sparse_tensor_get_dense_shape(make_sparse_tensor(x, y, z))
|
||||
|
||||
@fns
|
||||
def after_get_dense_shape(x, y, z):
|
||||
return z
|
||||
|
||||
return fns[tag]
|
||||
|
|
|
@ -35,6 +35,9 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore._checkparam import Rel
|
||||
from mindspore.nn import Optimizer
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Momentum
|
||||
from mindspore.train import Model
|
||||
from ....dataset_mock import MindData
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
|
||||
|
@ -47,6 +50,40 @@ size_op = P.Size()
|
|||
invert_permutation = P.InvertPermutation()
|
||||
logical_and = P.LogicalAnd()
|
||||
|
||||
def get_axis(x):
|
||||
shape = shape_op(x)
|
||||
length = F.tuple_len(shape)
|
||||
perm = F.make_range(0, length)
|
||||
return perm
|
||||
|
||||
class MSELoss(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MSELoss, self).__init__()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
self.square = P.Square()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
|
||||
def construct(self, data, label):
|
||||
diff = data - label
|
||||
return self.reduce_mean(self.square(diff), get_axis(diff))
|
||||
|
||||
|
||||
class MindDataSet(MindData):
|
||||
def __init__(self, dataset_types, dataset_shapes):
|
||||
super(MindDataSet, self).__init__(size=2, batch_size=32,
|
||||
np_types=dataset_types,
|
||||
output_shapes=dataset_shapes,
|
||||
input_indexs=(0, 1))
|
||||
def __next__(self):
|
||||
if self._size < self._iter_num:
|
||||
raise StopIteration
|
||||
self._iter_num += 1
|
||||
lst = []
|
||||
for shape_, type_ in zip(self._output_shapes, self._np_types):
|
||||
lst.append(Tensor(np.ones(shape_).astype(type_)))
|
||||
return tuple(lst)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _generate_shape_index(out_shape, indices_shape, axis):
|
||||
out_rank = len(out_shape)
|
||||
|
@ -189,8 +226,8 @@ def test_indexed_slices_make_indexed_slices():
|
|||
def construct(self, indices, values):
|
||||
ret = (IndexedSlices(indices, values, self.dense_shape),)
|
||||
return ret[0]
|
||||
indices = Tensor([[0, 0], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
indices = Tensor([1, 2])
|
||||
values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
|
||||
MakeIndexedSlices()(indices, values)
|
||||
|
||||
|
||||
|
@ -202,8 +239,8 @@ def test_indexed_slices_attr():
|
|||
def construct(self, indices, values):
|
||||
x = IndexedSlices(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
indices = Tensor([[0, 0], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
indices = Tensor([0])
|
||||
values = Tensor([[1, 2]], dtype=ms.float32)
|
||||
IndexedSlicesGetAttr()(indices, values)
|
||||
|
||||
|
||||
|
@ -279,3 +316,29 @@ def test_indexed_slices_env_get():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
train_network(inputs, label)
|
||||
|
||||
|
||||
def test_indexed_slices_model_train():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, in_features, out_features):
|
||||
super(Net, self).__init__()
|
||||
self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
|
||||
self.add = P.TensorAdd()
|
||||
self.cast = P.Cast()
|
||||
self.flag = True
|
||||
|
||||
def construct(self, inputs, label):
|
||||
x = self.add(inputs, self.weight)
|
||||
if self.flag:
|
||||
x = self.cast(x, mstype.float32)
|
||||
return x
|
||||
|
||||
dataset_types = (np.float32, np.float32)
|
||||
dataset_shapes = ((16, 16), (16, 16))
|
||||
dataset = MindDataSet(dataset_types, dataset_shapes)
|
||||
net = Net(16, 16)
|
||||
net.set_train()
|
||||
|
||||
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||
model = Model(net, optimizer=optimizer)
|
||||
model.train(2, dataset, dataset_sink_mode=False)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
@File : test_sparse_tensor.py
|
||||
@Author:
|
||||
@Date : 2020-07-16
|
||||
@Desc : test mindspore sparse_tensor's operation
|
||||
"""
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import Tensor, SparseTensor, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
|
||||
def test_sparse_tensor_make_sparse_tensor():
|
||||
class MakeSparseTensor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MakeSparseTensor, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
ret = (SparseTensor(indices, values, self.dense_shape),)
|
||||
return ret[0]
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
MakeSparseTensor()(indices, values)
|
||||
|
||||
|
||||
def test_sparse_tensor_attr():
|
||||
grad_op = C.GradOperation('get_all', get_all=True)
|
||||
class GradWrap(nn.Cell):
|
||||
def __init__(self, network):
|
||||
super(GradWrap, self).__init__()
|
||||
self.network = network
|
||||
def construct(self, input1, input2):
|
||||
gout = grad_op(self.network)(input1, input2)
|
||||
return gout
|
||||
|
||||
class SparseTensorGetAttr(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseTensorGetAttr, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
x = SparseTensor(indices, values, self.dense_shape)
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
SparseTensorGetAttr()(indices, values)
|
Loading…
Reference in New Issue