!10710 add ut for primitve c of conv2d

From: @lianliguang
Reviewed-by: @kisnwang,@chujinjin
Signed-off-by: @chujinjin
This commit is contained in:
mindspore-ci-bot 2020-12-29 19:13:21 +08:00 committed by Gitee
commit 4bb54e5ffe
2 changed files with 27 additions and 4 deletions

View File

@ -20,6 +20,7 @@
#include <memory>
#include <set>
#include <vector>
#include "ir/dtype/tensor_type.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
@ -110,8 +111,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
@ -181,7 +182,7 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
types.emplace("w", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
return TypeIdToType(kNumberTypeInt32);
}
return TypeIdToType(infer_type);
}

View File

@ -33,7 +33,29 @@ TEST_F(TestConv2d, test_cops_conv2d) {
conv_2d->Init(64, {7, 7});
auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{32, 3, 224, 224});
auto tensor_w = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{64, 3, 7, 7});
conv_2d->Infer({tensor_w->ToAbstract(), tensor_w->ToAbstract()});
MS_EXCEPTION_IF_NULL(tensor_x);
MS_EXCEPTION_IF_NULL(tensor_w);
auto conv_abstract = conv_2d->Infer({tensor_x->ToAbstract(), tensor_w->ToAbstract()});
MS_EXCEPTION_IF_NULL(conv_abstract);
EXPECT_EQ(conv_abstract->isa<abstract::AbstractTensor>(), true);
auto shape_ptr = conv_abstract->BuildShape();
MS_EXCEPTION_IF_NULL(shape_ptr);
EXPECT_EQ(shape_ptr->isa<abstract::Shape>(), true);
auto conv_shape = shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(conv_shape);
auto shape_vec = conv_shape->shape();
auto type = conv_abstract->BuildType();
MS_EXCEPTION_IF_NULL(type);
EXPECT_EQ(type->isa<TensorType>(), true);
auto tensor_type = type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto elem_type = tensor_type->element();
EXPECT_EQ(elem_type->type_id(), kNumberTypeFloat32);
EXPECT_EQ(shape_vec.size(), 4);
EXPECT_EQ(shape_vec[0], 32);
EXPECT_EQ(shape_vec[1], 64);
EXPECT_EQ(shape_vec[2], 218);
EXPECT_EQ(shape_vec[3], 218);
}
} // namespace mindspore