!45226 fix some bprop attr
Merge pull request !45226 from r1chardf1d0/bpe2
This commit is contained in:
commit
6dbfda8c83
|
@ -622,4 +622,20 @@ bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_type
|
|||
return IsIdentidityOrSubclass(check_type, accept);
|
||||
});
|
||||
}
|
||||
|
||||
ShapeVector PoolToNHWC(const ShapeVector &v) {
|
||||
ShapeVector new_v(v);
|
||||
new_v[kIndex1] = v[kIndex2];
|
||||
new_v[kIndex2] = v[kIndex3];
|
||||
new_v[kIndex3] = v[kIndex1];
|
||||
return new_v;
|
||||
}
|
||||
ShapeVector ConvToNHWC(const ShapeVector &v) {
|
||||
ShapeVector new_v(v);
|
||||
new_v[kIndex0] = v[kIndex1];
|
||||
new_v[kIndex1] = v[kIndex2];
|
||||
new_v[kIndex2] = v[kIndex3];
|
||||
new_v[kIndex3] = 1;
|
||||
return new_v;
|
||||
}
|
||||
} // namespace mindspore::expander::bprop
|
||||
|
|
|
@ -79,5 +79,7 @@ NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int
|
|||
TypeId PromoteBinaryDtype(TypeId t1, TypeId t2);
|
||||
NodePtr LGamma(const BpropIRBuilder *ib, const NodePtr &x);
|
||||
bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types);
|
||||
ShapeVector PoolToNHWC(const ShapeVector &v);
|
||||
ShapeVector ConvToNHWC(const ShapeVector &v);
|
||||
} // namespace mindspore::expander::bprop
|
||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
|
||||
|
|
|
@ -26,13 +26,21 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
auto dout = ib->GetInput(kIndex3);
|
||||
auto x_shape = ib->GetShape(x);
|
||||
auto w_shape = ib->GetShape(w);
|
||||
auto format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
auto dilation = GetValue<ShapeVector>(ib->GetAttr("dilation"));
|
||||
auto stride = GetValue<ShapeVector>(ib->GetAttr("stride"));
|
||||
if (format == "NHWC") {
|
||||
dilation = ConvToNHWC(dilation);
|
||||
stride = ConvToNHWC(stride);
|
||||
}
|
||||
auto dx = ib->Emit(kConv2DBackpropInputOpName, {dout, w, ib->Value<ShapeVector>(x_shape)},
|
||||
{{"mode", ib->GetAttr("mode")},
|
||||
{"dilation", ib->GetAttr("dilation")},
|
||||
{"stride", ib->GetAttr("stride")},
|
||||
{"dilation", MakeValue(dilation)},
|
||||
{"stride", MakeValue(stride)},
|
||||
{"group", ib->GetAttr("group")},
|
||||
{"groups", ib->GetAttr("group")},
|
||||
{"format", ib->GetAttr("format")},
|
||||
{"data_format", ib->GetAttr("data_format")},
|
||||
{"out_channel", ib->GetAttr("out_channel")},
|
||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
|
@ -40,11 +48,12 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
{"pad_list", ib->GetAttr("pad_list")}});
|
||||
auto dw = ib->Emit("Conv2DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
|
||||
{{"mode", ib->GetAttr("mode")},
|
||||
{"dilation", ib->GetAttr("dilation")},
|
||||
{"stride", ib->GetAttr("stride")},
|
||||
{"dilation", MakeValue(dilation)},
|
||||
{"stride", MakeValue(stride)},
|
||||
{"group", ib->GetAttr("group")},
|
||||
{"groups", ib->GetAttr("group")},
|
||||
{"format", ib->GetAttr("format")},
|
||||
{"data_format", ib->GetAttr("data_format")},
|
||||
{"out_channel", ib->GetAttr("out_channel")},
|
||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
|
@ -52,15 +61,22 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
{"pad_list", ib->GetAttr("pad_list")}});
|
||||
return {dx, dw};
|
||||
});
|
||||
|
||||
REG_BPROP_BUILDER(kMaxPoolOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
|
||||
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
|
||||
if (format == "NHWC") {
|
||||
kernel_size = PoolToNHWC(kernel_size);
|
||||
strides = PoolToNHWC(strides);
|
||||
}
|
||||
auto dx = ib->Emit(kMaxPoolGradOpName, {x, out, dout},
|
||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"strides", ib->GetAttr("strides")},
|
||||
{{"kernel_size", MakeValue(kernel_size)},
|
||||
{"strides", MakeValue(strides)},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
{"data_format", ib->GetAttr("data_format")},
|
||||
{"format", ib->GetAttr("format")}});
|
||||
return {dx};
|
||||
});
|
||||
|
@ -205,6 +221,7 @@ REG_BPROP_BUILDER("DeformableOffsets").SetBody([](const BpropIRBuilder *ib) -> N
|
|||
{"ksize", ib->GetAttr("ksize")},
|
||||
{"dilations", ib->GetAttr("dilations")},
|
||||
{"format", ib->GetAttr("format")},
|
||||
{"data_format", ib->GetAttr("format")},
|
||||
{"deformable_groups", ib->GetAttr("deformable_groups")},
|
||||
{"modulated", ib->GetAttr("modulated")}});
|
||||
return {ib->TupleGetItem(out_grad, 0), ib->TupleGetItem(out_grad, 1)};
|
||||
|
@ -479,6 +496,8 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
{"dilation", MakeValue(dilation)},
|
||||
{"group", ib->GetAttr("groups")},
|
||||
{"groups", ib->GetAttr("groups")},
|
||||
{"offset_x", MakeValue<int64_t>(0)},
|
||||
{"format", ib->GetAttr("format")},
|
||||
{"data_format", ib->GetAttr("format")}});
|
||||
auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
|
||||
{{"out_channel", ib->GetAttr("in_channel")},
|
||||
|
@ -493,6 +512,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
{"dilation", ib->GetAttr("dilations")},
|
||||
{"group", ib->GetAttr("groups")},
|
||||
{"groups", ib->GetAttr("groups")},
|
||||
{"format", ib->GetAttr("format")},
|
||||
{"data_format", ib->GetAttr("format")}});
|
||||
return {dx, dw};
|
||||
});
|
||||
|
@ -518,6 +538,7 @@ REG_BPROP_BUILDER("MaxPoolGradGrad").SetBody([](const BpropIRBuilder *ib) -> Nod
|
|||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"strides", ib->GetAttr("strides")},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
{"data_format", MakeValue("NCHW")},
|
||||
{"format", MakeValue("NCHW")}});
|
||||
return {dx1, dx2, dgrad};
|
||||
});
|
||||
|
@ -553,6 +574,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
dgrad = ib->Emit("MaxPoolGradGrad", {x1, x2, dout},
|
||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"strides", ib->GetAttr("strides")},
|
||||
{"data_format", MakeValue("NCHW")},
|
||||
{"format", MakeValue("NCHW")},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")}});
|
||||
} else {
|
||||
|
@ -565,6 +587,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
|||
{{"kernel_size", MakeValue(kernel_size)},
|
||||
{"strides", MakeValue(strides)},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
{"data_format", MakeValue("NCHW")},
|
||||
{"format", MakeValue("NCHW")}});
|
||||
auto ind = ib->TupleGetItem(tmp, 1);
|
||||
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
|
||||
|
@ -668,10 +691,18 @@ REG_BPROP_BUILDER("AvgPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
|||
auto x = ib->GetInput(kIndex0);
|
||||
auto out = ib->GetInput(kIndex1);
|
||||
auto dout = ib->GetInput(kIndex2);
|
||||
auto format = GetValue<std::string>(ib->GetAttr("format"));
|
||||
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
|
||||
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
|
||||
if (format == "NHWC") {
|
||||
kernel_size = PoolToNHWC(kernel_size);
|
||||
strides = PoolToNHWC(strides);
|
||||
}
|
||||
auto dx = ib->Emit("AvgPoolGrad", {x, out, dout},
|
||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
||||
{"strides", ib->GetAttr("strides")},
|
||||
{{"kernel_size", MakeValue(kernel_size)},
|
||||
{"strides", MakeValue(strides)},
|
||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||
{"data_format", ib->GetAttr("data_format")},
|
||||
{"format", ib->GetAttr("format")}});
|
||||
return {dx};
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue