!45226 fix some bprop attr

Merge pull request !45226 from r1chardf1d0/bpe2
This commit is contained in:
i-robot 2022-11-08 07:55:55 +00:00 committed by Gitee
commit 6dbfda8c83
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 58 additions and 9 deletions

View File

@ -622,4 +622,20 @@ bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_type
return IsIdentidityOrSubclass(check_type, accept);
});
}
ShapeVector PoolToNHWC(const ShapeVector &v) {
ShapeVector new_v(v);
new_v[kIndex1] = v[kIndex2];
new_v[kIndex2] = v[kIndex3];
new_v[kIndex3] = v[kIndex1];
return new_v;
}
ShapeVector ConvToNHWC(const ShapeVector &v) {
ShapeVector new_v(v);
new_v[kIndex0] = v[kIndex1];
new_v[kIndex1] = v[kIndex2];
new_v[kIndex2] = v[kIndex3];
new_v[kIndex3] = 1;
return new_v;
}
} // namespace mindspore::expander::bprop

View File

@ -79,5 +79,7 @@ NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int
TypeId PromoteBinaryDtype(TypeId t1, TypeId t2);
NodePtr LGamma(const BpropIRBuilder *ib, const NodePtr &x);
bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types);
ShapeVector PoolToNHWC(const ShapeVector &v);
ShapeVector ConvToNHWC(const ShapeVector &v);
} // namespace mindspore::expander::bprop
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_

View File

@ -26,13 +26,21 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
auto dout = ib->GetInput(kIndex3);
auto x_shape = ib->GetShape(x);
auto w_shape = ib->GetShape(w);
auto format = GetValue<std::string>(ib->GetAttr("format"));
auto dilation = GetValue<ShapeVector>(ib->GetAttr("dilation"));
auto stride = GetValue<ShapeVector>(ib->GetAttr("stride"));
if (format == "NHWC") {
dilation = ConvToNHWC(dilation);
stride = ConvToNHWC(stride);
}
auto dx = ib->Emit(kConv2DBackpropInputOpName, {dout, w, ib->Value<ShapeVector>(x_shape)},
{{"mode", ib->GetAttr("mode")},
{"dilation", ib->GetAttr("dilation")},
{"stride", ib->GetAttr("stride")},
{"dilation", MakeValue(dilation)},
{"stride", MakeValue(stride)},
{"group", ib->GetAttr("group")},
{"groups", ib->GetAttr("group")},
{"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("data_format")},
{"out_channel", ib->GetAttr("out_channel")},
{"kernel_size", ib->GetAttr("kernel_size")},
{"pad_mode", ib->GetAttr("pad_mode")},
@ -40,11 +48,12 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
{"pad_list", ib->GetAttr("pad_list")}});
auto dw = ib->Emit("Conv2DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
{{"mode", ib->GetAttr("mode")},
{"dilation", ib->GetAttr("dilation")},
{"stride", ib->GetAttr("stride")},
{"dilation", MakeValue(dilation)},
{"stride", MakeValue(stride)},
{"group", ib->GetAttr("group")},
{"groups", ib->GetAttr("group")},
{"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("data_format")},
{"out_channel", ib->GetAttr("out_channel")},
{"kernel_size", ib->GetAttr("kernel_size")},
{"pad_mode", ib->GetAttr("pad_mode")},
@ -52,15 +61,22 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
{"pad_list", ib->GetAttr("pad_list")}});
return {dx, dw};
});
REG_BPROP_BUILDER(kMaxPoolOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex2);
auto format = GetValue<std::string>(ib->GetAttr("format"));
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
if (format == "NHWC") {
kernel_size = PoolToNHWC(kernel_size);
strides = PoolToNHWC(strides);
}
auto dx = ib->Emit(kMaxPoolGradOpName, {x, out, dout},
{{"kernel_size", ib->GetAttr("kernel_size")},
{"strides", ib->GetAttr("strides")},
{{"kernel_size", MakeValue(kernel_size)},
{"strides", MakeValue(strides)},
{"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", ib->GetAttr("data_format")},
{"format", ib->GetAttr("format")}});
return {dx};
});
@ -205,6 +221,7 @@ REG_BPROP_BUILDER("DeformableOffsets").SetBody([](const BpropIRBuilder *ib) -> N
{"ksize", ib->GetAttr("ksize")},
{"dilations", ib->GetAttr("dilations")},
{"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("format")},
{"deformable_groups", ib->GetAttr("deformable_groups")},
{"modulated", ib->GetAttr("modulated")}});
return {ib->TupleGetItem(out_grad, 0), ib->TupleGetItem(out_grad, 1)};
@ -479,6 +496,8 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
{"dilation", MakeValue(dilation)},
{"group", ib->GetAttr("groups")},
{"groups", ib->GetAttr("groups")},
{"offset_x", MakeValue<int64_t>(0)},
{"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("format")}});
auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
{{"out_channel", ib->GetAttr("in_channel")},
@ -493,6 +512,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
{"dilation", ib->GetAttr("dilations")},
{"group", ib->GetAttr("groups")},
{"groups", ib->GetAttr("groups")},
{"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("format")}});
return {dx, dw};
});
@ -518,6 +538,7 @@ REG_BPROP_BUILDER("MaxPoolGradGrad").SetBody([](const BpropIRBuilder *ib) -> Nod
{{"kernel_size", ib->GetAttr("kernel_size")},
{"strides", ib->GetAttr("strides")},
{"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", MakeValue("NCHW")},
{"format", MakeValue("NCHW")}});
return {dx1, dx2, dgrad};
});
@ -553,6 +574,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
dgrad = ib->Emit("MaxPoolGradGrad", {x1, x2, dout},
{{"kernel_size", ib->GetAttr("kernel_size")},
{"strides", ib->GetAttr("strides")},
{"data_format", MakeValue("NCHW")},
{"format", MakeValue("NCHW")},
{"pad_mode", ib->GetAttr("pad_mode")}});
} else {
@ -565,6 +587,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
{{"kernel_size", MakeValue(kernel_size)},
{"strides", MakeValue(strides)},
{"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", MakeValue("NCHW")},
{"format", MakeValue("NCHW")}});
auto ind = ib->TupleGetItem(tmp, 1);
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
@ -668,10 +691,18 @@ REG_BPROP_BUILDER("AvgPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex2);
auto format = GetValue<std::string>(ib->GetAttr("format"));
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
if (format == "NHWC") {
kernel_size = PoolToNHWC(kernel_size);
strides = PoolToNHWC(strides);
}
auto dx = ib->Emit("AvgPoolGrad", {x, out, dout},
{{"kernel_size", ib->GetAttr("kernel_size")},
{"strides", ib->GetAttr("strides")},
{{"kernel_size", MakeValue(kernel_size)},
{"strides", MakeValue(strides)},
{"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", ib->GetAttr("data_format")},
{"format", ib->GetAttr("format")}});
return {dx};
});