fix sub-int8 bug

This commit is contained in:
xuanyue 2022-03-21 15:45:24 +08:00
parent b724ceaa8e
commit c4cebac212
1 changed files with 22 additions and 5 deletions

View File

@ -138,11 +138,28 @@ int SubInt8Run(void *cdata, int task_id, float, float) {
int SubInt8CPUKernel::Run() {
if (broadcast_) {
ArithmeticParameter tile_para;
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
for (size_t i = 0; i < tile_para.ndim_; i++) {
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
auto out_shape = out_tensors_[FIRST_INPUT]->shape();
tile_para.ndim_ = out_shape.size();
auto in_shape0 = in_tensors_[FIRST_INPUT]->shape();
MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape0.size(), RET_ERROR,
"Sub first-input shape size is larger than out.");
for (size_t i = 0; i < out_shape.size() - in_shape0.size(); ++i) {
tile_para.in_shape0_[i] = 1;
}
for (size_t i = 0; i < in_shape0.size(); ++i) {
tile_para.in_shape0_[i + out_shape.size() - in_shape0.size()] = in_shape0[i];
}
auto in_shape1 = in_tensors_[SECOND_INPUT]->shape();
MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape1.size(), RET_ERROR,
"Sub second-input shape size is larger than out.");
for (size_t i = 0; i < out_shape.size() - in_shape1.size(); ++i) {
tile_para.in_shape1_[i] = 1;
}
for (size_t i = 0; i < in_shape1.size(); ++i) {
tile_para.in_shape1_[i + out_shape.size() - in_shape1.size()] = in_shape1[i];
}
for (size_t i = 0; i < out_shape.size(); ++i) {
tile_para.out_shape_[i] = out_shape[i];
}
tile0_data_ = static_cast<int8_t *>(ms_context_->allocator->Malloc(out_tensors_.at(0)->Size()));
if (tile0_data_ == nullptr) {