forked from mindspore-Ecosystem/mindspore
!10340 [MS][LITE][GPU]matmul support 3d
From: @chenzupeng Reviewed-by: Signed-off-by:
This commit is contained in:
commit
03b52b88b8
|
@ -43,9 +43,9 @@ int MatMulOpenCLKernel::CheckSpecs() {
|
|||
}
|
||||
transposeB = param->b_transpose_;
|
||||
enable_fp16_ = ocl_runtime_->GetFp16Enable();
|
||||
if (in_tensors_[0]->shape().size() != out_tensors_[0]->shape().size() ||
|
||||
(in_tensors_[0]->shape().size() != 2 && in_tensors_[0]->shape().size() != 4)) {
|
||||
MS_LOG(ERROR) << "matmul only support input shape size=2 or 4.";
|
||||
if (in_tensors_[0]->shape().size() != out_tensors_[0]->shape().size() || in_tensors_[0]->shape().size() < 2 ||
|
||||
in_tensors_[0]->shape().size() > 4) {
|
||||
MS_LOG(ERROR) << "matmul only support input shape size= 2, 3 or 4.";
|
||||
return mindspore::lite::RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -58,7 +58,7 @@ int MatMulOpenCLKernel::Prepare() {
|
|||
inShape[MAX_DIMS - dims + i] = in_tensors_[0]->shape()[i];
|
||||
outShape[MAX_DIMS - dims + i] = out_tensors_[0]->shape()[i];
|
||||
}
|
||||
std::map<int, std::string> dims2str = {{2, "_2d"}, {4, "_4d"}};
|
||||
std::map<int, std::string> dims2str = {{2, "_2d"}, {3, "_4d"}, {4, "_4d"}};
|
||||
kernel_name += dims2str[dims];
|
||||
#ifdef PROGRAM_WITH_IL
|
||||
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
|
||||
|
|
|
@ -70,4 +70,23 @@ TEST_F(TestOpenCL_MatMul, 4D) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TestOpenCL_MatMul, 3D) {
|
||||
int a = 2;
|
||||
int m = 2;
|
||||
int ci = 5;
|
||||
int co = 3;
|
||||
std::vector<int> input_shape = {a, m, ci};
|
||||
std::vector<int> output_shape = {a, m, co};
|
||||
std::vector<int> weight_shape = {a, co, ci};
|
||||
float input_data[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
|
||||
float weight_data[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30};
|
||||
float output_data[] = {15, 40, 65, 15, 40, 65, 90, 115, 140, 90, 115, 140};
|
||||
|
||||
for (auto fp16_enable : {false, true}) {
|
||||
auto *param = CreateParameter();
|
||||
TestMain({{input_shape, input_data, VAR}, {weight_shape, weight_data, CONST_TENSOR}}, {output_shape, output_data},
|
||||
param, fp16_enable);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite::opencl::test
|
||||
|
|
Loading…
Reference in New Issue