forked from mindspore-Ecosystem/mindspore
Add bprop for sparse_tensor
This commit is contained in:
parent
abcee8e586
commit
4d4e23fd9e
|
@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
||||||
* │ └── MapPy
|
* │ └── MapPy
|
||||||
* ├── Tail
|
* ├── Tail
|
||||||
* ├── MakeTupleGradient
|
* ├── MakeTupleGradient
|
||||||
|
* ├── MakeListGradient
|
||||||
* ├── GradOperation
|
* ├── GradOperation
|
||||||
* └── TupleAdd
|
* └── TupleAdd
|
||||||
*/
|
*/
|
||||||
|
@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_
|
||||||
// do nothing
|
// do nothing
|
||||||
} else if (meta_func_graph->isa<prim::MakeTupleGradient>()) {
|
} else if (meta_func_graph->isa<prim::MakeTupleGradient>()) {
|
||||||
// do nothing
|
// do nothing
|
||||||
|
} else if (meta_func_graph->isa<prim::MakeListGradient>()) {
|
||||||
|
// do nothing
|
||||||
} else if (meta_func_graph->isa<prim::TupleAdd>()) {
|
} else if (meta_func_graph->isa<prim::TupleAdd>()) {
|
||||||
// do nothing
|
// do nothing
|
||||||
} else if (meta_func_graph->isa<prim::TupleSlice>()) {
|
} else if (meta_func_graph->isa<prim::TupleSlice>()) {
|
||||||
|
|
|
@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
|
||||||
return fg;
|
return fg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
|
int list_size = SizeToInt(args_spec_list.size());
|
||||||
|
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << "▶make_list_" << list_size;
|
||||||
|
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
||||||
|
fg->debug_info()->set_name(ss.str());
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> params;
|
||||||
|
params.push_back(NewValueNode(prim::kPrimMakeList));
|
||||||
|
for (int i = 0; i < list_size; ++i) {
|
||||||
|
params.push_back(fg->add_parameter());
|
||||||
|
}
|
||||||
|
|
||||||
|
// make fprob first result, maketuple's forward result.
|
||||||
|
AnfNodePtr out = fg->NewCNode(params);
|
||||||
|
|
||||||
|
// make fprob second result, maketuple's backward function.
|
||||||
|
FuncGraphPtr b = std::make_shared<FuncGraph>();
|
||||||
|
|
||||||
|
ss.clear();
|
||||||
|
ss << "◀make_list_" << list_size;
|
||||||
|
b->debug_info()->set_name(ss.str());
|
||||||
|
AnfNodePtr dout = b->add_parameter();
|
||||||
|
|
||||||
|
std::vector<AnfNodePtr> grads;
|
||||||
|
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
grads.push_back(NewValueNode(newenv));
|
||||||
|
for (int i = 0; i < list_size; ++i) {
|
||||||
|
grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
|
||||||
|
}
|
||||||
|
|
||||||
|
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
b->set_output(b->NewCNode(grads));
|
||||||
|
|
||||||
|
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
|
||||||
|
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
|
||||||
|
return fg;
|
||||||
|
}
|
||||||
|
|
||||||
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
|
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param)
|
||||||
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
|
: MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) {
|
||||||
if (get_by_list) {
|
if (get_by_list) {
|
||||||
|
|
|
@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph {
|
||||||
};
|
};
|
||||||
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
|
using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
|
||||||
|
|
||||||
|
class MakeListGradient : public MetaFuncGraph {
|
||||||
|
public:
|
||||||
|
explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
|
~MakeListGradient() override = default;
|
||||||
|
MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
|
||||||
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
|
friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
|
};
|
||||||
|
using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
|
||||||
|
|
||||||
class GradOperation : public MetaFuncGraph {
|
class GradOperation : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
|
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
|
||||||
|
|
|
@ -463,6 +463,10 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi
|
||||||
auto elem = GetValue<int>(e);
|
auto elem = GetValue<int>(e);
|
||||||
return elem;
|
return elem;
|
||||||
});
|
});
|
||||||
|
if (IntToSize(indices_shp[1]) != dense_shape_vec.size()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices "
|
||||||
|
<< indices_shp[1] << ", but got " << dense_shape_vec.size();
|
||||||
|
}
|
||||||
for (auto dense_shape_elem : dense_shape_vec) {
|
for (auto dense_shape_elem : dense_shape_vec) {
|
||||||
if (dense_shape_elem < 0) {
|
if (dense_shape_elem < 0) {
|
||||||
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
|
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
|
||||||
|
|
|
@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
|
||||||
return meta;
|
return meta;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
|
||||||
|
MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
|
||||||
|
bprop_registry_meta_[prim::kPrimMakeList] = meta;
|
||||||
|
return meta;
|
||||||
|
}
|
||||||
|
|
||||||
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
|
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
||||||
return fprop;
|
return fprop;
|
||||||
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
} else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr bprop_fg = nullptr;
|
FuncGraphPtr bprop_fg = nullptr;
|
||||||
|
|
|
@ -59,6 +59,15 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
|
||||||
[](const AbstractAttribute &item) { return item.second; });
|
[](const AbstractAttribute &item) { return item.second; });
|
||||||
return std::make_shared<AbstractTuple>(baselist);
|
return std::make_shared<AbstractTuple>(baselist);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) {
|
||||||
|
if (t == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
if (t->isa<AbstractList>()) {
|
if (t->isa<AbstractList>()) {
|
||||||
auto abs_list = dyn_cast<AbstractList>(t);
|
auto abs_list = dyn_cast<AbstractList>(t);
|
||||||
return std::make_shared<AbstractTuple>(abs_list->elements());
|
return std::make_shared<AbstractTuple>(abs_list->elements());
|
||||||
|
@ -358,7 +367,41 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
||||||
new_node = EraseMakeKeywordArgNode(cnode);
|
new_node = EraseMakeKeywordArgNode(cnode);
|
||||||
} else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
|
} else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) {
|
||||||
new_node = EraseExtractKeywordArg(cnode);
|
new_node = EraseExtractKeywordArg(cnode);
|
||||||
} else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
|
}
|
||||||
|
|
||||||
|
if (new_node != nullptr) {
|
||||||
|
new_node->set_abstract(node->abstract());
|
||||||
|
MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString();
|
||||||
|
(void)manager->Replace(node, new_node);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &node : manager->all_nodes()) {
|
||||||
|
auto ret = Reabs(node->abstract());
|
||||||
|
if (ret) {
|
||||||
|
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
|
||||||
|
<< ret->ToString();
|
||||||
|
node->set_abstract(ret);
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) {
|
||||||
|
MS_EXCEPTION_IF_NULL(manager);
|
||||||
|
manager->AddFuncGraph(root);
|
||||||
|
|
||||||
|
bool changed = false;
|
||||||
|
|
||||||
|
// Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var
|
||||||
|
AnfNodeSet all_node = manager->all_nodes();
|
||||||
|
for (auto &node : all_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
AnfNodePtr new_node = nullptr;
|
||||||
|
if (IsPrimitiveCNode(node, prim::kPrimMakeList)) {
|
||||||
new_node = ConvertMakeListToMakeTuple(cnode);
|
new_node = ConvertMakeListToMakeTuple(cnode);
|
||||||
} else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
|
} else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) {
|
||||||
new_node = ConvertListGetItemToTupleGetItem(cnode);
|
new_node = ConvertListGetItemToTupleGetItem(cnode);
|
||||||
|
@ -377,7 +420,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &node : manager->all_nodes()) {
|
for (auto &node : manager->all_nodes()) {
|
||||||
auto ret = Reabs(node->abstract());
|
auto ret = AdaptAbs(node->abstract());
|
||||||
if (ret) {
|
if (ret) {
|
||||||
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
|
MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with "
|
||||||
<< ret->ToString();
|
<< ret->ToString();
|
||||||
|
|
|
@ -32,6 +32,7 @@ namespace opt {
|
||||||
|
|
||||||
// Remove the class type from graphs
|
// Remove the class type from graphs
|
||||||
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||||
|
bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager);
|
||||||
|
|
||||||
// Remove most uses of tuples from the graph
|
// Remove most uses of tuples from the graph
|
||||||
// tuples that are returned will be kept
|
// tuples that are returned will be kept
|
||||||
|
|
|
@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool CleanListPass(const ResourcePtr &res) {
|
||||||
|
MS_EXCEPTION_IF_NULL(res->func_graph());
|
||||||
|
|
||||||
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
|
bool changed = opt::CleanList(func_graph, res->manager());
|
||||||
|
|
||||||
|
abstract::AbstractBasePtrList args_spec;
|
||||||
|
auto parameters = func_graph->parameters();
|
||||||
|
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
|
||||||
|
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
|
||||||
|
if (changed) {
|
||||||
|
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
|
||||||
|
res->set_func_graph(new_fg);
|
||||||
|
}
|
||||||
|
res->set_args_spec(args_spec);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
||||||
|
@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
|
|
||||||
// Safe inlining
|
// Safe inlining
|
||||||
irpass.inline_,
|
irpass.inline_,
|
||||||
|
irpass.sparse_tensor_eliminate_,
|
||||||
});
|
});
|
||||||
opt::OptPassConfig a_2 = opt::OptPassConfig({
|
opt::OptPassConfig a_2 = opt::OptPassConfig({
|
||||||
irpass.merge_addn_,
|
irpass.merge_addn_,
|
||||||
|
@ -157,7 +176,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
irpass.make_ref_eliminate_,
|
irpass.make_ref_eliminate_,
|
||||||
irpass.get_ref_param_eliminate_,
|
irpass.get_ref_param_eliminate_,
|
||||||
irpass.indexed_slices_eliminate_,
|
irpass.indexed_slices_eliminate_,
|
||||||
irpass.sparse_tensor_eliminate_,
|
|
||||||
});
|
});
|
||||||
OptPassGroupMap map({
|
OptPassGroupMap map({
|
||||||
{"b_1", b_1},
|
{"b_1", b_1},
|
||||||
|
@ -322,19 +340,23 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<PassItem> kVmPasses = {{"opt_a", OptPassAGroup},
|
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
{"simplify_data_structures", SimplifyDataStructuresPass},
|
{"opt_a", OptPassAGroup},
|
||||||
|
{"clean_list", CleanListPass},
|
||||||
{"opt_b", OptPassBGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
{"cconv", CconvPass},
|
{"cconv", CconvPass},
|
||||||
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
|
||||||
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
|
||||||
{"add_control_depend", AddControlDependPass}};
|
{"add_control_depend", AddControlDependPass}};
|
||||||
|
|
||||||
std::vector<PassItem> kGePasses = {
|
std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
{"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass},
|
{"opt_a", OptPassAGroup},
|
||||||
{"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass},
|
{"clean_list", CleanListPass},
|
||||||
{"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup},
|
{"opt_b", OptPassBGroup},
|
||||||
{"cconv", CconvPass}};
|
{"add_control_depend", AddControlDependPass},
|
||||||
|
{"opt_control", ControlGroup},
|
||||||
|
{"opt_prepare", PrepareGroup},
|
||||||
|
{"cconv", CconvPass}};
|
||||||
|
|
||||||
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
|
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
|
||||||
} // namespace pipeline
|
} // namespace pipeline
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "ir/func_graph_cloner.h"
|
#include "ir/func_graph_cloner.h"
|
||||||
#include "abstract/utils.h"
|
#include "abstract/utils.h"
|
||||||
#include "debug/trace.h"
|
#include "debug/trace.h"
|
||||||
|
#include "utils/context/ms_context.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace abstract {
|
namespace abstract {
|
||||||
|
@ -373,9 +374,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
|
||||||
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
// parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
|
||||||
AbstractBasePtrList bparams;
|
AbstractBasePtrList bparams;
|
||||||
bparams.push_back(SensitivityTransform(orig_func_));
|
bparams.push_back(SensitivityTransform(orig_func_));
|
||||||
(void)std::transform(
|
auto context = MsContext::GetInstance();
|
||||||
args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
MS_EXCEPTION_IF_NULL(context);
|
||||||
[](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); });
|
bool enable_sparse = context->enable_sparse();
|
||||||
|
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams),
|
||||||
|
[&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr {
|
||||||
|
if (enable_sparse && arg_spec->isa<AbstractTensor>()) {
|
||||||
|
return std::make_shared<AbstractUndetermined>();
|
||||||
|
}
|
||||||
|
return SensitivityTransform(arg_spec);
|
||||||
|
});
|
||||||
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
|
AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
|
||||||
AbstractFunctionPtr bprop =
|
AbstractFunctionPtr bprop =
|
||||||
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
|
std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
|
||||||
|
|
|
@ -116,6 +116,11 @@ def bprop_tuple_getitem(data, idx, out, dout):
|
||||||
"""Backpropagator for primitive `tuple_getitem`."""
|
"""Backpropagator for primitive `tuple_getitem`."""
|
||||||
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
||||||
|
|
||||||
|
@bprops.register("list_getitem")
|
||||||
|
def bprop_list_getitem(data, idx, out, dout):
|
||||||
|
"""Backpropagator for primitive `list_getitem`."""
|
||||||
|
return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx)
|
||||||
|
|
||||||
|
|
||||||
@bprops.register("identity")
|
@bprops.register("identity")
|
||||||
def bprop_identity(x, out, dout):
|
def bprop_identity(x, out, dout):
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import mindspore as ms
|
||||||
from mindspore.ops import _selected_grad_ops as SG
|
from mindspore.ops import _selected_grad_ops as SG
|
||||||
from .. import functional as F
|
from .. import functional as F
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
|
@ -33,6 +34,7 @@ shape_op = P.Shape()
|
||||||
reduce_sum = P.ReduceSum()
|
reduce_sum = P.ReduceSum()
|
||||||
reshape = P.Reshape()
|
reshape = P.Reshape()
|
||||||
tile = P.Tile()
|
tile = P.Tile()
|
||||||
|
is_sub_class = P.IsSubClass()
|
||||||
|
|
||||||
|
|
||||||
def binop_grad_common(x, y, dx, dy):
|
def binop_grad_common(x, y, dx, dy):
|
||||||
|
@ -990,6 +992,12 @@ def get_bprop_scalar_addn(self):
|
||||||
"""Generate bprop for AddN"""
|
"""Generate bprop for AddN"""
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
def bprop(x, out, dout):
|
||||||
|
if is_sub_class(F.typeof(x), ms.list_):
|
||||||
|
dx = []
|
||||||
|
for _ in range(len(x)):
|
||||||
|
dx.append(dout)
|
||||||
|
return (dx,)
|
||||||
|
|
||||||
dx = ()
|
dx = ()
|
||||||
for _ in range(len(x)):
|
for _ in range(len(x)):
|
||||||
dx = dx + (dout,)
|
dx = dx + (dout,)
|
||||||
|
|
|
@ -16,6 +16,7 @@ import numpy as np
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
|
import mindspore.ops.composite as C
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.ops import operations as P
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
@ -45,3 +46,17 @@ def test_net():
|
||||||
add = Net()
|
add = Net()
|
||||||
output = add(x, y)
|
output = add(x, y)
|
||||||
assert output == expect
|
assert output == expect
|
||||||
|
|
||||||
|
|
||||||
|
def test_grad_addn_with_list():
|
||||||
|
grad_op = C.GradOperation('get_all', get_all=True)
|
||||||
|
class AddN(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.add_n = P.AddN()
|
||||||
|
|
||||||
|
def construct(self, a, b):
|
||||||
|
return self.add_n([a, b])
|
||||||
|
|
||||||
|
inp = Tensor(np.ones([128, 96]).astype(np.float32))
|
||||||
|
grad_op(AddN())(inp, inp)
|
||||||
|
|
|
@ -252,7 +252,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all():
|
||||||
self.network = network
|
self.network = network
|
||||||
def construct(self, x, y):
|
def construct(self, x, y):
|
||||||
grad = grad_all(self.network)(x, y)
|
grad = grad_all(self.network)(x, y)
|
||||||
return grad, grad[0], grad[1]
|
return grad[0].indices(), grad[0].values(), grad[0].dense_shape()
|
||||||
class SparseGatherV2(nn.Cell):
|
class SparseGatherV2(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SparseGatherV2, self).__init__()
|
super(SparseGatherV2, self).__init__()
|
||||||
|
@ -276,7 +276,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram():
|
||||||
weights = self.weights
|
weights = self.weights
|
||||||
grad = grad_by_list(self.network, weights)(x)
|
grad = grad_by_list(self.network, weights)(x)
|
||||||
x = grad[0]
|
x = grad[0]
|
||||||
return x, x.values(), x.indices(), x.dense_shape()
|
return x.values(), x.indices(), x.dense_shape()
|
||||||
class SparseGatherV2(nn.Cell):
|
class SparseGatherV2(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(SparseGatherV2, self).__init__()
|
super(SparseGatherV2, self).__init__()
|
||||||
|
|
|
@ -18,6 +18,9 @@
|
||||||
@Date : 2020-07-16
|
@Date : 2020-07-16
|
||||||
@Desc : test mindspore sparse_tensor's operation
|
@Desc : test mindspore sparse_tensor's operation
|
||||||
"""
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore as ms
|
import mindspore as ms
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import composite as C
|
from mindspore.ops import composite as C
|
||||||
|
@ -25,17 +28,20 @@ from mindspore import Tensor, SparseTensor, context
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
context.set_context(mode=context.GRAPH_MODE, enable_sparse=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MakeSparseTensor(nn.Cell):
|
||||||
|
def __init__(self, dense_shape):
|
||||||
|
super(MakeSparseTensor, self).__init__()
|
||||||
|
self.dense_shape = dense_shape
|
||||||
|
def construct(self, indices, values):
|
||||||
|
ret = (SparseTensor(indices, values, self.dense_shape),)
|
||||||
|
return ret[0]
|
||||||
|
|
||||||
|
|
||||||
def test_sparse_tensor_make_sparse_tensor():
|
def test_sparse_tensor_make_sparse_tensor():
|
||||||
class MakeSparseTensor(nn.Cell):
|
|
||||||
def __init__(self):
|
|
||||||
super(MakeSparseTensor, self).__init__()
|
|
||||||
self.dense_shape = (3, 4)
|
|
||||||
def construct(self, indices, values):
|
|
||||||
ret = (SparseTensor(indices, values, self.dense_shape),)
|
|
||||||
return ret[0]
|
|
||||||
indices = Tensor([[0, 1], [1, 2]])
|
indices = Tensor([[0, 1], [1, 2]])
|
||||||
values = Tensor([1, 2], dtype=ms.float32)
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
MakeSparseTensor()(indices, values)
|
MakeSparseTensor((3, 4))(indices, values)
|
||||||
|
|
||||||
|
|
||||||
def test_sparse_tensor_attr():
|
def test_sparse_tensor_attr():
|
||||||
|
@ -59,3 +65,20 @@ def test_sparse_tensor_attr():
|
||||||
indices = Tensor([[0, 1], [1, 2]])
|
indices = Tensor([[0, 1], [1, 2]])
|
||||||
values = Tensor([1, 2], dtype=ms.float32)
|
values = Tensor([1, 2], dtype=ms.float32)
|
||||||
SparseTensorGetAttr()(indices, values)
|
SparseTensorGetAttr()(indices, values)
|
||||||
|
grad_op(SparseTensorGetAttr())(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim():
|
||||||
|
indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32))
|
||||||
|
values = Tensor(np.array([100, 200], dtype=np.float32))
|
||||||
|
dense_shape = (2, 2)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
MakeSparseTensor(dense_shape)(indices, values)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparse_tensor_indices_dim_less_than_dense_shape_dim():
|
||||||
|
indices = Tensor(np.array([[0, 0], [0, 1]], dtype=np.int32))
|
||||||
|
values = Tensor(np.array([100, 200], dtype=np.float32))
|
||||||
|
dense_shape = (2, 2, 2)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
MakeSparseTensor(dense_shape)(indices, values)
|
||||||
|
|
Loading…
Reference in New Issue