!45226 fix some bprop attr

Merge pull request !45226 from r1chardf1d0/bpe2
This commit is contained in:
i-robot 2022-11-08 07:55:55 +00:00 committed by Gitee
commit 6dbfda8c83
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 58 additions and 9 deletions

View File

@ -622,4 +622,20 @@ bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_type
return IsIdentidityOrSubclass(check_type, accept); return IsIdentidityOrSubclass(check_type, accept);
}); });
} }
ShapeVector PoolToNHWC(const ShapeVector &v) {
ShapeVector new_v(v);
new_v[kIndex1] = v[kIndex2];
new_v[kIndex2] = v[kIndex3];
new_v[kIndex3] = v[kIndex1];
return new_v;
}
ShapeVector ConvToNHWC(const ShapeVector &v) {
ShapeVector new_v(v);
new_v[kIndex0] = v[kIndex1];
new_v[kIndex1] = v[kIndex2];
new_v[kIndex2] = v[kIndex3];
new_v[kIndex3] = 1;
return new_v;
}
} // namespace mindspore::expander::bprop } // namespace mindspore::expander::bprop

View File

@ -79,5 +79,7 @@ NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int
TypeId PromoteBinaryDtype(TypeId t1, TypeId t2); TypeId PromoteBinaryDtype(TypeId t1, TypeId t2);
NodePtr LGamma(const BpropIRBuilder *ib, const NodePtr &x); NodePtr LGamma(const BpropIRBuilder *ib, const NodePtr &x);
bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types); bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types);
ShapeVector PoolToNHWC(const ShapeVector &v);
ShapeVector ConvToNHWC(const ShapeVector &v);
} // namespace mindspore::expander::bprop } // namespace mindspore::expander::bprop
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_ #endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_

View File

@ -26,13 +26,21 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
auto dout = ib->GetInput(kIndex3); auto dout = ib->GetInput(kIndex3);
auto x_shape = ib->GetShape(x); auto x_shape = ib->GetShape(x);
auto w_shape = ib->GetShape(w); auto w_shape = ib->GetShape(w);
auto format = GetValue<std::string>(ib->GetAttr("format"));
auto dilation = GetValue<ShapeVector>(ib->GetAttr("dilation"));
auto stride = GetValue<ShapeVector>(ib->GetAttr("stride"));
if (format == "NHWC") {
dilation = ConvToNHWC(dilation);
stride = ConvToNHWC(stride);
}
auto dx = ib->Emit(kConv2DBackpropInputOpName, {dout, w, ib->Value<ShapeVector>(x_shape)}, auto dx = ib->Emit(kConv2DBackpropInputOpName, {dout, w, ib->Value<ShapeVector>(x_shape)},
{{"mode", ib->GetAttr("mode")}, {{"mode", ib->GetAttr("mode")},
{"dilation", ib->GetAttr("dilation")}, {"dilation", MakeValue(dilation)},
{"stride", ib->GetAttr("stride")}, {"stride", MakeValue(stride)},
{"group", ib->GetAttr("group")}, {"group", ib->GetAttr("group")},
{"groups", ib->GetAttr("group")}, {"groups", ib->GetAttr("group")},
{"format", ib->GetAttr("format")}, {"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("data_format")},
{"out_channel", ib->GetAttr("out_channel")}, {"out_channel", ib->GetAttr("out_channel")},
{"kernel_size", ib->GetAttr("kernel_size")}, {"kernel_size", ib->GetAttr("kernel_size")},
{"pad_mode", ib->GetAttr("pad_mode")}, {"pad_mode", ib->GetAttr("pad_mode")},
@ -40,11 +48,12 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
{"pad_list", ib->GetAttr("pad_list")}}); {"pad_list", ib->GetAttr("pad_list")}});
auto dw = ib->Emit("Conv2DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)}, auto dw = ib->Emit("Conv2DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
{{"mode", ib->GetAttr("mode")}, {{"mode", ib->GetAttr("mode")},
{"dilation", ib->GetAttr("dilation")}, {"dilation", MakeValue(dilation)},
{"stride", ib->GetAttr("stride")}, {"stride", MakeValue(stride)},
{"group", ib->GetAttr("group")}, {"group", ib->GetAttr("group")},
{"groups", ib->GetAttr("group")}, {"groups", ib->GetAttr("group")},
{"format", ib->GetAttr("format")}, {"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("data_format")},
{"out_channel", ib->GetAttr("out_channel")}, {"out_channel", ib->GetAttr("out_channel")},
{"kernel_size", ib->GetAttr("kernel_size")}, {"kernel_size", ib->GetAttr("kernel_size")},
{"pad_mode", ib->GetAttr("pad_mode")}, {"pad_mode", ib->GetAttr("pad_mode")},
@ -52,15 +61,22 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
{"pad_list", ib->GetAttr("pad_list")}}); {"pad_list", ib->GetAttr("pad_list")}});
return {dx, dw}; return {dx, dw};
}); });
REG_BPROP_BUILDER(kMaxPoolOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtrList { REG_BPROP_BUILDER(kMaxPoolOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto format = GetValue<std::string>(ib->GetAttr("format"));
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
if (format == "NHWC") {
kernel_size = PoolToNHWC(kernel_size);
strides = PoolToNHWC(strides);
}
auto dx = ib->Emit(kMaxPoolGradOpName, {x, out, dout}, auto dx = ib->Emit(kMaxPoolGradOpName, {x, out, dout},
{{"kernel_size", ib->GetAttr("kernel_size")}, {{"kernel_size", MakeValue(kernel_size)},
{"strides", ib->GetAttr("strides")}, {"strides", MakeValue(strides)},
{"pad_mode", ib->GetAttr("pad_mode")}, {"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", ib->GetAttr("data_format")},
{"format", ib->GetAttr("format")}}); {"format", ib->GetAttr("format")}});
return {dx}; return {dx};
}); });
@ -205,6 +221,7 @@ REG_BPROP_BUILDER("DeformableOffsets").SetBody([](const BpropIRBuilder *ib) -> N
{"ksize", ib->GetAttr("ksize")}, {"ksize", ib->GetAttr("ksize")},
{"dilations", ib->GetAttr("dilations")}, {"dilations", ib->GetAttr("dilations")},
{"format", ib->GetAttr("format")}, {"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("format")},
{"deformable_groups", ib->GetAttr("deformable_groups")}, {"deformable_groups", ib->GetAttr("deformable_groups")},
{"modulated", ib->GetAttr("modulated")}}); {"modulated", ib->GetAttr("modulated")}});
return {ib->TupleGetItem(out_grad, 0), ib->TupleGetItem(out_grad, 1)}; return {ib->TupleGetItem(out_grad, 0), ib->TupleGetItem(out_grad, 1)};
@ -479,6 +496,8 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
{"dilation", MakeValue(dilation)}, {"dilation", MakeValue(dilation)},
{"group", ib->GetAttr("groups")}, {"group", ib->GetAttr("groups")},
{"groups", ib->GetAttr("groups")}, {"groups", ib->GetAttr("groups")},
{"offset_x", MakeValue<int64_t>(0)},
{"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("format")}}); {"data_format", ib->GetAttr("format")}});
auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)}, auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
{{"out_channel", ib->GetAttr("in_channel")}, {{"out_channel", ib->GetAttr("in_channel")},
@ -493,6 +512,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
{"dilation", ib->GetAttr("dilations")}, {"dilation", ib->GetAttr("dilations")},
{"group", ib->GetAttr("groups")}, {"group", ib->GetAttr("groups")},
{"groups", ib->GetAttr("groups")}, {"groups", ib->GetAttr("groups")},
{"format", ib->GetAttr("format")},
{"data_format", ib->GetAttr("format")}}); {"data_format", ib->GetAttr("format")}});
return {dx, dw}; return {dx, dw};
}); });
@ -518,6 +538,7 @@ REG_BPROP_BUILDER("MaxPoolGradGrad").SetBody([](const BpropIRBuilder *ib) -> Nod
{{"kernel_size", ib->GetAttr("kernel_size")}, {{"kernel_size", ib->GetAttr("kernel_size")},
{"strides", ib->GetAttr("strides")}, {"strides", ib->GetAttr("strides")},
{"pad_mode", ib->GetAttr("pad_mode")}, {"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", MakeValue("NCHW")},
{"format", MakeValue("NCHW")}}); {"format", MakeValue("NCHW")}});
return {dx1, dx2, dgrad}; return {dx1, dx2, dgrad};
}); });
@ -553,6 +574,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
dgrad = ib->Emit("MaxPoolGradGrad", {x1, x2, dout}, dgrad = ib->Emit("MaxPoolGradGrad", {x1, x2, dout},
{{"kernel_size", ib->GetAttr("kernel_size")}, {{"kernel_size", ib->GetAttr("kernel_size")},
{"strides", ib->GetAttr("strides")}, {"strides", ib->GetAttr("strides")},
{"data_format", MakeValue("NCHW")},
{"format", MakeValue("NCHW")}, {"format", MakeValue("NCHW")},
{"pad_mode", ib->GetAttr("pad_mode")}}); {"pad_mode", ib->GetAttr("pad_mode")}});
} else { } else {
@ -565,6 +587,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
{{"kernel_size", MakeValue(kernel_size)}, {{"kernel_size", MakeValue(kernel_size)},
{"strides", MakeValue(strides)}, {"strides", MakeValue(strides)},
{"pad_mode", ib->GetAttr("pad_mode")}, {"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", MakeValue("NCHW")},
{"format", MakeValue("NCHW")}}); {"format", MakeValue("NCHW")}});
auto ind = ib->TupleGetItem(tmp, 1); auto ind = ib->TupleGetItem(tmp, 1);
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32)); auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
@ -668,10 +691,18 @@ REG_BPROP_BUILDER("AvgPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
auto x = ib->GetInput(kIndex0); auto x = ib->GetInput(kIndex0);
auto out = ib->GetInput(kIndex1); auto out = ib->GetInput(kIndex1);
auto dout = ib->GetInput(kIndex2); auto dout = ib->GetInput(kIndex2);
auto format = GetValue<std::string>(ib->GetAttr("format"));
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
if (format == "NHWC") {
kernel_size = PoolToNHWC(kernel_size);
strides = PoolToNHWC(strides);
}
auto dx = ib->Emit("AvgPoolGrad", {x, out, dout}, auto dx = ib->Emit("AvgPoolGrad", {x, out, dout},
{{"kernel_size", ib->GetAttr("kernel_size")}, {{"kernel_size", MakeValue(kernel_size)},
{"strides", ib->GetAttr("strides")}, {"strides", MakeValue(strides)},
{"pad_mode", ib->GetAttr("pad_mode")}, {"pad_mode", ib->GetAttr("pad_mode")},
{"data_format", ib->GetAttr("data_format")},
{"format", ib->GetAttr("format")}}); {"format", ib->GetAttr("format")}});
return {dx}; return {dx};
}); });