convert the implementation of Slice and StridedSlice CPU operators to nnacl

This commit is contained in:
buxue 2021-04-25 21:05:17 +08:00
parent 9edf30bd05
commit 4a703e437e
10 changed files with 419 additions and 269 deletions

View File

@ -226,27 +226,14 @@ bool EltWiseGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inpu
{"GeLUGrad", &EltWiseGradCPUKernel<T>::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel<T>::AsinGrad},
{"ACosGrad", &EltWiseGradCPUKernel<T>::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel<T>::AtanGrad},
{"AsinhGrad", &EltWiseGradCPUKernel<T>::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel<T>::AcoshGrad}};
T *input1 = reinterpret_cast<T *>(inputs[0]->addr);
T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
const auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t count = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0;
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
std::vector<common::Task> tasks;
size_t start = 0;
size_t once_compute_size = (count + thread_num - 1) / thread_num;
while (start < count) {
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
auto block = [&, start, end]() {
elt_map.at(kernel_name_)(this, input1, input2, output, start, end);
return common::SUCCESS;
};
tasks.emplace_back(block);
start += once_compute_size;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
CPUKernelUtils::ParallelFor(
std::bind(elt_map.at(kernel_name_), this, input1, input2, output, std::placeholders::_1, std::placeholders::_2),
count);
return true;
}
} // namespace kernel

View File

@ -72,7 +72,7 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
if (in_data == NULL || out_data == NULL || param == NULL) {
return NNACL_NULL_PTR;
}
if (param->num_axes_ > DIMENSION_6D) {
if (param->num_axes_ > DIMENSION_8D) {
return NNACL_PARAM_INVALID;
}
@ -107,6 +107,10 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
} else if (param->data_type == kDataTypeInt) {
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
} else if (param->data_type == kDataTypeFloat64) {
*((double *)out_data + out_offset) = *((double *)in_data + in_offset);
} else if (param->data_type == kDataTypeBool) {
*((bool *)out_data + out_offset) = *((bool *)in_data + in_offset);
#ifdef ENABLE_ARM64
} else if (param->data_type == kDataTypeFloat16) {
*((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset);

View File

@ -69,7 +69,8 @@ typedef enum LiteDataType {
kDataTypeFloat16,
kDataTypeInt,
kDataTypeInt8,
KDataTypeBool,
kDataTypeBool,
kDataTypeFloat64
} LiteDataType;
typedef enum DataOrder {

View File

@ -46,26 +46,26 @@ void ReduceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
if constexpr (std::is_same<T, bool>::value) {
if (kernel_name == "ReduceAll") {
reduce_type_ = ReduceType::ReduceAll;
reduce_type_ = kReduceAll;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out &= input[pos]; };
} else if (kernel_name == "ReduceAny") {
reduce_type_ = ReduceType::ReduceAny;
reduce_type_ = kReduceAny;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out |= input[pos]; };
} else {
MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name_ << " for bool.";
}
} else {
if (kernel_name == "ReduceMax") {
reduce_type_ = ReduceType::ReduceMax;
reduce_type_ = kReduceMax;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::max(input[pos], *out); };
} else if (kernel_name == "ReduceMin") {
reduce_type_ = ReduceType::ReduceMin;
reduce_type_ = kReduceMin;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::min(input[pos], *out); };
} else if (kernel_name == "ReduceSum") {
reduce_type_ = ReduceType::ReduceSum;
reduce_type_ = kReduceSum;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; };
} else if (kernel_name == "ReduceMean") {
reduce_type_ = ReduceType::ReduceMean;
reduce_type_ = kReduceMean;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; };
} else {
MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name;
@ -86,7 +86,7 @@ bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
for (size_t i = 1; i < input_size; ++i) {
reduce_func_(input_addr, i, output_addr);
}
if (reduce_type_ == ReduceType::ReduceMean) {
if (reduce_type_ == kReduceMean) {
*output_addr /= input_size;
}
} else {
@ -126,7 +126,7 @@ bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
reduce_func_(input_addr, iter.GetPos(), &output_addr[i]);
iter.GenNextPos();
}
if (reduce_type_ == ReduceType::ReduceMean) {
if (reduce_type_ == kReduceMean) {
output_addr[i] /= stride;
}
}

View File

@ -24,8 +24,6 @@
namespace mindspore {
namespace kernel {
enum class ReduceType { ReduceAll, ReduceAny, ReduceMax, ReduceMin, ReduceSum, ReduceMean };
template <typename T>
class ReduceCPUKernel : public CPUKernel {
public:
@ -36,6 +34,7 @@ class ReduceCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
enum ReduceType { kReduceAll, kReduceAny, kReduceMax, kReduceMin, kReduceSum, kReduceMean };
std::vector<size_t> input_shape_;
std::vector<int64_t> axis_;
ReduceType reduce_type_;

View File

@ -13,231 +13,107 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include "backend/kernel_compiler/cpu/slice_cpu_kernel.h"
#include <algorithm>
#include <unordered_map>
#include "common/thread_pool.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 8;
int NormalizeBeginPos(int begin_pos, int dim_len) {
if (begin_pos < 0) {
int normal_pos = begin_pos + dim_len;
return std::max(normal_pos, 0);
}
return std::min(begin_pos, dim_len - 1);
}
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
std::vector<int64_t> begin_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
[](const int64_t &value) { return static_cast<int>(value); });
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
MS_EXCEPTION_IF_NULL(prim);
auto strides = prim->GetAttr(STRIDES);
if (strides != nullptr) {
std::vector<int64_t> strides_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
std::vector<int64_t> end_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
[](const int64_t &value) { return static_cast<int>(value); });
TransArg();
ClipBegin();
} else {
std::vector<int> sizes;
std::vector<int64_t> sizes_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE);
(void)std::transform(sizes_me.begin(), sizes_me.end(), std::back_inserter(sizes),
[](const int64_t &value) { return static_cast<int>(value); });
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
}
ClipBegin();
for (size_t i = 0; i < sizes.size(); ++i) {
while (sizes[i] < 0) {
sizes[i] = sizes[i] + SizeToInt(input_shape_[i]);
}
strides_.emplace_back(1);
end_.emplace_back(begin_[i] + sizes[i]);
}
static const std::unordered_map<TypeId, int> type_size_map = {{kNumberTypeBool, sizeof(bool)},
{kNumberTypeInt32, sizeof(int)},
{kNumberTypeFloat32, sizeof(float)},
{kNumberTypeFloat64, sizeof(double)}};
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > DIMENSION_8D || input_shape.empty()) {
MS_LOG(EXCEPTION) << "Slice only support 1D to 8D input tensor, but got " << input_shape.size() << "D.";
}
ExpandAllMemberDims();
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
auto size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE);
auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
if (begin.size() != input_shape.size() || size.size() != input_shape.size()) {
MS_LOG(EXCEPTION) << "Slice requires the length of begin and size must be equal to input dimension.";
}
InitSliceParam(input_shape, begin, size);
TypeId dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
auto size_pair = type_size_map.find(dtype);
if (size_pair == type_size_map.end()) {
MS_LOG(EXCEPTION) << "Slice supports bool, int32, float32 and float64 input tensor, but got "
<< TypeIdToType(dtype)->ToString();
}
data_size_ = size_pair->second;
}
void SliceCPUKernel::ClipBegin() {
for (size_t i = 0; i < begin_.size(); i++) {
if (begin_[i] < 0) {
auto k = begin_[i] + SizeToInt(input_shape_[i]);
begin_[i] = k < 0 ? 0 : k;
}
if (begin_[i] > SizeToInt(input_shape_[i])) {
begin_[i] = SizeToInt(input_shape_[i]);
}
void SliceCPUKernel::ParallelRun(void *input_addr, void *output_addr, int thread_num) {
std::vector<common::Task> tasks;
int thread_index = 0;
while (thread_index < thread_num) {
auto block = [&, thread_index]() {
DoSlice(input_addr, output_addr, &slice_param_, thread_index, data_size_);
return common::SUCCESS;
};
tasks.emplace_back(block);
thread_index++;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}
void SliceCPUKernel::ExpandAllMemberDims() {
auto input_len = input_shape_.size();
if (input_len < 4) {
for (size_t i = 0; i < 4 - input_len; ++i) {
input_shape_.insert(input_shape_.begin(), 1);
begin_.insert(begin_.begin(), 0);
strides_.insert(strides_.begin(), 1);
end_.insert(end_.begin(), 1);
}
}
for (size_t i = 0; i < 4; ++i) {
if (SignOfStride(i)) {
int ax = (end_[i] - begin_[i]) * SignOfStride(i);
if (ax < 0) {
ax = 0;
void SliceCPUKernel::InitSliceParam(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &size) {
for (size_t i = 0; i < DIMENSION_8D; i++) {
if (i < input_shape.size()) {
int dim_len = SizeToInt(input_shape[i]);
int begin_pos = LongToInt(begin[i]);
int slice_size = LongToInt(size[i]);
if (slice_size <= 0) {
MS_LOG(EXCEPTION) << "Slice requires the each dimension slice size must be greater than 0.";
}
output_shape_.push_back(IntToSize(ax));
slice_param_.shape_[i] = dim_len;
slice_param_.size_[i] = slice_size;
slice_param_.begin_[i] = NormalizeBeginPos(begin_pos, dim_len);
int end = slice_param_.begin_[i] + slice_param_.size_[i];
slice_param_.end_[i] = std::min(end, dim_len);
} else {
slice_param_.shape_[i] = 1;
slice_param_.begin_[i] = 0;
slice_param_.size_[i] = 1;
slice_param_.end_[i] = 1;
}
}
slice_param_.param_length_ = DIMENSION_8D;
size_t max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
slice_param_.op_parameter_.thread_num_ = std::min(slice_param_.size_[1], SizeToInt(max_thread_num));
}
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
bool ret{true};
if (dtype_ == kNumberTypeInt32) {
ret = LaunchKernel<int>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) {
ret = LaunchKernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeBool) {
ret = LaunchKernel<bool>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
ret = LaunchKernel<double>(inputs, outputs);
if (outputs[0]->size == 0) {
return true;
}
auto input_addr = inputs[0]->addr;
auto output_addr = outputs[0]->addr;
int thread_num = slice_param_.op_parameter_.thread_num_;
if (parallel_ && thread_num >= 2) {
ParallelRun(input_addr, output_addr, thread_num);
} else {
MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64";
return false;
}
return ret;
}
template <typename T>
bool SliceCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
begin_[2] * input_element_num_[2]};
size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1],
strides_[2] * input_element_num_[2]};
auto in_n_offset = in_start_offset[0];
auto out_n_offset = 0;
for (int i = begin_[0]; signstride[0] * i < signstride[0] * end_[0];
i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) {
if (can_copy_memory[0]) {
CopyDataToOutput<T>(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
continue;
}
auto in_c_offset = in_start_offset[1];
auto out_c_offset = 0;
for (int j = begin_[1]; signstride[1] * j < signstride[1] * end_[1];
j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) {
if (can_copy_memory[1]) {
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
input_element_num_[1], 1);
continue;
}
auto in_h_offset = in_start_offset[2];
auto out_h_offset = 0;
for (int k = begin_[2]; signstride[2] * k < signstride[2] * end_[2];
k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) {
if (can_copy_memory[2]) {
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
continue;
}
for (int m = begin_[3]; signstride[3] * m < signstride[3] * end_[3]; m += strides_[3]) {
*output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m];
}
}
}
}
return true;
}
bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const {
for (size_t i = dim + 1; i < 4; ++i) {
if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) {
return false;
}
DoSliceNoParallel(input_addr, output_addr, &slice_param_, data_size_);
}
return true;
}
int SliceCPUKernel::SignOfStride(size_t axis) const {
if (strides_[axis] > 0) {
return 1;
}
return -1;
}
template <typename T>
void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
size_t copy_num, int id) const {
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto in_buff_size = inputs[0]->size;
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto out_buff_size = outputs[0]->size;
if ((in_offset + copy_num) * sizeof(T) > in_buff_size) {
MS_LOG(EXCEPTION) << "input memory out of bounds.";
}
if ((out_offset + copy_num) * sizeof(T) > out_buff_size) {
MS_LOG(EXCEPTION) << id << " output memory out of bounds.";
}
size_t buff_size = out_buff_size - out_offset * sizeof(T);
size_t copy_size = copy_num * sizeof(T);
if (buff_size < copy_size) {
MS_LOG(EXCEPTION) << "output buffer is not enough. memcpy failed!";
}
auto ret = memcpy_s(output_addr + out_offset, copy_size, input_addr + in_offset, copy_size);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret;
}
}
void SliceCPUKernel::TransArg() {
if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) {
MS_LOG(EXCEPTION) << "stride|end|input size must be equal";
}
for (size_t i = 0; i < strides_.size(); ++i) {
if (strides_[i] == 0) {
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
}
if (end_[i] == 0 && begin_[i] < 0) {
end_[i] = end_[i] + SizeToInt(input_shape_[i]);
}
if (end_[i] < 0) {
end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]);
}
if (end_[i] > SizeToInt(input_shape_[i])) {
end_[i] = SizeToInt(input_shape_[i]);
}
}
}
void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output.";
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower.";
}
if (input_shape.size() == 0) {
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -13,12 +13,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "nnacl/base/slice_base.h"
namespace mindspore {
namespace kernel {
@ -33,41 +37,20 @@ class SliceCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;
private:
template <typename T>
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
template <typename T>
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num,
int id) const;
void ExpandAllMemberDims();
bool CanCopyMemoryOnAxis(size_t dim) const;
int SignOfStride(size_t axis) const;
void CheckParam(const CNodePtr &kernel_node) const;
void TransArg();
void ClipBegin();
std::vector<int> begin_;
std::vector<int> end_;
std::vector<int> strides_;
std::vector<size_t> input_shape_;
std::vector<size_t> input_element_num_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
TypeId dtype_{kTypeUnknown};
void InitSliceParam(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
const std::vector<int64_t> &size);
void ParallelRun(void *input_addr, void *output_addr, int thread_num);
bool parallel_{true};
int data_size_{4};
SliceParameter slice_param_;
};
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel);
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,226 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/stridedslice_cpu_kernel.h"
#include <utility>
#include <functional>
#include <algorithm>
#include <unordered_map>
#include "common/thread_pool.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
enum PosType { kBegin, kEnd };
int NormalizePos(int pos, int dim_len, PosType pos_type) {
if (pos < 0) {
int normal_pos = pos + dim_len;
normal_pos = std::max(normal_pos, 0);
return normal_pos;
}
int max_pos = pos_type == kBegin ? dim_len - 1 : dim_len;
return std::min(pos, max_pos);
}
void StridedSliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (input_shape_.size() > DIMENSION_8D || input_shape_.empty()) {
MS_LOG(EXCEPTION) << "StridedSlice only support 1D to 8D input tensor, but got " << input_shape_.size() << "D.";
}
auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
if (begin.size() != end.size() || begin.size() != stride.size() || begin.size() > input_shape_.size()) {
MS_LOG(EXCEPTION)
<< "StridedSLice requires the length of begin, stride and end must be equal and less than input dimension.";
}
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
InitSliceParam(begin, end, stride);
parallel_ = MatchParallelPattern();
if (parallel_) {
InitParallelParam();
}
}
bool StridedSliceCPUKernel::MatchParallelPattern() {
// This function is seeking if that the number of only one dimension
// is different between input and output. If so, we can do some trick.
// Example 1:
// input shape info: [1, 80, 46, 40]
// output shape info: [1, 80, 20, 40]
// Example 2:
// input shape info: [1, 46, 40]
// output shape info: [1, 20, 40]
if (input_shape_.size() != output_shape_.size()) {
return false;
}
std::vector<int> axis_list;
for (size_t i = 0; i < input_shape_.size(); ++i) {
if (input_shape_[i] != output_shape_[i]) {
axis_list.emplace_back(i);
}
}
if (axis_list.size() == 1) {
split_axis_ = axis_list.front();
return true;
}
return false;
}
void StridedSliceCPUKernel::InitParallelParam() {
outer_ = SizeToInt(
std::accumulate(input_shape_.begin(), input_shape_.begin() + split_axis_, size_t(1), std::multiplies<size_t>()));
inner_ = SizeToInt(
std::accumulate(input_shape_.begin() + split_axis_ + 1, input_shape_.end(), size_t(1), std::multiplies<size_t>()));
int max_thread_num = SizeToInt(common::ThreadPool::GetInstance().GetSyncRunThreadNum());
int thread_num = 1;
if (outer_ == 1) {
parallel_strategy_ = kOnSplitAxis;
thread_num = std::min(SizeToInt(output_shape_[split_axis_]), max_thread_num);
cal_num_per_thread_ = UP_DIV(output_shape_[split_axis_], thread_num);
} else {
parallel_strategy_ = kOnOuter;
thread_num = std::min(outer_, max_thread_num);
cal_num_per_thread_ = UP_DIV(outer_, thread_num);
}
slice_param_.op_parameter_.thread_num_ = thread_num;
}
void StridedSliceCPUKernel::InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end,
const std::vector<int64_t> &stride) {
static const std::unordered_map<TypeId, std::pair<LiteDataType, int>> type_convert_map = {
{kNumberTypeBool, {kDataTypeBool, sizeof(bool)}},
{kNumberTypeInt32, {kDataTypeInt, sizeof(int)}},
{kNumberTypeFloat32, {kDataTypeFloat, sizeof(float)}},
{kNumberTypeFloat64, {kDataTypeFloat64, sizeof(double)}}};
auto type_pair = type_convert_map.find(dtype_);
if (type_pair == type_convert_map.end()) {
MS_LOG(EXCEPTION) << "StridedSlice supports bool, int32, float32 and float64 input tensor, but got "
<< TypeIdToType(dtype_)->ToString();
}
data_size_ = type_pair->second.second;
slice_param_.data_type = type_pair->second.first;
for (size_t i = 0; i < DIMENSION_8D; i++) {
if (i < begin.size()) {
int dim_len = SizeToInt(input_shape_[i]);
int begin_pos = LongToInt(begin[i]);
int end_pos = LongToInt(end[i]);
int stride_size = LongToInt(stride[i]);
if (stride_size == 0) {
MS_LOG(EXCEPTION) << "StridedSlice requires the each dimension slice stride can't be 0.";
}
slice_param_.in_shape_[i] = dim_len;
slice_param_.strides_[i] = stride_size;
slice_param_.begins_[i] = NormalizePos(begin_pos, dim_len, kBegin);
slice_param_.ends_[i] = NormalizePos(end_pos, dim_len, kEnd);
if (slice_param_.ends_[i] <= slice_param_.begins_[i] && slice_param_.strides_[i] > 0) {
slice_param_.ends_[i] = slice_param_.begins_[i] + 1;
}
if (slice_param_.ends_[i] >= slice_param_.begins_[i] && slice_param_.strides_[i] < 0) {
slice_param_.ends_[i] = slice_param_.begins_[i] - 1;
}
} else if (i < input_shape_.size()) {
int dim_len = SizeToInt(input_shape_[i]);
slice_param_.in_shape_[i] = dim_len;
slice_param_.begins_[i] = 0;
slice_param_.ends_[i] = dim_len;
slice_param_.strides_[i] = 1;
} else {
slice_param_.in_shape_[i] = 1;
slice_param_.begins_[i] = 0;
slice_param_.ends_[i] = 1;
slice_param_.strides_[i] = 1;
}
}
slice_param_.in_shape_length_ = DIMENSION_8D;
slice_param_.num_axes_ = DIMENSION_8D;
}
int StridedSliceCPUKernel::RunTaskOnOuter(uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
int begin_index = slice_param_.begins_[split_axis_];
int inner_size = inner_ * data_size_;
uint8_t *cur_in_ptr = input_addr + (start_pos * input_shape_[split_axis_] + begin_index) * inner_size;
uint8_t *cur_out_ptr = output_addr + start_pos * output_shape_[split_axis_] * inner_size;
int cur_outer = outer_ - start_pos;
if (cur_outer <= 0) {
return common::SUCCESS;
}
cur_outer = cur_outer > cal_num_per_thread_ ? cal_num_per_thread_ : cur_outer;
FastStride(cur_in_ptr, cur_out_ptr, output_shape_[split_axis_], slice_param_.strides_[split_axis_], cur_outer,
inner_size, input_shape_[split_axis_] * inner_size);
return common::SUCCESS;
}
int StridedSliceCPUKernel::RunTaskOnSplitAxis(uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
int begin_index = slice_param_.begins_[split_axis_];
int inner_size = inner_ * data_size_;
uint8_t *cur_in_ptr = input_addr + (start_pos * slice_param_.strides_[split_axis_] + begin_index) * inner_size;
uint8_t *cur_out_ptr = output_addr + start_pos * inner_size;
int cal_axis_num = output_shape_[split_axis_] - start_pos;
if (cal_axis_num <= 0) {
return common::SUCCESS;
}
cal_axis_num = cal_axis_num > cal_num_per_thread_ ? cal_num_per_thread_ : cal_axis_num;
FastStride(cur_in_ptr, cur_out_ptr, cal_axis_num, slice_param_.strides_[split_axis_], 1, inner_size, 0);
return common::SUCCESS;
}
void StridedSliceCPUKernel::ParallelRun(uint8_t *input_addr, uint8_t *output_addr, int thread_num) {
int thread_index = 0;
std::vector<common::Task> tasks;
std::function<int(StridedSliceCPUKernel *, uint8_t *, uint8_t *, int)> execute_func;
if (parallel_strategy_ == kOnOuter) {
execute_func = &StridedSliceCPUKernel::RunTaskOnOuter;
} else if (parallel_strategy_ == kOnSplitAxis) {
execute_func = &StridedSliceCPUKernel::RunTaskOnSplitAxis;
} else {
MS_LOG(EXCEPTION) << "Not supported parallel execute strategy for StridedSlice.";
}
while (thread_index < thread_num) {
tasks.emplace_back(std::bind(execute_func, this, input_addr, output_addr, thread_index * cal_num_per_thread_));
thread_index++;
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}
bool StridedSliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (outputs[0]->size == 0) {
return true;
}
auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr);
int thread_num = slice_param_.op_parameter_.thread_num_;
if (parallel_ && thread_num >= 2) {
ParallelRun(input_addr, output_addr, thread_num);
} else {
DoStridedSlice(input_addr, output_addr, &slice_param_);
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,74 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "nnacl/fp32/strided_slice_fp32.h"
namespace mindspore {
namespace kernel {
class StridedSliceCPUKernel : public CPUKernel {
public:
StridedSliceCPUKernel() = default;
~StridedSliceCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
enum ParallelStrategy { kOnSplitAxis, kOnOuter };
void InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end,
const std::vector<int64_t> &stride);
bool MatchParallelPattern();
void InitParallelParam();
void ParallelRun(uint8_t *input_addr, uint8_t *output_addr, int thread_num);
int RunTaskOnOuter(uint8_t *input_addr, uint8_t *output_addr, int start_pos);
int RunTaskOnSplitAxis(uint8_t *input_addr, uint8_t *output_addr, int start_pos);
TypeId dtype_;
int data_size_{4};
int split_axis_{-1};
int inner_{1};
int outer_{1};
int cal_num_per_thread_{1};
bool parallel_{false};
ParallelStrategy parallel_strategy_{kOnSplitAxis};
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
StridedSliceParameter slice_param_;
};
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
StridedSliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
StridedSliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
StridedSliceCPUKernel);
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
StridedSliceCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_

View File

@ -64,7 +64,7 @@ int ReduceCPUKernel::CallReduceUnit(int task_id) {
}
reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_),
static_cast<float *>(dst_data_), task_id, context_->thread_num_);
} else if (data_type_ == KDataTypeBool) {
} else if (data_type_ == kDataTypeBool) {
if (!bool_reducer_) {
MS_LOG(ERROR) << "function bool_reducer_ is null.";
return RET_NULL_PTR;
@ -96,7 +96,7 @@ int ReduceCPUKernel::Run() {
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
data_type_ = kDataTypeFloat;
} else if (in_tensors().at(0)->data_type() == kNumberTypeBool) {
data_type_ = KDataTypeBool;
data_type_ = kDataTypeBool;
} else {
data_type_ = kDataTypeInt;
}
@ -183,7 +183,7 @@ int ReduceCPUKernel::MallocTmpBuffer() {
void *buffer = nullptr;
if (data_type_ == kDataTypeFloat) {
buffer = context_->allocator->Malloc(size * sizeof(float));
} else if (data_type_ == KDataTypeBool) {
} else if (data_type_ == kDataTypeBool) {
buffer = context_->allocator->Malloc(size * sizeof(bool));
} else {
buffer = context_->allocator->Malloc(size * sizeof(int));