!46449 bprop meta graph

Merge pull request !46449 from lianliguang/bprop-mindir
This commit is contained in:
i-robot 2023-01-06 06:20:56 +00:00 committed by Gitee
commit 2daad1f347
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 160 additions and 196 deletions

View File

@ -63,5 +63,33 @@ FuncGraphPtr TransposeBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(Transpose, prim::kPrimTranspose, TransposeBprop, 2);
FuncGraphPtr CastBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
constexpr size_t expected_arg_size = 4;
auto fg = NewGraph(input_abs);
// x, out, dout
auto &parameters = fg->parameters();
CheckArgSize(parameters, input_abs, primal, expected_arg_size);
auto x = parameters[kIndex0];
auto t = parameters[kIndex1];
auto dout = parameters[kIndex3];
const auto zeros_like_node = ZerosLikeFunction(fg, t);
const auto cast = Cast(fg);
const auto dtype = DType();
AnfNodePtr return_node;
auto get_dtype = NewNode(fg, {dtype, x});
if (input_abs[kIndex3]->isa<abstract::AbstractRowTensor>()) {
auto row_tensor_values = NewNode(fg, {RowTensorGetValues(), dout});
auto value = NewNode(fg, {cast, row_tensor_values, get_dtype});
auto indices = NewNode(fg, {RowTensorGetIndices(), dout});
auto dense_shape = NewNode(fg, {RowTensorGetDenseShape(), dout});
return_node = NewNode(fg, {MakeRowTensor(), indices, value, dense_shape});
} else {
return_node = NewNode(fg, {cast, dout, get_dtype});
}
fg->set_output(NewNode(fg, {MakeTuple(), return_node, zeros_like_node}));
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(Cast, prim::kPrimCast, CastBprop, 2);
} // namespace graph_bprop
} // namespace mindspore

View File

@ -70,5 +70,54 @@ FuncGraphPtr SubBprop(const PrimitivePtr &primal, const AbstractBasePtrList &inp
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(Sub, prim::kPrimSub, SubBprop, 2);
FuncGraphPtr AddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs);
constexpr size_t expected_arg_size = 4;
const auto &parameters = fg->parameters();
CheckArgSize(parameters, input_abs, primal, expected_arg_size);
fg->set_output(
BinopGradCommon(fg, parameters[kIndex0], parameters[kIndex1], parameters[kIndex3], parameters[kIndex3]));
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(Add, prim::kPrimAdd, AddBprop, 2);
FuncGraphPtr AssignAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs);
constexpr size_t expected_arg_size = 4;
const auto &parameters = fg->parameters();
CheckArgSize(parameters, input_abs, primal, expected_arg_size);
auto x = parameters[kIndex0];
auto y = parameters[kIndex1];
auto out1 = ZerosLikeFunction(fg, x);
auto out2 = ZerosLikeFunction(fg, y);
fg->set_output(NewNode(fg, {MakeTuple(), out1, out2}));
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(AssignAdd, prim::kPrimAssignAdd, AssignAddBprop, 2);
FuncGraphPtr NegBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto neg_grad = Neg();
auto fg = NewGraph(input_abs);
constexpr size_t expected_arg_size = 3;
const auto &parameters = fg->parameters();
CheckArgSize(parameters, input_abs, primal, expected_arg_size);
auto dx = NewNode(fg, {neg_grad, parameters[kIndex2]});
fg->set_output(NewNode(fg, {MakeTuple(), dx}));
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(Neg, prim::kPrimNeg, NegBprop, 1);
FuncGraphPtr LogicalOrBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs);
constexpr size_t expected_arg_size = 4;
const auto &parameters = fg->parameters();
CheckArgSize(parameters, input_abs, primal, expected_arg_size);
auto dx = ZerosLikeFunction(fg, parameters[kIndex0]);
auto dy = ZerosLikeFunction(fg, parameters[kIndex1]);
fg->set_output(NewNode(fg, {MakeTuple(), dx, dy}));
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(LogicalOr, prim::kPrimLogicalOr, LogicalOrBprop, 2);
} // namespace graph_bprop
} // namespace mindspore

View File

@ -141,5 +141,34 @@ FuncGraphPtr BatchNormBprop(const PrimitivePtr &primal, const AbstractBasePtrLis
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(BatchNorm, prim::kPrimBatchNorm, BatchNormBprop, 5);
FuncGraphPtr BiasAddBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs);
auto format = GetAttr<std::string>(primal, "format");
// x, out, dout
constexpr size_t expected_arg_size = 4;
const auto &parameters = fg->parameters();
CheckArgSize(parameters, input_abs, primal, expected_arg_size);
auto dout = parameters[kIndex3];
auto bais_add_grad = NewNode(fg, {BiasAddGrad(format), dout});
fg->set_output(NewNode(fg, {MakeTuple(), dout, bais_add_grad}));
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(BiasAdd, prim::kPrimBiasAdd, BiasAddBprop, 2);
FuncGraphPtr GeLUBprop(const PrimitivePtr &primal, const AbstractBasePtrList &input_abs) {
auto fg = NewGraph(input_abs);
// x, out, dout
constexpr size_t expected_arg_size = 3;
const auto &parameters = fg->parameters();
CheckArgSize(parameters, input_abs, primal, expected_arg_size);
auto x = parameters[kIndex0];
auto out = parameters[kIndex1];
auto dout = parameters[kIndex2];
auto dx = NewNode(fg, {GeLUGrad(), dout, x, out});
fg->set_output(NewNode(fg, {MakeTuple(), dx}));
return fg;
}
REGISTER_PRIMITIVE_BPROP_IMPL(GeLU, prim::kPrimGeLU, GeLUBprop, 1);
} // namespace graph_bprop
} // namespace mindspore

View File

@ -157,11 +157,20 @@ AnfNodePtr Add() { return NewValueNode(prim::GetPythonOps("add", "mindspore.ops.
AnfNodePtr Mod() { return NewValueNode(prim::GetPythonOps("mod", "mindspore.ops.composite.multitype_ops.mod_impl")); }
AnfNodePtr Mul(const FuncGraphPtr &fg) {
return fg->NewCNodeInOrder({GetClassType("mindspore.ops.operations.math_ops", "Mul")});
}
AnfNodePtr ZerosLikeFunction(const FuncGraphPtr &fg, const AnfNodePtr &input) {
return fg->NewCNodeInOrder(
{NewValueNode(prim::GetPythonOps("zeros_like", "mindspore.ops.composite.multitype_ops.zeros_like_impl")), input});
}
AnfNodePtr BiasAddGrad(const string &format) {
auto prim = NewPrimitive(prim::kPrimBiasAddGrad, {{"format", MakeValue(format)}});
return NewValueNode(prim);
}
AnfNodePtr MatMul(const FuncGraphPtr &fg, bool transpose_a, bool transpose_b) {
return fg->NewCNodeInOrder({GetClassType("mindspore.ops.operations.math_ops", "MatMul"), NewValueNode(transpose_a),
NewValueNode(transpose_b)});
@ -171,12 +180,22 @@ AnfNodePtr Conj() { return NewValueNode(prim::kPrimConj); }
AnfNodePtr ReluGrad() { return NewValueNode(prim::kPrimReluGrad); }
AnfNodePtr GeLUGrad() { return NewValueNode(prim::kPrimGeLUGrad); }
AnfNodePtr MakeTuple() { return NewValueNode(prim::kPrimMakeTuple); }
AnfNodePtr TensorShape() { return NewValueNode(prim::kPrimTensorShape); }
AnfNodePtr Shape() { return NewValueNode(prim::kPrimShape); }
AnfNodePtr RowTensorGetValues() { return NewValueNode(prim::kPrimRowTensorGetValues); }
AnfNodePtr RowTensorGetIndices() { return NewValueNode(prim::kPrimRowTensorGetIndices); }
AnfNodePtr RowTensorGetDenseShape() { return NewValueNode(prim::kPrimRowTensorGetDenseShape); }
AnfNodePtr MakeRowTensor() { return NewValueNode(prim::kPrimMakeRowTensor); }
AnfNodePtr Cast(const FuncGraphPtr &fg) {
return fg->NewCNodeInOrder({GetClassType("mindspore.ops.operations.array_ops", "Cast")});
}
@ -484,5 +503,7 @@ ValuePtr GetPadModStr(const ValuePtr &value, bool upper) {
(void)std::transform(str.begin(), str.end(), str.begin(), toupper);
return MakeValue(str);
}
AnfNodePtr DType() { return NewValueNode(prim::kPrimDType); }
} // namespace graph_bprop
} // namespace mindspore

View File

@ -26,13 +26,19 @@ namespace mindspore {
namespace graph_bprop {
// Ops.
AnfNodePtr Add();
AnfNodePtr Mul(const FuncGraphPtr &fg);
AnfNodePtr Mod();
AnfNodePtr MatMul(const FuncGraphPtr &fg, bool transpose_a = false, bool transpose_b = false);
AnfNodePtr Conj();
AnfNodePtr ReluGrad();
AnfNodePtr GeLUGrad();
AnfNodePtr MakeTuple();
AnfNodePtr TensorShape();
AnfNodePtr Shape();
AnfNodePtr RowTensorGetValues();
AnfNodePtr RowTensorGetIndices();
AnfNodePtr RowTensorGetDenseShape();
AnfNodePtr MakeRowTensor();
AnfNodePtr Cast(const FuncGraphPtr &fg);
AnfNodePtr ReduceProd(const FuncGraphPtr &fg);
AnfNodePtr ExpandDims(const FuncGraphPtr &fg);
@ -50,6 +56,8 @@ AnfNodePtr Reshape(const FuncGraphPtr &fg);
AnfNodePtr DynamicBroadcastGradientArgs();
AnfNodePtr MaxPoolGrad(const FuncGraphPtr &fg, const PrimitivePtr &primal);
AnfNodePtr BatchNormGrad(const FuncGraphPtr &fg, const PrimitivePtr &primal);
AnfNodePtr DType();
AnfNodePtr BiasAddGrad(const string &format);
// Common methods.
AnfNodePtr ZerosLikeFunction(const FuncGraphPtr &fg, const AnfNodePtr &input);

View File

@ -18,7 +18,6 @@
import numpy as np
import mindspore as ms
from mindspore import Tensor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations.array_ops import Fills, NonZero
@ -97,47 +96,6 @@ def get_bprop_dtype(self):
return bprop
dout_cast = C.MultitypeFuncGraph("dout_cast")
@dout_cast.register("Tensor", "Tensor")
def dout_cast_tensor(dout, x):
"""Casts dout to the dtype of x for Tensor."""
cast = P.Cast()
get_dtype = P.DType()
dx = cast(dout, get_dtype(x))
return dx
@dout_cast.register("Number", "Number")
def dout_cast_number(dout, x):
"""Casts dout to the dtype of x for Number."""
cast = P.Cast()
get_dtype = P.DType()
dx = cast(dout, get_dtype(x))
return dx
@dout_cast.register("RowTensor", "Tensor")
def dout_cast_row_tensor(dout, x):
"""Casts dout values to the dtype of x for RowTensor."""
cast = P.Cast()
get_dtype = P.DType()
values = cast(dout.values, get_dtype(x))
return RowTensorInner(dout.indices, values, dout.dense_shape)
@bprop_getters.register(P.Cast)
def get_bprop_cast(self):
"""Generate bprop for Cast"""
def bprop(x, t, out, dout):
dx = dout_cast(dout, x)
return dx, zeros_like(t)
return bprop
@bprop_getters.register(P.Shape)
def get_bprop_shape(self):
"""Generate bprop for Shape"""

View File

@ -333,16 +333,6 @@ def bprop_batchmatmul(self):
return bprop
@bprop_getters.register(P.Add)
def get_bprop_add(self):
"""Grad definition for `Add` operation."""
def bprop(x, y, out, dout):
return binop_grad_common(x, y, dout, dout)
return bprop
@bprop_getters.register(P.TensorAdd)
def get_bprop_tensor_add(self):
"""Grad definition for `Add` operation."""
@ -369,18 +359,6 @@ def get_bprop_matrix_inverse(self):
return bprop
@bprop_getters.register(P.Neg)
def get_bprop_neg(self):
"""Grad definition for `Neg` operation."""
neg_grad = P.Neg()
def bprop(x, out, dout):
dx = neg_grad(dout)
return (dx,)
return bprop
@bprop_getters.register(P.Mul)
def get_bprop_mul(self):
"""Grad definition for `Mul` operation."""
@ -1208,16 +1186,6 @@ def get_bprop_logical_and(self):
return bprop
@bprop_getters.register(P.LogicalOr)
def get_bprop_logical_or(self):
"""Grad definition for `LogicalOr` operation."""
def bprop(x, y, out, dout):
return zeros_like(x), zeros_like(y)
return bprop
@bprop_getters.register(P.NPUAllocFloatStatus)
def get_bprop_npu_alloc_float_status(self):
"""Grad definition for `NPUAllocFloatStatus` operation."""

View File

@ -30,17 +30,6 @@ from mindspore.ops.operations import _rl_inner_ops as rl_ops
from mindspore.ops._utils.utils import range_op, get_1d_shape
@bprop_getters.register(P.BiasAdd)
def get_bprop_bias_add(self):
"""Grad definition for `BiasAdd` operation."""
bias_grad = G.BiasAddGrad(self.data_format)
def bprop(x, w, out, dout):
return dout, bias_grad(dout)
return bprop
@constexpr
def bias_add_gradgrad_helper(shape, bias_shape, data_format):
"""Helper function of BiasGradGrad to calculate expanded shape."""
@ -796,19 +785,6 @@ def get_bprop_tanh_grad(self):
return bprop
@bprop_getters.register(P.Gelu)
@bprop_getters.register(P.GeLU)
def get_bprop_gelu(self):
"""Grad definition for `GeLU` operation."""
input_grad = G.GeLUGrad()
def bprop(x, out, dout):
dx = input_grad(dout, x, out)
return (dx,)
return bprop
@bprop_getters.register(P.FastGeLU)
def get_bprop_fast_gelu(self):
"""Grad definition for `FastGeLU` operation."""

View File

@ -1,19 +0,0 @@
0.1.1 MindSpore*2.0.0:½
}'get_bprop_assign_add.1231:[CNode]1232:1'get_bprop_assign_add.1231:[CNode]1232:1"REF::bprop.1233:Default/bprop.1233-op958get_bprop_assign_add.1231*
get_bprop_assign_add.1231:self*
get_bprop_assign_add.1231:x*
get_bprop_assign_add.1231:y*
get_bprop_assign_add.1231:out*
get_bprop_assign_add.1231:dout2)
'get_bprop_assign_add.1231:[CNode]1232:1:@864154f4834e62d84d34aab9399558528d5e734f6725d5daf7fbc1907cb32a1aJ/grad_math_ops.pyB¶
²
get_bprop_assign_add.1231:xbprop.1233:[CNode]1234:2bprop.1233:[CNode]1234:2".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op959
²
get_bprop_assign_add.1231:ybprop.1233:[CNode]1235:3bprop.1233:[CNode]1235:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op960
¡
bprop.1233:[CNode]1234:2
bprop.1233:[CNode]1235:3bprop.1233:[CNode]1236:4bprop.1233:[CNode]1236:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op961
bprop.12332
bprop.1233:[CNode]1236:4Pb&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -1,24 +0,0 @@
0.1.1 MindSpore*2.0.0:Ó
g get_bprop_bias_add.3:[CNode]12:1 get_bprop_bias_add.3:[CNode]12:1" REF::bprop.5:Default/bprop.5-op0
A [ValueNode]13 [ValueNode]13"Constant*
value* data_format€

get_bprop_bias_add.3:self
[ValueNode]13get_bprop_bias_add.3:[CNode]1:2get_bprop_bias_add.3:[CNode]1:2"REF::getattr:3:Default/getattr-op1
ù
get_bprop_bias_add.3:[CNode]1:2 get_bprop_bias_add.3:bias_grad:4 get_bprop_bias_add.3:bias_grad:4">REF::ClassType::mindspore.ops.operations._grad_ops.BiasAddGrad:RDefault/MindIRClassType:class 'mindspore.ops.operations._grad_ops.BiasAddGrad'-op2get_bprop_bias_add.3*
get_bprop_bias_add.3:self*
get_bprop_bias_add.3:x*
get_bprop_bias_add.3:w*
get_bprop_bias_add.3:out*
get_bprop_bias_add.3:dout2"
get_bprop_bias_add.3:[CNode]12:1:@bb60fd260d3ef4ddcc66def38153515987d8927fb4701af4a701ca11958fb768J/grad_nn_ops.pyBŸ
m
get_bprop_bias_add.3:doutbprop.5:[CNode]2:5bprop.5:[CNode]2:5"%REF::get_bprop_bias_add.3:bias_grad:4:3
Ž
get_bprop_bias_add.3:dout
bprop.5:[CNode]2:5bprop.5:[CNode]4:6bprop.5:[CNode]4:6"REF::S-Prim-MakeTuple:7:Default/S-Prim-MakeTuple-op4bprop.52
bprop.5:[CNode]4:6Pb
getattr:3getattrb&
S-Prim-MakeTuple:7S-Prim-MakeTupleh

View File

@ -1,19 +0,0 @@
0.1.1 MindSpore*2.0.0:Ü
]get_bprop_cast.6:[CNode]7:1get_bprop_cast.6:[CNode]7:1" REF::bprop.8:Default/bprop.8-op3get_bprop_cast.6*
get_bprop_cast.6:self*
get_bprop_cast.6:x*
get_bprop_cast.6:t*
get_bprop_cast.6:out*
get_bprop_cast.6:dout2
get_bprop_cast.6:[CNode]7:1:@2a049f3579950913c6ea42bb677f44470016652aa549a6dee2350ea48d50f039J/grad_array_ops.pyBË

get_bprop_cast.6:dout
get_bprop_cast.6:x bprop.8:dx:2 bprop.8:dx:2"REF::MetaFuncGraph::dout_cast:Default/S-Prim-dout_cast-op4

get_bprop_cast.6:tbprop.8:[CNode]9:3bprop.8:[CNode]9:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:-Default/S-Prim-hyper_map[zeros_like_leaf]-op5
ƒ
bprop.8:dx:2
bprop.8:[CNode]9:3bprop.8:[CNode]10:4bprop.8:[CNode]10:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op6bprop.82
bprop.8:[CNode]10:4Pb&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -1,17 +0,0 @@
0.1.1 MindSpore*2.0.0:
]get_bprop_gelu.1:[CNode]2:1get_bprop_gelu.1:[CNode]2:1" REF::bprop.3:Default/bprop.3-op0
Ìget_bprop_gelu.1:input_grad:2get_bprop_gelu.1:input_grad:2";REF::ClassType::mindspore.ops.operations._grad_ops.GeLUGrad:ODefault/MindIRClassType:class 'mindspore.ops.operations._grad_ops.GeLUGrad'-op1get_bprop_gelu.1*
get_bprop_gelu.1:self*
get_bprop_gelu.1:x*
get_bprop_gelu.1:out*
get_bprop_gelu.1:dout2
get_bprop_gelu.1:[CNode]2:1:@ba5f19ce240c22c5c190fb2987a36f33482caaadfb86668976a5653115fc4e93J/grad_nn_ops.pyB•

get_bprop_gelu.1:dout
get_bprop_gelu.1:x
get_bprop_gelu.1:out bprop.3:dx:3 bprop.3:dx:3""REF::get_bprop_gelu.1:input_grad:2:2
m
bprop.3:dx:3bprop.3:[CNode]4:4bprop.3:[CNode]4:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op3bprop.32
bprop.3:[CNode]4:4Pb&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -1,19 +0,0 @@
0.1.1 MindSpore*2.0.0:½
}'get_bprop_logical_or.1225:[CNode]1226:1'get_bprop_logical_or.1225:[CNode]1226:1"REF::bprop.1227:Default/bprop.1227-op954get_bprop_logical_or.1225*
get_bprop_logical_or.1225:self*
get_bprop_logical_or.1225:x*
get_bprop_logical_or.1225:y*
get_bprop_logical_or.1225:out*
get_bprop_logical_or.1225:dout2)
'get_bprop_logical_or.1225:[CNode]1226:1:@906051cca7d6d4b88a09a10b80bb5f0541066115667786dd7364cba0508be483J/grad_math_ops.pyB¶
²
get_bprop_logical_or.1225:xbprop.1227:[CNode]1228:2bprop.1227:[CNode]1228:2".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op955
²
get_bprop_logical_or.1225:ybprop.1227:[CNode]1229:3bprop.1227:[CNode]1229:3".REF::MetaFuncGraph::hyper_map[zeros_like_leaf]:/Default/S-Prim-hyper_map[zeros_like_leaf]-op956
¡
bprop.1227:[CNode]1228:2
bprop.1227:[CNode]1229:3bprop.1227:[CNode]1230:4bprop.1227:[CNode]1230:4"REF::S-Prim-MakeTuple:5:Default/S-Prim-MakeTuple-op957
bprop.12272
bprop.1227:[CNode]1230:4Pb&
S-Prim-MakeTuple:5S-Prim-MakeTupleh

View File

@ -1972,3 +1972,28 @@ def test_upsample_trilinear_3d():
"""
op = ops.UpsampleTrilinear3D(output_size=[4, 64, 48])
grad_compile_(Tensor(input_data=np.random.randn(2, 3, 4, 512, 256)), op)
def test_add():
"""
Feature: Bprop pre-compilation.
Description: Compile the backward graph for the add op.
Expectation: Load the bprop mindir successfully.
"""
x = Tensor(np.ones([1]).astype(np.int32) * 100)
y = Tensor(np.ones([1]).astype(np.int32) * 100)
add = Net(P.Add())
grad = GradNet(add)
grad.compile(x, y)
def test_neg():
"""
Feature: Bprop pre-compilation.
Description: Compile the backward graph for the batch neg op.
Expectation: Load the bprop mindir successfully.
"""
x = Tensor(np.array([1, 2, -1, 2, 0, -3.5]), mindspore.float32)
neg_net = Net(ops.Neg())
grad = GradNet(neg_net)
grad.compile(x)