!45643 fix conv3dtranspose error log

Merge pull request !45643 from 胡彬/conv3d-transpose-fix
This commit is contained in:
i-robot 2022-11-17 10:49:51 +00:00 committed by Gitee
commit ab52c856f5
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 10 additions and 0 deletions

View File

@ -259,6 +259,16 @@ class Conv3DTransposeInfer : public abstract::OpInferBase {
int64_t pad_mode;
CheckAndConvertUtils::GetPadModEnumValue(primitive->GetAttr(kPadMode), &pad_mode);
if ((w_shape[kAxis2] != abstract::Shape::kShapeDimAny && w_shape[kAxis2] != kernel_size[kAxis0]) ||
(w_shape[kAxis3] != abstract::Shape::kShapeDimAny && w_shape[kAxis3] != kernel_size[kAxis1]) ||
(w_shape[kAxis4] != abstract::Shape::kShapeDimAny && w_shape[kAxis4] != kernel_size[kAxis2])) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the dimension 'DHW' of input 'weight' must be "
<< " equal to the shape of 'kernel size', but got 'DHW' of input 'weight': ("
<< w_shape[kAxis2] << ", " << w_shape[kAxis3] << ", " << w_shape[kAxis4]
<< "), and 'kernel size': (" << kernel_size[kAxis0] << ", " << w_shape[kAxis1] << ", "
<< w_shape[kAxis2] << ").";
}
int64_t d_out = abstract::Shape::kShapeDimAny;
int64_t w_out = abstract::Shape::kShapeDimAny;
int64_t h_out = abstract::Shape::kShapeDimAny;