forked from mindspore-Ecosystem/mindspore
!10710 add ut for primitve c of conv2d
From: @lianliguang Reviewed-by: @kisnwang,@chujinjin Signed-off-by: @chujinjin
This commit is contained in:
commit
4bb54e5ffe
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ir/dtype/tensor_type.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
|
@ -110,8 +111,8 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
MS_EXCEPTION_IF_NULL(conv_prim);
|
||||
auto prim_name = conv_prim->name();
|
||||
CheckAndConvertUtils::CheckInRange("conv2d_infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->BuildShape(), prim_name);
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
|
@ -181,7 +182,7 @@ TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
types.emplace("w", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
if (infer_type == kNumberTypeInt8) {
|
||||
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
|
||||
return TypeIdToType(kNumberTypeInt32);
|
||||
}
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
|
|
|
@ -33,7 +33,29 @@ TEST_F(TestConv2d, test_cops_conv2d) {
|
|||
conv_2d->Init(64, {7, 7});
|
||||
auto tensor_x = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{32, 3, 224, 224});
|
||||
auto tensor_w = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, std::vector<int64_t>{64, 3, 7, 7});
|
||||
conv_2d->Infer({tensor_w->ToAbstract(), tensor_w->ToAbstract()});
|
||||
MS_EXCEPTION_IF_NULL(tensor_x);
|
||||
MS_EXCEPTION_IF_NULL(tensor_w);
|
||||
auto conv_abstract = conv_2d->Infer({tensor_x->ToAbstract(), tensor_w->ToAbstract()});
|
||||
MS_EXCEPTION_IF_NULL(conv_abstract);
|
||||
EXPECT_EQ(conv_abstract->isa<abstract::AbstractTensor>(), true);
|
||||
auto shape_ptr = conv_abstract->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
EXPECT_EQ(shape_ptr->isa<abstract::Shape>(), true);
|
||||
auto conv_shape = shape_ptr->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_shape);
|
||||
auto shape_vec = conv_shape->shape();
|
||||
auto type = conv_abstract->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
EXPECT_EQ(type->isa<TensorType>(), true);
|
||||
auto tensor_type = type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto elem_type = tensor_type->element();
|
||||
EXPECT_EQ(elem_type->type_id(), kNumberTypeFloat32);
|
||||
EXPECT_EQ(shape_vec.size(), 4);
|
||||
EXPECT_EQ(shape_vec[0], 32);
|
||||
EXPECT_EQ(shape_vec[1], 64);
|
||||
EXPECT_EQ(shape_vec[2], 218);
|
||||
EXPECT_EQ(shape_vec[3], 218);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue