forked from mindspore-Ecosystem/mindspore
reduce mean fp16
This commit is contained in:
parent
5f27ff4afe
commit
5787dd642d
|
@ -196,4 +196,5 @@ kernel::LiteKernel *CpuReduceInt8KernelCreator(const std::vector<lite::tensor::T
|
|||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Reduce, CpuReduceFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mean, CpuMeanFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Reduce, CpuReduceInt8KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Mean, CpuReduceInt8KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -0,0 +1,216 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/reduce_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "nnacl/fp16/reduce_fp16.h"
|
||||
#include "src/runtime/kernel/arm/base/reduce_base.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_Mean;
|
||||
using mindspore::schema::PrimitiveType_Reduce;
|
||||
using mindspore::schema::ReduceMode;
|
||||
using mindspore::schema::ReduceMode_ReduceMax;
|
||||
using mindspore::schema::ReduceMode_ReduceMean;
|
||||
using mindspore::schema::ReduceMode_ReduceMin;
|
||||
using mindspore::schema::ReduceMode_ReduceProd;
|
||||
using mindspore::schema::ReduceMode_ReduceSum;
|
||||
using mindspore::schema::ReduceMode_ReduceSumSquare;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int ReduceFp16CPUKernel::Init() {
|
||||
auto ret = ReduceBaseCPUKernel::Init();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
if (mode_ != static_cast<int>(ReduceMode_ReduceMean)) {
|
||||
MS_LOG(ERROR) << "Reduce fp16 only support ReduceMode_ReduceMean";
|
||||
return RET_ERROR;
|
||||
}
|
||||
reducer_ = ReduceMean;
|
||||
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int ReduceFp16CPUKernel::ReSize() {
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
fp16_input_ = nullptr;
|
||||
}
|
||||
auto ele_num = in_tensors_.at(0)->ElementsNum();
|
||||
fp16_input_ = reinterpret_cast<float16_t *>(malloc(sizeof(float16_t) * ele_num));
|
||||
if (fp16_input_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc fp16_src_data_ falied";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return MallocTmpBuffer();
|
||||
}
|
||||
|
||||
int ReduceFp16CPUKernel::CallReduceUnit(int task_id) {
|
||||
auto ret = reducer_(outer_size_, inner_size_, axis_size_, fp16_src_data_, tmp_shape_.data(), fp16_dst_data_, task_id,
|
||||
context_->thread_num_);
|
||||
return ret;
|
||||
}
|
||||
|
||||
int ReduceImpl(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
|
||||
auto reduce = reinterpret_cast<ReduceFp16CPUKernel *>(cdata);
|
||||
auto error_code = reduce->CallReduceUnit(task_id);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reduce Run error task_id[" << task_id << "] error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ReduceFp16CPUKernel::Run() {
|
||||
auto prepare_ret = Prepare();
|
||||
if (prepare_ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
||||
return prepare_ret;
|
||||
}
|
||||
tmp_shape_ = in_tensors_.at(0)->shape();
|
||||
src_data_ = static_cast<float *>(in_tensors_.at(0)->Data());
|
||||
auto ele_num = in_tensors_.at(0)->ElementsNum();
|
||||
Float32ToFloat16(src_data_, fp16_input_, ele_num);
|
||||
fp16_src_data_ = fp16_input_;
|
||||
for (int i = 0; i < data_buffers_.size(); ++i) {
|
||||
fp16_dst_data_ = data_buffers_[i];
|
||||
int axis = axes_[i];
|
||||
outer_size_ = 1;
|
||||
for (int j = 0; j < axis; j++) {
|
||||
outer_size_ *= tmp_shape_[j];
|
||||
}
|
||||
inner_size_ = 1;
|
||||
for (int k = axis + 1; k < static_cast<int>(tmp_shape_.size()); k++) {
|
||||
inner_size_ *= tmp_shape_[k];
|
||||
}
|
||||
axis_size_ = tmp_shape_[axis];
|
||||
auto error_code = LiteBackendParallelLaunch(ReduceImpl, this, context_->thread_num_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
tmp_shape_[axis] = 1;
|
||||
fp16_src_data_ = fp16_dst_data_;
|
||||
}
|
||||
|
||||
dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->Data());
|
||||
Float16ToFloat32(fp16_dst_data_, dst_data_, out_tensors_.at(0)->ElementsNum());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ReduceFp16CPUKernel::MallocTmpBuffer() {
|
||||
for (auto buffer : data_buffers_) {
|
||||
if (buffer != nullptr) {
|
||||
free(buffer);
|
||||
buffer = nullptr;
|
||||
}
|
||||
}
|
||||
data_buffers_.clear();
|
||||
|
||||
auto input_shape = in_tensors_.at(0)->shape();
|
||||
for (auto i = 0; i < num_axes_; i++) {
|
||||
int axis = axes_[i];
|
||||
size_t size = 1;
|
||||
for (auto j = 0; j < input_shape.size(); j++) {
|
||||
if (static_cast<size_t>(axis) != j) {
|
||||
size *= input_shape[j];
|
||||
}
|
||||
}
|
||||
float16_t *buffer = reinterpret_cast<float16_t *>(malloc(size * sizeof(float16_t)));
|
||||
if (buffer == nullptr) {
|
||||
MS_LOG(ERROR) << "Malloc data failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
data_buffers_.emplace_back(buffer);
|
||||
input_shape[axis] = 1;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuReduceFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Reduce);
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Reduce opParameter nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
if (desc.type != schema::PrimitiveType_Reduce) {
|
||||
MS_LOG(ERROR) << "Reduce op desc.type should be PrimitiveType_Reduce, got " << desc.type;
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) ReduceFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuMeanFp16KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||
OpParameter *opParameter, const lite::Context *ctx,
|
||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||
MS_ASSERT(opParameter != nullptr);
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_Mean);
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Reduce opParameter nullptr";
|
||||
return nullptr;
|
||||
}
|
||||
if (desc.type != schema::PrimitiveType_Mean) {
|
||||
MS_LOG(ERROR) << "Reduce op desc.type should be PrimitiveType_Mean, got " << desc.type;
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) ReduceFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "Reduce new ReduceCPUKernel failed.";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Reduce, CpuReduceFp16KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mean, CpuMeanFp16KernelCreator)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_REDUCE_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_REDUCE_FP16_H_
|
||||
|
||||
#include <arm_neon.h>
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/reduce_base.h"
|
||||
#include "ir/anf.h"
|
||||
using mindspore::schema::ReduceMode;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class ReduceFp16CPUKernel : public ReduceBaseCPUKernel {
|
||||
typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data,
|
||||
const int *src_shape, float16_t *dst_data, const int tid, const int thread_num);
|
||||
|
||||
public:
|
||||
ReduceFp16CPUKernel(OpParameter *param, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
|
||||
const lite::Primitive *primitive)
|
||||
: ReduceBaseCPUKernel(param, inputs, outputs, ctx, primitive) {}
|
||||
~ReduceFp16CPUKernel() {
|
||||
for (auto i = 0; i < data_buffers_.size(); i++) {
|
||||
float16_t *buffer = data_buffers_[i];
|
||||
if (buffer != nullptr) {
|
||||
free(buffer);
|
||||
buffer = nullptr;
|
||||
}
|
||||
}
|
||||
if (fp16_input_ != nullptr) {
|
||||
free(fp16_input_);
|
||||
fp16_input_ = nullptr;
|
||||
}
|
||||
src_data_ = nullptr;
|
||||
dst_data_ = nullptr;
|
||||
}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int CallReduceUnit(int task_id);
|
||||
|
||||
private:
|
||||
Reducer reducer_ = nullptr;
|
||||
std::vector<float16_t *> data_buffers_;
|
||||
const float *src_data_ = nullptr;
|
||||
float *dst_data_ = nullptr;
|
||||
float16_t *fp16_input_ = nullptr;
|
||||
const float16_t *fp16_src_data_ = nullptr;
|
||||
float16_t *fp16_dst_data_ = nullptr;
|
||||
|
||||
private:
|
||||
int MallocTmpBuffer();
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_REDUCE_FP16_H_
|
|
@ -19,7 +19,7 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/fp32/reduce.h"
|
||||
#include "nnacl/fp32/reduce.h"
|
||||
#include "src/runtime/kernel/arm/base/reduce_base.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <float.h>
|
||||
#include "nnacl/fp16/reduce_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data,
|
||||
const int *src_shape, float16_t *dst_data, const int tid, const int thread_num) {
|
||||
if (src_data == NULL || src_shape == NULL || dst_data == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
int i, j, k;
|
||||
for (j = tid; j < outer_size; j += thread_num) {
|
||||
const float16_t *outer_src = src_data + j * axis_size * inner_size;
|
||||
float16_t *outer_dst = dst_data + j * inner_size;
|
||||
for (k = 0; k < inner_size; k++) {
|
||||
const float16_t *inner_src = outer_src + k;
|
||||
float16_t *inner_dst = outer_dst + k;
|
||||
float16_t tmp = 0.0;
|
||||
for (i = 0; i < axis_size; i++) {
|
||||
tmp += inner_src[i * inner_size];
|
||||
}
|
||||
*inner_dst = tmp / (float16_t)axis_size;
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_REDUCE_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_REDUCE_FP16_H_
|
||||
#include "nnacl/op_base.h"
|
||||
#include "src/runtime/kernel/arm/nnacl/reduce_parameter.h"
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int ReduceMean(const int outer_size, const int inner_size, const int axis_size, const float16_t *src_data,
|
||||
const int *src_shape, float16_t *dst_data, const int tid, const int thread_num);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_FP16_REDUCE_FP16_H_
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <memory>
|
||||
#include "utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/arm/fp16/reduce_fp16.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class TestReduceFp16 : public mindspore::CommonTest {
|
||||
public:
|
||||
TestReduceFp16() = default;
|
||||
void Prepare(const std::vector<int> &input_shape, const std::vector<int> &output_shape, float *input_data,
|
||||
float *output_data, const int num_axis, const int *axes, const int thread_num);
|
||||
|
||||
void TearDown() override;
|
||||
|
||||
public:
|
||||
float err_tol = 1e-5;
|
||||
lite::tensor::Tensor in_tensor_;
|
||||
lite::tensor::Tensor out_tensor_;
|
||||
std::vector<lite::tensor::Tensor *> inputs_{&in_tensor_};
|
||||
std::vector<lite::tensor::Tensor *> outputs_{&out_tensor_};
|
||||
ReduceParameter param_ = {{}};
|
||||
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat16, schema::PrimitiveType_Reduce};
|
||||
lite::Context ctx_ = lite::Context();
|
||||
kernel::KernelCreator creator_ = nullptr;
|
||||
kernel::LiteKernel *kernel_ = nullptr;
|
||||
};
|
||||
|
||||
void TestReduceFp16::TearDown() {
|
||||
in_tensor_.SetData(nullptr);
|
||||
out_tensor_.SetData(nullptr);
|
||||
}
|
||||
|
||||
void TestReduceFp16::Prepare(const std::vector<int> &input_shape, const std::vector<int> &output_shape,
|
||||
float *input_data, float *output_data, const int num_axis, const int *axes,
|
||||
const int thread_num) {
|
||||
in_tensor_.set_data_type(kNumberTypeFloat32);
|
||||
in_tensor_.set_shape(input_shape);
|
||||
out_tensor_.set_data_type(kNumberTypeFloat32);
|
||||
out_tensor_.set_shape(output_shape);
|
||||
in_tensor_.SetData(input_data);
|
||||
out_tensor_.SetData(output_data);
|
||||
|
||||
bool keep_axis = false;
|
||||
|
||||
int mode = static_cast<int>(schema::ReduceMode_ReduceMean);
|
||||
ReduceParameter param_ = {{}};
|
||||
param_.keep_dims_ = keep_axis;
|
||||
for (auto i = 0; i < num_axis; i++) {
|
||||
param_.axes_[i] = axes[i];
|
||||
}
|
||||
param_.num_axes_ = num_axis;
|
||||
param_.mode_ = mode;
|
||||
|
||||
desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat16, schema::PrimitiveType_Reduce};
|
||||
ctx_ = lite::Context();
|
||||
ctx_.thread_num_ = thread_num;
|
||||
creator_ = lite::KernelRegistry::GetInstance()->GetCreator(desc);
|
||||
ASSERT_NE(creator_, nullptr);
|
||||
kernel_ = creator_(inputs_, outputs_, reinterpret_cast<OpParameter *>(¶m_), &ctx_, desc, nullptr);
|
||||
ASSERT_NE(kernel_, nullptr);
|
||||
}
|
||||
TEST_F(TestReduceFp16, Mean) {
|
||||
float in[96] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
|
||||
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0,
|
||||
32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0,
|
||||
48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0,
|
||||
64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0,
|
||||
80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0};
|
||||
float out[24] = {0.0f};
|
||||
float correct[24] = {18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
|
||||
66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0};
|
||||
|
||||
std::vector<int> input_shape = {2, 4, 4, 3};
|
||||
std::vector<int> output_shape = {2, 1, 4, 3};
|
||||
|
||||
int axes[] = {3};
|
||||
int num_axis = 1;
|
||||
int thread_num = 1;
|
||||
Prepare(input_shape, output_shape, in, out, num_axis, axes, thread_num);
|
||||
CompareOutputData(out, correct, 24, 1e-3);
|
||||
}
|
||||
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue