From af40b005fb64d4a90ab96e524967e13839869eb4 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Wed, 21 Oct 2020 11:36:59 +0800 Subject: [PATCH] onehot support 3 inputs --- .../src/runtime/kernel/arm/fp32/one_hot.cc | 60 ++++++++++++------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc index ef48273ab69..bf10b5e0796 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/one_hot.cc @@ -31,14 +31,15 @@ using mindspore::schema::PrimitiveType_OneHot; namespace mindspore::kernel { namespace { constexpr size_t kInputNum = 4; +constexpr size_t kInputNumOpt = 3; constexpr size_t kOutputNum = 1; } // namespace int OneHotCPUKernel::Init() { // indices depth on_value off_value - if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { - MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size() - << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); + if ((in_tensors_.size() != kInputNum && in_tensors_.size() != kInputNumOpt) || out_tensors_.size() != kOutputNum) { + MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << " or " << kInputNumOpt << ", got " + << in_tensors_.size() << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); return RET_ERROR; } if (context_ == nullptr) { @@ -132,27 +133,42 @@ int OneHotCPUKernel::GetParams() { } one_hot_param->depth_ = *depth; - auto on_value_tensor = in_tensors_.at(2); - if (on_value_tensor == nullptr) { - MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; - return RET_NULL_PTR; - } - const float *on_value = static_cast(on_value_tensor->MutableData()); - if (on_value == nullptr) { - return RET_NULL_PTR; - } - one_hot_param->on_value_ = *on_value; + if (in_tensors_.size() == kInputNum) { + auto on_value_tensor = in_tensors_.at(2); + if (on_value_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; + return RET_NULL_PTR; + } + const float *on_value = static_cast(on_value_tensor->MutableData()); + if (on_value == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->on_value_ = *on_value; - auto off_value_tensor = in_tensors_.at(3); - if (off_value_tensor == nullptr) { - MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr"; - return RET_NULL_PTR; + auto off_value_tensor = in_tensors_.at(3); + if (off_value_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr"; + return RET_NULL_PTR; + } + const float *off_value = static_cast(off_value_tensor->MutableData()); + if (off_value == nullptr) { + return RET_NULL_PTR; + } + one_hot_param->off_value_ = *off_value; + } else { + auto off_on_tensor = in_tensors_.at(2); + if (off_on_tensor == nullptr) { + MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; + return RET_NULL_PTR; + } + const int64_t *off_on_values = static_cast(off_on_tensor->MutableData()); + if (off_on_values == nullptr) { + MS_LOG(ERROR) << "OneHot input[2] data is nullptr"; + return RET_NULL_PTR; + } + one_hot_param->off_value_ = static_cast(off_on_values[0]); + one_hot_param->on_value_ = static_cast(off_on_values[1]); } - const float *off_value = static_cast(off_value_tensor->MutableData()); - if (off_value == nullptr) { - return RET_NULL_PTR; - } - one_hot_param->off_value_ = *off_value; one_hot_param->outer_size_ = outer_size_; one_hot_param->inner_size_ = inner_size_;