!49614 [MS][LITE] tile fusion thread cut

Merge pull request !49614 from Greatpan/tilefuison_thread_cut
This commit is contained in:
i-robot 2023-03-02 07:44:51 +00:00 committed by Gitee
commit b834aa3a45
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 51 additions and 27 deletions

View File

@ -38,32 +38,42 @@ int TileCPUKernel::Prepare() {
return ReSize();
}
int TileCPUKernel::DoubleInputScenes() {
CHECK_NULL_RETURN(in_tensors_.at(1));
if (in_tensors_[1]->data() == nullptr) {
resize_done_ = false;
return RET_OK;
}
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) {
MS_LOG(ERROR) << "tile's input1 data_num cannot be larger than input0's shape_size.";
return RET_ERROR;
}
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type()
<< " must be kNumberTypeInt32 or kNumberTypeInt!";
return RET_ERROR;
}
CHECK_NULL_RETURN(in_tensors_[1]->data());
auto input1_addr = reinterpret_cast<int *>(in_tensors_[1]->data());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
if (input1_addr[i] <= 0) {
MS_LOG(ERROR) << "Tile input1 data must be greater than 0";
return RET_ERROR;
}
tile_parameter_->dims_[i] = i;
tile_parameter_->multiples_[i] = input1_addr[i];
}
return RET_OK;
}
int TileCPUKernel::ReSize() {
auto ret = RET_OK;
CHECK_NULL_RETURN(tile_parameter_);
if (in_tensors_.size() == kDoubleInputsSize) {
CHECK_NULL_RETURN(in_tensors_.at(1));
if (in_tensors_[1]->data() == nullptr) {
resize_done_ = false;
return RET_OK;
}
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) {
MS_LOG(ERROR) << "tile's input1 data_num cannot be larger than input0's shape_size.";
return RET_ERROR;
}
if (in_tensors_[1]->data_type() != kNumberTypeInt && in_tensors_[1]->data_type() != kNumberTypeInt32) {
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type()
<< " must be kNumberTypeInt32 or kNumberTypeInt!";
return RET_ERROR;
}
CHECK_NULL_RETURN(in_tensors_[1]->data());
auto input1_addr = reinterpret_cast<int *>(in_tensors_[1]->data());
for (int i = 0; i < in_tensors_[1]->ElementsNum(); ++i) {
if (input1_addr[i] <= 0) {
MS_LOG(ERROR) << "Tile input1 data must be greater than 0";
return RET_ERROR;
}
tile_parameter_->dims_[i] = i;
tile_parameter_->multiples_[i] = input1_addr[i];
ret = DoubleInputScenes();
if (ret != RET_OK) {
return ret;
}
}
@ -97,7 +107,19 @@ int TileCPUKernel::ReSize() {
MS_LOG(ERROR) << "tile not support data type: " << data_type;
return RET_ERROR;
}
return FillOneDimTileParam();
ret = FillOneDimTileParam();
if (ret != RET_OK) {
return ret;
}
if (one_dim_tile_) {
if (UpdateThreadNumPass(TC_TYPE(schema::PrimitiveType_TileFusion, 0), 0, 0, tile_parameter_->fast_outer_size_) !=
RET_OK) {
return RET_ERROR;
}
}
return RET_OK;
}
int SimpleTile(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
@ -140,8 +162,8 @@ int TileCPUKernel::FillOneDimTileParam() {
}
int TileCPUKernel::SimpleTileImpl(int task_id) {
CHECK_LESS_RETURN(static_cast<size_t>(op_parameter_->thread_num_), 1);
size_t unit = UP_DIV(tile_parameter_->fast_outer_size_, static_cast<size_t>(op_parameter_->thread_num_));
CHECK_LESS_RETURN(static_cast<size_t>(thread_num_), 1);
size_t unit = UP_DIV(tile_parameter_->fast_outer_size_, static_cast<size_t>(thread_num_));
if (unit == 0 && task_id > 0) {
return RET_OK;
}
@ -153,7 +175,7 @@ int TileCPUKernel::SimpleTileImpl(int task_id) {
}
int TileCPUKernel::RunSimpleTile() {
auto ret = ParallelLaunch(this->ms_context_, SimpleTile, this, op_parameter_->thread_num_);
auto ret = ParallelLaunch(this->ms_context_, SimpleTile, this, thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "RunSimpleTile error code[" << ret << "]";
return ret;

View File

@ -38,6 +38,7 @@ class TileCPUKernel : public LiteKernel {
private:
int RunSimpleTile();
int FillOneDimTileParam();
int DoubleInputScenes();
bool one_dim_tile_ = false;
uint8_t *input_addr_ = nullptr;
uint8_t *output_addr_ = nullptr;

View File

@ -71,6 +71,7 @@ const std::map<int32_t, float> kernel_compute_cost_map_ = {
{TC_TYPE(schema::PrimitiveType_LayerNormFusion, 0), 507.812f}, // dataNum about 0.5k
{TC_TYPE(schema::PrimitiveType_OneHot, 0), 136.562f}, // dataNum about 1.5k
{TC_TYPE(schema::PrimitiveType_TileFusion, 0), 259.0625f}, // dataNum about 0.8k
{TC_TYPE(schema::PrimitiveType_ReduceFusion, schema::ReduceMode_ReduceAll), 66.5625f}, // dataNum about 3k
{TC_TYPE(schema::PrimitiveType_ReduceFusion, schema::ReduceMode_ReduceASum), 206.5625f}, // dataNum about 1k