forked from mindspore-Ecosystem/mindspore
!3114 add coo_tensor
Merge pull request !3114 from riemann_penn/coo_tensor
This commit is contained in:
commit
4a19e6b8cb
|
@ -17,7 +17,7 @@
|
||||||
"""Resources for ast tree parse."""
|
"""Resources for ast tree parse."""
|
||||||
import ast
|
import ast
|
||||||
import math
|
import math
|
||||||
from mindspore import IndexedSlices
|
from mindspore import IndexedSlices, SparseTensor
|
||||||
from mindspore.ops.composite import multitype_ops
|
from mindspore.ops.composite import multitype_ops
|
||||||
from mindspore.ops import functional as F, composite as C
|
from mindspore.ops import functional as F, composite as C
|
||||||
from . import standard_method as M
|
from . import standard_method as M
|
||||||
|
@ -140,4 +140,5 @@ convert_object_map = {
|
||||||
|
|
||||||
# user defined
|
# user defined
|
||||||
IndexedSlices: F.make_indexed_slices,
|
IndexedSlices: F.make_indexed_slices,
|
||||||
|
SparseTensor: F.make_sparse_tensor,
|
||||||
}
|
}
|
||||||
|
|
|
@ -124,6 +124,8 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
|
||||||
// Do Nothing
|
// Do Nothing
|
||||||
} else if (type->isa<UndeterminedType>()) {
|
} else if (type->isa<UndeterminedType>()) {
|
||||||
// Do Nothing
|
// Do Nothing
|
||||||
|
} else if (type->isa<SparseTensorType>()) {
|
||||||
|
// Do Nothing
|
||||||
} else if (type->isa<Tuple>()) {
|
} else if (type->isa<Tuple>()) {
|
||||||
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
TuplePtr tuple_type = dyn_cast<Tuple>(type);
|
||||||
type_proto->set_data_type(irpb::DT_TUPLE);
|
type_proto->set_data_type(irpb::DT_TUPLE);
|
||||||
|
|
|
@ -803,6 +803,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
|
||||||
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
|
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
|
||||||
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
|
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
|
||||||
if (a_tuple == nullptr || b_tuple == nullptr) {
|
if (a_tuple == nullptr || b_tuple == nullptr) {
|
||||||
|
TypePtrList types;
|
||||||
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
||||||
|
[](const AbstractBasePtr &arg) -> TypePtr {
|
||||||
|
MS_EXCEPTION_IF_NULL(arg);
|
||||||
|
return arg->BuildType();
|
||||||
|
});
|
||||||
|
auto stub = GenerateStubFunc(types);
|
||||||
|
if (stub != nullptr) {
|
||||||
|
MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
|
||||||
|
<< ", function: " << stub->ToString();
|
||||||
|
return stub;
|
||||||
|
}
|
||||||
MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
|
MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
|
||||||
<< args_spec_list[1]->ToString();
|
<< args_spec_list[1]->ToString();
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
|
||||||
return py::none();
|
return py::none();
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
|
|
||||||
auto context = MsContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
bool enable_sparse = context->enable_sparse();
|
|
||||||
if (!enable_sparse) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<AnfNodePtr> parameters;
|
|
||||||
ParameterPtr undetermined_param = nullptr;
|
|
||||||
auto stub = std::make_shared<FuncGraph>();
|
|
||||||
for (size_t i = 0; i < types.size(); ++i) {
|
|
||||||
auto param = stub->add_parameter();
|
|
||||||
parameters.push_back(param);
|
|
||||||
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
|
|
||||||
undetermined_param = param;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (undetermined_param != nullptr) {
|
|
||||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
|
||||||
for (size_t i = 0; i < types.size(); ++i) {
|
|
||||||
if (types[i]->type_id() == kObjectTypeFunction) {
|
|
||||||
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
|
|
||||||
inputs.push_back(stub->NewCNode(call_prim));
|
|
||||||
} else {
|
|
||||||
inputs.push_back(parameters[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto stub_output = stub->NewCNode(inputs);
|
|
||||||
stub->set_output(stub_output);
|
|
||||||
stub->set_stub(true);
|
|
||||||
return stub;
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
|
||||||
auto py_fn = SignMatch(types);
|
auto py_fn = SignMatch(types);
|
||||||
std::ostringstream buffer;
|
std::ostringstream buffer;
|
||||||
|
|
|
@ -283,6 +283,11 @@ const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeInd
|
||||||
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
|
||||||
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
|
||||||
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
|
||||||
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
|
|
||||||
|
// SparseTensor
|
||||||
|
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
|
||||||
|
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
|
||||||
|
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
|
||||||
|
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
|
||||||
} // namespace prim
|
} // namespace prim
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -292,7 +292,12 @@ extern const PrimitivePtr kPrimMakeIndexedSlices;
|
||||||
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
|
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
|
||||||
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
|
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
|
||||||
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
|
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
|
||||||
extern const PrimitivePtr kPrimIsIndexedSlices;
|
|
||||||
|
// SparseTensor
|
||||||
|
extern const PrimitivePtr kPrimMakeSparseTensor;
|
||||||
|
extern const PrimitivePtr kPrimSparseTensorGetValues;
|
||||||
|
extern const PrimitivePtr kPrimSparseTensorGetIndices;
|
||||||
|
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
|
||||||
|
|
||||||
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
|
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
|
||||||
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
|
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";
|
||||||
|
|
|
@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
|
||||||
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||||
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
|
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
|
||||||
|
|
||||||
|
auto indices_dtype = indices->element()->BuildType();
|
||||||
|
if (!indices_dtype->isa<Int>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
|
||||||
|
}
|
||||||
|
auto indices_shp = indices->shape()->shape();
|
||||||
|
if (indices_shp.size() != 1) {
|
||||||
|
MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
|
||||||
|
<< " dimension tensor";
|
||||||
|
}
|
||||||
|
auto values_shp = values->shape()->shape();
|
||||||
|
if (indices_shp[0] != values_shp[0]) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
|
||||||
|
<< values_shp[0] << ", but got " << indices_shp[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto elem_type : dense_shape->ElementsType()) {
|
||||||
|
if (!elem_type->isa<Int>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||||
auto shp = dense_shape_value->value();
|
auto shp = dense_shape_value->value();
|
||||||
|
@ -358,6 +378,12 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
|
||||||
auto elem = GetValue<int>(e);
|
auto elem = GetValue<int>(e);
|
||||||
return elem;
|
return elem;
|
||||||
});
|
});
|
||||||
|
for (auto dense_shape_elem : dense_shape_vec) {
|
||||||
|
if (dense_shape_elem < 0) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
|
||||||
|
<< dense_shape_value->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
|
||||||
ret->set_indices(indices);
|
ret->set_indices(indices);
|
||||||
ret->set_values(values);
|
ret->set_values(values);
|
||||||
|
@ -395,16 +421,89 @@ AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, c
|
||||||
return indexed_slices->dense_shape();
|
return indexed_slices->dense_shape();
|
||||||
}
|
}
|
||||||
|
|
||||||
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list) {
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 3);
|
||||||
|
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||||
|
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||||
|
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
|
||||||
|
|
||||||
|
auto indices_dtype = indices->element()->BuildType();
|
||||||
|
if (!indices_dtype->isa<Int>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
|
||||||
|
}
|
||||||
|
auto indices_shp = indices->shape()->shape();
|
||||||
|
if (indices_shp.size() != 2) {
|
||||||
|
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
|
||||||
|
<< " dimension tensor";
|
||||||
|
}
|
||||||
|
auto values_shp = values->shape()->shape();
|
||||||
|
if (values_shp.size() != 1) {
|
||||||
|
MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
|
||||||
|
<< " dimension tensor";
|
||||||
|
}
|
||||||
|
if (indices_shp[0] != values_shp[0]) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
|
||||||
|
<< values_shp[0] << ", but got " << indices_shp[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto elem_type : dense_shape->ElementsType()) {
|
||||||
|
if (!elem_type->isa<Int>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(dense_shape_value);
|
||||||
|
auto shp = dense_shape_value->value();
|
||||||
|
std::vector<int> dense_shape_vec;
|
||||||
|
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
|
||||||
|
[](const ValuePtr &e) -> int {
|
||||||
|
auto elem = GetValue<int>(e);
|
||||||
|
return elem;
|
||||||
|
});
|
||||||
|
for (auto dense_shape_elem : dense_shape_vec) {
|
||||||
|
if (dense_shape_elem < 0) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
|
||||||
|
<< dense_shape_value->ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
|
||||||
|
ret->set_indices(indices);
|
||||||
|
ret->set_values(values);
|
||||||
|
ret->set_dense_shape(dense_shape);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
const std::string op_name = primitive->name();
|
const std::string op_name = primitive->name();
|
||||||
CheckArgsSize(op_name, args_spec_list, 1);
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
bool ret = false;
|
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
|
||||||
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) {
|
MS_EXCEPTION_IF_NULL(sparse_tensor->values());
|
||||||
ret = true;
|
return sparse_tensor->values();
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
|
|
||||||
return std::make_shared<AbstractScalar>(ret);
|
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
|
||||||
|
return sparse_tensor->indices();
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tensors and a tuple.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
|
||||||
|
return sparse_tensor->dense_shape();
|
||||||
}
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -264,7 +264,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
|
||||||
return IsPrimitiveCNode(user.first, prim);
|
return IsPrimitiveCNode(user.first, prim);
|
||||||
});
|
});
|
||||||
if (cnode == users.end()) {
|
if (cnode == users.end()) {
|
||||||
MS_LOG(EXCEPTION) << "Fail to find cnode.";
|
MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
|
||||||
}
|
}
|
||||||
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
|
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@
|
||||||
#include "frontend/optimizer/irpass/transpose_eliminate.h"
|
#include "frontend/optimizer/irpass/transpose_eliminate.h"
|
||||||
#include "frontend/optimizer/opt.h"
|
#include "frontend/optimizer/opt.h"
|
||||||
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
|
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
|
||||||
|
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -159,6 +160,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
indexed_slices_eliminate_ = MakeSubstitution(
|
indexed_slices_eliminate_ = MakeSubstitution(
|
||||||
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
|
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
|
||||||
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
|
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
|
||||||
|
|
||||||
|
// SparseTensor Eliminate
|
||||||
|
sparse_tensor_eliminate_ = MakeSubstitution(
|
||||||
|
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
|
||||||
|
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
|
||||||
}
|
}
|
||||||
|
|
||||||
ResolveIRPassLib::ResolveIRPassLib() {
|
ResolveIRPassLib::ResolveIRPassLib() {
|
||||||
|
|
|
@ -107,6 +107,9 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// IndexedSlices Eliminate
|
// IndexedSlices Eliminate
|
||||||
SubstitutionPtr indexed_slices_eliminate_;
|
SubstitutionPtr indexed_slices_eliminate_;
|
||||||
|
|
||||||
|
// SparseTensor Eliminate
|
||||||
|
SubstitutionPtr sparse_tensor_eliminate_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// the collection of irpass for resolve action
|
// the collection of irpass for resolve action
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "frontend/optimizer/irpass.h"
|
||||||
|
#include "frontend/optimizer/optimizer.h"
|
||||||
|
#include "ir/visitor.h"
|
||||||
|
#include "frontend/operator/ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
|
||||||
|
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
|
||||||
|
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
|
||||||
|
class SparseTensorEliminater : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
Reset();
|
||||||
|
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);
|
||||||
|
|
||||||
|
if (is_match_) {
|
||||||
|
return tuple_->input(1);
|
||||||
|
}
|
||||||
|
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);
|
||||||
|
|
||||||
|
if (is_match_) {
|
||||||
|
return tuple_->input(2);
|
||||||
|
}
|
||||||
|
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);
|
||||||
|
|
||||||
|
if (is_match_) {
|
||||||
|
return tuple_->input(3);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Visit(const CNodePtr &cnode) override {
|
||||||
|
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
|
||||||
|
tuple_ = cnode;
|
||||||
|
is_match_ = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() {
|
||||||
|
tuple_ = nullptr;
|
||||||
|
is_match_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool is_match_{false};
|
||||||
|
CNodePtr tuple_{nullptr};
|
||||||
|
};
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
|
|
@ -157,6 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
irpass.make_ref_eliminate_,
|
irpass.make_ref_eliminate_,
|
||||||
irpass.get_ref_param_eliminate_,
|
irpass.get_ref_param_eliminate_,
|
||||||
irpass.indexed_slices_eliminate_,
|
irpass.indexed_slices_eliminate_,
|
||||||
|
irpass.sparse_tensor_eliminate_,
|
||||||
});
|
});
|
||||||
OptPassGroupMap map({
|
OptPassGroupMap map({
|
||||||
{"b_1", b_1},
|
{"b_1", b_1},
|
||||||
|
|
|
@ -179,6 +179,12 @@ MethodMap &GetMethodMap() {
|
||||||
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
|
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
|
||||||
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
|
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
|
||||||
}},
|
}},
|
||||||
|
{kObjectTypeSparseTensorType,
|
||||||
|
{
|
||||||
|
{"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
|
||||||
|
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
|
||||||
|
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
|
||||||
|
}},
|
||||||
{kObjectTypeJTagged, {}},
|
{kObjectTypeJTagged, {}},
|
||||||
{kObjectTypeSymbolicKeyType, {}},
|
{kObjectTypeSymbolicKeyType, {}},
|
||||||
{kObjectTypeEnvType, {}}};
|
{kObjectTypeEnvType, {}}};
|
||||||
|
|
|
@ -138,7 +138,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
||||||
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
|
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
|
||||||
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
|
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
|
||||||
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
|
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
|
||||||
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}},
|
// SparseTensor
|
||||||
|
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
|
||||||
|
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
|
||||||
|
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
|
||||||
|
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
|
||||||
};
|
};
|
||||||
return prim_eval_implement_map;
|
return prim_eval_implement_map;
|
||||||
}
|
}
|
||||||
|
|
|
@ -358,7 +358,13 @@ AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, cons
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list);
|
||||||
|
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const AbstractBasePtrList &args_spec_list);
|
const AbstractBasePtrList &args_spec_list);
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -36,6 +36,7 @@ using mindspore::abstract::AbstractIndexedSlices;
|
||||||
using mindspore::abstract::AbstractJTagged;
|
using mindspore::abstract::AbstractJTagged;
|
||||||
using mindspore::abstract::AbstractList;
|
using mindspore::abstract::AbstractList;
|
||||||
using mindspore::abstract::AbstractScalar;
|
using mindspore::abstract::AbstractScalar;
|
||||||
|
using mindspore::abstract::AbstractSparseTensor;
|
||||||
using mindspore::abstract::AbstractTensor;
|
using mindspore::abstract::AbstractTensor;
|
||||||
using mindspore::abstract::AbstractTuple;
|
using mindspore::abstract::AbstractTuple;
|
||||||
using mindspore::abstract::AbstractType;
|
using mindspore::abstract::AbstractType;
|
||||||
|
@ -95,7 +96,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
|
||||||
|
|
||||||
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
|
||||||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
|
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
|
||||||
ptrBase->isa<abstract::AbstractRefKey>()) {
|
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,10 +17,10 @@ from . import dtype
|
||||||
from .api import ms_function
|
from .api import ms_function
|
||||||
from .dtype import *
|
from .dtype import *
|
||||||
from .parameter import Parameter, ParameterTuple
|
from .parameter import Parameter, ParameterTuple
|
||||||
from .tensor import MetaTensor, Tensor, IndexedSlices
|
from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MetaTensor", "Tensor", "IndexedSlices", # tensor
|
"MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor
|
||||||
'ms_function', # api
|
'ms_function', # api
|
||||||
'Parameter', 'ParameterTuple', # parameter
|
'Parameter', 'ParameterTuple', # parameter
|
||||||
"dtype"
|
"dtype"
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
|
||||||
from . import dtype as mstype
|
from . import dtype as mstype
|
||||||
from ._register_for_tensor import tensor_operator_registry
|
from ._register_for_tensor import tensor_operator_registry
|
||||||
|
|
||||||
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices']
|
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor']
|
||||||
np_types = (np.int8, np.int16, np.int32, np.int64,
|
np_types = (np.int8, np.int16, np.int32, np.int64,
|
||||||
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
|
||||||
np.float32, np.float64, np.bool_)
|
np.float32, np.float64, np.bool_)
|
||||||
|
@ -211,3 +211,7 @@ class Tensor(Tensor_):
|
||||||
class IndexedSlices:
|
class IndexedSlices:
|
||||||
def __init__(self, indices, values, dense_shape):
|
def __init__(self, indices, values, dense_shape):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
class SparseTensor:
|
||||||
|
def __init__(self, indices, values, dense_shape):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -1093,5 +1093,64 @@ std::string AbstractIndexedSlices::ToString() const {
|
||||||
<< ", dense_shape: " << dense_shape_->ToString();
|
<< ", dense_shape: " << dense_shape_->ToString();
|
||||||
return buffer.str();
|
return buffer.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SparseTensor
|
||||||
|
TypePtr AbstractSparseTensor::BuildType() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
TypePtr element_type = element()->BuildType();
|
||||||
|
return std::make_shared<SparseTensorType>(element_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractSparseTensor::Clone() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
|
||||||
|
ShapePtr shp = shape();
|
||||||
|
clone->set_shape(shp->Clone());
|
||||||
|
clone->set_value(GetValueTrack());
|
||||||
|
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractSparseTensor::Broaden() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
|
||||||
|
auto shp = shape();
|
||||||
|
broaden->set_shape(shp->Clone());
|
||||||
|
broaden->set_value(kAnyValue);
|
||||||
|
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||||
|
return broaden;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
|
||||||
|
auto shp = shape()->Clone();
|
||||||
|
shp->Broaden();
|
||||||
|
broaden->set_shape(shp);
|
||||||
|
broaden->set_value(kAnyValue);
|
||||||
|
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
|
||||||
|
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
|
||||||
|
return broaden;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string AbstractSparseTensor::ToString() const {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
BaseShapePtr shape_track = GetShapeTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape_track);
|
||||||
|
MS_EXCEPTION_IF_NULL(element());
|
||||||
|
auto value_track = GetValueTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_track);
|
||||||
|
buffer << type_name() << "("
|
||||||
|
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
|
||||||
|
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
|
||||||
|
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
|
||||||
|
<< ", dense_shape: " << dense_shape_->ToString();
|
||||||
|
return buffer.str();
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -604,10 +604,39 @@ class AbstractIndexedSlices : public AbstractUndetermined {
|
||||||
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
|
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
|
||||||
|
|
||||||
const AbstractTensorPtr indices() const { return indices_; }
|
const AbstractTensorPtr indices() const { return indices_; }
|
||||||
const AbstractTensorPtr values() const { return values_; }
|
|
||||||
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
|
||||||
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||||
|
const AbstractTensorPtr values() const { return values_; }
|
||||||
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
||||||
|
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||||
|
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||||
|
TypePtr BuildType() const override;
|
||||||
|
AbstractBasePtr Clone() const override;
|
||||||
|
AbstractBasePtr Broaden() const override;
|
||||||
|
AbstractBasePtr BroadenWithShape() const;
|
||||||
|
|
||||||
|
std::string ToString() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
AbstractTensorPtr indices_;
|
||||||
|
AbstractTensorPtr values_;
|
||||||
|
AbstractTuplePtr dense_shape_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// SparseTensor
|
||||||
|
class AbstractSparseTensor : public AbstractUndetermined {
|
||||||
|
public:
|
||||||
|
explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
|
||||||
|
: AbstractUndetermined(element, shape) {}
|
||||||
|
AbstractSparseTensor(const TypePtr &element_type, const std::vector<int> &shape)
|
||||||
|
: AbstractUndetermined(element_type, shape) {}
|
||||||
|
~AbstractSparseTensor() override = default;
|
||||||
|
MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined)
|
||||||
|
|
||||||
|
const AbstractTensorPtr indices() const { return indices_; }
|
||||||
|
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
|
||||||
|
const AbstractTensorPtr values() const { return values_; }
|
||||||
|
void set_values(const AbstractTensorPtr &values) { values_ = values; }
|
||||||
|
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
|
||||||
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
|
||||||
TypePtr BuildType() const override;
|
TypePtr BuildType() const override;
|
||||||
AbstractBasePtr Clone() const override;
|
AbstractBasePtr Clone() const override;
|
||||||
|
|
|
@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
ABSTRACT_REPORT_NAME_TRAITS(Class)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
|
||||||
|
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
|
||||||
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
|
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -221,6 +221,48 @@ bool IndexedSlicesType::operator==(const Type &other) const {
|
||||||
return *element_type_ == *other_elem_type;
|
return *element_type_ == *other_elem_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypePtr SparseTensorType::DeepCopy() const {
|
||||||
|
MS_EXCEPTION_IF_NULL(element_type_);
|
||||||
|
if (IsGeneric()) {
|
||||||
|
return std::make_shared<SparseTensorType>();
|
||||||
|
}
|
||||||
|
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SparseTensorType::ToReprString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "SparseTensor";
|
||||||
|
}
|
||||||
|
return "SparseTensor[" + element_type_->ToReprString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SparseTensorType::ToString() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "SparseTensor";
|
||||||
|
}
|
||||||
|
return "SparseTensor[" + element_type_->ToString() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SparseTensorType::DumpText() const {
|
||||||
|
if (element_type_ == nullptr) {
|
||||||
|
return "SparseTensor";
|
||||||
|
}
|
||||||
|
return "SparseTensor[" + element_type_->DumpText() + "]";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SparseTensorType::operator==(const Type &other) const {
|
||||||
|
if (!IsSameObjectType(*this, other)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
|
||||||
|
if (element_type_ == nullptr && other_elem_type == nullptr) {
|
||||||
|
return true;
|
||||||
|
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return *element_type_ == *other_elem_type;
|
||||||
|
}
|
||||||
|
|
||||||
Function::Function() : Object(kObjectTypeFunction) {
|
Function::Function() : Object(kObjectTypeFunction) {
|
||||||
args_ = std::vector<TypePtr>();
|
args_ = std::vector<TypePtr>();
|
||||||
retval_ = nullptr;
|
retval_ = nullptr;
|
||||||
|
|
|
@ -177,6 +177,29 @@ class IndexedSlicesType : public Object {
|
||||||
};
|
};
|
||||||
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
|
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
|
||||||
|
|
||||||
|
class SparseTensorType : public Object {
|
||||||
|
public:
|
||||||
|
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
|
||||||
|
explicit SparseTensorType(const TypePtr &ele)
|
||||||
|
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
|
||||||
|
~SparseTensorType() override = default;
|
||||||
|
MS_DECLARE_PARENT(SparseTensorType, Object)
|
||||||
|
|
||||||
|
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
|
||||||
|
const TypePtr element() const { return element_type_; }
|
||||||
|
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
|
||||||
|
|
||||||
|
TypePtr DeepCopy() const override;
|
||||||
|
std::string ToString() const override;
|
||||||
|
std::string ToReprString() const override;
|
||||||
|
std::string DumpText() const override;
|
||||||
|
bool operator==(const Type &other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
TypePtr element_type_;
|
||||||
|
};
|
||||||
|
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
|
||||||
|
|
||||||
class Function : public Object {
|
class Function : public Object {
|
||||||
public:
|
public:
|
||||||
Function();
|
Function();
|
||||||
|
|
|
@ -117,6 +117,8 @@ const char *ObjectIdLabel(const TypeId &v) {
|
||||||
return "kObjectTypeTensorType";
|
return "kObjectTypeTensorType";
|
||||||
case kObjectTypeIndexedSlicesType:
|
case kObjectTypeIndexedSlicesType:
|
||||||
return "kObjectTypeIndexedSlicesType";
|
return "kObjectTypeIndexedSlicesType";
|
||||||
|
case kObjectTypeSparseTensorType:
|
||||||
|
return "kObjectTypeSparseTensorType";
|
||||||
case kObjectTypeUndeterminedType:
|
case kObjectTypeUndeterminedType:
|
||||||
return "kObjectTypeUndeterminedType";
|
return "kObjectTypeUndeterminedType";
|
||||||
case kObjectTypeDictionary:
|
case kObjectTypeDictionary:
|
||||||
|
|
|
@ -51,6 +51,7 @@ enum TypeId : int {
|
||||||
kObjectTypeKeyword,
|
kObjectTypeKeyword,
|
||||||
kObjectTypeTensorType,
|
kObjectTypeTensorType,
|
||||||
kObjectTypeIndexedSlicesType,
|
kObjectTypeIndexedSlicesType,
|
||||||
|
kObjectTypeSparseTensorType,
|
||||||
kObjectTypeUndeterminedType,
|
kObjectTypeUndeterminedType,
|
||||||
kObjectTypeClass,
|
kObjectTypeClass,
|
||||||
kObjectTypeDictionary,
|
kObjectTypeDictionary,
|
||||||
|
|
|
@ -207,6 +207,23 @@ TypePtr IndexedSlicesStrToType(const std::string &type_name) {
|
||||||
return std::make_shared<IndexedSlicesType>(element_type);
|
return std::make_shared<IndexedSlicesType>(element_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypePtr SparseTensorStrToType(const std::string &type_name) {
|
||||||
|
if (type_name == "SparseTensor") {
|
||||||
|
return std::make_shared<SparseTensorType>();
|
||||||
|
}
|
||||||
|
auto start = type_name.find_first_of('[') + 1;
|
||||||
|
auto end = type_name.find_last_of(']');
|
||||||
|
if (start >= type_name.size()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto element_str = type_name.substr(start, end - start);
|
||||||
|
auto element_type = StringToType(element_str);
|
||||||
|
if (element_type == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return std::make_shared<SparseTensorType>(element_type);
|
||||||
|
}
|
||||||
|
|
||||||
TypePtr UndeterminedStrToType(const std::string &type_name) {
|
TypePtr UndeterminedStrToType(const std::string &type_name) {
|
||||||
if (type_name == "Undetermined") {
|
if (type_name == "Undetermined") {
|
||||||
return std::make_shared<UndeterminedType>();
|
return std::make_shared<UndeterminedType>();
|
||||||
|
@ -349,6 +366,8 @@ TypePtr StringToType(const std::string &type_name) {
|
||||||
type = UndeterminedStrToType(type_name);
|
type = UndeterminedStrToType(type_name);
|
||||||
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
|
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
|
||||||
type = IndexedSlicesStrToType(type_name);
|
type = IndexedSlicesStrToType(type_name);
|
||||||
|
} else if (type_name.compare(0, strlen("SparseTensor"), "SparseTensor") == 0) {
|
||||||
|
type = SparseTensorStrToType(type_name);
|
||||||
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
|
||||||
type = ListStrToType(type_name);
|
type = ListStrToType(type_name);
|
||||||
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
|
||||||
|
@ -428,6 +447,7 @@ const TypePtr kTypeEnv = std::make_shared<EnvType>();
|
||||||
const TypePtr kTypeType = std::make_shared<TypeType>();
|
const TypePtr kTypeType = std::make_shared<TypeType>();
|
||||||
const TypePtr kTensorType = std::make_shared<TensorType>();
|
const TypePtr kTensorType = std::make_shared<TensorType>();
|
||||||
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
|
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
|
||||||
|
const TypePtr kSparseTensorType = std::make_shared<SparseTensorType>();
|
||||||
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
|
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
|
||||||
const TypePtr kString = std::make_shared<String>();
|
const TypePtr kString = std::make_shared<String>();
|
||||||
const TypePtr kList = std::make_shared<List>();
|
const TypePtr kList = std::make_shared<List>();
|
||||||
|
|
|
@ -139,6 +139,8 @@ REGISTER_PYBIND_DEFINE(
|
||||||
}));
|
}));
|
||||||
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
|
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
|
||||||
.def(py::init());
|
.def(py::init());
|
||||||
|
(void)py::class_<SparseTensorType, Type, std::shared_ptr<SparseTensorType>>(m_sub, "SparseTensorType")
|
||||||
|
.def(py::init());
|
||||||
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
|
||||||
.def(py::init());
|
.def(py::init());
|
||||||
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
|
||||||
|
|
|
@ -17,9 +17,49 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "ir/meta_func_graph.h"
|
#include "ir/meta_func_graph.h"
|
||||||
|
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||||
|
#include "pipeline/jit/static_analysis/abstract_function.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
#include "frontend/operator/ops.h"
|
||||||
|
|
||||||
// namespace to support intermediate representation definition
|
// namespace to support intermediate representation definition
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
FuncGraphPtr MetaFuncGraph::GenerateStubFunc(const TypePtrList &types) {
|
||||||
|
auto context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
|
bool enable_sparse = context->enable_sparse();
|
||||||
|
if (!enable_sparse) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> parameters;
|
||||||
|
ParameterPtr undetermined_param = nullptr;
|
||||||
|
auto stub = std::make_shared<FuncGraph>();
|
||||||
|
for (size_t i = 0; i < types.size(); ++i) {
|
||||||
|
auto param = stub->add_parameter();
|
||||||
|
parameters.push_back(param);
|
||||||
|
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
|
||||||
|
undetermined_param = param;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (undetermined_param != nullptr) {
|
||||||
|
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||||
|
for (size_t i = 0; i < types.size(); ++i) {
|
||||||
|
if (types[i]->type_id() == kObjectTypeFunction) {
|
||||||
|
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
|
||||||
|
inputs.push_back(stub->NewCNode(call_prim));
|
||||||
|
} else {
|
||||||
|
inputs.push_back(parameters[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto stub_output = stub->NewCNode(inputs);
|
||||||
|
stub->set_output(stub_output);
|
||||||
|
stub->set_stub(true);
|
||||||
|
return stub;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
|
FuncGraphPtr MetaFuncGraph::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) {
|
||||||
TypePtrList types;
|
TypePtrList types;
|
||||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
||||||
|
|
|
@ -79,6 +79,7 @@ class MetaFuncGraph : public FuncGraphBase {
|
||||||
std::shared_ptr<Derived> shared_from_base() {
|
std::shared_ptr<Derived> shared_from_base() {
|
||||||
return std::static_pointer_cast<Derived>(shared_from_this());
|
return std::static_pointer_cast<Derived>(shared_from_this());
|
||||||
}
|
}
|
||||||
|
FuncGraphPtr GenerateStubFunc(const TypePtrList &types);
|
||||||
std::string name_;
|
std::string name_;
|
||||||
std::vector<Signature> signatures_;
|
std::vector<Signature> signatures_;
|
||||||
std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_;
|
std::unordered_map<TypePtrList, FuncGraphPtr, TypeListHasher, TypeListEqual> cache_;
|
||||||
|
|
|
@ -40,18 +40,12 @@ class ParamValue {
|
||||||
const std::string &name() const { return name_; }
|
const std::string &name() const { return name_; }
|
||||||
void set_name(const std::string &name) { name_ = name; }
|
void set_name(const std::string &name) { name_ = name; }
|
||||||
|
|
||||||
const std::string &sparse_grad() const { return sparse_grad_; }
|
|
||||||
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
|
||||||
|
|
||||||
bool requires_grad() const { return requires_grad_; }
|
bool requires_grad() const { return requires_grad_; }
|
||||||
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
|
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
|
||||||
|
|
||||||
bool layerwise_parallel() const { return layerwise_parallel_; }
|
bool layerwise_parallel() const { return layerwise_parallel_; }
|
||||||
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
|
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
|
||||||
|
|
||||||
bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
|
|
||||||
void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; }
|
|
||||||
|
|
||||||
// Whether the parameter clone from other parameter.
|
// Whether the parameter clone from other parameter.
|
||||||
bool cloned() const { return cloned_; }
|
bool cloned() const { return cloned_; }
|
||||||
|
|
||||||
|
@ -81,10 +75,8 @@ class ParamValue {
|
||||||
private:
|
private:
|
||||||
tensor::MetaTensorPtr value_;
|
tensor::MetaTensorPtr value_;
|
||||||
std::string name_{"Parameter"};
|
std::string name_{"Parameter"};
|
||||||
std::string sparse_grad_;
|
|
||||||
bool requires_grad_{true};
|
bool requires_grad_{true};
|
||||||
bool layerwise_parallel_{false};
|
bool layerwise_parallel_{false};
|
||||||
bool has_indexed_slices_grad_{false};
|
|
||||||
bool be_cloned_{false};
|
bool be_cloned_{false};
|
||||||
bool cloned_{false};
|
bool cloned_{false};
|
||||||
std::vector<int32_t> be_cloned_index_;
|
std::vector<int32_t> be_cloned_index_;
|
||||||
|
|
|
@ -29,14 +29,10 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
|
||||||
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
|
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
|
||||||
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
|
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
|
||||||
&ParamValue::set_layerwise_parallel)
|
&ParamValue::set_layerwise_parallel)
|
||||||
.def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad,
|
|
||||||
&ParamValue::set_has_indexed_slices_grad)
|
|
||||||
.def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad)
|
|
||||||
.def(py::pickle(
|
.def(py::pickle(
|
||||||
[](const ParamValue &p) { // __getstate__
|
[](const ParamValue &p) { // __getstate__
|
||||||
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
|
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
|
||||||
p.layerwise_parallel(), p.has_indexed_slices_grad(),
|
p.layerwise_parallel());
|
||||||
p.sparse_grad());
|
|
||||||
},
|
},
|
||||||
[](const py::tuple &t) { // __setstate__
|
[](const py::tuple &t) { // __setstate__
|
||||||
if (t.size() != 6) {
|
if (t.size() != 6) {
|
||||||
|
@ -47,8 +43,6 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
|
||||||
p->set_name(t[1].cast<std::string>());
|
p->set_name(t[1].cast<std::string>());
|
||||||
p->set_requires_grad(t[2].cast<bool>());
|
p->set_requires_grad(t[2].cast<bool>());
|
||||||
p->set_layerwise_parallel(t[3].cast<bool>());
|
p->set_layerwise_parallel(t[3].cast<bool>());
|
||||||
p->set_has_indexed_slices_grad(t[4].cast<bool>());
|
|
||||||
p->set_sparse_grad(t[5].cast<std::string>());
|
|
||||||
return p;
|
return p;
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
|
@ -159,6 +159,10 @@ indexed_slices_get_values = Primitive('IndexedSlicesGetValues')
|
||||||
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
indexed_slices_get_indices = Primitive('IndexedSlicesGetIndices')
|
||||||
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
indexed_slices_get_dense_shape = Primitive('IndexedSlicesGetDenseShape')
|
||||||
|
|
||||||
|
make_sparse_tensor = Primitive('MakeSparseTensor')
|
||||||
|
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
|
||||||
|
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
|
||||||
|
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
|
||||||
|
|
||||||
tensor_operator_registry.register('__add__', tensor_add)
|
tensor_operator_registry.register('__add__', tensor_add)
|
||||||
tensor_operator_registry.register('__sub__', tensor_sub)
|
tensor_operator_registry.register('__sub__', tensor_sub)
|
||||||
|
|
|
@ -616,5 +616,18 @@ TEST_F(TestOptLib, test_indexed_slices) {
|
||||||
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
||||||
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TestOptLib, test_sparse_tensor) {
|
||||||
|
FuncGraphPtr before_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_indices");
|
||||||
|
FuncGraphPtr after_get_indices = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_indices");
|
||||||
|
FuncGraphPtr before_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_values");
|
||||||
|
FuncGraphPtr after_get_values = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_values");
|
||||||
|
FuncGraphPtr before_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "before_get_dense_shape");
|
||||||
|
FuncGraphPtr after_get_dense_shape = getPyFun.CallAndParseRet("test_sparse_tensor", "after_get_dense_shape");
|
||||||
|
auto patterns = std::vector<SubstitutionPtr>({irpass.sparse_tensor_eliminate_});
|
||||||
|
ASSERT_TRUE(CheckOpt(before_get_indices, after_get_indices, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before_get_values, after_get_values, patterns));
|
||||||
|
ASSERT_TRUE(CheckOpt(before_get_dense_shape, after_get_dense_shape, patterns));
|
||||||
|
}
|
||||||
} // namespace opt
|
} // namespace opt
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1163,3 +1163,38 @@ def test_indexed_slices(tag):
|
||||||
return z
|
return z
|
||||||
|
|
||||||
return fns[tag]
|
return fns[tag]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_tensor(tag):
|
||||||
|
""" test_add_zero """
|
||||||
|
fns = FnDict()
|
||||||
|
make_sparse_tensor = Primitive('MakeSparseTensor')
|
||||||
|
sparse_tensor_get_values = Primitive('SparseTensorGetValues')
|
||||||
|
sparse_tensor_get_indices = Primitive('SparseTensorGetIndices')
|
||||||
|
sparse_tensor_get_dense_shape = Primitive('SparseTensorGetDenseShape')
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_get_indices(x, y, z):
|
||||||
|
return sparse_tensor_get_indices(make_sparse_tensor(x, y, z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after_get_indices(x, y, z):
|
||||||
|
return x
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_get_values(x, y, z):
|
||||||
|
return sparse_tensor_get_values(make_sparse_tensor(x, y, z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after_get_values(x, y, z):
|
||||||
|
return y
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before_get_dense_shape(x, y, z):
|
||||||
|
return sparse_tensor_get_dense_shape(make_sparse_tensor(x, y, z))
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after_get_dense_shape(x, y, z):
|
||||||
|
return z
|
||||||
|
|
||||||
|
return fns[tag]
|
||||||
|
|
|
@ -35,6 +35,9 @@ from mindspore._checkparam import Validator as validator
|
||||||
from mindspore._checkparam import Rel
|
from mindspore._checkparam import Rel
|
||||||
from mindspore.nn import Optimizer
|
from mindspore.nn import Optimizer
|
||||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||||
|
from mindspore.nn.optim import Momentum
|
||||||
|
from mindspore.train import Model
|
||||||
|
from ....dataset_mock import MindData
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||||
|
|
||||||
|
@ -47,6 +50,40 @@ size_op = P.Size()
|
||||||
invert_permutation = P.InvertPermutation()
|
invert_permutation = P.InvertPermutation()
|
||||||
logical_and = P.LogicalAnd()
|
logical_and = P.LogicalAnd()
|
||||||
|
|
||||||
|
def get_axis(x):
|
||||||
|
shape = shape_op(x)
|
||||||
|
length = F.tuple_len(shape)
|
||||||
|
perm = F.make_range(0, length)
|
||||||
|
return perm
|
||||||
|
|
||||||
|
class MSELoss(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MSELoss, self).__init__()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.square = P.Square()
|
||||||
|
self.reduce_mean = P.ReduceMean()
|
||||||
|
|
||||||
|
def construct(self, data, label):
|
||||||
|
diff = data - label
|
||||||
|
return self.reduce_mean(self.square(diff), get_axis(diff))
|
||||||
|
|
||||||
|
|
||||||
|
class MindDataSet(MindData):
|
||||||
|
def __init__(self, dataset_types, dataset_shapes):
|
||||||
|
super(MindDataSet, self).__init__(size=2, batch_size=32,
|
||||||
|
np_types=dataset_types,
|
||||||
|
output_shapes=dataset_shapes,
|
||||||
|
input_indexs=(0, 1))
|
||||||
|
def __next__(self):
|
||||||
|
if self._size < self._iter_num:
|
||||||
|
raise StopIteration
|
||||||
|
self._iter_num += 1
|
||||||
|
lst = []
|
||||||
|
for shape_, type_ in zip(self._output_shapes, self._np_types):
|
||||||
|
lst.append(Tensor(np.ones(shape_).astype(type_)))
|
||||||
|
return tuple(lst)
|
||||||
|
|
||||||
|
|
||||||
@constexpr
|
@constexpr
|
||||||
def _generate_shape_index(out_shape, indices_shape, axis):
|
def _generate_shape_index(out_shape, indices_shape, axis):
|
||||||
out_rank = len(out_shape)
|
out_rank = len(out_shape)
|
||||||
|
@ -189,8 +226,8 @@ def test_indexed_slices_make_indexed_slices():
|
||||||
def construct(self, indices, values):
|
def construct(self, indices, values):
|
||||||
ret = (IndexedSlices(indices, values, self.dense_shape),)
|
ret = (IndexedSlices(indices, values, self.dense_shape),)
|
||||||
return ret[0]
|
return ret[0]
|
||||||
indices = Tensor([[0, 0], [1, 2]])
|
indices = Tensor([1, 2])
|
||||||
values = Tensor([1, 2], dtype=ms.float32)
|
values = Tensor([[0, 0], [1, 2]], dtype=ms.float32)
|
||||||
MakeIndexedSlices()(indices, values)
|
MakeIndexedSlices()(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
@ -202,8 +239,8 @@ def test_indexed_slices_attr():
|
||||||
def construct(self, indices, values):
|
def construct(self, indices, values):
|
||||||
x = IndexedSlices(indices, values, self.dense_shape)
|
x = IndexedSlices(indices, values, self.dense_shape)
|
||||||
return x.values(), x.indices(), x.dense_shape()
|
return x.values(), x.indices(), x.dense_shape()
|
||||||
indices = Tensor([[0, 0], [1, 2]])
|
indices = Tensor([0])
|
||||||
values = Tensor([1, 2], dtype=ms.float32)
|
values = Tensor([[1, 2]], dtype=ms.float32)
|
||||||
IndexedSlicesGetAttr()(indices, values)
|
IndexedSlicesGetAttr()(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
@ -279,3 +316,29 @@ def test_indexed_slices_env_get():
|
||||||
net_with_loss = WithLossCell(net, loss)
|
net_with_loss = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
train_network(inputs, label)
|
train_network(inputs, label)
|
||||||
|
|
||||||
|
|
||||||
|
def test_indexed_slices_model_train():
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, in_features, out_features):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.flag = True
|
||||||
|
|
||||||
|
def construct(self, inputs, label):
|
||||||
|
x = self.add(inputs, self.weight)
|
||||||
|
if self.flag:
|
||||||
|
x = self.cast(x, mstype.float32)
|
||||||
|
return x
|
||||||
|
|
||||||
|
dataset_types = (np.float32, np.float32)
|
||||||
|
dataset_shapes = ((16, 16), (16, 16))
|
||||||
|
dataset = MindDataSet(dataset_types, dataset_shapes)
|
||||||
|
net = Net(16, 16)
|
||||||
|
net.set_train()
|
||||||
|
|
||||||
|
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
|
||||||
|
model = Model(net, optimizer=optimizer)
|
||||||
|
model.train(2, dataset, dataset_sink_mode=False)
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
@File : test_sparse_tensor.py
|
||||||
|
@Author:
|
||||||
|
@Date : 2020-07-16
|
||||||
|
@Desc : test mindspore sparse_tensor's operation
|
||||||
|
"""
|
||||||
|
import mindspore as ms
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
from mindspore import Tensor, SparseTensor, context
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||||
|
|
||||||
|
def test_sparse_tensor_make_sparse_tensor():
|
||||||
|
class MakeSparseTensor(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(MakeSparseTensor, self).__init__()
|
||||||
|
self.dense_shape = (3, 4)
|
||||||
|
def construct(self, indices, values):
|
||||||
|
ret = (SparseTensor(indices, values, self.dense_shape),)
|
||||||
|
return ret[0]
|
||||||
|
indices = Tensor([[0, 1], [1, 2]])
|
||||||
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
MakeSparseTensor()(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_tensor_attr():
|
||||||
|
grad_op = C.GradOperation('get_all', get_all=True)
|
||||||
|
class GradWrap(nn.Cell):
|
||||||
|
def __init__(self, network):
|
||||||
|
super(GradWrap, self).__init__()
|
||||||
|
self.network = network
|
||||||
|
def construct(self, input1, input2):
|
||||||
|
gout = grad_op(self.network)(input1, input2)
|
||||||
|
return gout
|
||||||
|
|
||||||
|
class SparseTensorGetAttr(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SparseTensorGetAttr, self).__init__()
|
||||||
|
self.dense_shape = (3, 4)
|
||||||
|
def construct(self, indices, values):
|
||||||
|
x = SparseTensor(indices, values, self.dense_shape)
|
||||||
|
return x.values(), x.indices(), x.dense_shape()
|
||||||
|
|
||||||
|
indices = Tensor([[0, 1], [1, 2]])
|
||||||
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
|
SparseTensorGetAttr()(indices, values)
|
Loading…
Reference in New Issue