forked from mindspore-Ecosystem/mindspore
!31636 [lite]fix sub-int8 bug
Merge pull request !31636 from 徐安越/master_core
This commit is contained in:
commit
94ae4b2592
|
@ -138,11 +138,28 @@ int SubInt8Run(void *cdata, int task_id, float, float) {
|
||||||
int SubInt8CPUKernel::Run() {
|
int SubInt8CPUKernel::Run() {
|
||||||
if (broadcast_) {
|
if (broadcast_) {
|
||||||
ArithmeticParameter tile_para;
|
ArithmeticParameter tile_para;
|
||||||
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
|
auto out_shape = out_tensors_[FIRST_INPUT]->shape();
|
||||||
for (size_t i = 0; i < tile_para.ndim_; i++) {
|
tile_para.ndim_ = out_shape.size();
|
||||||
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
|
auto in_shape0 = in_tensors_[FIRST_INPUT]->shape();
|
||||||
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
|
MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape0.size(), RET_ERROR,
|
||||||
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
|
"Sub first-input shape size is larger than out.");
|
||||||
|
for (size_t i = 0; i < out_shape.size() - in_shape0.size(); ++i) {
|
||||||
|
tile_para.in_shape0_[i] = 1;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < in_shape0.size(); ++i) {
|
||||||
|
tile_para.in_shape0_[i + out_shape.size() - in_shape0.size()] = in_shape0[i];
|
||||||
|
}
|
||||||
|
auto in_shape1 = in_tensors_[SECOND_INPUT]->shape();
|
||||||
|
MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape1.size(), RET_ERROR,
|
||||||
|
"Sub second-input shape size is larger than out.");
|
||||||
|
for (size_t i = 0; i < out_shape.size() - in_shape1.size(); ++i) {
|
||||||
|
tile_para.in_shape1_[i] = 1;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < in_shape1.size(); ++i) {
|
||||||
|
tile_para.in_shape1_[i + out_shape.size() - in_shape1.size()] = in_shape1[i];
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < out_shape.size(); ++i) {
|
||||||
|
tile_para.out_shape_[i] = out_shape[i];
|
||||||
}
|
}
|
||||||
tile0_data_ = static_cast<int8_t *>(ms_context_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
tile0_data_ = static_cast<int8_t *>(ms_context_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||||
if (tile0_data_ == nullptr) {
|
if (tile0_data_ == nullptr) {
|
||||||
|
|
Loading…
Reference in New Issue