forked from mindspore-Ecosystem/mindspore
Add bprop for sparse_tensor
This commit is contained in:
parent
abcee8e586
commit
4d4e23fd9e
|
@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
|||
* │ └── MapPy
|
||||
* ├── Tail
|
||||
* ├── MakeTupleGradient
|
||||
* ├── MakeListGradient
|
||||
* ├── GradOperation
|
||||
* └── TupleAdd
|
||||
*/
|
||||
|
@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_
|
|||
// do nothing
|
||||
} else if (meta_func_graph->isa<prim::MakeTupleGradient>()) {
|
||||
// do nothing
|
||||
} else if (meta_func_graph->isa<prim::MakeListGradient>()) {
|
||||
// do nothing
|
||||
} else if (meta_func_graph->isa<prim::TupleAdd>()) {
|
||||
// do nothing
|
||||
} else if (meta_func_graph->isa<prim::TupleSlice>()) {
|
||||
|
|
|
@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
|
|||
return fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
int list_size = SizeToInt(args_spec_list.size());
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << "▶make_list_" << list_size;
|
||||
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
||||
fg->debug_info()->set_name(ss.str());
|
||||
|
||||
std::vector<AnfNodePtr> params;
|
||||
params.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
for (int i = 0; i < list_size; ++i) {
|
||||
params.push_back(fg->add_parameter());
|
||||
}
|
||||
|
||||
// make fprob first result, maketuple's forward result.
|
||||
AnfNodePtr out = fg->NewCNode(params);
|
||||
|
||||
// make fprob second result, maketuple's backward function.
|
||||
FuncGraphPtr b = std::make_shared<FuncGraph>();
|
||||
|
||||
ss.clear();
|
||||
ss << "◀make_list_" << list_size;
|
||||
b->debug_info()->set_name(ss.str());
|
||||
AnfNodePtr dout = b->add_parameter();
|
||||
|
||||
std::vector<AnfNodePtr> grads;
|
||||
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
grads.push_back(NewValueNode(newenv));
|
||||
for (int i = 0; i < list_size; ++i) {
|
||||
grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
|
||||
}
|
||||
|
||||
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
b->set_output(b->NewCNode(grads));
|
||||
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
|
||||
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
|
||||
return fg;
|
||||
}
|
||||
|
||||
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
|
||||
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
|
||||
if (get_by_list) {
|
||||
|
|
|
@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph {
|
|||
};
|
||||
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
|
||||
|
||||
class MakeListGradient : public MetaFuncGraph {
|
||||
public:
|
||||
explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~MakeListGradient() override = default;
|
||||
MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
|
||||
|
||||
class GradOperation : public MetaFuncGraph {
|
||||
public:
|
||||
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
|
||||
|
|
|
@ -463,6 +463,10 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
|
|||
auto elem = GetValue<int>(e);
|
||||
return elem;
|
||||
});
|
||||
if (IntToSize(indices_shp[1]) != dense_shape_vec.size()) {
|
||||
MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices "
|
||||
<< indices_shp[1] << ", but got " << dense_shape_vec.size();
|
||||
}
|
||||
for (auto dense_shape_elem : dense_shape_vec) {
|
||||
if (dense_shape_elem < 0) {
|
||||
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
|
||||
|
|
|
@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
|
|||
return meta;
|
||||
}
|
||||
|
||||
if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
|
||||
MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
|
||||
bprop_registry_meta_[prim::kPrimMakeList] = meta;
|
||||
return meta;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
|
||||
}
|
||||
|
||||
|
@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|||
return fprop;
|
||||
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||
return nullptr;
|
||||
} else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
FuncGraphPtr bprop_fg = nullptr;
|
||||
|
|
|
@ -59,6 +59,15 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
|
|||
[](const AbstractAttribute &item) { return item.second; });
|
||||
return std::make_shared<AbstractTuple>(baselist);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
|
||||
if (t == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (t->isa<AbstractList>()) {
|
||||
auto abs_list = dyn_cast<AbstractList>(t);
|
||||
return std::make_shared<AbstractTuple>(abs_list->elements());
|
||||
|
@ -358,7 +367,41 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
new_node = EraseMakeKeywordArgNode(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
|
||||
new_node = EraseExtractKeywordArg(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
|
||||
}
|
||||
|
||||
if (new_node != nullptr) {
|
||||
new_node->set_abstract(node->abstract());
|
||||
MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
|
||||
(void)manager->Replace(node, new_node);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &node : manager->all_nodes()) {
|
||||
auto ret = Reabs(node->abstract());
|
||||
if (ret) {
|
||||
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
|
||||
<< ret->ToString();
|
||||
node->set_abstract(ret);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(root);
|
||||
|
||||
bool changed = false;
|
||||
|
||||
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||
AnfNodeSet all_node = manager->all_nodes();
|
||||
for (auto &node : all_node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtr new_node = nullptr;
|
||||
if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
|
||||
new_node = ConvertMakeListToMakeTuple(cnode);
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
|
||||
new_node = ConvertListGetItemToTupleGetItem(cnode);
|
||||
|
@ -377,7 +420,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
|||
}
|
||||
|
||||
for (auto &node : manager->all_nodes()) {
|
||||
auto ret = Reabs(node->abstract());
|
||||
auto ret = AdaptAbs(node->abstract());
|
||||
if (ret) {
|
||||
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
|
||||
<< ret->ToString();
|
||||
|
|
|
@ -32,6 +32,7 @@ namespace opt {
|
|||
|
||||
// Remove the class type from graphs
|
||||
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||
|
||||
// Remove most uses of tuples from the graph
|
||||
// tuples that are returned will be kept
|
||||
|
|
|
@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool CleanListPass(const ResourcePtr &res) {
|
||||
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||
|
||||
FuncGraphPtr func_graph = res->func_graph();
|
||||
bool changed = opt::CleanList(func_graph, res->manager());
|
||||
|
||||
abstract::AbstractBasePtrList args_spec;
|
||||
auto parameters = func_graph->parameters();
|
||||
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||
if (changed) {
|
||||
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||
res->set_func_graph(new_fg);
|
||||
}
|
||||
res->set_args_spec(args_spec);
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
||||
|
@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
|
||||
// Safe inlining
|
||||
irpass.inline_,
|
||||
irpass.sparse_tensor_eliminate_,
|
||||
});
|
||||
opt::OptPassConfig a_2 = opt::OptPassConfig({
|
||||
irpass.merge_addn_,
|
||||
|
@ -157,7 +176,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.make_ref_eliminate_,
|
||||
irpass.get_ref_param_eliminate_,
|
||||
irpass.indexed_slices_eliminate_,
|
||||
irpass.sparse_tensor_eliminate_,
|
||||
});
|
||||
OptPassGroupMap map({
|
||||
{"b_1", b_1},
|
||||
|
@ -322,19 +340,23 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup},
|
||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_list", CleanListPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"cconv", CconvPass},
|
||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||
{"add_control_depend", AddControlDependPass}};
|
||||
|
||||
std::vector<PassItem> kGePasses = {
|
||||
{"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass},
|
||||
{"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup},
|
||||
{"cconv", CconvPass}};
|
||||
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||
{"opt_a", OptPassAGroup},
|
||||
{"clean_list", CleanListPass},
|
||||
{"opt_b", OptPassBGroup},
|
||||
{"add_control_depend", AddControlDependPass},
|
||||
{"opt_control", ControlGroup},
|
||||
{"opt_prepare", PrepareGroup},
|
||||
{"cconv", CconvPass}};
|
||||
|
||||
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
|
||||
} // namespace pipeline
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ir/func_graph_cloner.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "debug/trace.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace abstract {
|
||||
|
@ -373,9 +374,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
|||
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
||||
AbstractBasePtrList bparams;
|
||||
bparams.push_back(SensitivityTransform(orig_func_));
|
||||
(void)std::transform(
|
||||
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
||||
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->enable_sparse();
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
||||
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
|
||||
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
|
||||
return std::make_shared<AbstractUndetermined>();
|
||||
}
|
||||
return SensitivityTransform(arg_spec);
|
||||
});
|
||||
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
|
||||
AbstractFunctionPtr bprop =
|
||||
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
|
||||
|
|
|
@ -116,6 +116,11 @@ def bprop_tuple_getitem(data, idx, out, dout):
|
|||
"""Backpropagator for primitive `tuple_getitem`."""
|
||||
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
||||
|
||||
@bprops.register("list_getitem")
|
||||
def bprop_list_getitem(data, idx, out, dout):
|
||||
"""Backpropagator for primitive `list_getitem`."""
|
||||
return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
||||
|
||||
|
||||
@bprops.register("identity")
|
||||
def bprop_identity(x, out, dout):
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
from functools import reduce
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
from mindspore.ops import _selected_grad_ops as SG
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
|
@ -33,6 +34,7 @@ shape_op = P.Shape()
|
|||
reduce_sum = P.ReduceSum()
|
||||
reshape = P.Reshape()
|
||||
tile = P.Tile()
|
||||
is_sub_class = P.IsSubClass()
|
||||
|
||||
|
||||
def binop_grad_common(x, y, dx, dy):
|
||||
|
@ -990,6 +992,12 @@ def get_bprop_scalar_addn(self):
|
|||
"""Generate bprop for AddN"""
|
||||
|
||||
def bprop(x, out, dout):
|
||||
if is_sub_class(F.typeof(x), ms.list_):
|
||||
dx = []
|
||||
for _ in range(len(x)):
|
||||
dx.append(dout)
|
||||
return (dx,)
|
||||
|
||||
dx = ()
|
||||
for _ in range(len(x)):
|
||||
dx = dx + (dout,)
|
||||
|
|
|
@ -16,6 +16,7 @@ import numpy as np
|
|||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.composite as C
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
@ -45,3 +46,17 @@ def test_net():
|
|||
add = Net()
|
||||
output = add(x, y)
|
||||
assert output == expect
|
||||
|
||||
|
||||
def test_grad_addn_with_list():
|
||||
grad_op = C.GradOperation('get_all', get_all=True)
|
||||
class AddN(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_n = P.AddN()
|
||||
|
||||
def construct(self, a, b):
|
||||
return self.add_n([a, b])
|
||||
|
||||
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||
grad_op(AddN())(inp, inp)
|
||||
|
|
|
@ -252,7 +252,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
|
|||
self.network = network
|
||||
def construct(self, x, y):
|
||||
grad = grad_all(self.network)(x, y)
|
||||
return grad, grad[0], grad[1]
|
||||
return grad[0].indices(), grad[0].values(), grad[0].dense_shape()
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
|
@ -276,7 +276,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
|||
weights = self.weights
|
||||
grad = grad_by_list(self.network, weights)(x)
|
||||
x = grad[0]
|
||||
return x, x.values(), x.indices(), x.dense_shape()
|
||||
return x.values(), x.indices(), x.dense_shape()
|
||||
class SparseGatherV2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SparseGatherV2, self).__init__()
|
||||
|
|
|
@ -18,6 +18,9 @@
|
|||
@Date : 2020-07-16
|
||||
@Desc : test mindspore sparse_tensor's operation
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import composite as C
|
||||
|
@ -25,17 +28,20 @@ from mindspore import Tensor, SparseTensor, context
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||
|
||||
|
||||
class MakeSparseTensor(nn.Cell):
|
||||
def __init__(self, dense_shape):
|
||||
super(MakeSparseTensor, self).__init__()
|
||||
self.dense_shape = dense_shape
|
||||
def construct(self, indices, values):
|
||||
ret = (SparseTensor(indices, values, self.dense_shape),)
|
||||
return ret[0]
|
||||
|
||||
|
||||
def test_sparse_tensor_make_sparse_tensor():
|
||||
class MakeSparseTensor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MakeSparseTensor, self).__init__()
|
||||
self.dense_shape = (3, 4)
|
||||
def construct(self, indices, values):
|
||||
ret = (SparseTensor(indices, values, self.dense_shape),)
|
||||
return ret[0]
|
||||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
MakeSparseTensor()(indices, values)
|
||||
MakeSparseTensor((3, 4))(indices, values)
|
||||
|
||||
|
||||
def test_sparse_tensor_attr():
|
||||
|
@ -59,3 +65,20 @@ def test_sparse_tensor_attr():
|
|||
indices = Tensor([[0, 1], [1, 2]])
|
||||
values = Tensor([1, 2], dtype=ms.float32)
|
||||
SparseTensorGetAttr()(indices, values)
|
||||
grad_op(SparseTensorGetAttr())(indices, values)
|
||||
|
||||
|
||||
def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim():
|
||||
indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32))
|
||||
values = Tensor(np.array([100, 200], dtype=np.float32))
|
||||
dense_shape = (2, 2)
|
||||
with pytest.raises(TypeError):
|
||||
MakeSparseTensor(dense_shape)(indices, values)
|
||||
|
||||
|
||||
def test_sparse_tensor_indices_dim_less_than_dense_shape_dim():
|
||||
indices = Tensor(np.array([[0, 0], [0, 1]], dtype=np.int32))
|
||||
values = Tensor(np.array([100, 200], dtype=np.float32))
|
||||
dense_shape = (2, 2, 2)
|
||||
with pytest.raises(TypeError):
|
||||
MakeSparseTensor(dense_shape)(indices, values)
|
||||
|
|
Loading…
Reference in New Issue