!31636 [lite]fix sub-int8 bug

Merge pull request !31636 from 徐安越/master_core
This commit is contained in:
i-robot 2022-03-22 02:15:03 +00:00 committed by Gitee
commit 94ae4b2592
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 22 additions and 5 deletions

View File

@ -138,11 +138,28 @@ int SubInt8Run(void *cdata, int task_id, float, float) {
int SubInt8CPUKernel::Run() { int SubInt8CPUKernel::Run() {
if (broadcast_) { if (broadcast_) {
ArithmeticParameter tile_para; ArithmeticParameter tile_para;
tile_para.ndim_ = out_tensors_.at(0)->shape().size(); auto out_shape = out_tensors_[FIRST_INPUT]->shape();
for (size_t i = 0; i < tile_para.ndim_; i++) { tile_para.ndim_ = out_shape.size();
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); auto in_shape0 = in_tensors_[FIRST_INPUT]->shape();
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i); MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape0.size(), RET_ERROR,
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); "Sub first-input shape size is larger than out.");
for (size_t i = 0; i < out_shape.size() - in_shape0.size(); ++i) {
tile_para.in_shape0_[i] = 1;
}
for (size_t i = 0; i < in_shape0.size(); ++i) {
tile_para.in_shape0_[i + out_shape.size() - in_shape0.size()] = in_shape0[i];
}
auto in_shape1 = in_tensors_[SECOND_INPUT]->shape();
MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape1.size(), RET_ERROR,
"Sub second-input shape size is larger than out.");
for (size_t i = 0; i < out_shape.size() - in_shape1.size(); ++i) {
tile_para.in_shape1_[i] = 1;
}
for (size_t i = 0; i < in_shape1.size(); ++i) {
tile_para.in_shape1_[i + out_shape.size() - in_shape1.size()] = in_shape1[i];
}
for (size_t i = 0; i < out_shape.size(); ++i) {
tile_para.out_shape_[i] = out_shape[i];
} }
tile0_data_ = static_cast<int8_t *>(ms_context_->allocator->Malloc(out_tensors_.at(0)->Size())); tile0_data_ = static_cast<int8_t *>(ms_context_->allocator->Malloc(out_tensors_.at(0)->Size()));
if (tile0_data_ == nullptr) { if (tile0_data_ == nullptr) {