forked from mindspore-Ecosystem/mindspore
!11682 [ms][lite][cpu] add tensorlist and bias_add fp16 ops
From: @lzkcode Reviewed-by: Signed-off-by:
This commit is contained in:
commit
645507ba1b
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,8 +19,14 @@
|
|||
#include "nnacl/common_func.h"
|
||||
#include "nnacl/nnacl_utils.h"
|
||||
|
||||
void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t ndim, int *inShape, int *inStrides,
|
||||
int *outStrides, int *multiple) {
|
||||
int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1,
|
||||
float16_t *out, int size, ArithmeticParameter *param) {
|
||||
TileDimensionsFp16(in0, in1, tile_in0, tile_in1, param);
|
||||
return ElementAddFp16(tile_in0, tile_in1, out, size);
|
||||
}
|
||||
|
||||
void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim, size_t ndim, const int *inShape,
|
||||
const int *inStrides, const int *outStrides, const int *multiple) {
|
||||
int srcDimSize = inShape[dim];
|
||||
if (dim == ndim - 1) {
|
||||
for (int i = 0; i < multiple[dim]; i++) {
|
||||
|
@ -37,7 +43,7 @@ void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t
|
|||
}
|
||||
}
|
||||
|
||||
void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
|
||||
void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
|
||||
ArithmeticParameter *param) {
|
||||
CalcMultiplesAndStrides(param);
|
||||
TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_,
|
||||
|
@ -219,6 +225,12 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int
|
|||
float16x8_t vout = vaddq_f16(vin0, vin1);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
for (; index <= element_size - 4; index += C4NUM) {
|
||||
float16x4_t vin0 = vld1_f16(input0 + index);
|
||||
float16x4_t vin1 = vld1_f16(input1 + index);
|
||||
float16x4_t vout = vadd_f16(vin0, vin1);
|
||||
vst1_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
for (; index < element_size; index++) {
|
||||
output[index] = input0[index] + input1[index];
|
||||
|
@ -270,6 +282,14 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
|||
vout = vmaxq_f16(vout, zeros);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
float16x4_t zeros1 = vdup_n_f16(0.0f);
|
||||
for (; index <= element_size - 4; index += C4NUM) {
|
||||
float16x4_t vin0 = vld1_f16(input0 + index);
|
||||
float16x4_t vin1 = vld1_f16(input1 + index);
|
||||
float16x4_t vout = vadd_f16(vin0, vin1);
|
||||
vout = vmax_f16(vout, zeros1);
|
||||
vst1_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
for (; index < element_size; index++) {
|
||||
float16_t res = input0[index] + input1[index];
|
||||
|
@ -328,6 +348,15 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
|
|||
vout = vminq_f16(vmaxq_f16(vout, zeros), bounds);
|
||||
vst1q_f16(output + index, vout);
|
||||
}
|
||||
float16x4_t zeros1 = vdup_n_f16(0.0);
|
||||
float16x4_t bounds1 = vdup_n_f16(6.0);
|
||||
for (; index <= element_size - 4; index += C4NUM) {
|
||||
float16x4_t vin0 = vld1_f16(input0 + index);
|
||||
float16x4_t vin1 = vld1_f16(input1 + index);
|
||||
float16x4_t vout = vadd_f16(vin0, vin1);
|
||||
vout = vmin_f16(vmax_f16(vout, zeros1), bounds1);
|
||||
vst1_f16(output + index, vout);
|
||||
}
|
||||
#endif
|
||||
for (; index < element_size; index++) {
|
||||
output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,6 +26,12 @@
|
|||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void TileOneDimensionFp16(const float16_t *inData, float16_t *outData, int dim, size_t ndim, const int *inShape,
|
||||
const int *inStrides, const int *outStrides, const int *multiple);
|
||||
void TileDimensionsFp16(const float16_t *data0, const float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
|
||||
ArithmeticParameter *param);
|
||||
|
||||
int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
ArithmeticParameter *param);
|
||||
int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
|
||||
|
@ -84,6 +90,8 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
|
|||
int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
int BroadcastAddFp16(const float16_t *in0, const float16_t *in1, float16_t *tile_in0, float16_t *tile_in1,
|
||||
float16_t *out, int size, ArithmeticParameter *param);
|
||||
|
||||
int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
|
||||
|
@ -111,8 +119,6 @@ int ElementLessEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output,
|
|||
int ElementGreaterFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
|
||||
int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, uint8_t *output, int element_size);
|
||||
|
||||
void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1,
|
||||
ArithmeticParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -125,11 +125,6 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
|
|||
MS_ASSERT(inputs_.at(1) != nullptr);
|
||||
MS_ASSERT(inputs_.at(2) != nullptr);
|
||||
auto input0 = reinterpret_cast<TensorList *>(inputs_.at(0));
|
||||
if (input0->tensors_data_type() != GetElementDType()) {
|
||||
MS_LOG(ERROR) << "op dtype: " << GetElementDType()
|
||||
<< " is not equal in_tensor[0] dtype: " << input0->tensors_data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto get_index = inputs_.at(1);
|
||||
MS_ASSERT(get_index != nullptr);
|
||||
if (get_index->ElementsNum() != 1) {
|
||||
|
@ -184,7 +179,7 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
|
|||
MS_LOG(ERROR) << "element_shape_ is not fullyDefined!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output->set_data_type(GetElementDType());
|
||||
output->set_data_type(input0->data_type());
|
||||
output->set_shape(element_shape_);
|
||||
}
|
||||
output->set_format(input0->GetTensor(index_)->format());
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/runtime/kernel/arm/fp16/bias_fp16.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_BiasAdd;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int BiasCPUFp16Kernel::ReSize() {
|
||||
auto dims = in_tensors_.at(0)->shape();
|
||||
bias_param_->ndim_ = dims.size();
|
||||
if (bias_param_->ndim_ < 1 || bias_param_->ndim_ > 5) {
|
||||
MS_LOG(ERROR) << "input shape is invalid";
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < bias_param_->ndim_; i++) {
|
||||
bias_param_->in_shape0_[i] = dims[i];
|
||||
bias_param_->in_shape1_[i] = 1;
|
||||
bias_param_->out_shape_[i] = dims[i];
|
||||
}
|
||||
bias_param_->in_shape1_[bias_param_->ndim_ - 1] = dims[bias_param_->ndim_ - 1];
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int BiasCPUFp16Kernel::Run() {
|
||||
auto in = reinterpret_cast<float16_t *>(in_tensors_.at(0)->MutableData());
|
||||
auto out = reinterpret_cast<float16_t *>(out_tensors_.at(0)->MutableData());
|
||||
size_t data_size = in_tensors_.at(0)->ElementsNum();
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
auto *tile_in = reinterpret_cast<float16_t *>(context_->allocator->Malloc(data_size * sizeof(float16_t)));
|
||||
auto *tile_bias = reinterpret_cast<float16_t *>(context_->allocator->Malloc(data_size * sizeof(float16_t)));
|
||||
if (tile_in == nullptr || tile_bias == nullptr) {
|
||||
MS_LOG(ERROR) << "Memory allocation failed";
|
||||
context_->allocator->Free(tile_in);
|
||||
context_->allocator->Free(tile_bias);
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
BroadcastAddFp16(in, bias_data_, tile_in, tile_bias, out, data_size, bias_param_);
|
||||
context_->allocator->Free(tile_in);
|
||||
context_->allocator->Free(tile_bias);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
BiasCPUFp16Kernel::~BiasCPUFp16Kernel() {
|
||||
if ((bias_data_type_ == kNumberTypeFloat || bias_data_type_ == kNumberTypeFloat32) && bias_data_ != nullptr) {
|
||||
free(bias_data_);
|
||||
bias_data_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int BiasCPUFp16Kernel::Init() {
|
||||
auto bias_tensor = in_tensors_.at(1);
|
||||
MS_ASSERT(bias_tensor != nullptr);
|
||||
bias_data_type_ = bias_tensor->data_type();
|
||||
if (bias_data_type_ == kNumberTypeFloat || bias_data_type_ == kNumberTypeFloat32) {
|
||||
bias_data_ = reinterpret_cast<float16_t *>(malloc(bias_tensor->ElementsNum() * sizeof(float16_t)));
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "bias_data_ is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto *bias = reinterpret_cast<float *>(bias_tensor->MutableData());
|
||||
if (bias != nullptr) {
|
||||
MS_LOG(ERROR) << "bias is nullptr!";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
for (int i = 0; i < bias_tensor->ElementsNum(); ++i) {
|
||||
bias_data_[i] = (float16_t)(bias[i]);
|
||||
}
|
||||
} else {
|
||||
bias_data_ = reinterpret_cast<float16_t *>(bias_tensor->MutableData());
|
||||
if (bias_data_ == nullptr) {
|
||||
MS_LOG(ERROR) << "bias_data_ is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
}
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, LiteKernelCreator<BiasCPUFp16Kernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,45 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "nnacl/fp16/arithmetic_fp16.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class BiasCPUFp16Kernel : public LiteKernel {
|
||||
public:
|
||||
BiasCPUFp16Kernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
|
||||
bias_param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
||||
}
|
||||
~BiasCPUFp16Kernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
ArithmeticParameter *bias_param_ = nullptr;
|
||||
float16_t *bias_data_ = nullptr;
|
||||
TypeId bias_data_type_;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_BIAS_H_
|
|
@ -53,25 +53,24 @@ int TensorListFromTensorCPUKernel::IsCompatibleShape() {
|
|||
}
|
||||
|
||||
int TensorListFromTensorCPUKernel::Init() {
|
||||
input0_ = in_tensors_[0]; // row tensor
|
||||
input1_ = in_tensors_[1]; // element_shape tensor
|
||||
output0_ = out_tensors_[0];
|
||||
return IsCompatibleShape();
|
||||
}
|
||||
|
||||
int TensorListFromTensorCPUKernel::ReSize() {
|
||||
auto ret = this->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed!";
|
||||
return ret;
|
||||
#ifdef ENABLE_FP16
|
||||
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) {
|
||||
dtype_ = kNumberTypeFloat16;
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListFromTensorCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
int TensorListFromTensorCPUKernel::Run() {
|
||||
input0_ = in_tensors_[0]; // row tensor
|
||||
input1_ = in_tensors_[1]; // element_shape tensor
|
||||
output0_ = out_tensors_[0];
|
||||
if (IsCompatibleShape() != RET_OK) {
|
||||
MS_LOG(ERROR) << "IsNotCompatibleShape!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input0_->shape().size() == 0) {
|
||||
MS_LOG(ERROR) << "input0_->shape().size():" << input0_->shape().size() << " must be greater than 0";
|
||||
}
|
||||
|
@ -86,7 +85,9 @@ int TensorListFromTensorCPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
int devision_dim0 = input0_->ElementsNum() / dim0;
|
||||
auto in_ptr = reinterpret_cast<float *>(input0_->data_c());
|
||||
auto data_offset = devision_dim0 * lite::DataTypeSize(dtype_);
|
||||
auto in_data = reinterpret_cast<char *>(input0_->data_c());
|
||||
MS_ASSERT(in_data != nullptr);
|
||||
// copy data from input0(tensor) to output(tensorlist) vector<*tensor>
|
||||
for (int i = 0; i < dim0; ++i) {
|
||||
auto out_ptr = output0->GetTensor(i);
|
||||
|
@ -96,37 +97,17 @@ int TensorListFromTensorCPUKernel::Run() {
|
|||
<< " must be euqal to devision_dim0:" << devision_dim0;
|
||||
return RET_ERROR;
|
||||
}
|
||||
memcpy(reinterpret_cast<float *>(out_ptr->MutableData()), in_ptr, devision_dim0 * sizeof(float));
|
||||
in_ptr += devision_dim0;
|
||||
auto out_data = out_ptr->MutableData();
|
||||
MS_ASSERT(out_data != nullptr);
|
||||
memcpy(out_data, in_data, data_offset);
|
||||
in_data += data_offset;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuTensorListFromTensorFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *op_parameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (op_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Input context is nullptr!";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_TensorListFromTensor);
|
||||
op_parameter->thread_num_ = ctx->thread_num_;
|
||||
auto *kernel = new (std::nothrow) TensorListFromTensorCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new TensorListFromTensorCPUKernel fail!";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListFromTensor, CpuTensorListFromTensorFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListFromTensor,
|
||||
LiteKernelCreator<TensorListFromTensorCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListFromTensor, LiteKernelCreator<TensorListFromTensorCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListFromTensor,
|
||||
LiteKernelCreator<TensorListFromTensorCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "src/lite_kernel.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "nnacl/tensorlist_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TensorListFromTensorCPUKernel : public LiteKernel {
|
||||
|
@ -28,7 +29,8 @@ class TensorListFromTensorCPUKernel : public LiteKernel {
|
|||
TensorListFromTensorCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive),
|
||||
dtype_(reinterpret_cast<TensorListParameter *>(parameter)->element_dtype_) {}
|
||||
~TensorListFromTensorCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
|
@ -41,6 +43,7 @@ class TensorListFromTensorCPUKernel : public LiteKernel {
|
|||
lite::Tensor *output0_ = nullptr;
|
||||
lite::Tensor *input0_ = nullptr;
|
||||
lite::Tensor *input1_ = nullptr;
|
||||
TypeId dtype_ = kTypeUnknown;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -31,11 +31,11 @@ namespace mindspore::kernel {
|
|||
int TensorListGetItemCPUKernel::Init() {
|
||||
MS_ASSERT(in_tensors_.size() >= 2);
|
||||
MS_ASSERT(in_tensors_.at(0) != nullptr);
|
||||
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
|
||||
if (dtype_ != input0->tensors_data_type()) {
|
||||
MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type();
|
||||
return RET_ERROR;
|
||||
#ifdef ENABLE_FP16
|
||||
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) {
|
||||
dtype_ = kNumberTypeFloat16;
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -45,6 +45,10 @@ int TensorListGetItemCPUKernel::Run() {
|
|||
MS_ASSERT(in_tensors_.at(1) != nullptr);
|
||||
MS_ASSERT(out_tensors_.at(0) != nullptr);
|
||||
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
|
||||
if (dtype_ != input0->tensors_data_type()) {
|
||||
MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
MS_ASSERT(in_tensors_.at(1)->data_c() != nullptr);
|
||||
index_ = reinterpret_cast<int *>(in_tensors_.at(1)->data_c())[0];
|
||||
int dim0 = input0->ElementsNum() - 1;
|
||||
|
@ -66,8 +70,7 @@ int TensorListGetItemCPUKernel::Run() {
|
|||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
// reset 0 and dtype = dtype_
|
||||
// TODO(DT_VARIANT): dtype = DT_VARIANT is not handle
|
||||
// reset data buffer is zero
|
||||
auto out_data = out_tensors_[0]->data_c();
|
||||
if (out_data == nullptr) {
|
||||
MS_LOG(ERROR) << "data of out_tensors_[0] is nullptr";
|
||||
|
@ -80,30 +83,7 @@ int TensorListGetItemCPUKernel::Run() {
|
|||
|
||||
int TensorListGetItemCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
kernel::LiteKernel *CpuTensorListGetItemFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *op_parameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (op_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Input context is nullptr!";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_TensorListGetItem);
|
||||
auto *kernel = new (std::nothrow) TensorListGetItemCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new TensorListGetItemCPUKernel fail!";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListGetItem, CpuTensorListGetItemFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListGetItem, CpuTensorListGetItemFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListGetItem, LiteKernelCreator<TensorListGetItemCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListGetItem, LiteKernelCreator<TensorListGetItemCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListGetItem, LiteKernelCreator<TensorListGetItemCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -27,7 +27,14 @@ using mindspore::schema::PrimitiveType_TensorListReserve;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int TensorListReserveCPUKernel::Init() { return RET_OK; }
|
||||
int TensorListReserveCPUKernel::Init() {
|
||||
#ifdef ENABLE_FP16
|
||||
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && element_dtype_ == kNumberTypeFloat32) {
|
||||
element_dtype_ = kNumberTypeFloat16;
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListReserveCPUKernel::Run() {
|
||||
auto input0 = in_tensors_.at(0);
|
||||
|
@ -48,5 +55,6 @@ int TensorListReserveCPUKernel::Run() {
|
|||
int TensorListReserveCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListReserve, LiteKernelCreator<TensorListReserveCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -28,7 +28,14 @@ using mindspore::schema::PrimitiveType_TensorListSetItem;
|
|||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int TensorListSetItemCPUKernel::Init() { return RET_OK; }
|
||||
int TensorListSetItemCPUKernel::Init() {
|
||||
#ifdef ENABLE_FP16
|
||||
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) {
|
||||
dtype_ = kNumberTypeFloat16;
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListSetItemCPUKernel::CheckParam() {
|
||||
if (dtype_ != kTypeUnknown && dtype_ != input0_->tensors_data_type()) {
|
||||
|
@ -143,5 +150,6 @@ int TensorListSetItemCPUKernel::Run() {
|
|||
int TensorListSetItemCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListSetItem, LiteKernelCreator<TensorListSetItemCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListSetItem, LiteKernelCreator<TensorListSetItemCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListSetItem, LiteKernelCreator<TensorListSetItemCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -60,6 +60,11 @@ int TensorListStackCPUKernel::Init() {
|
|||
MS_ASSERT(input0_ != nullptr);
|
||||
output0_ = out_tensors_[0];
|
||||
MS_ASSERT(output0_ != nullptr);
|
||||
#ifdef ENABLE_FP16
|
||||
if (lite::IsSupportFloat16() && context_->IsCpuFloat16Enabled() && dtype_ == kNumberTypeFloat32) {
|
||||
dtype_ = kNumberTypeFloat16;
|
||||
}
|
||||
#endif
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -159,17 +164,21 @@ int TensorListStackCPUKernel::Run() {
|
|||
MS_LOG(ERROR) << "out_tensors_[0]->ElementsNum():" << out_ele_num << "must be equal to in_ele_num:" << in_ele_num;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto out_ptr = reinterpret_cast<float *>(output0_->MutableData());
|
||||
auto out_data = reinterpret_cast<char *>(output0_->MutableData());
|
||||
auto unknown_type_offset = TypeUnknownSize * lite::DataTypeSize(dtype_);
|
||||
MS_ASSERT(out_data != nullptr);
|
||||
for (int i = 0; i < num_element_; ++i) {
|
||||
auto in_ptr = input0_->GetTensor(i);
|
||||
MS_ASSERT(in_ptr != nullptr);
|
||||
if (in_ptr->data_type() != kTypeUnknown) {
|
||||
int in_size = in_ptr->ElementsNum();
|
||||
memcpy(out_ptr, in_ptr->data_c(), in_size * sizeof(float));
|
||||
out_ptr += in_size;
|
||||
int data_size = in_ptr->ElementsNum() * lite::DataTypeSize(dtype_);
|
||||
auto in_data = in_ptr->data_c();
|
||||
MS_ASSERT(in_data != nullptr);
|
||||
memcpy(out_data, in_data, data_size);
|
||||
out_data += data_size;
|
||||
} else {
|
||||
memset(out_ptr, 0, TypeUnknownSize * sizeof(float));
|
||||
out_ptr += TypeUnknownSize;
|
||||
memset(out_data, 0, unknown_type_offset);
|
||||
out_data += unknown_type_offset;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
|
@ -178,5 +187,6 @@ int TensorListStackCPUKernel::Run() {
|
|||
int TensorListStackCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListStack, LiteKernelCreator<TensorListStackCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_TensorListStack, LiteKernelCreator<TensorListStackCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_TensorListStack, LiteKernelCreator<TensorListStackCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue