add bprop expander for dynamic

This commit is contained in:
TronZhang 2023-03-03 16:29:32 +08:00
parent 05b6f800b0
commit 4d228ecd3b
3 changed files with 142 additions and 32 deletions

View File

@ -209,7 +209,7 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
return node; return node;
} else if (abs->isa<abstract::AbstractSequence>()) { } else if (abs->isa<abstract::AbstractSequence>()) {
auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>(); auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
if (sequence_abs->empty()) { if (!sequence_abs->dynamic_len() && sequence_abs->empty()) {
return node; return node;
} }
return Emit(prim::kSequenceZerosLike, {node}); return Emit(prim::kSequenceZerosLike, {node});

View File

@ -273,6 +273,8 @@ void RegArrayBpropExpanderOps1() {
REGISTER_EXPANDER_BPROP_IMPL(Argmax); REGISTER_EXPANDER_BPROP_IMPL(Argmax);
REGISTER_EXPANDER_BPROP_IMPL(Argmin); REGISTER_EXPANDER_BPROP_IMPL(Argmin);
REGISTER_EXPANDER_BPROP_IMPL(BatchToSpace); REGISTER_EXPANDER_BPROP_IMPL(BatchToSpace);
REGISTER_EXPANDER_BPROP_IMPL(BatchToSpaceND);
REGISTER_EXPANDER_BPROP_IMPL(BroadcastTo);
REGISTER_EXPANDER_BPROP_IMPL(CheckNumerics); REGISTER_EXPANDER_BPROP_IMPL(CheckNumerics);
REGISTER_EXPANDER_BPROP_IMPL(Col2Im); REGISTER_EXPANDER_BPROP_IMPL(Col2Im);
REGISTER_EXPANDER_BPROP_IMPL(ConjugateTranspose); REGISTER_EXPANDER_BPROP_IMPL(ConjugateTranspose);
@ -280,35 +282,54 @@ void RegArrayBpropExpanderOps1() {
REGISTER_EXPANDER_BPROP_IMPL(Diag); REGISTER_EXPANDER_BPROP_IMPL(Diag);
REGISTER_EXPANDER_BPROP_IMPL(DiagPart); REGISTER_EXPANDER_BPROP_IMPL(DiagPart);
REGISTER_EXPANDER_BPROP_IMPL(DType); REGISTER_EXPANDER_BPROP_IMPL(DType);
REGISTER_EXPANDER_BPROP_IMPL(DynamicShape);
REGISTER_EXPANDER_BPROP_IMPL(Fill); REGISTER_EXPANDER_BPROP_IMPL(Fill);
REGISTER_EXPANDER_BPROP_IMPL(Fills); REGISTER_EXPANDER_BPROP_IMPL(Fills);
REGISTER_EXPANDER_BPROP_IMPL(Flatten);
REGISTER_EXPANDER_BPROP_IMPL(GatherD); REGISTER_EXPANDER_BPROP_IMPL(GatherD);
REGISTER_EXPANDER_BPROP_IMPL(GatherDGrad);
REGISTER_EXPANDER_BPROP_IMPL(GatherDGradV2);
REGISTER_EXPANDER_BPROP_IMPL(GatherNd);
REGISTER_EXPANDER_BPROP_IMPL(Identity); REGISTER_EXPANDER_BPROP_IMPL(Identity);
REGISTER_EXPANDER_BPROP_IMPL(IdentityN); REGISTER_EXPANDER_BPROP_IMPL(IdentityN);
REGISTER_EXPANDER_BPROP_IMPL(MaskedSelect); REGISTER_EXPANDER_BPROP_IMPL(MaskedSelect);
REGISTER_EXPANDER_BPROP_IMPL(MatrixDiagV3); REGISTER_EXPANDER_BPROP_IMPL(MatrixDiagV3);
REGISTER_EXPANDER_BPROP_IMPL(NonZero); REGISTER_EXPANDER_BPROP_IMPL(NonZero);
REGISTER_EXPANDER_BPROP_IMPL(OnesLike); REGISTER_EXPANDER_BPROP_IMPL(OnesLike);
REGISTER_EXPANDER_BPROP_IMPL(Pack);
REGISTER_EXPANDER_BPROP_IMPL(Padding);
REGISTER_EXPANDER_BPROP_IMPL(Range); REGISTER_EXPANDER_BPROP_IMPL(Range);
REGISTER_EXPANDER_BPROP_IMPL(Rank); REGISTER_EXPANDER_BPROP_IMPL(Rank);
REGISTER_EXPANDER_BPROP_IMPL(ResizeNearestNeighbor);
REGISTER_EXPANDER_BPROP_IMPL(ReverseSequence); REGISTER_EXPANDER_BPROP_IMPL(ReverseSequence);
REGISTER_EXPANDER_BPROP_IMPL(ReverseV2); REGISTER_EXPANDER_BPROP_IMPL(ReverseV2);
REGISTER_EXPANDER_BPROP_IMPL(ScatterMax); REGISTER_EXPANDER_BPROP_IMPL(ScatterMax);
REGISTER_EXPANDER_BPROP_IMPL(ScatterMin); REGISTER_EXPANDER_BPROP_IMPL(ScatterMin);
REGISTER_EXPANDER_BPROP_IMPL(ScatterNd);
REGISTER_EXPANDER_BPROP_IMPL(SegmentMax); REGISTER_EXPANDER_BPROP_IMPL(SegmentMax);
REGISTER_EXPANDER_BPROP_IMPL(SegmentMin); REGISTER_EXPANDER_BPROP_IMPL(SegmentMin);
REGISTER_EXPANDER_BPROP_IMPL(Select); REGISTER_EXPANDER_BPROP_IMPL(Select);
REGISTER_EXPANDER_BPROP_IMPL(Shape);
REGISTER_EXPANDER_BPROP_IMPL(Slice); REGISTER_EXPANDER_BPROP_IMPL(Slice);
REGISTER_EXPANDER_BPROP_IMPL(SpaceToBatch); REGISTER_EXPANDER_BPROP_IMPL(SpaceToBatch);
REGISTER_EXPANDER_BPROP_IMPL(SpaceToBatchND);
REGISTER_EXPANDER_BPROP_IMPL(SpaceToDepth); REGISTER_EXPANDER_BPROP_IMPL(SpaceToDepth);
REGISTER_EXPANDER_BPROP_IMPL(Split); REGISTER_EXPANDER_BPROP_IMPL(Split);
REGISTER_EXPANDER_BPROP_IMPL(SplitV); REGISTER_EXPANDER_BPROP_IMPL(SplitV);
REGISTER_EXPANDER_BPROP_IMPL(Squeeze);
REGISTER_EXPANDER_BPROP_IMPL(Stack);
REGISTER_EXPANDER_BPROP_IMPL(StridedSliceV2);
REGISTER_EXPANDER_BPROP_IMPL(StridedSliceGrad);
}
void RegArrayBpropExpanderOps2() {
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterAdd); REGISTER_EXPANDER_BPROP_IMPL(TensorScatterAdd);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterDiv); REGISTER_EXPANDER_BPROP_IMPL(TensorScatterDiv);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterElement); REGISTER_EXPANDER_BPROP_IMPL(TensorScatterElement);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterMul); REGISTER_EXPANDER_BPROP_IMPL(TensorScatterMul);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterSub); REGISTER_EXPANDER_BPROP_IMPL(TensorScatterSub);
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterUpdate); REGISTER_EXPANDER_BPROP_IMPL(TensorScatterUpdate);
REGISTER_EXPANDER_BPROP_IMPL(TensorShape);
REGISTER_EXPANDER_BPROP_IMPL(Tril); REGISTER_EXPANDER_BPROP_IMPL(Tril);
REGISTER_EXPANDER_BPROP_IMPL(Triu); REGISTER_EXPANDER_BPROP_IMPL(Triu);
REGISTER_EXPANDER_BPROP_IMPL(ZerosLike); REGISTER_EXPANDER_BPROP_IMPL(ZerosLike);
@ -320,9 +341,10 @@ void RegArrayBpropExpanderOps1() {
REGISTER_EXPANDER_BPROP_IMPL(Tile); REGISTER_EXPANDER_BPROP_IMPL(Tile);
REGISTER_EXPANDER_BPROP_IMPL(StridedSlice); REGISTER_EXPANDER_BPROP_IMPL(StridedSlice);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentSum); REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentSum);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentProd);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentMin);
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentMax);
} }
void RegArrayBpropExpanderOps2() {}
void RegClipBpropExpanderOps() {} void RegClipBpropExpanderOps() {}
void RegCommBpropExpanderOps() {} void RegCommBpropExpanderOps() {}
void RegInnerBpropExpanderOps() {} void RegInnerBpropExpanderOps() {}

View File

@ -343,6 +343,26 @@ NodePtrList ConcatBpropStatic(const BpropIRBuilder *ib, const NodePtr &dout, con
} }
return {ib->MakeTuple(res)}; return {ib->MakeTuple(res)};
} }
NodePtrList StackBpropFunc(const BpropIRBuilder *ib) {
auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2);
auto num = ib->GetAttr("num");
auto ret = ib->Emit("Unstack", {dout}, {{"num", num}, {"axis", ib->GetAttr("axis")}});
auto x_abs = x->get()->abstract();
MS_EXCEPTION_IF_NULL(x_abs);
bool is_list = x_abs->isa<abstract::AbstractList>();
if (is_list) {
NodePtrList res;
auto num_v = LongToSize(GetValue<int64_t>(num));
for (size_t i = 0; i < num_v; ++i) {
res.push_back(ib->TupleGetItem(ret, i));
}
return {ib->MakeList(res)};
}
return {ret};
}
} // namespace } // namespace
REG_BPROP_BUILDERS_BEGIN(GradArrayOps) REG_BPROP_BUILDERS_BEGIN(GradArrayOps)
@ -532,17 +552,8 @@ REG_BPROP_BUILDER("Range").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUN
return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)}; return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)};
}); });
REG_BPROP_BUILDER("Pack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Pack").SetUnusedInputs({i1}).SetBody(StackBpropFunc);
auto dout = ib->GetInput(kIndex2); REG_BPROP_BUILDER("Stack").SetUnusedInputs({i1}).SetBody(StackBpropFunc);
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
return {ret};
});
REG_BPROP_BUILDER("Stack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2);
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
return {ret};
});
REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
@ -626,11 +637,34 @@ REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i0, i1}).SetBody(BOD
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto x_shape = ib->GetShape(x); auto x_shape = ib->GetShape(x);
ShapeVector new_shape; NodePtr shape;
for (size_t i = 2; i < x_shape.size(); i++) { if (!IsDynamic(x_shape)) {
new_shape.push_back(x_shape[i]); ShapeVector new_shape;
for (size_t i = 2; i < x_shape.size(); i++) {
new_shape.push_back(x_shape[i]);
}
shape = ib->EmitValue(MakeValue(new_shape));
} else {
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
auto shape = inputs[0];
ShapeVector res;
for (size_t i = 2; i < shape.size(); ++i) {
res.push_back(shape[i]);
}
return {res};
};
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) -> ShapeVector {
if (!invalid_indices.empty()) {
return ShapeVector(-1);
}
auto rank = SizeToLong(inputs[0].size());
return ShapeVector{rank > 2 ? (rank - 2) : 0};
};
shape = ib->ShapeCalc({x}, shape_func, infer_func, {})[0];
} }
auto shape = ib->EmitValue(MakeValue(new_shape));
auto out = ib->Emit("ResizeNearestNeighborGrad", {dout, shape}, {{"align_corners", ib->GetAttr("align_corners")}}); auto out = ib->Emit("ResizeNearestNeighborGrad", {dout, shape}, {{"align_corners", ib->GetAttr("align_corners")}});
return {out}; return {out};
}); });
@ -639,7 +673,13 @@ REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto indices = ib->GetInput(kIndex1); auto indices = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
auto shp = ib->EmitValue(MakeValue(ib->GetShape(x))); auto x_shp = ib->GetShape(x);
NodePtr shp;
if (IsDynamic(x_shp)) {
shp = ib->Emit("TensorShape", {x});
} else {
shp = ib->EmitValue(MakeValue(x_shp));
}
return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)}; return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)};
}); });
@ -674,8 +714,11 @@ REG_BPROP_BUILDER("TensorScatterUpdate").SetUnusedInputs({i0, i3}).SetBody(BODYF
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto dx = ib->Reshape(dout, ib->GetShape(x)); auto x_shape = ib->GetShape(x);
return {dx}; if (IsDynamic(x_shape)) {
return {ib->Reshape(dout, ib->Emit("Shape", {x}))};
}
return {ib->Reshape(dout, x_shape)};
}); });
REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
@ -989,15 +1032,46 @@ REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib)
auto x_shape = ib->GetShape(x); auto x_shape = ib->GetShape(x);
auto dout_shape = ib->GetShape(dout); auto dout_shape = ib->GetShape(dout);
if (x_shape == dout_shape) {
bool input_dynamic = IsDynamic(x_shape) || IsDynamic(dout_shape);
if (!input_dynamic && x_shape == dout_shape) {
return {dout}; return {dout};
} }
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape); auto dout_dtype = ib->GetDtype(dout)->type_id();
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
auto reduction_axes = tuple_out[kIndex1]; NodePtr dx;
auto reduced_grad = ib->ReduceSum(dout, reduction_axes, true); if (!input_dynamic && !IsDynamic(broadcast_shape)) {
auto dx = ib->Reshape(reduced_grad, x_shape); auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
auto reduction_axes = tuple_out[kIndex1];
NodePtr reduced_grad;
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
auto dout_cast = ib->Cast(dout, kFloat32);
reduced_grad = ib->ReduceSum(dout_cast, reduction_axes, true);
reduced_grad = ib->Cast(reduced_grad, ib->GetDtype(dout));
} else {
reduced_grad = ib->ReduceSum(dout, reduction_axes, true);
}
dx = ib->Reshape(reduced_grad, x_shape);
} else {
auto x_shape_node = ib->Emit("TensorShape", {x});
auto broadcast_shape_node = ib->Emit("TensorShape", {dout});
auto brod = ib->Emit("DynamicBroadcastGradientArgs", {broadcast_shape_node, x_shape_node});
auto reduction_axes = ib->TupleGetItem(brod, 1);
NodePtr reduced_grad;
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
auto dout_cast = ib->Cast(dout, kFloat32);
reduced_grad = ib->Emit("ReduceSum", {dout_cast, reduction_axes},
{{"keep_dims", MakeValue(true)}, {"skip_mode", MakeValue(true)}});
reduced_grad = ib->Cast(reduced_grad, ib->GetDtype(dout));
} else {
reduced_grad =
ib->Emit("ReduceSum", {dout, reduction_axes}, {{"keep_dims", MakeValue(true)}, {"skip_mode", MakeValue(true)}});
}
dx = ib->Reshape(reduced_grad, x_shape_node);
}
return {dx}; return {dx};
}); });
@ -1068,6 +1142,9 @@ REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto shapex = ib->GetShape(x); auto shapex = ib->GetShape(x);
if (IsDynamic(shapex)) {
return {ib->Reshape(dout, ib->Emit("Shape", {x}))};
}
return {ib->Reshape(dout, shapex)}; return {ib->Reshape(dout, shapex)};
}); });
@ -1075,10 +1152,15 @@ REG_BPROP_BUILDER("Padding").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto shp = ib->GetShape(x); auto shp = ib->GetShape(x);
std::vector<int64_t> begin; if (!IsDynamic(shp)) {
(void)begin.insert(begin.end(), shp.size(), 0); std::vector<int64_t> begin(shp.size(), 0);
auto dx = ib->Emit("Slice", {dout, ib->Value<ShapeVector>(begin), ib->Value<ShapeVector>(shp)}); auto dx = ib->Emit("Slice", {dout, ib->Value<ShapeVector>(begin), ib->Value<ShapeVector>(shp)});
return {dx}; return {dx};
}
auto shape_node = ib->Emit("Shape", {x});
auto begin_node = ib->ZerosLike(shape_node);
return {ib->Emit("Slice", {dout, begin_node, shape_node})};
}); });
REG_BPROP_BUILDER("Transpose").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) { REG_BPROP_BUILDER("Transpose").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
@ -1245,7 +1327,13 @@ REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(i
auto end = ib->GetInput(kIndex2); auto end = ib->GetInput(kIndex2);
auto strides = ib->GetInput(kIndex3); auto strides = ib->GetInput(kIndex3);
auto dout = ib->GetInput(kIndex5); auto dout = ib->GetInput(kIndex5);
auto x_shape = ib->Tensor(ib->GetShape(x)); auto x_shape_vec = ib->GetShape(x);
NodePtr x_shape;
if (IsDynamic(x_shape_vec)) {
x_shape = ib->Emit("TensorShape", {x});
} else {
x_shape = ib->Tensor(x_shape_vec);
}
auto dx = ib->Emit("StridedSliceV2Grad", {x_shape, begin, end, strides, dout}, auto dx = ib->Emit("StridedSliceV2Grad", {x_shape, begin, end, strides, dout},
{{"begin_mask", ib->GetAttr("begin_mask")}, {{"begin_mask", ib->GetAttr("begin_mask")},
{"end_mask", ib->GetAttr("end_mask")}, {"end_mask", ib->GetAttr("end_mask")},