add space_to_batch_nd for opencl

This commit is contained in:
wandongdong 2020-09-29 05:55:04 -07:00
parent d86df26db8
commit 9b2771b444
5 changed files with 426 additions and 0 deletions

View File

@ -0,0 +1,40 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
__kernel void space_to_batch_nd_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size,
int4 dst_size, int2 block_size, int4 paddings) {
int X = get_global_id(0); // c
int Y = get_global_id(1); // w
int Z = get_global_id(2); // h
if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {
return;
}
for (int i = 0; i < block_size.x; ++i) {
for (int j = 0; j < block_size.y; ++j) {
int w_org = Y * block_size.y + j - paddings.z;
int h_org = Z * block_size.x + i - paddings.x;
FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
res_data = READ_IMAGE(src_data, smp_zero, (int2)(w_org * dst_size.x + X, h_org));
WRITE_IMAGE(dst_data, (int2)(Y * dst_size.x + X, (i * block_size.y + j) * dst_size.z + Z), res_data);
}
}
}
__kernel void space_to_batch_nd_NC4HW4(__read_only image2d_t src_data, __write_only image2d_t dst_data, int4 src_size,
int4 dst_size, int2 block_size, int4 paddings) {
int X = get_global_id(0); // c
int Y = get_global_id(1); // w
int Z = get_global_id(2); // h
if (X >= dst_size.x || Y >= dst_size.y || Y >= dst_size.z) {
return;
}
for (int i = 0; i < block_size.x; ++i) {
for (int j = 0; j < block_size.y; ++j) {
int w_org = Y * block_size.y + j - paddings.z;
int h_org = Z * block_size.x + i - paddings.x;
FLT4 res_data = (FLT4)(0.0f, 0.0f, 0.0f, 0.0f);
if (w_org >= 0 && w_org < src_size.y && h_org >= 0 && h_org < src_size.z) {
res_data = READ_IMAGE(src_data, smp_zero, (int2)(h_org * src_size.y + Y, X));
}
WRITE_IMAGE(dst_data, (int2)(Z * dst_size.y + Y, (i * block_size.y + j) * dst_size.x + X), res_data);
}
}
}

View File

@ -40,6 +40,9 @@ int GatherOpenCLKernel::Init() {
out_ori_format_ = out_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(op_format_);
out_tensors_[0]->SetFormat(op_format_);
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
if (in_format == schema::Format_NC4HW4) {
kernel_name += "_NC4HW4";
} else {
@ -50,6 +53,7 @@ int GatherOpenCLKernel::Init() {
std::string program_name = "gather";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
// init indices_data_
auto indices_tensor = in_tensors_.at(1);
int indices_num = indices_tensor->ElementsNum();

View File

@ -0,0 +1,150 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstring>
#include <string>
#include <algorithm>
#include <set>
#include <utility>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/kernel/space_to_batch_nd.h"
#include "src/runtime/kernel/opencl/cl/space_to_batch_nd.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_SpaceToBatchND;
namespace mindspore::kernel {
int SpaceToBatchNDOpenCLKernel::Init() {
std::string kernel_name = "space_to_batch_nd";
auto in_format = op_format_;
if (in_tensors_[0]->shape().size() != 4 && out_tensors_[0]->shape().size() != 4) {
MS_LOG(ERROR) << "input/output shape size must be 4, actual: " << in_tensors_[0]->shape().size() << ", "
<< out_tensors_[0]->shape().size();
return RET_ERROR;
}
if (in_format != schema::Format_NHWC4 && in_format != schema::Format_NC4HW4) {
MS_LOG(ERROR) << "input format(" << in_format << ") "
<< "format not support!";
return RET_ERROR;
}
auto *param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_);
param->need_paddings_ = (param->paddings_[0] | param->paddings_[1] | param->paddings_[2] | param->paddings_[3]);
param->padded_in_shape_[kNHWC_N] = in_tensors_[0]->shape().at(kNHWC_N);
param->padded_in_shape_[kNHWC_H] = in_tensors_[0]->shape().at(kNHWC_H) + param->paddings_[0] + param->paddings_[1];
param->padded_in_shape_[kNHWC_W] = in_tensors_[0]->shape().at(kNHWC_W) + param->paddings_[2] + param->paddings_[3];
param->padded_in_shape_[kNHWC_C] = in_tensors_[0]->shape().at(kNHWC_C);
if (param->block_sizes_[0] < 1 || param->block_sizes_[1] < 1) {
MS_LOG(ERROR) << "block_sizes_ must > 1, actual " << param->block_sizes_[0] << ", " << param->block_sizes_[1];
return RET_ERROR;
}
if (param->padded_in_shape_[kNHWC_H] % param->block_sizes_[0] ||
param->padded_in_shape_[kNHWC_W] % param->block_sizes_[1]) {
MS_LOG(ERROR) << "padded shape must be multiple of block!";
return RET_ERROR;
}
in_ori_format_ = in_tensors_[0]->GetFormat();
out_ori_format_ = out_tensors_[0]->GetFormat();
in_tensors_[0]->SetFormat(op_format_);
out_tensors_[0]->SetFormat(op_format_);
#ifdef PROGRAM_WITH_IL
kernel_ = ocl_runtime_->GetKernelFromBinary(kernel_name);
#else
if (in_format == schema::Format_NC4HW4) {
kernel_name += "_NC4HW4";
} else {
kernel_name += "_NHWC4";
}
std::set<std::string> build_options;
std::string source = space_to_batch_nd_source;
std::string program_name = "space_to_batch_nd";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif
return RET_OK;
}
int SpaceToBatchNDOpenCLKernel::InitBuffer() { return RET_OK; }
int SpaceToBatchNDOpenCLKernel::ReSize() { return RET_OK; }
int SpaceToBatchNDOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (in_tensors_[0]->GetFormat() == schema::Format::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * CO4;
im_dst_y = out_tensors_[0]->Height() * out_tensors_[0]->Batch();
} else {
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * CO4;
im_dst_x = out_tensors_[0]->Width();
}
size_t img_dtype = CL_FLOAT;
auto enable_fp16_ = ocl_runtime_->GetFp16Enable();
if (enable_fp16_) {
img_dtype = CL_HALF_FLOAT;
}
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = std::move(vec);
return RET_OK;
}
int SpaceToBatchNDOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
auto param = reinterpret_cast<SpaceToBatchParameter *>(this->op_parameter_);
auto input_shape = in_tensors_[0]->shape();
auto output_shape = out_tensors_[0]->shape();
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t CI4 = UP_DIV(in_tensors_[0]->Channel(), C4NUM);
cl_int4 src_size = {(cl_int)CI4, in_tensors_[0]->Width(), in_tensors_[0]->Height(), in_tensors_[0]->Batch()};
cl_int4 dst_size = {(cl_int)CO4, out_tensors_[0]->Width(), out_tensors_[0]->Height(), out_tensors_[0]->Batch()};
cl_int2 block_size = {param->block_sizes_[0], param->block_sizes_[1]};
cl_int4 paddings = {param->paddings_[0], param->paddings_[1], param->paddings_[2], param->paddings_[3]};
std::vector<size_t> local = {1, 1, 1};
std::vector<size_t> global = {(size_t)dst_size.s[0], (size_t)dst_size.s[1], (size_t)dst_size.s[2]};
int arg_cn = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_[0]->data_c(), lite::opencl::MemType::IMG);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_[0]->data_c(), lite::opencl::MemType::IMG);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, src_size);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, dst_size);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, block_size);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, paddings);
ocl_runtime_->RunKernel(kernel_, global, local, nullptr);
return RET_OK;
}
kernel::LiteKernel *OpenCLSpaceToBatchNDKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) SpaceToBatchNDOpenCLKernel(opParameter, inputs, outputs);
if (kernel == nullptr) {
MS_LOG(ERROR) << "Kernel " << opParameter->name_ << " new failed.";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Kernel " << opParameter->name_ << " init failed.";
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SpaceToBatchND, OpenCLSpaceToBatchNDKernelCreator);
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_SpaceToBatchND, OpenCLSpaceToBatchNDKernelCreator);
} // namespace mindspore::kernel

View File

@ -0,0 +1,48 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPACE_TO_BATCH_ND_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SPACE_TO_BATCH_ND_H_
#include <vector>
#include "src/runtime/kernel/opencl/opencl_kernel.h"
#include "nnacl/fp32/space_to_batch.h"
namespace mindspore::kernel {
class SpaceToBatchNDOpenCLKernel : public OpenCLKernel {
public:
explicit SpaceToBatchNDOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~SpaceToBatchNDOpenCLKernel() override{};
int Init() override;
int ReSize() override;
int Run() override;
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
int InitBuffer();
private:
cl::Kernel kernel_;
};
} // namespace mindspore::kernel
#endif

View File

@ -0,0 +1,184 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/space_to_batch_nd.h"
namespace mindspore {
class TestSpaceToBatchNDOpenCL : public mindspore::CommonTest {
public:
TestSpaceToBatchNDOpenCL() {}
};
template <typename T>
void test_main_space_to_batch_nd(void *input_data, void *correct_data, const std::vector<int> &input_shape,
SpaceToBatchParameter *param, TypeId data_type, schema::Format format) {
MS_LOG(INFO) << " begin test ";
auto ocl_runtime_wrap = lite::opencl::OpenCLRuntimeWrapper();
auto ocl_runtime = ocl_runtime_wrap.GetInstance();
ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
std::vector<int> output_shape = input_shape;
output_shape[0] = input_shape[0] * param->block_sizes_[0] * param->block_sizes_[1];
output_shape[1] = (input_shape[1] + param->paddings_[0] + param->paddings_[1]) / param->block_sizes_[0];
output_shape[2] = (input_shape[2] + +param->paddings_[2] + param->paddings_[3]) / param->block_sizes_[1];
auto tensor_a = lite::Tensor(TypeId(data_type), input_shape, format);
auto tensor_c = lite::Tensor(TypeId(data_type), output_shape, format);
std::vector<lite::Tensor *> inputs{&tensor_a};
std::vector<lite::Tensor *> outputs{&tensor_c};
size_t input_size = tensor_a.Size();
auto *pkernel =
new (std::nothrow) kernel::SpaceToBatchNDOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
if (pkernel == nullptr) {
MS_LOG(INFO) << "new SpaceToBatchNDOpenCLKernel failed ";
return;
}
pkernel->Init();
// to do allocate memory for inputs and outputs
for (auto &input_tensor : inputs) {
input_tensor->MallocData(allocator);
}
MS_LOG(INFO) << " initialize sub_graph ";
std::vector<kernel::LiteKernel *> kernels{pkernel};
auto *sub_graph = new (std::nothrow) kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
if (sub_graph == nullptr) {
delete pkernel;
MS_LOG(INFO) << " new SubGraphOpenCLKernel failed ";
return;
}
sub_graph->Init();
MS_LOG(INFO) << " init tensors ";
T *input_ptr = reinterpret_cast<T *>(inputs[0]->MutableData());
memcpy(input_ptr, input_data, input_size);
std::cout << "==================input data================" << std::endl;
for (auto i = 0; i < inputs[0]->ElementsNum(); ++i) {
std::cout << input_ptr[i] << ", ";
}
std::cout << std::endl;
sub_graph->Run();
auto *output_data = reinterpret_cast<T *>(outputs[0]->MutableData());
std::cout << "==================output data================" << std::endl;
for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) {
std::cout << output_data[i] << ", ";
}
std::cout << std::endl;
std::cout << "==================correct data================" << std::endl;
for (auto i = 0; i < outputs[0]->ElementsNum(); ++i) {
std::cout << static_cast<T *>(correct_data)[i] << ", ";
}
std::cout << std::endl;
CommonTest::CompareOutputData<T>(output_data, static_cast<T *>(correct_data), outputs[0]->ElementsNum(), 0.0001);
delete sub_graph;
}
TEST_F(TestSpaceToBatchNDOpenCL, NHWC4H2W2Pad2222) {
std::vector<int> input_shape{1, 6, 6, 4};
SpaceToBatchParameter *param = std::make_unique<SpaceToBatchParameter>().release();
if (param == nullptr) {
return;
}
param->block_sizes_[0] = 2;
param->block_sizes_[1] = 2;
param->paddings_[0] = 2;
param->paddings_[1] = 2;
param->paddings_[2] = 2;
param->paddings_[3] = 2;
float input_data[] = {172, 47, 117, 192, 67, 251, 195, 103, 9, 211, 21, 242, 36, 87, 70, 216, 88, 140,
58, 193, 230, 39, 87, 174, 88, 81, 165, 25, 77, 72, 9, 148, 115, 208, 243, 197,
254, 79, 175, 192, 82, 99, 216, 177, 243, 29, 147, 147, 142, 167, 32, 193, 9, 185,
127, 32, 31, 202, 244, 151, 163, 254, 203, 114, 183, 28, 34, 128, 128, 164, 53, 133,
38, 232, 244, 17, 79, 132, 105, 42, 186, 31, 120, 1, 65, 231, 169, 57, 35, 102,
119, 11, 174, 82, 91, 128, 142, 99, 53, 140, 121, 170, 84, 203, 68, 6, 196, 47,
127, 244, 131, 204, 100, 180, 232, 78, 143, 148, 227, 186, 23, 207, 141, 117, 85, 48,
49, 69, 169, 163, 192, 95, 197, 94, 0, 113, 178, 36, 162, 48, 93, 131, 98, 42};
float correct_data[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 172, 47, 117, 192, 9, 211, 21, 242, 88, 140, 58, 193, 0, 0, 0, 0, 0, 0, 0, 0, 142, 167,
32, 193, 31, 202, 244, 151, 183, 28, 34, 128, 0, 0, 0, 0, 0, 0, 0, 0, 142, 99, 53, 140, 68,
6, 196, 47, 100, 180, 232, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 67, 251, 195, 103, 36, 87, 70, 216, 230, 39, 87, 174, 0, 0,
0, 0, 0, 0, 0, 0, 9, 185, 127, 32, 163, 254, 203, 114, 128, 164, 53, 133, 0, 0, 0, 0, 0,
0, 0, 0, 121, 170, 84, 203, 127, 244, 131, 204, 143, 148, 227, 186, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 88, 81, 165, 25, 115, 208,
243, 197, 82, 99, 216, 177, 0, 0, 0, 0, 0, 0, 0, 0, 38, 232, 244, 17, 186, 31, 120, 1, 35,
102, 119, 11, 0, 0, 0, 0, 0, 0, 0, 0, 23, 207, 141, 117, 169, 163, 192, 95, 178, 36, 162, 48,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 77, 72, 9, 148, 254, 79, 175, 192, 243, 29, 147, 147, 0, 0, 0, 0, 0, 0, 0, 0, 79,
132, 105, 42, 65, 231, 169, 57, 174, 82, 91, 128, 0, 0, 0, 0, 0, 0, 0, 0, 85, 48, 49, 69,
197, 94, 0, 113, 93, 131, 98, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NHWC;
test_main_space_to_batch_nd<float>(input_data, correct_data, input_shape, param, data_type, format);
}
TEST_F(TestSpaceToBatchNDOpenCL, Nc4HW4H2W2Pad2222) {
std::vector<int> input_shape{1, 6, 6, 4};
SpaceToBatchParameter *param = std::make_unique<SpaceToBatchParameter>().release();
if (param == nullptr) {
return;
}
param->block_sizes_[0] = 2;
param->block_sizes_[1] = 2;
param->paddings_[0] = 2;
param->paddings_[1] = 2;
param->paddings_[2] = 2;
param->paddings_[3] = 2;
float input_data[] = {172, 47, 117, 192, 67, 251, 195, 103, 9, 211, 21, 242, 36, 87, 70, 216, 88, 140,
58, 193, 230, 39, 87, 174, 88, 81, 165, 25, 77, 72, 9, 148, 115, 208, 243, 197,
254, 79, 175, 192, 82, 99, 216, 177, 243, 29, 147, 147, 142, 167, 32, 193, 9, 185,
127, 32, 31, 202, 244, 151, 163, 254, 203, 114, 183, 28, 34, 128, 128, 164, 53, 133,
38, 232, 244, 17, 79, 132, 105, 42, 186, 31, 120, 1, 65, 231, 169, 57, 35, 102,
119, 11, 174, 82, 91, 128, 142, 99, 53, 140, 121, 170, 84, 203, 68, 6, 196, 47,
127, 244, 131, 204, 100, 180, 232, 78, 143, 148, 227, 186, 23, 207, 141, 117, 85, 48,
49, 69, 169, 163, 192, 95, 197, 94, 0, 113, 178, 36, 162, 48, 93, 131, 98, 42};
float correct_data[] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 172, 47, 117, 192, 9, 211, 21, 242, 88, 140, 58, 193, 0, 0, 0, 0, 0, 0, 0, 0, 142, 167,
32, 193, 31, 202, 244, 151, 183, 28, 34, 128, 0, 0, 0, 0, 0, 0, 0, 0, 142, 99, 53, 140, 68,
6, 196, 47, 100, 180, 232, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 67, 251, 195, 103, 36, 87, 70, 216, 230, 39, 87, 174, 0, 0,
0, 0, 0, 0, 0, 0, 9, 185, 127, 32, 163, 254, 203, 114, 128, 164, 53, 133, 0, 0, 0, 0, 0,
0, 0, 0, 121, 170, 84, 203, 127, 244, 131, 204, 143, 148, 227, 186, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 88, 81, 165, 25, 115, 208,
243, 197, 82, 99, 216, 177, 0, 0, 0, 0, 0, 0, 0, 0, 38, 232, 244, 17, 186, 31, 120, 1, 35,
102, 119, 11, 0, 0, 0, 0, 0, 0, 0, 0, 23, 207, 141, 117, 169, 163, 192, 95, 178, 36, 162, 48,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 77, 72, 9, 148, 254, 79, 175, 192, 243, 29, 147, 147, 0, 0, 0, 0, 0, 0, 0, 0, 79,
132, 105, 42, 65, 231, 169, 57, 174, 82, 91, 128, 0, 0, 0, 0, 0, 0, 0, 0, 85, 48, 49, 69,
197, 94, 0, 113, 93, 131, 98, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0};
TypeId data_type = kNumberTypeFloat32;
schema::Format format = schema::Format_NCHW;
test_main_space_to_batch_nd<float>(input_data, correct_data, input_shape, param, data_type, format);
}
} // namespace mindspore