forked from mindspore-Ecosystem/mindspore
!45643 fix conv3dtranspose error log
Merge pull request !45643 from 胡彬/conv3d-transpose-fix
This commit is contained in:
commit
ab52c856f5
|
@ -259,6 +259,16 @@ class Conv3DTransposeInfer : public abstract::OpInferBase {
|
|||
int64_t pad_mode;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode);
|
||||
|
||||
if ((w_shape[kAxis2] != abstract::Shape::kShapeDimAny && w_shape[kAxis2] != kernel_size[kAxis0]) ||
|
||||
(w_shape[kAxis3] != abstract::Shape::kShapeDimAny && w_shape[kAxis3] != kernel_size[kAxis1]) ||
|
||||
(w_shape[kAxis4] != abstract::Shape::kShapeDimAny && w_shape[kAxis4] != kernel_size[kAxis2])) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the dimension 'DHW' of input 'weight' must be "
|
||||
<< " equal to the shape of 'kernel size', but got 'DHW' of input 'weight': ("
|
||||
<< w_shape[kAxis2] << ", " << w_shape[kAxis3] << ", " << w_shape[kAxis4]
|
||||
<< "), and 'kernel size': (" << kernel_size[kAxis0] << ", " << w_shape[kAxis1] << ", "
|
||||
<< w_shape[kAxis2] << ").";
|
||||
}
|
||||
|
||||
int64_t d_out = abstract::Shape::kShapeDimAny;
|
||||
int64_t w_out = abstract::Shape::kShapeDimAny;
|
||||
int64_t h_out = abstract::Shape::kShapeDimAny;
|
||||
|
|
Loading…
Reference in New Issue