!45226 fix some bprop attr
Merge pull request !45226 from r1chardf1d0/bpe2
This commit is contained in:
commit
6dbfda8c83
|
@ -622,4 +622,20 @@ bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_type
|
||||||
return IsIdentidityOrSubclass(check_type, accept);
|
return IsIdentidityOrSubclass(check_type, accept);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShapeVector PoolToNHWC(const ShapeVector &v) {
|
||||||
|
ShapeVector new_v(v);
|
||||||
|
new_v[kIndex1] = v[kIndex2];
|
||||||
|
new_v[kIndex2] = v[kIndex3];
|
||||||
|
new_v[kIndex3] = v[kIndex1];
|
||||||
|
return new_v;
|
||||||
|
}
|
||||||
|
ShapeVector ConvToNHWC(const ShapeVector &v) {
|
||||||
|
ShapeVector new_v(v);
|
||||||
|
new_v[kIndex0] = v[kIndex1];
|
||||||
|
new_v[kIndex1] = v[kIndex2];
|
||||||
|
new_v[kIndex2] = v[kIndex3];
|
||||||
|
new_v[kIndex3] = 1;
|
||||||
|
return new_v;
|
||||||
|
}
|
||||||
} // namespace mindspore::expander::bprop
|
} // namespace mindspore::expander::bprop
|
||||||
|
|
|
@ -79,5 +79,7 @@ NodePtr ArgminOrArgmaxGrad(const BpropIRBuilder *ib, const NodePtr &x, const int
|
||||||
TypeId PromoteBinaryDtype(TypeId t1, TypeId t2);
|
TypeId PromoteBinaryDtype(TypeId t1, TypeId t2);
|
||||||
NodePtr LGamma(const BpropIRBuilder *ib, const NodePtr &x);
|
NodePtr LGamma(const BpropIRBuilder *ib, const NodePtr &x);
|
||||||
bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types);
|
bool CheckType(const TypePtr &check_type, const std::set<TypePtr> &template_types);
|
||||||
|
ShapeVector PoolToNHWC(const ShapeVector &v);
|
||||||
|
ShapeVector ConvToNHWC(const ShapeVector &v);
|
||||||
} // namespace mindspore::expander::bprop
|
} // namespace mindspore::expander::bprop
|
||||||
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
|
#endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_BPROP_EXPANDER_COMMON_UTILS_H_
|
||||||
|
|
|
@ -26,13 +26,21 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
||||||
auto dout = ib->GetInput(kIndex3);
|
auto dout = ib->GetInput(kIndex3);
|
||||||
auto x_shape = ib->GetShape(x);
|
auto x_shape = ib->GetShape(x);
|
||||||
auto w_shape = ib->GetShape(w);
|
auto w_shape = ib->GetShape(w);
|
||||||
|
auto format = GetValue<std::string>(ib->GetAttr("format"));
|
||||||
|
auto dilation = GetValue<ShapeVector>(ib->GetAttr("dilation"));
|
||||||
|
auto stride = GetValue<ShapeVector>(ib->GetAttr("stride"));
|
||||||
|
if (format == "NHWC") {
|
||||||
|
dilation = ConvToNHWC(dilation);
|
||||||
|
stride = ConvToNHWC(stride);
|
||||||
|
}
|
||||||
auto dx = ib->Emit(kConv2DBackpropInputOpName, {dout, w, ib->Value<ShapeVector>(x_shape)},
|
auto dx = ib->Emit(kConv2DBackpropInputOpName, {dout, w, ib->Value<ShapeVector>(x_shape)},
|
||||||
{{"mode", ib->GetAttr("mode")},
|
{{"mode", ib->GetAttr("mode")},
|
||||||
{"dilation", ib->GetAttr("dilation")},
|
{"dilation", MakeValue(dilation)},
|
||||||
{"stride", ib->GetAttr("stride")},
|
{"stride", MakeValue(stride)},
|
||||||
{"group", ib->GetAttr("group")},
|
{"group", ib->GetAttr("group")},
|
||||||
{"groups", ib->GetAttr("group")},
|
{"groups", ib->GetAttr("group")},
|
||||||
{"format", ib->GetAttr("format")},
|
{"format", ib->GetAttr("format")},
|
||||||
|
{"data_format", ib->GetAttr("data_format")},
|
||||||
{"out_channel", ib->GetAttr("out_channel")},
|
{"out_channel", ib->GetAttr("out_channel")},
|
||||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||||
|
@ -40,11 +48,12 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
||||||
{"pad_list", ib->GetAttr("pad_list")}});
|
{"pad_list", ib->GetAttr("pad_list")}});
|
||||||
auto dw = ib->Emit("Conv2DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
|
auto dw = ib->Emit("Conv2DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
|
||||||
{{"mode", ib->GetAttr("mode")},
|
{{"mode", ib->GetAttr("mode")},
|
||||||
{"dilation", ib->GetAttr("dilation")},
|
{"dilation", MakeValue(dilation)},
|
||||||
{"stride", ib->GetAttr("stride")},
|
{"stride", MakeValue(stride)},
|
||||||
{"group", ib->GetAttr("group")},
|
{"group", ib->GetAttr("group")},
|
||||||
{"groups", ib->GetAttr("group")},
|
{"groups", ib->GetAttr("group")},
|
||||||
{"format", ib->GetAttr("format")},
|
{"format", ib->GetAttr("format")},
|
||||||
|
{"data_format", ib->GetAttr("data_format")},
|
||||||
{"out_channel", ib->GetAttr("out_channel")},
|
{"out_channel", ib->GetAttr("out_channel")},
|
||||||
{"kernel_size", ib->GetAttr("kernel_size")},
|
{"kernel_size", ib->GetAttr("kernel_size")},
|
||||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||||
|
@ -52,15 +61,22 @@ REG_BPROP_BUILDER(kConv2DOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
||||||
{"pad_list", ib->GetAttr("pad_list")}});
|
{"pad_list", ib->GetAttr("pad_list")}});
|
||||||
return {dx, dw};
|
return {dx, dw};
|
||||||
});
|
});
|
||||||
|
|
||||||
REG_BPROP_BUILDER(kMaxPoolOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
REG_BPROP_BUILDER(kMaxPoolOpName).SetBody([](const BpropIRBuilder *ib) -> NodePtrList {
|
||||||
auto x = ib->GetInput(kIndex0);
|
auto x = ib->GetInput(kIndex0);
|
||||||
auto out = ib->GetInput(kIndex1);
|
auto out = ib->GetInput(kIndex1);
|
||||||
auto dout = ib->GetInput(kIndex2);
|
auto dout = ib->GetInput(kIndex2);
|
||||||
|
auto format = GetValue<std::string>(ib->GetAttr("format"));
|
||||||
|
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
|
||||||
|
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
|
||||||
|
if (format == "NHWC") {
|
||||||
|
kernel_size = PoolToNHWC(kernel_size);
|
||||||
|
strides = PoolToNHWC(strides);
|
||||||
|
}
|
||||||
auto dx = ib->Emit(kMaxPoolGradOpName, {x, out, dout},
|
auto dx = ib->Emit(kMaxPoolGradOpName, {x, out, dout},
|
||||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
{{"kernel_size", MakeValue(kernel_size)},
|
||||||
{"strides", ib->GetAttr("strides")},
|
{"strides", MakeValue(strides)},
|
||||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||||
|
{"data_format", ib->GetAttr("data_format")},
|
||||||
{"format", ib->GetAttr("format")}});
|
{"format", ib->GetAttr("format")}});
|
||||||
return {dx};
|
return {dx};
|
||||||
});
|
});
|
||||||
|
@ -205,6 +221,7 @@ REG_BPROP_BUILDER("DeformableOffsets").SetBody([](const BpropIRBuilder *ib) -> N
|
||||||
{"ksize", ib->GetAttr("ksize")},
|
{"ksize", ib->GetAttr("ksize")},
|
||||||
{"dilations", ib->GetAttr("dilations")},
|
{"dilations", ib->GetAttr("dilations")},
|
||||||
{"format", ib->GetAttr("format")},
|
{"format", ib->GetAttr("format")},
|
||||||
|
{"data_format", ib->GetAttr("format")},
|
||||||
{"deformable_groups", ib->GetAttr("deformable_groups")},
|
{"deformable_groups", ib->GetAttr("deformable_groups")},
|
||||||
{"modulated", ib->GetAttr("modulated")}});
|
{"modulated", ib->GetAttr("modulated")}});
|
||||||
return {ib->TupleGetItem(out_grad, 0), ib->TupleGetItem(out_grad, 1)};
|
return {ib->TupleGetItem(out_grad, 0), ib->TupleGetItem(out_grad, 1)};
|
||||||
|
@ -479,6 +496,8 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
|
||||||
{"dilation", MakeValue(dilation)},
|
{"dilation", MakeValue(dilation)},
|
||||||
{"group", ib->GetAttr("groups")},
|
{"group", ib->GetAttr("groups")},
|
||||||
{"groups", ib->GetAttr("groups")},
|
{"groups", ib->GetAttr("groups")},
|
||||||
|
{"offset_x", MakeValue<int64_t>(0)},
|
||||||
|
{"format", ib->GetAttr("format")},
|
||||||
{"data_format", ib->GetAttr("format")}});
|
{"data_format", ib->GetAttr("format")}});
|
||||||
auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
|
auto dw = ib->Emit("Conv3DBackpropFilter", {dout, x, ib->Value<ShapeVector>(w_shape)},
|
||||||
{{"out_channel", ib->GetAttr("in_channel")},
|
{{"out_channel", ib->GetAttr("in_channel")},
|
||||||
|
@ -493,6 +512,7 @@ REG_BPROP_BUILDER("Conv3DTranspose").SetBody([](const BpropIRBuilder *ib) -> Nod
|
||||||
{"dilation", ib->GetAttr("dilations")},
|
{"dilation", ib->GetAttr("dilations")},
|
||||||
{"group", ib->GetAttr("groups")},
|
{"group", ib->GetAttr("groups")},
|
||||||
{"groups", ib->GetAttr("groups")},
|
{"groups", ib->GetAttr("groups")},
|
||||||
|
{"format", ib->GetAttr("format")},
|
||||||
{"data_format", ib->GetAttr("format")}});
|
{"data_format", ib->GetAttr("format")}});
|
||||||
return {dx, dw};
|
return {dx, dw};
|
||||||
});
|
});
|
||||||
|
@ -518,6 +538,7 @@ REG_BPROP_BUILDER("MaxPoolGradGrad").SetBody([](const BpropIRBuilder *ib) -> Nod
|
||||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
{{"kernel_size", ib->GetAttr("kernel_size")},
|
||||||
{"strides", ib->GetAttr("strides")},
|
{"strides", ib->GetAttr("strides")},
|
||||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||||
|
{"data_format", MakeValue("NCHW")},
|
||||||
{"format", MakeValue("NCHW")}});
|
{"format", MakeValue("NCHW")}});
|
||||||
return {dx1, dx2, dgrad};
|
return {dx1, dx2, dgrad};
|
||||||
});
|
});
|
||||||
|
@ -553,6 +574,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
||||||
dgrad = ib->Emit("MaxPoolGradGrad", {x1, x2, dout},
|
dgrad = ib->Emit("MaxPoolGradGrad", {x1, x2, dout},
|
||||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
{{"kernel_size", ib->GetAttr("kernel_size")},
|
||||||
{"strides", ib->GetAttr("strides")},
|
{"strides", ib->GetAttr("strides")},
|
||||||
|
{"data_format", MakeValue("NCHW")},
|
||||||
{"format", MakeValue("NCHW")},
|
{"format", MakeValue("NCHW")},
|
||||||
{"pad_mode", ib->GetAttr("pad_mode")}});
|
{"pad_mode", ib->GetAttr("pad_mode")}});
|
||||||
} else {
|
} else {
|
||||||
|
@ -565,6 +587,7 @@ REG_BPROP_BUILDER("MaxPoolGrad").SetBody([](const BpropIRBuilder *ib) -> NodePtr
|
||||||
{{"kernel_size", MakeValue(kernel_size)},
|
{{"kernel_size", MakeValue(kernel_size)},
|
||||||
{"strides", MakeValue(strides)},
|
{"strides", MakeValue(strides)},
|
||||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||||
|
{"data_format", MakeValue("NCHW")},
|
||||||
{"format", MakeValue("NCHW")}});
|
{"format", MakeValue("NCHW")}});
|
||||||
auto ind = ib->TupleGetItem(tmp, 1);
|
auto ind = ib->TupleGetItem(tmp, 1);
|
||||||
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
|
auto batch = ib->Tensor(Range(b), TypeIdToType(TypeId::kNumberTypeInt32));
|
||||||
|
@ -668,10 +691,18 @@ REG_BPROP_BUILDER("AvgPool").SetBody([](const BpropIRBuilder *ib) -> NodePtrList
|
||||||
auto x = ib->GetInput(kIndex0);
|
auto x = ib->GetInput(kIndex0);
|
||||||
auto out = ib->GetInput(kIndex1);
|
auto out = ib->GetInput(kIndex1);
|
||||||
auto dout = ib->GetInput(kIndex2);
|
auto dout = ib->GetInput(kIndex2);
|
||||||
|
auto format = GetValue<std::string>(ib->GetAttr("format"));
|
||||||
|
auto kernel_size = GetValue<ShapeVector>(ib->GetAttr("kernel_size"));
|
||||||
|
auto strides = GetValue<ShapeVector>(ib->GetAttr("strides"));
|
||||||
|
if (format == "NHWC") {
|
||||||
|
kernel_size = PoolToNHWC(kernel_size);
|
||||||
|
strides = PoolToNHWC(strides);
|
||||||
|
}
|
||||||
auto dx = ib->Emit("AvgPoolGrad", {x, out, dout},
|
auto dx = ib->Emit("AvgPoolGrad", {x, out, dout},
|
||||||
{{"kernel_size", ib->GetAttr("kernel_size")},
|
{{"kernel_size", MakeValue(kernel_size)},
|
||||||
{"strides", ib->GetAttr("strides")},
|
{"strides", MakeValue(strides)},
|
||||||
{"pad_mode", ib->GetAttr("pad_mode")},
|
{"pad_mode", ib->GetAttr("pad_mode")},
|
||||||
|
{"data_format", ib->GetAttr("data_format")},
|
||||||
{"format", ib->GetAttr("format")}});
|
{"format", ib->GetAttr("format")}});
|
||||||
return {dx};
|
return {dx};
|
||||||
});
|
});
|
||||||
|
|
Loading…
Reference in New Issue