forked from mindspore-Ecosystem/mindspore
!6819 [MSLITE][Develop] fix split fp16 kernel
Merge pull request !6819 from sunsuodong/fix_split_fp16
This commit is contained in:
commit
b446d66d11
|
@ -13,10 +13,10 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "nnacl/fp16/cast_fp16.h"
|
|
||||||
#include "nnacl/fp16/split_fp16.h"
|
|
||||||
#include "src/runtime/kernel/arm/fp16/split_fp16.h"
|
#include "src/runtime/kernel/arm/fp16/split_fp16.h"
|
||||||
|
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
|
||||||
#include "src/runtime/kernel/arm/base/split_base.h"
|
#include "src/runtime/kernel/arm/base/split_base.h"
|
||||||
|
#include "nnacl/fp16/split_fp16.h"
|
||||||
#include "nnacl/split.h"
|
#include "nnacl/split.h"
|
||||||
#include "nnacl/split_parameter.h"
|
#include "nnacl/split_parameter.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
|
@ -36,9 +36,10 @@ int SplitFp16CPUKernel::Init() {
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
output_ptr_.resize(param->num_split_);
|
output_ptr_.resize(param->num_split_);
|
||||||
|
for (size_t i = 0; i < output_ptr_.size(); i++) {
|
||||||
|
output_ptr_[i] = nullptr;
|
||||||
|
}
|
||||||
if (!InferShapeDone()) {
|
if (!InferShapeDone()) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
@ -79,48 +80,37 @@ int SplitFp16CPUKernel::Run() {
|
||||||
MS_LOG(ERROR) << "Prepare failed.";
|
MS_LOG(ERROR) << "Prepare failed.";
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
auto in_tensor = in_tensors_.front();
|
input_ptr_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
|
||||||
if (in_tensor->data_type() == kNumberTypeFloat32) {
|
if (input_ptr_ == nullptr) {
|
||||||
input_ptr_ =
|
MS_LOG(ERROR) << "input or output is nullptr";
|
||||||
reinterpret_cast<float16_t *>(context_->allocator->Malloc(in_tensor->ElementsNum() * sizeof(float16_t)));
|
return RET_ERROR;
|
||||||
if (input_ptr_ == nullptr) {
|
|
||||||
MS_LOG(ERROR) << "malloc input_ptr_ failed.";
|
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
Float32ToFloat16(reinterpret_cast<float *>(in_tensor->MutableData()), input_ptr_, in_tensor->ElementsNum());
|
|
||||||
} else {
|
|
||||||
input_ptr_ = reinterpret_cast<float16_t *>(in_tensor->MutableData());
|
|
||||||
}
|
}
|
||||||
for (int i = 0; i < param->num_split_; i++) {
|
for (int i = 0; i < param->num_split_; i++) {
|
||||||
if (in_tensor->data_type() == kNumberTypeFloat32) {
|
output_ptr_[i] = MallocOutputFp16(out_tensors_.at(i), context_);
|
||||||
output_ptr_[i] = reinterpret_cast<float16_t *>(
|
if (output_ptr_[i] == nullptr) {
|
||||||
context_->allocator->Malloc(out_tensors_.at(i)->ElementsNum() * sizeof(float16_t)));
|
FreeInputAndOutput();
|
||||||
if (output_ptr_[i] == nullptr) {
|
MS_LOG(ERROR) << "input or output is nullptr";
|
||||||
MS_LOG(ERROR) << "malloc output_ptr_[" << i << "]" << " failed.";
|
return RET_ERROR;
|
||||||
return RET_ERROR;
|
|
||||||
}
|
|
||||||
Float32ToFloat16(reinterpret_cast<float *>(out_tensors_.at(i)->MutableData()), output_ptr_[i],
|
|
||||||
out_tensors_.at(i)->ElementsNum());
|
|
||||||
} else {
|
|
||||||
output_ptr_[i] = reinterpret_cast<float16_t *>(out_tensors_.at(i)->MutableData());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ret = ParallelLaunch(this->context_->thread_pool_, SplitRun, this, thread_n_num_);
|
ret = ParallelLaunch(this->context_->thread_pool_, SplitRun, this, thread_n_num_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "split error error_code[" << ret << "]";
|
MS_LOG(ERROR) << "split error error_code[" << ret << "]";
|
||||||
return RET_ERROR;
|
|
||||||
}
|
}
|
||||||
if (in_tensor->data_type() == kNumberTypeFloat32) {
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SplitFp16CPUKernel::FreeInputAndOutput() {
|
||||||
|
if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
|
||||||
context_->allocator->Free(input_ptr_);
|
context_->allocator->Free(input_ptr_);
|
||||||
input_ptr_ = nullptr;
|
input_ptr_ = nullptr;
|
||||||
}
|
}
|
||||||
for (int i = 0; i < param->num_split_; i++) {
|
for (int i = 0; i < param->num_split_; i++) {
|
||||||
if (in_tensor->data_type() == kNumberTypeFloat32) {
|
if (out_tensors_.at(i)->data_type() == kNumberTypeFloat32) {
|
||||||
context_->allocator->Free(output_ptr_[i]);
|
context_->allocator->Free(output_ptr_[i]);
|
||||||
output_ptr_[i] = nullptr;
|
output_ptr_[i] = nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return RET_OK;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel::LiteKernel *CpuSplitFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
kernel::LiteKernel *CpuSplitFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||||
|
|
|
@ -37,8 +37,9 @@ class SplitFp16CPUKernel : public SplitBaseCPUKernel {
|
||||||
int Split(int task_id);
|
int Split(int task_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float16_t *input_ptr_;
|
float16_t *input_ptr_ = nullptr;
|
||||||
std::vector<float16_t *> output_ptr_;
|
std::vector<float16_t *> output_ptr_;
|
||||||
|
void FreeInputAndOutput();
|
||||||
};
|
};
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue