forked from mindspore-Ecosystem/mindspore
!30262 [lite]optimize reduce op
Merge pull request !30262 from 徐安越/master1
This commit is contained in:
commit
279cacd1f2
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -103,6 +103,8 @@ int ReduceBaseCPUKernel::Prepare() {
|
|||
MS_CHECK_FALSE_MSG(op_parameter_->thread_num_ == 0, RET_ERROR, "thread_num_ should not be 0");
|
||||
if (in_tensors_.size() > 1) {
|
||||
auto axes_tensor = in_tensors_.at(1);
|
||||
MS_CHECK_TRUE_MSG(axes_tensor != nullptr, RET_ERROR, "axes-tensor is a nullptr.");
|
||||
MS_CHECK_TRUE_MSG(axes_tensor->IsConst(), RET_ERROR, "axes-tensor must be a constant.");
|
||||
MS_CHECK_FALSE_MSG((axes_tensor->data_type() != kNumberTypeInt && axes_tensor->data_type() != kNumberTypeInt32),
|
||||
RET_ERROR, "The data type of axes tensor should be int32");
|
||||
num_axes_ = axes_tensor->ElementsNum();
|
||||
|
@ -173,8 +175,48 @@ int ReduceBaseCPUKernel::ReSize() {
|
|||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
DecideIfOnlyCopy();
|
||||
CalculateTmpBufferSize();
|
||||
CalculateInnerOuterSize();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void ReduceBaseCPUKernel::DecideIfOnlyCopy() {
|
||||
auto in_shape = in_tensors_[FIRST_INPUT]->shape();
|
||||
if (mode_ == schema::ReduceMode_ReduceSumSquare || mode_ == schema::ReduceMode_ReduceASum ||
|
||||
mode_ == schema::ReduceMode_ReduceAll) {
|
||||
only_copy_ = false;
|
||||
return;
|
||||
}
|
||||
if (std::all_of(axes_, axes_ + num_axes_, [&in_shape](int axis) { return in_shape[axis] == 1; })) {
|
||||
only_copy_ = true;
|
||||
} else {
|
||||
only_copy_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
int ReduceBaseCPUKernel::CopyInputToOutput() {
|
||||
auto in_tensor = in_tensors().front();
|
||||
CHECK_NULL_RETURN(in_tensor);
|
||||
auto out_tensor = out_tensors().front();
|
||||
CHECK_NULL_RETURN(out_tensor);
|
||||
if (in_tensor->allocator() == nullptr || in_tensor->allocator() != out_tensor->allocator() ||
|
||||
in_tensor->allocator() != ms_context_->allocator || op_parameter_->is_train_session_ ||
|
||||
((in_tensor->IsGraphInput() || in_tensor->IsGraphOutput()) && out_tensor->IsGraphOutput())) {
|
||||
CHECK_NULL_RETURN(out_tensor->data());
|
||||
CHECK_NULL_RETURN(in_tensor->data());
|
||||
MS_CHECK_FALSE(in_tensor->Size() == 0, RET_ERROR);
|
||||
if (in_tensor->data() != out_tensor->data()) {
|
||||
memcpy(out_tensor->data(), in_tensor->data(), in_tensor->Size());
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
out_tensor->FreeData();
|
||||
out_tensor->ResetRefCount();
|
||||
in_tensor->allocator()->IncRefCount(in_tensor->data(), out_tensor->ref_count());
|
||||
out_tensor->set_data(in_tensor->data());
|
||||
out_tensor->set_own_data(in_tensor->own_data());
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -38,11 +38,14 @@ class ReduceBaseCPUKernel : public InnerKernel {
|
|||
|
||||
void CalculateTmpBufferSize();
|
||||
void CalculateInnerOuterSize();
|
||||
void DecideIfOnlyCopy();
|
||||
int CopyInputToOutput();
|
||||
|
||||
int axes_[MAX_SHAPE_SIZE] = {0};
|
||||
int num_axes_{0};
|
||||
int mode_{0};
|
||||
bool reduce_to_end_{false};
|
||||
bool only_copy_{false};
|
||||
|
||||
std::vector<size_t> buffer_sizes_;
|
||||
std::vector<int> outer_sizes_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -95,6 +95,9 @@ int ReduceImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
|||
}
|
||||
|
||||
int ReduceCPUKernel::Run() {
|
||||
if (only_copy_) {
|
||||
return CopyInputToOutput();
|
||||
}
|
||||
data_type_ = in_tensors().at(0)->data_type();
|
||||
auto ret = MallocTmpBuffer();
|
||||
if (ret != RET_OK) {
|
||||
|
|
Loading…
Reference in New Issue