!9926 optimize list setitem in bprop
From: @zhangbuxue Reviewed-by: @zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
dd134d7554
|
@ -25,7 +25,7 @@
|
|||
#include "frontend/optimizer/irpass/inline.h"
|
||||
#include "frontend/optimizer/irpass/incorporate_call.h"
|
||||
#include "frontend/optimizer/irpass/incorporate_getitem.h"
|
||||
#include "frontend/optimizer/irpass/item_tuple_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/mark_interface_fusion.h"
|
||||
#include "frontend/optimizer/irpass/merge_addn.h"
|
||||
#include "frontend/optimizer/irpass/accumulaten_eliminate.h"
|
||||
|
@ -67,8 +67,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
|
||||
|
||||
// ops eliminate
|
||||
item_tuple_eliminate_ = MakeSubstitution(std::make_shared<ItemTupleEliminater>(), "item_tuple_eliminate",
|
||||
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem});
|
||||
item_tuple_or_list_eliminate_ = MakeSubstitution(
|
||||
std::make_shared<ItemTupleOrListEliminater>(), "item_tuple_or_list_eliminate",
|
||||
{prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
|
||||
tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
|
||||
cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
|
||||
reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
|
||||
|
|
|
@ -39,7 +39,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr adjust_all_reduce_mul_add_;
|
||||
|
||||
// ops eliminate
|
||||
SubstitutionPtr item_tuple_eliminate_;
|
||||
SubstitutionPtr item_tuple_or_list_eliminate_;
|
||||
SubstitutionPtr tile_eliminate_;
|
||||
SubstitutionPtr cast_eliminate_;
|
||||
SubstitutionPtr reshape_eliminate_;
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
@ -33,6 +33,7 @@ namespace irpass {
|
|||
// (a, b, c, ...)[0] => a
|
||||
// (a, b, c, ...)[1] => b
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimMakeTuple, Xs}, C}
|
||||
// {prim::kPrimListGetItem, {prim::kPrimMakeList, Xs}, C}
|
||||
class GetitemEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -54,7 +55,7 @@ class GetitemEliminater : public AnfVisitor {
|
|||
|
||||
void Visit(const ValueNodePtr &vnode) override {
|
||||
if (tuple_ != nullptr && IsValueNode<Int64Imm>(vnode)) {
|
||||
int64_t idx = GetValue<int64_t>(vnode->value());
|
||||
auto idx = GetValue<int64_t>(vnode->value());
|
||||
if (idx < 0) {
|
||||
idx = idx + tuple_->size() - 1;
|
||||
}
|
||||
|
@ -80,6 +81,7 @@ class GetitemEliminater : public AnfVisitor {
|
|||
// (a, b, c, ...)[0] => a
|
||||
// (a, b, c, ...)[1] => b
|
||||
// {prim::kPrimTupleGetItem, C1, C}
|
||||
// {prim::kPrimListGetItem, C1, C}
|
||||
class GetitemConstEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -124,11 +126,13 @@ class GetitemConstEliminater : public AnfVisitor {
|
|||
// setitem((a, b, c, ...), 0, z) => (z, b, c, ...)
|
||||
// setitem((a, b, c, ...), 1, z) => (a, z, c, ...)
|
||||
// {prim::kPrimTupleSetItem, {prim::kPrimMakeTuple, Xs}, C, Z}
|
||||
// {prim::kPrimListSetItem, {prim::kPrimMakeList, Xs}, C, Z}
|
||||
class SetitemEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimTupleSetItem, {IsCNode, IsVNode, IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimListSetItem, {IsCNode, IsVNode, IsNode})(node);
|
||||
|
||||
auto fg = node->func_graph();
|
||||
if (fg != nullptr && z_ != nullptr) {
|
||||
|
@ -178,11 +182,13 @@ class SetitemEliminater : public AnfVisitor {
|
|||
};
|
||||
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimTupleSetItem, Y, C1, X}, C2}
|
||||
// {prim::kPrimListGetItem, {prim::kPrimListSetItem, Y, C1, X}, C2}
|
||||
class GetSetitemEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
Reset();
|
||||
AnfVisitor::Match(prim::kPrimTupleGetItem, {IsCNode, IsVNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimListGetItem, {IsCNode, IsVNode})(node);
|
||||
|
||||
auto fg = node->func_graph();
|
||||
if (fg != nullptr && key1_ >= 0 && key2_ >= 0) {
|
||||
|
@ -195,7 +201,7 @@ class GetSetitemEliminater : public AnfVisitor {
|
|||
}
|
||||
|
||||
void Visit(const CNodePtr &cnode) override {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem)) {
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimTupleSetItem) || IsPrimitiveCNode(cnode, prim::kPrimListSetItem)) {
|
||||
if (cnode->size() < 4) {
|
||||
return;
|
||||
}
|
||||
|
@ -239,6 +245,8 @@ class GetSetitemEliminater : public AnfVisitor {
|
|||
|
||||
// {prim::kPrimTupleGetItem, {prim::kPrimDepend, X, Y}, C} ->
|
||||
// {prim::kPrimDepend, {prim::kPrimTupleGetItem, X, C}, Y}
|
||||
// {prim::kPrimListGetItem, {prim::kPrimDepend, X, Y}, C} ->
|
||||
// {prim::kPrimDepend, {prim::kPrimListGetItem, X, C}, Y}
|
||||
class GetitemDependReorder : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -274,9 +282,9 @@ class GetitemDependReorder : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr}, y_{nullptr}, c_{nullptr};
|
||||
};
|
||||
|
||||
class ItemTupleEliminater : public OptimizerCaller {
|
||||
class ItemTupleOrListEliminater : public OptimizerCaller {
|
||||
public:
|
||||
ItemTupleEliminater()
|
||||
ItemTupleOrListEliminater()
|
||||
: get_item_eliminater_(std::make_shared<GetitemEliminater>()),
|
||||
get_item_const_eliminater_(std::make_shared<GetitemConstEliminater>()),
|
||||
set_item_eliminater_(std::make_shared<SetitemEliminater>()),
|
||||
|
@ -288,7 +296,7 @@ class ItemTupleEliminater : public OptimizerCaller {
|
|||
eliminaters_.emplace_back(get_set_item_eliminater_);
|
||||
eliminaters_.emplace_back(get_item_depend_reorder_);
|
||||
}
|
||||
~ItemTupleEliminater() = default;
|
||||
~ItemTupleOrListEliminater() = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||
AnfNodePtr new_node;
|
||||
|
@ -309,4 +317,4 @@ class ItemTupleEliminater : public OptimizerCaller {
|
|||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_ELIMINATE_H_
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_ITEM_TUPLE_OR_LIST_ELIMINATE_H_
|
|
@ -100,7 +100,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.specialize_transform_,
|
||||
|
||||
// Miscellaneous
|
||||
irpass.item_tuple_eliminate_,
|
||||
irpass.item_tuple_or_list_eliminate_,
|
||||
irpass.env_get_item_eliminate_,
|
||||
irpass.cast_eliminate_,
|
||||
irpass.reshape_eliminate_,
|
||||
|
@ -188,8 +188,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp
|
|||
}
|
||||
|
||||
OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining
|
||||
irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_});
|
||||
opt::OptPassConfig d_1 =
|
||||
opt::OptPassConfig({// Safe inlining
|
||||
irpass.call_graph_tuple_transform_, irpass.item_tuple_or_list_eliminate_});
|
||||
|
||||
OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
|
||||
|
||||
|
@ -198,7 +199,7 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib
|
|||
|
||||
OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig b_1 = opt::OptPassConfig(
|
||||
{irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, irpass.float_tuple_getitem_switch_,
|
||||
{irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_,
|
||||
irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_,
|
||||
irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_,
|
||||
irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_});
|
||||
|
|
|
@ -232,7 +232,7 @@ def ms_function(fn=None, obj=None, input_signature=None):
|
|||
equal to the case when `fn` is not None.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.ops import functional as F
|
||||
>>> from mindspore.ops import functional as F
|
||||
...
|
||||
>>> def tensor_add(x, y):
|
||||
... z = x + y
|
||||
|
|
|
@ -360,7 +360,7 @@ TEST_F(TestOptLib, test_tuple_getitem) {
|
|||
FuncGraphPtr after_2 = std::make_shared<FuncGraph>();
|
||||
after_2->set_output(value_node_2);
|
||||
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
|
||||
ASSERT_TRUE(CheckOpt(make_get_0, after_0, patterns));
|
||||
ASSERT_TRUE(CheckOpt(make_get_1, after_1, patterns));
|
||||
ASSERT_TRUE(CheckOpt(make_get_const, after_2, patterns));
|
||||
|
@ -372,7 +372,7 @@ TEST_F(TestOptLib, test_tuple_setitem) {
|
|||
FuncGraphPtr after_0 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_0");
|
||||
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_setitem", "after_1");
|
||||
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
|
||||
|
||||
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
|
||||
|
@ -384,7 +384,7 @@ TEST_F(TestOptLib, test_tuple_get_set_item) {
|
|||
FuncGraphPtr before_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "before_0");
|
||||
FuncGraphPtr after_1 = getPyFun.CallAndParseRet("test_tuple_get_set_item", "after_0");
|
||||
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_eliminate_});
|
||||
auto patterns = std::vector<SubstitutionPtr>({irpass.item_tuple_or_list_eliminate_});
|
||||
|
||||
ASSERT_TRUE(CheckOpt(before_0, after_0, patterns));
|
||||
ASSERT_TRUE(CheckOpt(before_1, after_1, patterns));
|
||||
|
|
|
@ -13,9 +13,14 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test enumerate"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -168,3 +173,60 @@ def test_list_index_3D_parameter():
|
|||
|
||||
net = Net()
|
||||
net(Tensor(0))
|
||||
|
||||
|
||||
def test_const_list_index_3D_bprop():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, input_x):
|
||||
list_x = self.value
|
||||
list_x[2][0][1] = input_x
|
||||
return self.relu(list_x[2][0][1])
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, sens):
|
||||
return self.grad_all_with_sens(self.net)(x, sens)
|
||||
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
x = Tensor(np.arange(2 * 3).reshape(2, 3))
|
||||
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
|
||||
grad_net(x, sens)
|
||||
|
||||
|
||||
def test_parameter_list_index_3D_bprop():
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.value = [[1], [2, 2], [[3, 3], [3, 3]]]
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x, value):
|
||||
list_value = [[x], [x, x], [[x, x], [x, x]]]
|
||||
list_value[2][0][1] = value
|
||||
return self.relu(list_value[2][0][1])
|
||||
|
||||
class GradNet(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(GradNet, self).__init__()
|
||||
self.net = net
|
||||
self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, value, sens):
|
||||
return self.grad_all_with_sens(self.net)(x, value, sens)
|
||||
|
||||
net = Net()
|
||||
grad_net = GradNet(net)
|
||||
x = Tensor(np.arange(2 * 3).reshape(2, 3))
|
||||
value = Tensor(np.ones((2, 3), np.int64))
|
||||
sens = Tensor(np.arange(2 * 3).reshape(2, 3))
|
||||
grad_net(x, value, sens)
|
||||
|
|
Loading…
Reference in New Issue