forked from mindspore-Ecosystem/mindspore
add bprop expander for dynamic
This commit is contained in:
parent
05b6f800b0
commit
4d228ecd3b
|
@ -209,7 +209,7 @@ NodePtr Emitter::ZerosLike(const NodePtr &node) const {
|
|||
return node;
|
||||
} else if (abs->isa<abstract::AbstractSequence>()) {
|
||||
auto sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
|
||||
if (sequence_abs->empty()) {
|
||||
if (!sequence_abs->dynamic_len() && sequence_abs->empty()) {
|
||||
return node;
|
||||
}
|
||||
return Emit(prim::kSequenceZerosLike, {node});
|
||||
|
|
|
@ -273,6 +273,8 @@ void RegArrayBpropExpanderOps1() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(Argmax);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Argmin);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(BatchToSpace);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(BatchToSpaceND);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(BroadcastTo);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(CheckNumerics);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Col2Im);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ConjugateTranspose);
|
||||
|
@ -280,35 +282,54 @@ void RegArrayBpropExpanderOps1() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(Diag);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(DiagPart);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(DType);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(DynamicShape);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Fill);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Fills);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Flatten);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(GatherD);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(GatherDGrad);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(GatherDGradV2);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(GatherNd);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Identity);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(IdentityN);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(MaskedSelect);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(MatrixDiagV3);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(NonZero);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(OnesLike);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Pack);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Padding);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Range);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Rank);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ResizeNearestNeighbor);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ReverseSequence);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ReverseV2);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ScatterMax);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ScatterMin);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ScatterNd);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(SegmentMax);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(SegmentMin);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Select);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Shape);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Slice);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(SpaceToBatch);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(SpaceToBatchND);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(SpaceToDepth);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Split);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(SplitV);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Squeeze);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Stack);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(StridedSliceV2);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(StridedSliceGrad);
|
||||
}
|
||||
|
||||
void RegArrayBpropExpanderOps2() {
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterAdd);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterDiv);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterElement);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterMul);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterSub);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TensorScatterUpdate);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(TensorShape);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Tril);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(Triu);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(ZerosLike);
|
||||
|
@ -320,9 +341,10 @@ void RegArrayBpropExpanderOps1() {
|
|||
REGISTER_EXPANDER_BPROP_IMPL(Tile);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(StridedSlice);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentSum);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentProd);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentMin);
|
||||
REGISTER_EXPANDER_BPROP_IMPL(UnsortedSegmentMax);
|
||||
}
|
||||
|
||||
void RegArrayBpropExpanderOps2() {}
|
||||
void RegClipBpropExpanderOps() {}
|
||||
void RegCommBpropExpanderOps() {}
|
||||
void RegInnerBpropExpanderOps() {}
|
||||
|
|
|
@ -343,6 +343,26 @@ NodePtrList ConcatBpropStatic(const BpropIRBuilder *ib, const NodePtr &dout, con
|
|||
}
|
||||
return {ib->MakeTuple(res)};
|
||||
}
|
||||
|
||||
NodePtrList StackBpropFunc(const BpropIRBuilder *ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto num = ib->GetAttr("num");
|
||||
auto ret = ib->Emit("Unstack", {dout}, {{"num", num}, {"axis", ib->GetAttr("axis")}});
|
||||
|
||||
auto x_abs = x->get()->abstract();
|
||||
MS_EXCEPTION_IF_NULL(x_abs);
|
||||
bool is_list = x_abs->isa<abstract::AbstractList>();
|
||||
if (is_list) {
|
||||
NodePtrList res;
|
||||
auto num_v = LongToSize(GetValue<int64_t>(num));
|
||||
for (size_t i = 0; i < num_v; ++i) {
|
||||
res.push_back(ib->TupleGetItem(ret, i));
|
||||
}
|
||||
return {ib->MakeList(res)};
|
||||
}
|
||||
return {ret};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
REG_BPROP_BUILDERS_BEGIN(GradArrayOps)
|
||||
|
@ -532,17 +552,8 @@ REG_BPROP_BUILDER("Range").SetUnusedInputs({i0, i1, i2, i3, i4}).SetBody(BODYFUN
|
|||
return {ib->ZerosLike(start), ib->ZerosLike(limit), ib->ZerosLike(delta)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Pack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
|
||||
return {ret};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Stack").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto ret = ib->Emit("Unstack", {dout}, {{"num", ib->GetAttr("num")}, {"axis", ib->GetAttr("axis")}});
|
||||
return {ret};
|
||||
});
|
||||
REG_BPROP_BUILDER("Pack").SetUnusedInputs({i1}).SetBody(StackBpropFunc);
|
||||
REG_BPROP_BUILDER("Stack").SetUnusedInputs({i1}).SetBody(StackBpropFunc);
|
||||
|
||||
REG_BPROP_BUILDER("ReverseV2").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
|
@ -626,11 +637,34 @@ REG_BPROP_BUILDER("ResizeNearestNeighbor").SetUnusedInputs({i0, i1}).SetBody(BOD
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto x_shape = ib->GetShape(x);
|
||||
ShapeVector new_shape;
|
||||
for (size_t i = 2; i < x_shape.size(); i++) {
|
||||
new_shape.push_back(x_shape[i]);
|
||||
NodePtr shape;
|
||||
if (!IsDynamic(x_shape)) {
|
||||
ShapeVector new_shape;
|
||||
for (size_t i = 2; i < x_shape.size(); i++) {
|
||||
new_shape.push_back(x_shape[i]);
|
||||
}
|
||||
shape = ib->EmitValue(MakeValue(new_shape));
|
||||
} else {
|
||||
auto shape_func = [](const ShapeArray &inputs) -> ShapeArray {
|
||||
auto shape = inputs[0];
|
||||
ShapeVector res;
|
||||
for (size_t i = 2; i < shape.size(); ++i) {
|
||||
res.push_back(shape[i]);
|
||||
}
|
||||
return {res};
|
||||
};
|
||||
|
||||
auto infer_func = [](const ShapeArray &inputs, const std::unordered_set<size_t> &invalid_indices) -> ShapeVector {
|
||||
if (!invalid_indices.empty()) {
|
||||
return ShapeVector(-1);
|
||||
}
|
||||
auto rank = SizeToLong(inputs[0].size());
|
||||
return ShapeVector{rank > 2 ? (rank - 2) : 0};
|
||||
};
|
||||
|
||||
shape = ib->ShapeCalc({x}, shape_func, infer_func, {})[0];
|
||||
}
|
||||
auto shape = ib->EmitValue(MakeValue(new_shape));
|
||||
|
||||
auto out = ib->Emit("ResizeNearestNeighborGrad", {dout, shape}, {{"align_corners", ib->GetAttr("align_corners")}});
|
||||
return {out};
|
||||
});
|
||||
|
@ -639,7 +673,13 @@ REG_BPROP_BUILDER("GatherNd").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto indices = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex3);
|
||||
auto shp = ib->EmitValue(MakeValue(ib->GetShape(x)));
|
||||
auto x_shp = ib->GetShape(x);
|
||||
NodePtr shp;
|
||||
if (IsDynamic(x_shp)) {
|
||||
shp = ib->Emit("TensorShape", {x});
|
||||
} else {
|
||||
shp = ib->EmitValue(MakeValue(x_shp));
|
||||
}
|
||||
return {ib->Emit("ScatterNd", {indices, dout, shp}), ib->ZerosLike(indices)};
|
||||
});
|
||||
|
||||
|
@ -674,8 +714,11 @@ REG_BPROP_BUILDER("TensorScatterUpdate").SetUnusedInputs({i0, i3}).SetBody(BODYF
|
|||
REG_BPROP_BUILDER("Flatten").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto dx = ib->Reshape(dout, ib->GetShape(x));
|
||||
return {dx};
|
||||
auto x_shape = ib->GetShape(x);
|
||||
if (IsDynamic(x_shape)) {
|
||||
return {ib->Reshape(dout, ib->Emit("Shape", {x}))};
|
||||
}
|
||||
return {ib->Reshape(dout, x_shape)};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Reshape").SetUnusedInputs({i0, i1, i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -989,15 +1032,46 @@ REG_BPROP_BUILDER("BroadcastTo").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib)
|
|||
|
||||
auto x_shape = ib->GetShape(x);
|
||||
auto dout_shape = ib->GetShape(dout);
|
||||
if (x_shape == dout_shape) {
|
||||
|
||||
bool input_dynamic = IsDynamic(x_shape) || IsDynamic(dout_shape);
|
||||
if (!input_dynamic && x_shape == dout_shape) {
|
||||
return {dout};
|
||||
}
|
||||
|
||||
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
|
||||
auto reduction_axes = tuple_out[kIndex1];
|
||||
auto reduced_grad = ib->ReduceSum(dout, reduction_axes, true);
|
||||
auto dx = ib->Reshape(reduced_grad, x_shape);
|
||||
auto dout_dtype = ib->GetDtype(dout)->type_id();
|
||||
|
||||
NodePtr dx;
|
||||
if (!input_dynamic && !IsDynamic(broadcast_shape)) {
|
||||
auto tuple_out = BroadcastGradientArgs(broadcast_shape, x_shape);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(!tuple_out.empty(), "BroadcastGradientArgs out should not be empty!");
|
||||
auto reduction_axes = tuple_out[kIndex1];
|
||||
NodePtr reduced_grad;
|
||||
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
|
||||
auto dout_cast = ib->Cast(dout, kFloat32);
|
||||
reduced_grad = ib->ReduceSum(dout_cast, reduction_axes, true);
|
||||
reduced_grad = ib->Cast(reduced_grad, ib->GetDtype(dout));
|
||||
} else {
|
||||
reduced_grad = ib->ReduceSum(dout, reduction_axes, true);
|
||||
}
|
||||
dx = ib->Reshape(reduced_grad, x_shape);
|
||||
} else {
|
||||
auto x_shape_node = ib->Emit("TensorShape", {x});
|
||||
auto broadcast_shape_node = ib->Emit("TensorShape", {dout});
|
||||
auto brod = ib->Emit("DynamicBroadcastGradientArgs", {broadcast_shape_node, x_shape_node});
|
||||
auto reduction_axes = ib->TupleGetItem(brod, 1);
|
||||
NodePtr reduced_grad;
|
||||
if (dout_dtype == kNumberTypeInt16 || dout_dtype == kNumberTypeInt32 || dout_dtype == kNumberTypeInt64) {
|
||||
auto dout_cast = ib->Cast(dout, kFloat32);
|
||||
reduced_grad = ib->Emit("ReduceSum", {dout_cast, reduction_axes},
|
||||
{{"keep_dims", MakeValue(true)}, {"skip_mode", MakeValue(true)}});
|
||||
reduced_grad = ib->Cast(reduced_grad, ib->GetDtype(dout));
|
||||
} else {
|
||||
reduced_grad =
|
||||
ib->Emit("ReduceSum", {dout, reduction_axes}, {{"keep_dims", MakeValue(true)}, {"skip_mode", MakeValue(true)}});
|
||||
}
|
||||
dx = ib->Reshape(reduced_grad, x_shape_node);
|
||||
}
|
||||
|
||||
return {dx};
|
||||
});
|
||||
|
||||
|
@ -1068,6 +1142,9 @@ REG_BPROP_BUILDER("Squeeze").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shapex = ib->GetShape(x);
|
||||
if (IsDynamic(shapex)) {
|
||||
return {ib->Reshape(dout, ib->Emit("Shape", {x}))};
|
||||
}
|
||||
return {ib->Reshape(dout, shapex)};
|
||||
});
|
||||
|
||||
|
@ -1075,10 +1152,15 @@ REG_BPROP_BUILDER("Padding").SetUnusedInputs({i0, i1}).SetBody(BODYFUNC(ib) {
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto shp = ib->GetShape(x);
|
||||
std::vector<int64_t> begin;
|
||||
(void)begin.insert(begin.end(), shp.size(), 0);
|
||||
auto dx = ib->Emit("Slice", {dout, ib->Value<ShapeVector>(begin), ib->Value<ShapeVector>(shp)});
|
||||
return {dx};
|
||||
if (!IsDynamic(shp)) {
|
||||
std::vector<int64_t> begin(shp.size(), 0);
|
||||
auto dx = ib->Emit("Slice", {dout, ib->Value<ShapeVector>(begin), ib->Value<ShapeVector>(shp)});
|
||||
return {dx};
|
||||
}
|
||||
|
||||
auto shape_node = ib->Emit("Shape", {x});
|
||||
auto begin_node = ib->ZerosLike(shape_node);
|
||||
return {ib->Emit("Slice", {dout, begin_node, shape_node})};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER("Transpose").SetUnusedInputs({i0, i2}).SetBody(BODYFUNC(ib) {
|
||||
|
@ -1245,7 +1327,13 @@ REG_BPROP_BUILDER("StridedSliceV2").SetUnusedInputs({i0, i4}).SetBody(BODYFUNC(i
|
|||
auto end = ib->GetInput(kIndex2);
|
||||
auto strides = ib->GetInput(kIndex3);
|
||||
auto dout = ib->GetInput(kIndex5);
|
||||
auto x_shape = ib->Tensor(ib->GetShape(x));
|
||||
auto x_shape_vec = ib->GetShape(x);
|
||||
NodePtr x_shape;
|
||||
if (IsDynamic(x_shape_vec)) {
|
||||
x_shape = ib->Emit("TensorShape", {x});
|
||||
} else {
|
||||
x_shape = ib->Tensor(x_shape_vec);
|
||||
}
|
||||
auto dx = ib->Emit("StridedSliceV2Grad", {x_shape, begin, end, strides, dout},
|
||||
{{"begin_mask", ib->GetAttr("begin_mask")},
|
||||
{"end_mask", ib->GetAttr("end_mask")},
|
||||
|
|
Loading…
Reference in New Issue