diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 273a6f64581..580f243ebe0 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -198,6 +198,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap * │ └── MapPy * ├── Tail * ├── MakeTupleGradient + * ├── MakeListGradient * ├── GradOperation * └── TupleAdd */ @@ -241,6 +242,8 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_ // do nothing } else if (meta_func_graph->isa()) { // do nothing + } else if (meta_func_graph->isa()) { + // do nothing } else if (meta_func_graph->isa()) { // do nothing } else if (meta_func_graph->isa()) { diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 0586572dd1f..5fcbe258ba9 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -490,6 +490,47 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg return fg; } +FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + int list_size = SizeToInt(args_spec_list.size()); + + std::ostringstream ss; + ss << "▶make_list_" << list_size; + FuncGraphPtr fg = std::make_shared(); + fg->debug_info()->set_name(ss.str()); + + std::vector params; + params.push_back(NewValueNode(prim::kPrimMakeList)); + for (int i = 0; i < list_size; ++i) { + params.push_back(fg->add_parameter()); + } + + // make fprob first result, maketuple's forward result. + AnfNodePtr out = fg->NewCNode(params); + + // make fprob second result, maketuple's backward function. + FuncGraphPtr b = std::make_shared(); + + ss.clear(); + ss << "◀make_list_" << list_size; + b->debug_info()->set_name(ss.str()); + AnfNodePtr dout = b->add_parameter(); + + std::vector grads; + grads.push_back(NewValueNode(prim::kPrimMakeTuple)); + grads.push_back(NewValueNode(newenv)); + for (int i = 0; i < list_size; ++i) { + grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)})); + } + + b->set_flag(FUNC_GRAPH_FLAG_CORE, true); + b->set_output(b->NewCNode(grads)); + + fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); + fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); + (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList)); + return fg; +} + GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { if (get_by_list) { diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.h b/mindspore/ccsrc/frontend/operator/composite/composite.h index d0597e9befc..21a4588958f 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.h +++ b/mindspore/ccsrc/frontend/operator/composite/composite.h @@ -121,6 +121,16 @@ class MakeTupleGradient : public MetaFuncGraph { }; using MakeTupleGradientPtr = std::shared_ptr; +class MakeListGradient : public MetaFuncGraph { + public: + explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {} + ~MakeListGradient() override = default; + MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; } +}; +using MakeListGradientPtr = std::shared_ptr; + class GradOperation : public MetaFuncGraph { public: explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, diff --git a/mindspore/ccsrc/frontend/operator/prim_others.cc b/mindspore/ccsrc/frontend/operator/prim_others.cc index 25f41860f68..f33462b571d 100644 --- a/mindspore/ccsrc/frontend/operator/prim_others.cc +++ b/mindspore/ccsrc/frontend/operator/prim_others.cc @@ -463,6 +463,10 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi auto elem = GetValue(e); return elem; }); + if (IntToSize(indices_shp[1]) != dense_shape_vec.size()) { + MS_EXCEPTION(TypeError) << "The size of dense_shape must be equal with the second dimension of indices " + << indices_shp[1] << ", but got " << dense_shape_vec.size(); + } for (auto dense_shape_elem : dense_shape_vec) { if (dense_shape_elem < 0) { MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got " diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index aa76d279d53..105840874a5 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -88,6 +88,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { return meta; } + if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { + MetaFuncGraphPtr meta = std::make_shared("make_list_gradient"); + bprop_registry_meta_[prim::kPrimMakeList] = meta; + return meta; + } + MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; } @@ -103,6 +109,8 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R return fprop; } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) { return nullptr; + } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) { + return nullptr; } FuncGraphPtr bprop_fg = nullptr; diff --git a/mindspore/ccsrc/frontend/optimizer/clean.cc b/mindspore/ccsrc/frontend/optimizer/clean.cc index e35760ceaf3..2597ac67c20 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.cc +++ b/mindspore/ccsrc/frontend/optimizer/clean.cc @@ -59,6 +59,15 @@ static AbstractBasePtr Reabs(const AbstractBasePtr &t) { [](const AbstractAttribute &item) { return item.second; }); return std::make_shared(baselist); } + + return nullptr; +} + +static AbstractBasePtr AdaptAbs(const AbstractBasePtr &t) { + if (t == nullptr) { + return nullptr; + } + if (t->isa()) { auto abs_list = dyn_cast(t); return std::make_shared(abs_list->elements()); @@ -358,7 +367,41 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr new_node = EraseMakeKeywordArgNode(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimExtractKeywordArg)) { new_node = EraseExtractKeywordArg(cnode); - } else if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { + } + + if (new_node != nullptr) { + new_node->set_abstract(node->abstract()); + MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); + (void)manager->Replace(node, new_node); + changed = true; + } + } + + for (auto &node : manager->all_nodes()) { + auto ret = Reabs(node->abstract()); + if (ret) { + MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " + << ret->ToString(); + node->set_abstract(ret); + changed = true; + } + } + return changed; +} + +bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(root); + + bool changed = false; + + // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var + AnfNodeSet all_node = manager->all_nodes(); + for (auto &node : all_node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + AnfNodePtr new_node = nullptr; + if (IsPrimitiveCNode(node, prim::kPrimMakeList)) { new_node = ConvertMakeListToMakeTuple(cnode); } else if (IsPrimitiveCNode(node, prim::kPrimListGetItem)) { new_node = ConvertListGetItemToTupleGetItem(cnode); @@ -377,7 +420,7 @@ bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr } for (auto &node : manager->all_nodes()) { - auto ret = Reabs(node->abstract()); + auto ret = AdaptAbs(node->abstract()); if (ret) { MS_LOG(DEBUG) << "Replace " << node->DebugString() << "'s abstract " << node->abstract()->ToString() << " with " << ret->ToString(); diff --git a/mindspore/ccsrc/frontend/optimizer/clean.h b/mindspore/ccsrc/frontend/optimizer/clean.h index 32736a89515..3f44511d817 100644 --- a/mindspore/ccsrc/frontend/optimizer/clean.h +++ b/mindspore/ccsrc/frontend/optimizer/clean.h @@ -32,6 +32,7 @@ namespace opt { // Remove the class type from graphs bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); +bool CleanList(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); // Remove most uses of tuples from the graph // tuples that are returned will be kept diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index f3a03658a2e..27108380d90 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -69,6 +69,24 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { return true; } +bool CleanListPass(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res->func_graph()); + + FuncGraphPtr func_graph = res->func_graph(); + bool changed = opt::CleanList(func_graph, res->manager()); + + abstract::AbstractBasePtrList args_spec; + auto parameters = func_graph->parameters(); + (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); + if (changed) { + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + } + res->set_args_spec(args_spec); + return true; +} + namespace { OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ @@ -100,6 +118,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Safe inlining irpass.inline_, + irpass.sparse_tensor_eliminate_, }); opt::OptPassConfig a_2 = opt::OptPassConfig({ irpass.merge_addn_, @@ -157,7 +176,6 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.make_ref_eliminate_, irpass.get_ref_param_eliminate_, irpass.indexed_slices_eliminate_, - irpass.sparse_tensor_eliminate_, }); OptPassGroupMap map({ {"b_1", b_1}, @@ -322,19 +340,23 @@ bool InferenceOptPreparePass(const ResourcePtr &res) { return true; } -std::vector kVmPasses = {{"opt_a", OptPassAGroup}, - {"simplify_data_structures", SimplifyDataStructuresPass}, +std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_a", OptPassAGroup}, + {"clean_list", CleanListPass}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}, {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_control_depend", AddControlDependPass}}; -std::vector kGePasses = { - {"opt_a", OptPassAGroup}, {"simplify_data_structures", SimplifyDataStructuresPass}, - {"opt_b", OptPassBGroup}, {"add_control_depend", AddControlDependPass}, - {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, - {"cconv", CconvPass}}; +std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, + {"opt_a", OptPassAGroup}, + {"clean_list", CleanListPass}, + {"opt_b", OptPassBGroup}, + {"add_control_depend", AddControlDependPass}, + {"opt_control", ControlGroup}, + {"opt_prepare", PrepareGroup}, + {"cconv", CconvPass}}; std::vector kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}}; } // namespace pipeline diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 3e820eed3a6..eb50974e548 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -22,6 +22,7 @@ #include "ir/func_graph_cloner.h" #include "abstract/utils.h" #include "debug/trace.h" +#include "utils/context/ms_context.h" namespace mindspore { namespace abstract { @@ -373,9 +374,16 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y) AbstractBasePtrList bparams; bparams.push_back(SensitivityTransform(orig_func_)); - (void)std::transform( - args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), - [](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { return SensitivityTransform(arg_spec); }); + auto context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context); + bool enable_sparse = context->enable_sparse(); + (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(bparams), + [&enable_sparse](const AbstractBasePtr &arg_spec) -> AbstractBasePtr { + if (enable_sparse && arg_spec->isa()) { + return std::make_shared(); + } + return SensitivityTransform(arg_spec); + }); AbstractBasePtr bparams_final = std::make_shared(bparams); AbstractFunctionPtr bprop = std::make_shared(SensitivityTransform(result->abstract()), bparams_final); diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index 87566b11104..5dfe6e60068 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -116,6 +116,11 @@ def bprop_tuple_getitem(data, idx, out, dout): """Backpropagator for primitive `tuple_getitem`.""" return F.tuple_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) +@bprops.register("list_getitem") +def bprop_list_getitem(data, idx, out, dout): + """Backpropagator for primitive `list_getitem`.""" + return F.list_setitem(C.zeros_like(data), idx, dout), C.zeros_like(idx) + @bprops.register("identity") def bprop_identity(x, out, dout): diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index bae9034c620..7dd09685cc2 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -17,6 +17,7 @@ from functools import reduce import numpy as np +import mindspore as ms from mindspore.ops import _selected_grad_ops as SG from .. import functional as F from .. import operations as P @@ -33,6 +34,7 @@ shape_op = P.Shape() reduce_sum = P.ReduceSum() reshape = P.Reshape() tile = P.Tile() +is_sub_class = P.IsSubClass() def binop_grad_common(x, y, dx, dy): @@ -990,6 +992,12 @@ def get_bprop_scalar_addn(self): """Generate bprop for AddN""" def bprop(x, out, dout): + if is_sub_class(F.typeof(x), ms.list_): + dx = [] + for _ in range(len(x)): + dx.append(dout) + return (dx,) + dx = () for _ in range(len(x)): dx = dx + (dout,) diff --git a/tests/st/ops/ascend/test_addn.py b/tests/st/ops/ascend/test_addn.py index 6d0d5b5be00..fa97fcc973b 100644 --- a/tests/st/ops/ascend/test_addn.py +++ b/tests/st/ops/ascend/test_addn.py @@ -16,6 +16,7 @@ import numpy as np import mindspore.context as context import mindspore.nn as nn +import mindspore.ops.composite as C from mindspore import Tensor from mindspore.ops import operations as P @@ -45,3 +46,17 @@ def test_net(): add = Net() output = add(x, y) assert output == expect + + +def test_grad_addn_with_list(): + grad_op = C.GradOperation('get_all', get_all=True) + class AddN(nn.Cell): + def __init__(self): + super().__init__() + self.add_n = P.AddN() + + def construct(self, a, b): + return self.add_n([a, b]) + + inp = Tensor(np.ones([128, 96]).astype(np.float32)) + grad_op(AddN())(inp, inp) diff --git a/tests/ut/python/ir/test_indexed_slices.py b/tests/ut/python/ir/test_indexed_slices.py index ff0cfa1da5f..a9ed2fd95c6 100644 --- a/tests/ut/python/ir/test_indexed_slices.py +++ b/tests/ut/python/ir/test_indexed_slices.py @@ -252,7 +252,7 @@ def test_indexed_slices_sparse_gatherv2_grad_all(): self.network = network def construct(self, x, y): grad = grad_all(self.network)(x, y) - return grad, grad[0], grad[1] + return grad[0].indices(), grad[0].values(), grad[0].dense_shape() class SparseGatherV2(nn.Cell): def __init__(self): super(SparseGatherV2, self).__init__() @@ -276,7 +276,7 @@ def test_indexed_slices_sparse_gatherv2_grad_with_pram(): weights = self.weights grad = grad_by_list(self.network, weights)(x) x = grad[0] - return x, x.values(), x.indices(), x.dense_shape() + return x.values(), x.indices(), x.dense_shape() class SparseGatherV2(nn.Cell): def __init__(self): super(SparseGatherV2, self).__init__() diff --git a/tests/ut/python/ir/test_sparse_tensor.py b/tests/ut/python/ir/test_sparse_tensor.py index 3f8ca8b184c..a3096268ca6 100644 --- a/tests/ut/python/ir/test_sparse_tensor.py +++ b/tests/ut/python/ir/test_sparse_tensor.py @@ -18,6 +18,9 @@ @Date : 2020-07-16 @Desc : test mindspore sparse_tensor's operation """ +import numpy as np +import pytest + import mindspore as ms import mindspore.nn as nn from mindspore.ops import composite as C @@ -25,17 +28,20 @@ from mindspore import Tensor, SparseTensor, context context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + +class MakeSparseTensor(nn.Cell): + def __init__(self, dense_shape): + super(MakeSparseTensor, self).__init__() + self.dense_shape = dense_shape + def construct(self, indices, values): + ret = (SparseTensor(indices, values, self.dense_shape),) + return ret[0] + + def test_sparse_tensor_make_sparse_tensor(): - class MakeSparseTensor(nn.Cell): - def __init__(self): - super(MakeSparseTensor, self).__init__() - self.dense_shape = (3, 4) - def construct(self, indices, values): - ret = (SparseTensor(indices, values, self.dense_shape),) - return ret[0] indices = Tensor([[0, 1], [1, 2]]) values = Tensor([1, 2], dtype=ms.float32) - MakeSparseTensor()(indices, values) + MakeSparseTensor((3, 4))(indices, values) def test_sparse_tensor_attr(): @@ -59,3 +65,20 @@ def test_sparse_tensor_attr(): indices = Tensor([[0, 1], [1, 2]]) values = Tensor([1, 2], dtype=ms.float32) SparseTensorGetAttr()(indices, values) + grad_op(SparseTensorGetAttr())(indices, values) + + +def test_sparse_tensor_indices_dim_greater_than_dense_shape_dim(): + indices = Tensor(np.array([[0, 0, 0], [0, 0, 1]], dtype=np.int32)) + values = Tensor(np.array([100, 200], dtype=np.float32)) + dense_shape = (2, 2) + with pytest.raises(TypeError): + MakeSparseTensor(dense_shape)(indices, values) + + +def test_sparse_tensor_indices_dim_less_than_dense_shape_dim(): + indices = Tensor(np.array([[0, 0], [0, 1]], dtype=np.int32)) + values = Tensor(np.array([100, 200], dtype=np.float32)) + dense_shape = (2, 2, 2) + with pytest.raises(TypeError): + MakeSparseTensor(dense_shape)(indices, values)