forked from mindspore-Ecosystem/mindspore
commit
c4a1ab80a3
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp16/one_hot_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output,
|
||||
const OneHotParameter *one_hot_param, const int tid, const int thread_num) {
|
||||
if (indices == NULL || one_hot_param == NULL || output == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
if (thread_num == 0) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
|
||||
int outer_size = one_hot_param->outer_size_;
|
||||
int inner_size = one_hot_param->inner_size_;
|
||||
int depth = one_hot_param->depth_;
|
||||
int i, j, k;
|
||||
for (i = tid; i < outer_size; i += thread_num) {
|
||||
float16_t *output_ptr = output + i * depth * inner_size;
|
||||
for (k = 0; k < depth; k++) { // output layout: outer_size * depth * inner_size
|
||||
const int *indices_ptr = indices + i * inner_size;
|
||||
for (j = 0; j < inner_size; j++) {
|
||||
*output_ptr = off_value;
|
||||
int index = *(indices_ptr++);
|
||||
if (one_hot_param->support_neg_index_ && index < 0) {
|
||||
index += depth;
|
||||
}
|
||||
if (index == k) {
|
||||
*output_ptr = on_value;
|
||||
}
|
||||
output_ptr++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_FP16_ONE_HOT_H_
|
||||
#define MINDSPORE_NNACL_FP16_ONE_HOT_H_
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/one_hot_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int OneHotToFp16(const int *indices, float16_t on_value, float16_t off_value, float16_t *output,
|
||||
const OneHotParameter *one_hot_param, const int tid, const int thread_num);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP16_ONE_HOT_H_
|
|
@ -17,8 +17,8 @@
|
|||
#include "nnacl/fp32/one_hot_fp32.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_param, const int tid,
|
||||
const int thread_num) {
|
||||
int OneHotToFp32(const int *indices, float on_value, float off_value, float *output,
|
||||
const OneHotParameter *one_hot_param, const int tid, const int thread_num) {
|
||||
if (indices == NULL || one_hot_param == NULL || output == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
|
@ -29,8 +29,6 @@ int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_par
|
|||
int outer_size = one_hot_param->outer_size_;
|
||||
int inner_size = one_hot_param->inner_size_;
|
||||
int depth = one_hot_param->depth_;
|
||||
float on_value = one_hot_param->on_value_;
|
||||
float off_value = one_hot_param->off_value_;
|
||||
int i, j, k;
|
||||
for (i = tid; i < outer_size; i += thread_num) {
|
||||
float *output_ptr = output + i * depth * inner_size;
|
||||
|
|
|
@ -21,25 +21,13 @@
|
|||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct OneHotParameter {
|
||||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
int axis_;
|
||||
// other parameter
|
||||
int depth_;
|
||||
float on_value_;
|
||||
float off_value_;
|
||||
int outer_size_;
|
||||
int inner_size_;
|
||||
bool support_neg_index_; // if true, support neg index in indices tensor; if false, set off_value on neg index.
|
||||
} OneHotParameter;
|
||||
#include "nnacl/one_hot_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
int OneHot(const int *indices, float *output, const OneHotParameter *one_hot_param, const int tid,
|
||||
const int thread_num);
|
||||
int OneHotToFp32(const int *indices, float on_value, float off_value, float *output,
|
||||
const OneHotParameter *one_hot_param, const int tid, const int thread_num);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_NNACL_ONE_HOT_PARAMETER_H_
|
||||
#define MINDSPORE_NNACL_ONE_HOT_PARAMETER_H_
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef struct OneHotParameter {
|
||||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
int axis_;
|
||||
// other parameter
|
||||
int depth_;
|
||||
int outer_size_;
|
||||
int inner_size_;
|
||||
bool support_neg_index_; // if true, support neg index in indices tensor; if false, set off_value on neg index.
|
||||
} OneHotParameter;
|
||||
|
||||
#endif // MINDSPORE_NNACL_ONE_HOT_PARAMETER_H_
|
|
@ -74,7 +74,7 @@ Status ModelImpl::Build(const std::string &model_path, ModelType model_type,
|
|||
auto session = std::shared_ptr<session::LiteSession>(lite::LiteSession::CreateSession(model_path, &lite_context));
|
||||
if (session == nullptr) {
|
||||
MS_LOG(ERROR) << "Allocate session failed.";
|
||||
return kLiteNullptr;
|
||||
return kLiteError;
|
||||
}
|
||||
|
||||
session_.swap(session);
|
||||
|
|
|
@ -14,8 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/one_hot_fp32.h"
|
||||
#include "src/runtime/kernel/arm/base/one_hot_base.h"
|
||||
#include "nnacl/fp32/one_hot_fp32.h"
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
#include "nnacl/fp16/one_hot_fp16.h"
|
||||
#endif
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -96,25 +99,31 @@ int RunOneHot(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
|||
}
|
||||
|
||||
int OneHotCPUKernel::OneHotImpl(int task_id) {
|
||||
auto indices_data = static_cast<int *>(in_tensors_.at(0)->MutableData());
|
||||
auto indices_data = static_cast<int *>(in_tensors_.at(0)->data());
|
||||
auto output = out_tensors_.at(0);
|
||||
if (output == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot output nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto output_data = static_cast<float *>(output->MutableData());
|
||||
|
||||
auto ret = GetParams();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
|
||||
|
||||
ret = OneHot(indices_data, output_data, one_hot_param, task_id, thread_num_);
|
||||
return ret;
|
||||
if (output->data_type() == kNumberTypeFloat32) {
|
||||
auto output_data = reinterpret_cast<float *>(output->data());
|
||||
auto ret = OneHotToFp32(indices_data, on_value_, off_value_, output_data, one_hot_param, task_id, thread_num_);
|
||||
return ret;
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
} else if (output->data_type() == kNumberTypeFloat16) {
|
||||
auto output_data = reinterpret_cast<float16_t *>(output->data());
|
||||
auto ret = OneHotToFp16(indices_data, on_value_, off_value_, output_data, one_hot_param, task_id, thread_num_);
|
||||
return ret;
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "OneHot output datatype is unsupported: " << output->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
int OneHotCPUKernel::GetParams() {
|
||||
int OneHotCPUKernel::InitParamsAndOnOffValue() {
|
||||
auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
|
||||
if (one_hot_param == nullptr) {
|
||||
MS_LOG(ERROR) << "cast OneHotParameter nullptr";
|
||||
|
@ -126,7 +135,7 @@ int OneHotCPUKernel::GetParams() {
|
|||
MS_LOG(ERROR) << "OneHot inputs[1] depth nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const int *depth = static_cast<int *>(depth_tensor->MutableData());
|
||||
const int *depth = reinterpret_cast<int *>(depth_tensor->MutableData());
|
||||
if (depth == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
@ -135,42 +144,19 @@ int OneHotCPUKernel::GetParams() {
|
|||
if (in_tensors_.size() == kInputNum) {
|
||||
// 4 inputs: indices, depth, on_value, off_value
|
||||
one_hot_param->support_neg_index_ = false;
|
||||
auto on_value_tensor = in_tensors_.at(2);
|
||||
if (on_value_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
|
||||
auto ret = InitOnOffValueForFourInputs();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init on off value failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const float *on_value = static_cast<float *>(on_value_tensor->MutableData());
|
||||
if (on_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
one_hot_param->on_value_ = *on_value;
|
||||
|
||||
auto off_value_tensor = in_tensors_.at(3);
|
||||
if (off_value_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const float *off_value = static_cast<float *>(off_value_tensor->MutableData());
|
||||
if (off_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
one_hot_param->off_value_ = *off_value;
|
||||
} else {
|
||||
// 3 inputs: indices, depth, off_on_value
|
||||
one_hot_param->support_neg_index_ = true;
|
||||
auto off_on_tensor = in_tensors_.at(2);
|
||||
if (off_on_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
|
||||
auto ret = InitOnOffValueForThreeInputs();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init on off value failed";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
const float *off_on_values = static_cast<float *>(off_on_tensor->MutableData()); // need to support int type
|
||||
if (off_on_values == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot input[2] data is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
one_hot_param->off_value_ = static_cast<float>(off_on_values[0]);
|
||||
one_hot_param->on_value_ = static_cast<float>(off_on_values[1]);
|
||||
}
|
||||
|
||||
one_hot_param->outer_size_ = outer_size_;
|
||||
|
@ -179,7 +165,94 @@ int OneHotCPUKernel::GetParams() {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int OneHotCPUKernel::InitOnOffValueForFourInputs() {
|
||||
auto on_value_tensor = in_tensors_.at(2);
|
||||
if (on_value_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot on_value tensor is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
if (on_value_tensor->data_type() == kNumberTypeFloat32) {
|
||||
const auto *on_value = reinterpret_cast<float *>(on_value_tensor->data());
|
||||
if (on_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
this->on_value_ = *on_value;
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
} else if (on_value_tensor->data_type() == kNumberTypeFloat16) {
|
||||
const auto *on_value = reinterpret_cast<float16_t *>(on_value_tensor->data());
|
||||
if (on_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
this->on_value_ = *on_value;
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "OneHot on value datatype is unsupported: " << on_value_tensor->data_type();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
auto off_value_tensor = in_tensors_.at(3);
|
||||
if (off_value_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot off_value tensor is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (off_value_tensor->data_type() == kNumberTypeFloat32) {
|
||||
const auto *off_value = reinterpret_cast<float *>(off_value_tensor->data());
|
||||
if (off_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
this->off_value_ = *off_value;
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
} else if (off_value_tensor->data_type() == kNumberTypeFloat16) {
|
||||
const auto *off_value = reinterpret_cast<float16_t *>(off_value_tensor->data());
|
||||
if (off_value == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
this->off_value_ = *off_value;
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "OneHot off value datatype is unsupported: " << off_value_tensor->data_type();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int OneHotCPUKernel::InitOnOffValueForThreeInputs() {
|
||||
auto off_on_tensor = in_tensors_.at(2);
|
||||
if (off_on_tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
|
||||
if (off_on_tensor->data_type() == kNumberTypeFloat32) {
|
||||
const auto *off_on_values = reinterpret_cast<float *>(off_on_tensor->data());
|
||||
if (off_on_values == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
this->off_value_ = off_on_values[0];
|
||||
this->on_value_ = off_on_values[1];
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
} else if (off_on_tensor->data_type() == kNumberTypeFloat16) {
|
||||
const auto *off_on_values = reinterpret_cast<float16_t *>(off_on_tensor->data());
|
||||
if (off_on_values == nullptr) {
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
this->off_value_ = off_on_values[0];
|
||||
this->on_value_ = off_on_values[1];
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "OneHot off value datatype is unsupported: " << off_on_tensor->data_type();
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int OneHotCPUKernel::Run() {
|
||||
auto ret = InitParamsAndOnOffValue();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "OneHot init param failed:" << ret;
|
||||
return ret;
|
||||
}
|
||||
int error_code = ParallelLaunch(this->ms_context_, RunOneHot, this, op_parameter_->thread_num_);
|
||||
if (error_code != RET_OK) {
|
||||
MS_LOG(ERROR) << "OneHot function error error_code[" << error_code << "]";
|
|
@ -34,13 +34,22 @@ class OneHotCPUKernel : public InnerKernel {
|
|||
int OneHotImpl(int task_id);
|
||||
|
||||
private:
|
||||
int GetParams();
|
||||
int InitParamsAndOnOffValue();
|
||||
int InitOnOffValueForThreeInputs();
|
||||
int InitOnOffValueForFourInputs();
|
||||
|
||||
private:
|
||||
int thread_num_ = 1;
|
||||
int axis_ = 0;
|
||||
int outer_size_ = 0;
|
||||
int inner_size_ = 0;
|
||||
#if defined(ENABLE_ARM) && defined(ENABLE_FP16)
|
||||
float16_t on_value_ = 0.;
|
||||
float16_t off_value_ = 0.;
|
||||
#else
|
||||
float on_value_ = 0.;
|
||||
float off_value_ = 0.;
|
||||
#endif
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
|
@ -15,8 +15,10 @@
|
|||
*/
|
||||
#ifdef ENABLE_ARM
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#ifdef ENABLE_FP16
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
#endif
|
||||
#endif
|
||||
#include "nnacl/nnacl_common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
Loading…
Reference in New Issue