forked from mindspore-Ecosystem/mindspore
add fp16 kernel
This commit is contained in:
parent
ab26378451
commit
533a6574b2
|
@ -18,12 +18,18 @@
|
|||
#include <string.h>
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units) {
|
||||
if (output_ptr == NULL || update == NULL || output_unit_offsets == NULL || unit_size <= 0 || num_units < 0) {
|
||||
int DoScatterND(void *output, const void *update, int *output_unit_offsets, ScatterNDParameter *param, int task_id) {
|
||||
if (param->op_parameter.thread_num_ == 0) {
|
||||
return NNACL_ERR;
|
||||
}
|
||||
for (int i = 0; i < num_units; i++) {
|
||||
(void)memcpy(output_ptr + output_unit_offsets[i], update + unit_size * i, (size_t)(unit_size) * sizeof(float));
|
||||
int unit_per_thread = UP_DIV(param->num_unit, param->op_parameter.thread_num_);
|
||||
int begin = unit_per_thread * task_id;
|
||||
int end = MSMIN(begin + unit_per_thread, param->num_unit);
|
||||
|
||||
int data_type_len = param->data_type_len;
|
||||
for (int i = begin; i < end; i++) {
|
||||
(void)memcpy((int8_t *)output + output_unit_offsets[i] * data_type_len,
|
||||
(int8_t *)update + i * param->unit_size * data_type_len, param->unit_size * data_type_len);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -19,10 +19,17 @@
|
|||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct ScatterNDParameter {
|
||||
OpParameter op_parameter;
|
||||
int num_unit;
|
||||
int unit_size;
|
||||
int data_type_len;
|
||||
} ScatterNDParameter;
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int DoScatterND(float *output_ptr, const float *update, int *output_unit_offsets, int unit_size, int num_units);
|
||||
int DoScatterND(void *output, const void *update, int *output_unit_offsets, ScatterNDParameter *param, int task_id);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32/scatter_nd_fp32.h"
|
||||
using mindspore::schema::PrimitiveType_ScatterNd;
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -22,14 +23,14 @@ OpParameter *PopulateScatterNDParameter(const void *prim) {
|
|||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
|
||||
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
auto *param = static_cast<ScatterNDParameter *>(malloc(sizeof(ScatterNDParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(OpParameter));
|
||||
memset(param, 0, sizeof(ScatterNDParameter));
|
||||
|
||||
param->type_ = primitive->value_type();
|
||||
param->op_parameter.type_ = primitive->value_type();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_ScatterNd, PopulateScatterNDParameter, SCHEMA_CUR)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32/scatter_nd_fp32.h"
|
||||
using mindspore::schema::PrimitiveType_ScatterNdUpdate;
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -22,14 +23,14 @@ OpParameter *PopulateScatterNDUpdateParameter(const void *prim) {
|
|||
auto primitive = static_cast<const schema::Primitive *>(prim);
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
|
||||
auto *param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
auto *param = reinterpret_cast<ScatterNDParameter *>(malloc(sizeof(ScatterNDParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(param, 0, sizeof(OpParameter));
|
||||
memset(param, 0, sizeof(ScatterNDParameter));
|
||||
|
||||
param->type_ = primitive->value_type();
|
||||
param->op_parameter.type_ = primitive->value_type();
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
REG_POPULATE(PrimitiveType_ScatterNdUpdate, PopulateScatterNDUpdateParameter, SCHEMA_CUR)
|
||||
|
|
|
@ -16,19 +16,20 @@
|
|||
|
||||
#include "schema/model_v0_generated.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/fp32/scatter_nd_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
OpParameter *PopulateScatterNDParameter(const void *prim) {
|
||||
OpParameter *scatter_nd_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
|
||||
if (scatter_nd_param == nullptr) {
|
||||
auto *param = static_cast<ScatterNDParameter *>(malloc(sizeof(ScatterNDParameter)));
|
||||
if (param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc ScatterNDParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(scatter_nd_param, 0, sizeof(OpParameter));
|
||||
scatter_nd_param->type_ = schema::PrimitiveType_ScatterNd;
|
||||
return reinterpret_cast<OpParameter *>(scatter_nd_param);
|
||||
memset(param, 0, sizeof(ScatterNDParameter));
|
||||
param->op_parameter.type_ = schema::PrimitiveType_ScatterNd;
|
||||
return reinterpret_cast<OpParameter *>(param);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
|
|
@ -43,19 +43,17 @@ int ScatterNDCPUKernel::Prepare() {
|
|||
}
|
||||
|
||||
int ScatterNDCPUKernel::ReSize() {
|
||||
auto shape = in_tensors_.at(kScatterShapeIndex);
|
||||
auto indices = in_tensors_.at(kScatterIndicesIndex);
|
||||
auto update = in_tensors_.at(kScatterUpdateIndex);
|
||||
|
||||
update_ptr_ = reinterpret_cast<float *>(update->MutableData());
|
||||
MS_ASSERT(update_ptr_ != nullptr);
|
||||
output_ptr_ = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
|
||||
MS_ASSERT(output_ptr_ != nullptr);
|
||||
auto indices = in_tensors_[kScatterIndicesIndex];
|
||||
auto update = in_tensors_[kScatterUpdateIndex];
|
||||
auto shape = in_tensors_[kScatterShapeIndex];
|
||||
CHECK_NULL_RETURN(indices);
|
||||
CHECK_NULL_RETURN(update);
|
||||
CHECK_NULL_RETURN(shape);
|
||||
|
||||
// check indices shape
|
||||
auto shape_rank = shape->ElementsNum();
|
||||
auto shape_data = reinterpret_cast<int *>(shape->MutableData());
|
||||
MS_ASSERT(shape_data != nullptr);
|
||||
auto shape_data = reinterpret_cast<int *>(shape->data());
|
||||
CHECK_NULL_RETURN(shape_data);
|
||||
auto indice_unit_rank = indices->shape().back();
|
||||
if (indice_unit_rank > shape_rank) {
|
||||
MS_LOG(ERROR) << "Value of last dimension of indices is greater than shape rank.";
|
||||
|
@ -82,7 +80,7 @@ int ScatterNDCPUKernel::ReSize() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < shape->ElementsNum() - (indices_shape.size() - 1); i++) {
|
||||
for (size_t i = 0; i < shape_rank - (indices_shape.size() - 1); i++) {
|
||||
if (update_shape.at(i + indices_shape.size() - 1) != shape_data[i + indices_shape.size() - 1]) {
|
||||
MS_LOG(ERROR) << "Value of " << i + indices_shape.size() - 1
|
||||
<< " th dimension of indices is not equal to the corresbonding dimension of shape.";
|
||||
|
@ -90,82 +88,66 @@ int ScatterNDCPUKernel::ReSize() {
|
|||
}
|
||||
}
|
||||
|
||||
// calculate unit_size_
|
||||
unit_size_ = 1;
|
||||
// calculate unit_size
|
||||
param_->unit_size = 1;
|
||||
for (int i = indices_shape.size() - 1; i < update_rank; i++) {
|
||||
unit_size_ *= update_shape.at(i);
|
||||
param_->unit_size *= update_shape.at(i);
|
||||
}
|
||||
|
||||
// calculate offsets
|
||||
int out_stride = 1;
|
||||
out_strides_.push_back(1);
|
||||
std::vector<int> out_strides;
|
||||
out_strides.push_back(1);
|
||||
for (int i = indice_unit_rank - 2; i >= 0; i--) {
|
||||
out_stride *= shape_data[i + 1];
|
||||
out_strides_.push_back(out_stride);
|
||||
out_strides.push_back(out_stride);
|
||||
}
|
||||
|
||||
num_unit_ = 1;
|
||||
num_unit_ *= update_shape.at(indices_shape.size() - 2);
|
||||
param_->num_unit = 1;
|
||||
param_->num_unit *= update_shape.at(indices_shape.size() - C2NUM);
|
||||
for (int i = indices_shape.size() - 3; i >= 0; i--) {
|
||||
num_unit_ *= update_shape.at(i);
|
||||
param_->num_unit *= update_shape.at(i);
|
||||
}
|
||||
|
||||
int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
|
||||
int *indices_ptr = reinterpret_cast<int *>(indices->data());
|
||||
CHECK_NULL_RETURN(indices_ptr);
|
||||
output_unit_offsets_.clear();
|
||||
for (int i = 0; i < num_unit_; i++) {
|
||||
for (int i = 0; i < param_->num_unit; i++) {
|
||||
int tmp_stride = 0;
|
||||
for (int j = 0; j < indice_unit_rank; j++) {
|
||||
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
|
||||
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
|
||||
}
|
||||
output_unit_offsets_.push_back(tmp_stride);
|
||||
}
|
||||
|
||||
thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
|
||||
if (thread_n_num_ == 0) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScatterNDCPUKernel::ScatterND(int task_id) {
|
||||
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
|
||||
if (num_unit_thread <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int offset = task_id * thread_n_stride_;
|
||||
MS_LOG(ERROR) << "offset " << offset;
|
||||
auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
|
||||
unit_size_, num_unit_thread);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScatterND error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
void *update_data = in_tensors_[kScatterUpdateIndex]->data();
|
||||
auto output_tensor = out_tensors_[kOutputIndex];
|
||||
void *output_data = output_tensor->data();
|
||||
CHECK_NULL_RETURN(update_data);
|
||||
CHECK_NULL_RETURN(output_data);
|
||||
param_->data_type_len = output_tensor->data_type() == kNumberTypeFloat16 ? FP16_DATA_TYPE_LEN : sizeof(float);
|
||||
return DoScatterND(output_data, update_data, output_unit_offsets_.data(), param_, task_id);
|
||||
}
|
||||
|
||||
int ScatterNDRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
CHECK_NULL_RETURN(cdata);
|
||||
auto g_kernel = reinterpret_cast<ScatterNDCPUKernel *>(cdata);
|
||||
MS_ASSERT(g_kernel != nullptr);
|
||||
auto ret = g_kernel->ScatterND(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScatterNDRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
auto kernel = static_cast<ScatterNDCPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
return kernel->ScatterND(task_id);
|
||||
}
|
||||
|
||||
int ScatterNDCPUKernel::Run() {
|
||||
auto ret = ParallelLaunch(this->ms_context_, ScatterNDRun, this, thread_n_num_);
|
||||
auto ret = ParallelLaunch(ms_context_, ScatterNDRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScatterND error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
MS_LOG(ERROR) << "ScatterNDRun failed, ret: " << ret;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNd, LiteKernelCreator<ScatterNDCPUKernel>)
|
||||
#ifdef ENABLE_FP16
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ScatterNd, LiteKernelCreator<ScatterNDCPUKernel>)
|
||||
#endif
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -22,12 +22,13 @@
|
|||
#include "nnacl/fp32/scatter_nd_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class ScatterNDCPUKernel : public InnerKernel {
|
||||
public:
|
||||
explicit ScatterNDCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {}
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
param_ = reinterpret_cast<ScatterNDParameter *>(parameter);
|
||||
}
|
||||
~ScatterNDCPUKernel() override = default;
|
||||
|
||||
int Prepare() override;
|
||||
|
@ -36,13 +37,7 @@ class ScatterNDCPUKernel : public InnerKernel {
|
|||
int ScatterND(int task_id);
|
||||
|
||||
private:
|
||||
int thread_n_num_ = 1;
|
||||
int thread_n_stride_ = 1;
|
||||
int num_unit_ = 1;
|
||||
int unit_size_ = 1;
|
||||
float *output_ptr_ = nullptr;
|
||||
float *update_ptr_ = nullptr;
|
||||
std::vector<int> out_strides_;
|
||||
ScatterNDParameter *param_ = nullptr;
|
||||
std::vector<int> output_unit_offsets_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "src/runtime/kernel/arm/fp32/scatter_nd_update_fp32.h"
|
||||
#include <cstring>
|
||||
#include "src/runtime/kernel/arm/fp32/scatter_nd_fp32.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -48,9 +47,7 @@ int ScatterNdUpdateCPUKernel::ReSize() {
|
|||
auto indices = in_tensors_.at(kScatterIndicesIndex);
|
||||
auto update = in_tensors_.at(kScatterUpdateIndex);
|
||||
auto output = out_tensors_.front();
|
||||
|
||||
output_ptr_ = reinterpret_cast<float *>(output->MutableData());
|
||||
MS_ASSERT(output_ptr_ != nullptr);
|
||||
output_ptr_ = output->data();
|
||||
|
||||
// check indices shape
|
||||
int input_rank = static_cast<int>(input->shape().size());
|
||||
|
@ -69,55 +66,50 @@ int ScatterNdUpdateCPUKernel::ReSize() {
|
|||
int update_rank = static_cast<int>(update->shape().size());
|
||||
auto indices_shape = indices->shape();
|
||||
auto update_shape = update->shape();
|
||||
unit_size_ = 1;
|
||||
param_->unit_size = 1;
|
||||
for (int i = indices_shape.size() - 1; i < update_rank; i++) {
|
||||
unit_size_ *= update_shape.at(i);
|
||||
param_->unit_size *= update_shape.at(i);
|
||||
}
|
||||
|
||||
// calculate offsets
|
||||
int out_stride = 1;
|
||||
out_strides_.push_back(1);
|
||||
std::vector<int> out_strides;
|
||||
out_strides.push_back(1);
|
||||
for (int i = indice_unit_rank - 2; i >= 0; i--) {
|
||||
out_stride *= input->shape()[i + 1];
|
||||
out_strides_.push_back(out_stride);
|
||||
out_strides.push_back(out_stride);
|
||||
}
|
||||
std::reverse(out_strides_.begin(), out_strides_.end());
|
||||
std::reverse(out_strides.begin(), out_strides.end());
|
||||
|
||||
num_unit_ = 1;
|
||||
num_unit_ *= update_shape.at(indices_shape.size() - 2);
|
||||
param_->num_unit = 1;
|
||||
param_->num_unit *= update_shape.at(indices_shape.size() - C2NUM);
|
||||
for (int i = indices_shape.size() - 3; i >= 0; i--) {
|
||||
num_unit_ *= update_shape.at(i);
|
||||
param_->num_unit *= update_shape.at(i);
|
||||
}
|
||||
|
||||
int *indices_ptr = reinterpret_cast<int *>(indices->MutableData());
|
||||
MS_ASSERT(indices_ptr != nullptr);
|
||||
output_unit_offsets_.clear();
|
||||
for (int i = 0; i < num_unit_; i++) {
|
||||
for (int i = 0; i < param_->num_unit; i++) {
|
||||
int tmp_stride = 0;
|
||||
for (int j = 0; j < indice_unit_rank; j++) {
|
||||
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides_.at(j) * unit_size_;
|
||||
tmp_stride += indices_ptr[i * indice_unit_rank + j] * out_strides.at(j) * param_->unit_size;
|
||||
}
|
||||
output_unit_offsets_.push_back(tmp_stride);
|
||||
}
|
||||
|
||||
thread_n_num_ = MSMIN(op_parameter_->thread_num_, num_unit_);
|
||||
if (thread_n_num_ == 0) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
thread_n_stride_ = UP_DIV(num_unit_, thread_n_num_);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int ScatterNdUpdateCPUKernel::ScatterNdUpdate(int task_id) {
|
||||
int num_unit_thread = MSMIN(thread_n_stride_, num_unit_ - task_id * thread_n_stride_);
|
||||
if (num_unit_thread <= 0) {
|
||||
return RET_OK;
|
||||
}
|
||||
int offset = task_id * thread_n_stride_;
|
||||
auto ret = DoScatterND(output_ptr_, update_ptr_ + offset * unit_size_, output_unit_offsets_.data() + offset,
|
||||
unit_size_, num_unit_thread);
|
||||
void *update_data = in_tensors_[kScatterUpdateIndex]->data();
|
||||
auto output_tensor = out_tensors_[kOutputIndex];
|
||||
void *output_data = output_tensor->data();
|
||||
CHECK_NULL_RETURN(update_data);
|
||||
CHECK_NULL_RETURN(output_data);
|
||||
param_->data_type_len = output_tensor->data_type() == kNumberTypeFloat16 ? FP16_DATA_TYPE_LEN : sizeof(float);
|
||||
auto ret = DoScatterND(output_data, update_data, output_unit_offsets_.data(), param_, task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScatterNdUpdate error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
MS_LOG(ERROR) << "DoScatterND failed, ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
in_tensors_.at(kScatterUpdateInputIndex)->IncRefCount();
|
||||
|
@ -125,14 +117,9 @@ int ScatterNdUpdateCPUKernel::ScatterNdUpdate(int task_id) {
|
|||
}
|
||||
|
||||
int ScatterNdUpdateRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||
auto g_kernel = reinterpret_cast<ScatterNdUpdateCPUKernel *>(cdata);
|
||||
MS_ASSERT(g_kernel != nullptr);
|
||||
auto ret = g_kernel->ScatterNdUpdate(task_id);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScatterNdUpdateRun error task_id[" << task_id << "] error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
auto kernel = static_cast<ScatterNdUpdateCPUKernel *>(cdata);
|
||||
CHECK_NULL_RETURN(kernel);
|
||||
return kernel->ScatterNdUpdate(task_id);
|
||||
}
|
||||
|
||||
int ScatterNdUpdateCPUKernel::Run() {
|
||||
|
@ -147,25 +134,23 @@ int ScatterNdUpdateCPUKernel::Run() {
|
|||
in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
|
||||
out_tensor->set_data(in_tensor->data());
|
||||
out_tensor->set_own_data(in_tensor->own_data());
|
||||
output_ptr_ = reinterpret_cast<float *>(out_tensor->data());
|
||||
output_ptr_ = out_tensor->data();
|
||||
}
|
||||
auto indices = in_tensors_.at(kScatterIndicesIndex);
|
||||
if (!indices->IsConst() && ReSize() != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScatterNdUpdate resize failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto update = in_tensors_.at(kScatterUpdateIndex);
|
||||
update_ptr_ = reinterpret_cast<float *>(update->MutableData());
|
||||
MS_ASSERT(update_ptr_ != nullptr);
|
||||
|
||||
auto ret = ParallelLaunch(this->ms_context_, ScatterNdUpdateRun, this, thread_n_num_);
|
||||
auto ret = ParallelLaunch(ms_context_, ScatterNdUpdateRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "ScatterNdUpdate error error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
return ret;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ScatterNdUpdate, LiteKernelCreator<ScatterNdUpdateCPUKernel>)
|
||||
#ifdef ENABLE_FP16
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ScatterNdUpdate, LiteKernelCreator<ScatterNdUpdateCPUKernel>)
|
||||
#endif
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/fp32/scatter_nd_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
|
@ -26,7 +27,9 @@ class ScatterNdUpdateCPUKernel : public InnerKernel {
|
|||
public:
|
||||
explicit ScatterNdUpdateCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {}
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
param_ = reinterpret_cast<ScatterNDParameter *>(parameter);
|
||||
}
|
||||
~ScatterNdUpdateCPUKernel() override = default;
|
||||
|
||||
int Prepare() override;
|
||||
|
@ -35,13 +38,8 @@ class ScatterNdUpdateCPUKernel : public InnerKernel {
|
|||
int ScatterNdUpdate(int task_id);
|
||||
|
||||
private:
|
||||
int thread_n_num_ = 1;
|
||||
int thread_n_stride_ = 1;
|
||||
int num_unit_ = 1;
|
||||
int unit_size_ = 1;
|
||||
float *output_ptr_ = nullptr;
|
||||
float *update_ptr_ = nullptr;
|
||||
std::vector<int> out_strides_;
|
||||
ScatterNDParameter *param_ = nullptr;
|
||||
void *output_ptr_ = nullptr;
|
||||
std::vector<int> output_unit_offsets_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue