!49817 Add some array op support for bprop expander in graph mode.

Merge pull request !49817 from TronZhang/bprop_op_part3
This commit is contained in:
i-robot 2023-03-07 02:37:19 +00:00 committed by Gitee
commit ad9f51dd23
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 142 additions and 32 deletions

View File

@ -209,7 +209,7 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
return node;
} else if (abs->isa<abstract::AbstractSequence>()) {
auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
if (sequence_abs->empty()) {
if (!sequence_abs->dynamic_len() && sequence_abs->empty()) {
return node;
}
return Emit(prim::kSequenceZerosLike, {node});

View File

@ -273,6 +273,8 @@ void RegArrayBpropExpanderOps1() {
REGISTER_EXPANDER_BPROP_IMPL(Argmax);
REGISTER_EXPANDER_BPROP_IMPL(Argmin);
REGISTER_EXPANDER_BPROP_IMPL(BatchToSpace);
REGISTER_EXPANDER_BPROP_IMPL(BatchToSpaceND);
REGISTER_EXPANDER_BPROP_IMPL(BroadcastTo);
REGISTER_EXPANDER_BPROP_IMPL(CheckNumerics);
REGISTER_EXPANDER_BPROP_IMPL(Col2Im);
REGISTER_EXPANDER_BPROP_IMPL(ConjugateTranspose);
@ -280,35 +282,54 @@ void RegArrayBpropExpanderOps1() {
REGISTER_EXPANDER_BPROP_IMPL(Diag);
REGISTER_EXPANDER_BPROP_IMPL(DiagPart);
REGISTER_EXPANDER_BPROP_IMPL(DType);
REGISTER_EXPANDER_BPROP_IMPL(DynamicShape);
REGISTER_EXPANDER_BPROP_IMPL(Fill);
REGISTER_EXPANDER_BPROP_IMPL(Fills);
REGISTER_EXPANDER_BPROP_IMPL(Flatten);
REGISTER_EXPANDER_BPROP_IMPL(GatherD);
REGISTER_EXPANDER_BPROP_IMPL(GatherDGrad);
REGISTER_EXPANDER_BPROP_IMPL(GatherDGradV2);
REGISTER_EXPANDER_BPROP_IMPL(GatherNd);
REGISTER_EXPANDER_BPROP_IMPL(Identity);
REGISTER_EXPANDER_BPROP_IMPL(IdentityN);
REGISTER_EXPANDER_BPROP_IMPL(MaskedSelect);
REGISTER_EXPANDER_BPROP_IMPL(MatrixDiagV3);
REGISTER_EXPANDER_BPROP_IMPL(NonZero);
REGISTER_EXPANDER_BPROP_IMPL(OnesLike);
REGISTER_EXPANDER_BPROP_IMPL(Pack);
REGISTER_EXPANDER_BPROP_IMPL(Padding);
REGISTER_EXPANDER_BPROP_IMPL(Range);
REGISTER_EXPANDER_BPROP_IMPL(Rank);
REGISTER_EXPANDER_BPROP_IMPL(ResizeNearestNeighbor);
REGISTER_EXPANDER_BPROP_IMPL(ReverseSequence);
REGISTER_EXPANDER_BPROP_IMPL(ReverseV2);
REGISTER_EXPANDER_BPROP_IMPL(ScatterMax);
REGISTER_EXPANDER_BPROP_IMPL(ScatterMin);
REGISTER_EXPANDER_BPROP_IMPL(ScatterNd);
REGISTER_EXPANDER_BPROP_IMPL(SegmentMax);
REGISTER_EXPANDER_BPROP_IMPL(SegmentMin);
REGISTER_EXPANDER_BPROP_IMPL(Select);
REGISTER_EXPANDER_BPROP_IMPL(Shape);
REGISTER_EXPANDER_BPROP_IMPL(Slice);
REGISTER_EXPANDER_BPROP_IMPL(SpaceToBatch);
REGISTER_EXPANDER_BPROP_IMPL(SpaceToBatchND);
REGISTER_EXPANDER_BPROP_IMPL(SpaceToDepth);
REGISTER_EXPANDER_BPROP_IMPL(Split);
REGISTER_EXPANDER_BPROP_IMPL(SplitV);
REGISTER_EXPANDER_BPROP_IMPL(Squeeze);
REGISTER_EXPANDER_BPROP_IMPL(Stack);
REGISTER_EXPANDER_BPROP_IMPL(StridedSliceV2);
REGISTER_EXPANDER_BPROP_IMPL(StridedSliceGrad);
}
void RegArrayBpropExpanderOps2() {
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterAdd);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterDiv);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterElement);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterMul);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterSub);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterUpdate);
REGISTER_EXPANDER_BPROP_IMPL(TensorShape);
REGISTER_EXPANDER_BPROP_IMPL(Tril);
REGISTER_EXPANDER_BPROP_IMPL(Triu);
REGISTER_EXPANDER_BPROP_IMPL(ZerosLike);
@ -320,9 +341,10 @@ void RegArrayBpropExpanderOps1() {
REGISTER_EXPANDER_BPROP_IMPL(Tile);
REGISTER_EXPANDER_BPROP_IMPL(StridedSlice);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentSum);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentProd);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentMin);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentMax);
}
void RegArrayBpropExpanderOps2() {}
void RegClipBpropExpanderOps() {}
void RegCommBpropExpanderOps() {}
void RegInnerBpropExpanderOps() {}

View File

@ -343,6 +343,26 @@ NodePtrList ConcatBpropStatic(const BpropIRBuilder *ib, const NodePtr &dout, con
}
return {ib->MakeTuple(res)};
}
NodePtrList StackBpropFunc(const BpropIRBuilder *ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto num = ib->GetAttr("num");
auto ret = ib->Emit("Unstack", {dout}, {{"num", num}, {"axis", ib->GetAttr("axis")}});
auto x_abs = x->get()->abstract();
MS_EXCEPTION_IF_NULL(x_abs);
bool is_list = x_abs->isa<abstract::AbstractList>();
if (is_list) {
NodePtrList res;
auto num_v = LongToSize(GetValue<int64_t>(num));
for (size_t i = 0; i < num_v; ++i) {
res.push_back(ib->TupleGetItem(ret, i));
}
return {ib->MakeList(res)};
}
return {ret};
}
} // namespace
REG_BPROP_BUILDERS_BEGIN(GradArrayOps)
@ -532,17 +552,8 @@ REG_BPROP_BUILDER("Range").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUN
return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)};
});
REG_BPROP_BUILDER("Pack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2);
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
return {ret};
});
REG_BPROP_BUILDER("Stack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2);
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
return {ret};
});
REG_BPROP_BUILDER("Pack").SetUnusedInputs({i1}).SetBody(StackBpropFunc);
REG_BPROP_BUILDER("Stack").SetUnusedInputs({i1}).SetBody(StackBpropFunc);
REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2);
@ -626,11 +637,34 @@ REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i0, i1}).SetBody(BOD
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto x_shape = ib->GetShape(x);
ShapeVector new_shape;
for (size_t i = 2; i < x_shape.size(); i++) {
new_shape.push_back(x_shape[i]);
NodePtr shape;
if (!IsDynamic(x_shape)) {
ShapeVector new_shape;
for (size_t i = 2; i < x_shape.size(); i++) {
new_shape.push_back(x_shape[i]);
}
shape = ib->EmitValue(MakeValue(new_shape));
} else {
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
auto shape = inputs[0];
ShapeVector res;
for (size_t i = 2; i < shape.size(); ++i) {
res.push_back(shape[i]);
}
return {res};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) -> ShapeVector {
if (!invalid_indices.empty()) {
return ShapeVector(-1);
}
auto rank = SizeToLong(inputs[0].size());
return ShapeVector{rank > 2 ? (rank - 2) : 0};
};
shape = ib->ShapeCalc({x}, shape_func, infer_func, {})[0];
}
auto shape = ib->EmitValue(MakeValue(new_shape));
auto out = ib->Emit("ResizeNearestNeighborGrad", {dout, shape}, {{"align_corners", ib->GetAttr("align_corners")}});
return {out};
});
@ -639,7 +673,13 @@ REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3);
auto shp = ib->EmitValue(MakeValue(ib->GetShape(x)));
auto x_shp = ib->GetShape(x);
NodePtr shp;
if (IsDynamic(x_shp)) {
shp = ib->Emit("TensorShape", {x});
} else {
shp = ib->EmitValue(MakeValue(x_shp));
}
return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)};
});
@ -674,8 +714,11 @@ REG_BPROP_BUILDER("TensorScatterUpdate").SetUnusedInputs({i0, i3}).SetBody(BODYF
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto dx = ib->Reshape(dout, ib->GetShape(x));
return {dx};
auto x_shape = ib->GetShape(x);
if (IsDynamic(x_shape)) {
return {ib->Reshape(dout, ib->Emit("Shape", {x}))};
}
return {ib->Reshape(dout, x_shape)};
});
REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
@ -989,15 +1032,46 @@ REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib)
auto x_shape = ib->GetShape(x);
auto dout_shape = ib->GetShape(dout);
if (x_shape == dout_shape) {
bool input_dynamic = IsDynamic(x_shape) || IsDynamic(dout_shape);
if (!input_dynamic && x_shape == dout_shape) {
return {dout};
}
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
auto reduction_axes = tuple_out[kIndex1];
auto reduced_grad = ib->ReduceSum(dout, reduction_axes, true);
auto dx = ib->Reshape(reduced_grad, x_shape);
auto dout_dtype = ib->GetDtype(dout)->type_id();
NodePtr dx;
if (!input_dynamic && !IsDynamic(broadcast_shape)) {
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
auto reduction_axes = tuple_out[kIndex1];
NodePtr reduced_grad;
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
auto dout_cast = ib->Cast(dout, kFloat32);
reduced_grad = ib->ReduceSum(dout_cast, reduction_axes, true);
reduced_grad = ib->Cast(reduced_grad, ib->GetDtype(dout));
} else {
reduced_grad = ib->ReduceSum(dout, reduction_axes, true);
}
dx = ib->Reshape(reduced_grad, x_shape);
} else {
auto x_shape_node = ib->Emit("TensorShape", {x});
auto broadcast_shape_node = ib->Emit("TensorShape", {dout});
auto brod = ib->Emit("DynamicBroadcastGradientArgs", {broadcast_shape_node, x_shape_node});
auto reduction_axes = ib->TupleGetItem(brod, 1);
NodePtr reduced_grad;
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
auto dout_cast = ib->Cast(dout, kFloat32);
reduced_grad = ib->Emit("ReduceSum", {dout_cast, reduction_axes},
{{"keep_dims", MakeValue(true)}, {"skip_mode", MakeValue(true)}});
reduced_grad = ib->Cast(reduced_grad, ib->GetDtype(dout));
} else {
reduced_grad =
ib->Emit("ReduceSum", {dout, reduction_axes}, {{"keep_dims", MakeValue(true)}, {"skip_mode", MakeValue(true)}});
}
dx = ib->Reshape(reduced_grad, x_shape_node);
}
return {dx};
});
@ -1068,6 +1142,9 @@ REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto shapex = ib->GetShape(x);
if (IsDynamic(shapex)) {
return {ib->Reshape(dout, ib->Emit("Shape", {x}))};
}
return {ib->Reshape(dout, shapex)};
});
@ -1075,10 +1152,15 @@ REG_BPROP_BUILDER("Padding").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto shp = ib->GetShape(x);
std::vector<int64_t> begin;
(void)begin.insert(begin.end(), shp.size(), 0);
auto dx = ib->Emit("Slice", {dout, ib->Value<ShapeVector>(begin), ib->Value<ShapeVector>(shp)});
return {dx};
if (!IsDynamic(shp)) {
std::vector<int64_t> begin(shp.size(), 0);
auto dx = ib->Emit("Slice", {dout, ib->Value<ShapeVector>(begin), ib->Value<ShapeVector>(shp)});
return {dx};
}
auto shape_node = ib->Emit("Shape", {x});
auto begin_node = ib->ZerosLike(shape_node);
return {ib->Emit("Slice", {dout, begin_node, shape_node})};
});
REG_BPROP_BUILDER("Transpose").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
@ -1245,7 +1327,13 @@ REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(i
auto end = ib->GetInput(kIndex2);
auto strides = ib->GetInput(kIndex3);
auto dout = ib->GetInput(kIndex5);
auto x_shape = ib->Tensor(ib->GetShape(x));
auto x_shape_vec = ib->GetShape(x);
NodePtr x_shape;
if (IsDynamic(x_shape_vec)) {
x_shape = ib->Emit("TensorShape", {x});
} else {
x_shape = ib->Tensor(x_shape_vec);
}
auto dx = ib->Emit("StridedSliceV2Grad", {x_shape, begin, end, strides, dout},
{{"begin_mask", ib->GetAttr("begin_mask")},
{"end_mask", ib->GetAttr("end_mask")},