forked from mindspore-Ecosystem/mindspore
fix sub-int8 bug
This commit is contained in:
parent
b724ceaa8e
commit
c4cebac212
|
@ -138,11 +138,28 @@ int SubInt8Run(void *cdata, int task_id, float, float) {
|
|||
int SubInt8CPUKernel::Run() {
|
||||
if (broadcast_) {
|
||||
ArithmeticParameter tile_para;
|
||||
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
|
||||
for (size_t i = 0; i < tile_para.ndim_; i++) {
|
||||
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
|
||||
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
|
||||
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
|
||||
auto out_shape = out_tensors_[FIRST_INPUT]->shape();
|
||||
tile_para.ndim_ = out_shape.size();
|
||||
auto in_shape0 = in_tensors_[FIRST_INPUT]->shape();
|
||||
MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape0.size(), RET_ERROR,
|
||||
"Sub first-input shape size is larger than out.");
|
||||
for (size_t i = 0; i < out_shape.size() - in_shape0.size(); ++i) {
|
||||
tile_para.in_shape0_[i] = 1;
|
||||
}
|
||||
for (size_t i = 0; i < in_shape0.size(); ++i) {
|
||||
tile_para.in_shape0_[i + out_shape.size() - in_shape0.size()] = in_shape0[i];
|
||||
}
|
||||
auto in_shape1 = in_tensors_[SECOND_INPUT]->shape();
|
||||
MS_CHECK_TRUE_MSG(out_shape.size() >= in_shape1.size(), RET_ERROR,
|
||||
"Sub second-input shape size is larger than out.");
|
||||
for (size_t i = 0; i < out_shape.size() - in_shape1.size(); ++i) {
|
||||
tile_para.in_shape1_[i] = 1;
|
||||
}
|
||||
for (size_t i = 0; i < in_shape1.size(); ++i) {
|
||||
tile_para.in_shape1_[i + out_shape.size() - in_shape1.size()] = in_shape1[i];
|
||||
}
|
||||
for (size_t i = 0; i < out_shape.size(); ++i) {
|
||||
tile_para.out_shape_[i] = out_shape[i];
|
||||
}
|
||||
tile0_data_ = static_cast<int8_t *>(ms_context_->allocator->Malloc(out_tensors_.at(0)->Size()));
|
||||
if (tile0_data_ == nullptr) {
|
||||
|
|
Loading…
Reference in New Issue