fix fp16 multi threads bug

This commit is contained in:
wangzhe 2021-02-08 10:00:52 +08:00
parent 8bd95ed281
commit f6a7c01d73
13 changed files with 34 additions and 4 deletions

View File

@ -139,6 +139,7 @@ int MergeCPUKernel::Run() {
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Merge, LiteKernelCreator<MergeCPUKernel>)
} // namespace mindspore::kernel

View File

@ -41,11 +41,12 @@ int ReshapeBaseCPUKernel::ReSize() {
int ReshapeBaseCPUKernel::RunImpl(int task_id) {
size_t start_index = task_id * cal_max_num_per_thread_;
auto cur_in_ptr = input_ptr_ + start_index;
auto cur_out_ptr = output_ptr_ + start_index;
if (start_index > in_tensors_.front()->Size()) {
if (start_index >= in_tensors_.front()->Size()) {
return RET_OK;
}
auto cur_in_ptr = input_ptr_ + start_index;
auto cur_out_ptr = output_ptr_ + start_index;
size_t data_size = in_tensors_.front()->Size() - start_index;
data_size = data_size > cal_max_num_per_thread_ ? cal_max_num_per_thread_ : data_size;
memcpy(cur_out_ptr, cur_in_ptr, data_size);

View File

@ -55,7 +55,7 @@ int StackBaseCPUKernel::ReSize() {
axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() + 1 : param->axis_;
auto input_nums = in_tensors_.size();
if (input_nums == 1) {
copy_size_ = in_tensors_.front()->Size();
copy_size_ = in_tensors_.front()->ElementsNum() * data_type_size_;
} else {
MS_ASSERT(input_nums > 1);
copy_size_ = GetCopyNum(input0_shape, axis_, input0_shape.size()) * data_type_size_;

View File

@ -93,6 +93,7 @@ int SwitchCPUKernel::Run() {
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Switch, LiteKernelCreator<SwitchCPUKernel>)
} // namespace mindspore::kernel

View File

@ -51,6 +51,9 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) {
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
if (count <= 0) {
return RET_OK;
}
int error_code;
if (type_ == schema::ActivationType_RELU) {

View File

@ -126,6 +126,9 @@ int ArithmeticCompareFP16CPUKernel::DoArithmetic(int task_id) {
int cur_offset = stride_per_thread * task_id;
int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset)
: MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);
if (cur_count <= 0) {
return RET_OK;
}
int ret = RET_OK;
if (param_->broadcasting_) {

View File

@ -169,6 +169,9 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
int cur_offset = stride_per_thread * task_id;
int cur_count = param_->broadcasting_ ? MSMIN(stride_per_thread, outside_ - cur_offset)
: MSMIN(stride_per_thread, param_->out_elements_num_ - cur_offset);
if (cur_count <= 0) {
return RET_OK;
}
int ret = RET_OK;
if (param_->broadcasting_) {

View File

@ -88,6 +88,9 @@ int GatherFp16CPUKernel::DoGather(int task_id) {
}
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
int count = MSMIN(stride, outer_size - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto thread_stride = stride * task_id;
int8_t *int8_in = nullptr;
if (input_tensor->data_type() == kNumberTypeFloat32) {

View File

@ -54,6 +54,9 @@ int ActivationCPUKernel::DoActivation(int task_id) {
int stride = UP_DIV(length, thread_count_);
int count = MSMIN(stride, length - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto ret = RET_OK;

View File

@ -68,6 +68,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) {
return RET_OK;
}
if (func_fp32_ == nullptr) {
MS_LOG(ERROR) << "func_fp32_ function is nullptr!";

View File

@ -288,6 +288,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
if (count <= 0) {
return RET_OK;
}
if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";

View File

@ -57,6 +57,9 @@ int GatherCPUKernel::DoGather(int task_id) {
}
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
int count = MSMIN(stride, outer_size - stride * task_id);
if (count <= 0) {
return RET_OK;
}
auto thread_stride = stride * task_id;
int8_t *int8_in = reinterpret_cast<int8_t *>(input_tensor->data_c());

View File

@ -57,6 +57,9 @@ int PowerCPUKernel::RunImpl(int task_id) {
auto size = in_tensors_.at(0)->ElementsNum();
int stride = UP_DIV(size, thread_count_);
int len = MSMIN(stride, size - stride * task_id);
if (len <= 0) {
return RET_OK;
}
float *exp_addr = nullptr;
bool broadcast = true;
if (in_tensors_.size() == 2) {