add fp16 kernel

This commit is contained in:
sunsuodong 2021-11-28 01:27:32 -08:00
parent ab26378451
commit 533a6574b2
9 changed files with 108 additions and 132 deletions

View File

@ -18,12 +18,18 @@
#include <string.h>
#include "nnacl/errorcode.h"
int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units) {
if (output_ptr == NULL || update == NULL || output_unit_offsets == NULL || unit_size <= 0 || num_units < 0) {
int DoScatterND(void *output, const void *update, int *output_unit_offsets, ScatterNDParameter *param, int task_id) {
if (param->op_parameter.thread_num_ == 0) {
return NNACL_ERR;
}
for (int i = 0; i < num_units; i++) {
(void)memcpy(output_ptr + output_unit_offsets[i], update + unit_size * i, (size_t)(unit_size) * sizeof(float));
int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_);
int begin = unit_per_thread * task_id;
int end = MSMIN(begin + unit_per_thread, param->num_unit);
int data_type_len = param->data_type_len;
for (int i = begin; i < end; i++) {
(void)memcpy((int8_t *)output + output_unit_offsets[i] * data_type_len,
(int8_t *)update + i * param->unit_size * data_type_len, param->unit_size * data_type_len);
}
return NNACL_OK;
}

View File

@ -19,10 +19,17 @@
#include "nnacl/op_base.h"
typedef struct ScatterNDParameter {
OpParameter op_parameter;
int num_unit;
int unit_size;
int data_type_len;
} ScatterNDParameter;
#ifdef __cplusplus
extern "C" {
#endif
int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units);
int DoScatterND(void *output, const void *update, int *output_unit_offsets, ScatterNDParameter *param, int task_id);
#ifdef __cplusplus
}
#endif

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32/scatter_nd_fp32.h"
using mindspore::schema::PrimitiveType_ScatterNd;
namespace mindspore {
@ -22,14 +23,14 @@ OpParameter *PopulateScatterNDParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
auto *param = static_cast<ScatterNDParameter *>(malloc(sizeof(ScatterNDParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
memset(param, 0, sizeof(ScatterNDParameter));
param->type_ = primitive->value_type();
param->op_parameter.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_ScatterNd, PopulateScatterNDParameter, SCHEMA_CUR)

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32/scatter_nd_fp32.h"
using mindspore::schema::PrimitiveType_ScatterNdUpdate;
namespace mindspore {
@ -22,14 +23,14 @@ OpParameter *PopulateScatterNDUpdateParameter(const void *prim) {
auto primitive = static_cast<const schema::Primitive *>(prim);
MS_ASSERT(primitive != nullptr);
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
auto *param = reinterpret_cast<ScatterNDParameter *>(malloc(sizeof(ScatterNDParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(OpParameter));
memset(param, 0, sizeof(ScatterNDParameter));
param->type_ = primitive->value_type();
param->op_parameter.type_ = primitive->value_type();
return reinterpret_cast<OpParameter *>(param);
}
REG_POPULATE(PrimitiveType_ScatterNdUpdate, PopulateScatterNDUpdateParameter, SCHEMA_CUR)

View File

@ -16,19 +16,20 @@
#include "schema/model_v0_generated.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/fp32/scatter_nd_fp32.h"
namespace mindspore {
namespace lite {
namespace {
OpParameter *PopulateScatterNDParameter(const void *prim) {
OpParameter *scatter_nd_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (scatter_nd_param == nullptr) {
auto *param = static_cast<ScatterNDParameter *>(malloc(sizeof(ScatterNDParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
return nullptr;
}
memset(scatter_nd_param, 0, sizeof(OpParameter));
scatter_nd_param->type_ = schema::PrimitiveType_ScatterNd;
return reinterpret_cast<OpParameter *>(scatter_nd_param);
memset(param, 0, sizeof(ScatterNDParameter));
param->op_parameter.type_ = schema::PrimitiveType_ScatterNd;
return reinterpret_cast<OpParameter *>(param);
}
} // namespace

View File

@ -43,19 +43,17 @@ int ScatterNDCPUKernel::Prepare() {
}
int ScatterNDCPUKernel::ReSize() {
auto shape = in_tensors_.at(kScatterShapeIndex);
auto indices = in_tensors_.at(kScatterIndicesIndex);
auto update = in_tensors_.at(kScatterUpdateIndex);
update_ptr_ = reinterpret_cast<float *>(update->MutableData());
MS_ASSERT(update_ptr_ != nullptr);
output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
MS_ASSERT(output_ptr_ != nullptr);
auto indices = in_tensors_[kScatterIndicesIndex];
auto update = in_tensors_[kScatterUpdateIndex];
auto shape = in_tensors_[kScatterShapeIndex];
CHECK_NULL_RETURN(indices);
CHECK_NULL_RETURN(update);
CHECK_NULL_RETURN(shape);
// check indices shape
auto shape_rank = shape->ElementsNum();
auto shape_data = reinterpret_cast<int *>(shape->MutableData());
MS_ASSERT(shape_data != nullptr);
auto shape_data = reinterpret_cast<int *>(shape->data());
CHECK_NULL_RETURN(shape_data);
auto indice_unit_rank = indices->shape().back();
if (indice_unit_rank > shape_rank) {
MS_LOG(ERROR) << "Value of last dimension of indices is greater than shape rank.";
@ -82,7 +80,7 @@ int ScatterNDCPUKernel::ReSize() {
return RET_ERROR;
}
}
for (size_t i = 0; i < shape->ElementsNum() - (indices_shape.size() - 1); i++) {
for (size_t i = 0; i < shape_rank - (indices_shape.size() - 1); i++) {
if (update_shape.at(i + indices_shape.size() - 1) != shape_data[i + indices_shape.size() - 1]) {
MS_LOG(ERROR) << "Value of " << i + indices_shape.size() - 1
<< " th dimension of indices is not equal to the corresbonding dimension of shape.";
@ -90,82 +88,66 @@ int ScatterNDCPUKernel::ReSize() {
}
}
// calculate unit_size_
unit_size_ = 1;
// calculate unit_size
param_->unit_size = 1;
for (int i = indices_shape.size() - 1; i < update_rank; i++) {
unit_size_ *= update_shape.at(i);
param_->unit_size *= update_shape.at(i);
}
// calculate offsets
int out_stride = 1;
out_strides_.push_back(1);
std::vector<int> out_strides;
out_strides.push_back(1);
for (int i = indice_unit_rank - 2; i >= 0; i--) {
out_stride *= shape_data[i + 1];
out_strides_.push_back(out_stride);
out_strides.push_back(out_stride);
}
num_unit_ = 1;
num_unit_ *= update_shape.at(indices_shape.size() - 2);
param_->num_unit = 1;
param_->num_unit *= update_shape.at(indices_shape.size() - C2NUM);
for (int i = indices_shape.size() - 3; i >= 0; i--) {
num_unit_ *= update_shape.at(i);
param_->num_unit *= update_shape.at(i);
}
int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
int *indices_ptr = reinterpret_cast<int *>(indices->data());
CHECK_NULL_RETURN(indices_ptr);
output_unit_offsets_.clear();
for (int i = 0; i < num_unit_; i++) {
for (int i = 0; i < param_->num_unit; i++) {
int tmp_stride = 0;
for (int j = 0; j < indice_unit_rank; j++) {
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
}
output_unit_offsets_.push_back(tmp_stride);
}
thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
if (thread_n_num_ == 0) {
return RET_ERROR;
}
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
return RET_OK;
}
int ScatterNDCPUKernel::ScatterND(int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int offset = task_id * thread_n_stride_;
MS_LOG(ERROR) << "offset " << offset;
auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
unit_size_, num_unit_thread);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterND error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
void *update_data = in_tensors_[kScatterUpdateIndex]->data();
auto output_tensor = out_tensors_[kOutputIndex];
void *output_data = output_tensor->data();
CHECK_NULL_RETURN(update_data);
CHECK_NULL_RETURN(output_data);
param_->data_type_len = output_tensor->data_type() == kNumberTypeFloat16 ? FP16_DATA_TYPE_LEN : sizeof(float);
return DoScatterND(output_data, update_data, output_unit_offsets_.data(), param_, task_id);
}
int ScatterNDRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
CHECK_NULL_RETURN(cdata);
auto g_kernel = reinterpret_cast<ScatterNDCPUKernel *>(cdata);
MS_ASSERT(g_kernel != nullptr);
auto ret = g_kernel->ScatterND(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterNDRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
auto kernel = static_cast<ScatterNDCPUKernel *>(cdata);
CHECK_NULL_RETURN(kernel);
return kernel->ScatterND(task_id);
}
int ScatterNDCPUKernel::Run() {
auto ret = ParallelLaunch(this->ms_context_, ScatterNDRun, this, thread_n_num_);
auto ret = ParallelLaunch(ms_context_, ScatterNDRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterND error error_code[" << ret << "]";
return RET_ERROR;
MS_LOG(ERROR) << "ScatterNDRun failed, ret: " << ret;
}
return RET_OK;
return ret;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNd, LiteKernelCreator<ScatterNDCPUKernel>)
#ifdef ENABLE_FP16
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ScatterNd, LiteKernelCreator<ScatterNDCPUKernel>)
#endif
} // namespace mindspore::kernel

View File

@ -22,12 +22,13 @@
#include "nnacl/fp32/scatter_nd_fp32.h"
namespace mindspore::kernel {
class ScatterNDCPUKernel : public InnerKernel {
public:
explicit ScatterNDCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
: InnerKernel(parameter, inputs, outputs, ctx) {
param_ = reinterpret_cast<ScatterNDParameter *>(parameter);
}
~ScatterNDCPUKernel() override = default;
int Prepare() override;
@ -36,13 +37,7 @@ class ScatterNDCPUKernel : public InnerKernel {
int ScatterND(int task_id);
private:
int thread_n_num_ = 1;
int thread_n_stride_ = 1;
int num_unit_ = 1;
int unit_size_ = 1;
float *output_ptr_ = nullptr;
float *update_ptr_ = nullptr;
std::vector<int> out_strides_;
ScatterNDParameter *param_ = nullptr;
std::vector<int> output_unit_offsets_;
};
} // namespace mindspore::kernel

View File

@ -16,7 +16,6 @@
#include "src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h"
#include <cstring>
#include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
@ -48,9 +47,7 @@ int ScatterNdUpdateCPUKernel::ReSize() {
auto indices = in_tensors_.at(kScatterIndicesIndex);
auto update = in_tensors_.at(kScatterUpdateIndex);
auto output = out_tensors_.front();
output_ptr_ = reinterpret_cast<float *>(output->MutableData());
MS_ASSERT(output_ptr_ != nullptr);
output_ptr_ = output->data();
// check indices shape
int input_rank = static_cast<int>(input->shape().size());
@ -69,55 +66,50 @@ int ScatterNdUpdateCPUKernel::ReSize() {
int update_rank = static_cast<int>(update->shape().size());
auto indices_shape = indices->shape();
auto update_shape = update->shape();
unit_size_ = 1;
param_->unit_size = 1;
for (int i = indices_shape.size() - 1; i < update_rank; i++) {
unit_size_ *= update_shape.at(i);
param_->unit_size *= update_shape.at(i);
}
// calculate offsets
int out_stride = 1;
out_strides_.push_back(1);
std::vector<int> out_strides;
out_strides.push_back(1);
for (int i = indice_unit_rank - 2; i >= 0; i--) {
out_stride *= input->shape()[i + 1];
out_strides_.push_back(out_stride);
out_strides.push_back(out_stride);
}
std::reverse(out_strides_.begin(), out_strides_.end());
std::reverse(out_strides.begin(), out_strides.end());
num_unit_ = 1;
num_unit_ *= update_shape.at(indices_shape.size() - 2);
param_->num_unit = 1;
param_->num_unit *= update_shape.at(indices_shape.size() - C2NUM);
for (int i = indices_shape.size() - 3; i >= 0; i--) {
num_unit_ *= update_shape.at(i);
param_->num_unit *= update_shape.at(i);
}
int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
MS_ASSERT(indices_ptr != nullptr);
output_unit_offsets_.clear();
for (int i = 0; i < num_unit_; i++) {
for (int i = 0; i < param_->num_unit; i++) {
int tmp_stride = 0;
for (int j = 0; j < indice_unit_rank; j++) {
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
}
output_unit_offsets_.push_back(tmp_stride);
}
thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
if (thread_n_num_ == 0) {
return RET_ERROR;
}
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
return RET_OK;
}
int ScatterNdUpdateCPUKernel::ScatterNdUpdate(int task_id) {
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
if (num_unit_thread <= 0) {
return RET_OK;
}
int offset = task_id * thread_n_stride_;
auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
unit_size_, num_unit_thread);
void *update_data = in_tensors_[kScatterUpdateIndex]->data();
auto output_tensor = out_tensors_[kOutputIndex];
void *output_data = output_tensor->data();
CHECK_NULL_RETURN(update_data);
CHECK_NULL_RETURN(output_data);
param_->data_type_len = output_tensor->data_type() == kNumberTypeFloat16 ? FP16_DATA_TYPE_LEN : sizeof(float);
auto ret = DoScatterND(output_data, update_data, output_unit_offsets_.data(), param_, task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdate error task_id[" << task_id << "] error_code[" << ret << "]";
MS_LOG(ERROR) << "DoScatterND failed, ret: " << ret;
return RET_ERROR;
}
in_tensors_.at(kScatterUpdateInputIndex)->IncRefCount();
@ -125,14 +117,9 @@ int ScatterNdUpdateCPUKernel::ScatterNdUpdate(int task_id) {
}
int ScatterNdUpdateRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
auto g_kernel = reinterpret_cast<ScatterNdUpdateCPUKernel *>(cdata);
MS_ASSERT(g_kernel != nullptr);
auto ret = g_kernel->ScatterNdUpdate(task_id);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdateRun error task_id[" << task_id << "] error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
auto kernel = static_cast<ScatterNdUpdateCPUKernel *>(cdata);
CHECK_NULL_RETURN(kernel);
return kernel->ScatterNdUpdate(task_id);
}
int ScatterNdUpdateCPUKernel::Run() {
@ -147,25 +134,23 @@ int ScatterNdUpdateCPUKernel::Run() {
in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
out_tensor->set_data(in_tensor->data());
out_tensor->set_own_data(in_tensor->own_data());
output_ptr_ = reinterpret_cast<float *>(out_tensor->data());
output_ptr_ = out_tensor->data();
}
auto indices = in_tensors_.at(kScatterIndicesIndex);
if (!indices->IsConst() && ReSize() != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdate resize failed.";
return RET_ERROR;
}
auto update = in_tensors_.at(kScatterUpdateIndex);
update_ptr_ = reinterpret_cast<float *>(update->MutableData());
MS_ASSERT(update_ptr_ != nullptr);
auto ret = ParallelLaunch(this->ms_context_, ScatterNdUpdateRun, this, thread_n_num_);
auto ret = ParallelLaunch(ms_context_, ScatterNdUpdateRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ScatterNdUpdate error error_code[" << ret << "]";
return RET_ERROR;
}
return RET_OK;
return ret;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNdUpdate, LiteKernelCreator<ScatterNdUpdateCPUKernel>)
#ifdef ENABLE_FP16
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ScatterNdUpdate, LiteKernelCreator<ScatterNdUpdateCPUKernel>)
#endif
} // namespace mindspore::kernel

View File

@ -19,6 +19,7 @@
#include <vector>
#include "src/inner_kernel.h"
#include "nnacl/fp32/scatter_nd_fp32.h"
namespace mindspore::kernel {
@ -26,7 +27,9 @@ class ScatterNdUpdateCPUKernel : public InnerKernel {
public:
explicit ScatterNdUpdateCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: InnerKernel(parameter, inputs, outputs, ctx) {}
: InnerKernel(parameter, inputs, outputs, ctx) {
param_ = reinterpret_cast<ScatterNDParameter *>(parameter);
}
~ScatterNdUpdateCPUKernel() override = default;
int Prepare() override;
@ -35,13 +38,8 @@ class ScatterNdUpdateCPUKernel : public InnerKernel {
int ScatterNdUpdate(int task_id);
private:
int thread_n_num_ = 1;
int thread_n_stride_ = 1;
int num_unit_ = 1;
int unit_size_ = 1;
float *output_ptr_ = nullptr;
float *update_ptr_ = nullptr;
std::vector<int> out_strides_;
ScatterNDParameter *param_ = nullptr;
void *output_ptr_ = nullptr;
std::vector<int> output_unit_offsets_;
};
} // namespace mindspore::kernel