forked from mindspore-Ecosystem/mindspore
convert the implementation of Slice and StridedSlice CPU operators to nnacl
This commit is contained in:
parent
9edf30bd05
commit
4a703e437e
|
@ -226,27 +226,14 @@ bool EltWiseGradCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inpu
|
|||
{"GeLUGrad", &EltWiseGradCPUKernel<T>::GeluGrad}, {"AsinGrad", &EltWiseGradCPUKernel<T>::AsinGrad},
|
||||
{"ACosGrad", &EltWiseGradCPUKernel<T>::ACosGrad}, {"AtanGrad", &EltWiseGradCPUKernel<T>::AtanGrad},
|
||||
{"AsinhGrad", &EltWiseGradCPUKernel<T>::AsinhGrad}, {"AcoshGrad", &EltWiseGradCPUKernel<T>::AcoshGrad}};
|
||||
T *input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
T *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
const auto *input1 = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
const auto *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t count = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
const float block_size = 128.0;
|
||||
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
|
||||
std::vector<common::Task> tasks;
|
||||
size_t start = 0;
|
||||
size_t once_compute_size = (count + thread_num - 1) / thread_num;
|
||||
while (start < count) {
|
||||
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
|
||||
auto block = [&, start, end]() {
|
||||
elt_map.at(kernel_name_)(this, input1, input2, output, start, end);
|
||||
return common::SUCCESS;
|
||||
};
|
||||
tasks.emplace_back(block);
|
||||
start += once_compute_size;
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
CPUKernelUtils::ParallelFor(
|
||||
std::bind(elt_map.at(kernel_name_), this, input1, input2, output, std::placeholders::_1, std::placeholders::_2),
|
||||
count);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -72,7 +72,7 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
|
|||
if (in_data == NULL || out_data == NULL || param == NULL) {
|
||||
return NNACL_NULL_PTR;
|
||||
}
|
||||
if (param->num_axes_ > DIMENSION_6D) {
|
||||
if (param->num_axes_ > DIMENSION_8D) {
|
||||
return NNACL_PARAM_INVALID;
|
||||
}
|
||||
|
||||
|
@ -107,6 +107,10 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
|
|||
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
|
||||
} else if (param->data_type == kDataTypeInt) {
|
||||
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
|
||||
} else if (param->data_type == kDataTypeFloat64) {
|
||||
*((double *)out_data + out_offset) = *((double *)in_data + in_offset);
|
||||
} else if (param->data_type == kDataTypeBool) {
|
||||
*((bool *)out_data + out_offset) = *((bool *)in_data + in_offset);
|
||||
#ifdef ENABLE_ARM64
|
||||
} else if (param->data_type == kDataTypeFloat16) {
|
||||
*((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset);
|
||||
|
|
|
@ -69,7 +69,8 @@ typedef enum LiteDataType {
|
|||
kDataTypeFloat16,
|
||||
kDataTypeInt,
|
||||
kDataTypeInt8,
|
||||
KDataTypeBool,
|
||||
kDataTypeBool,
|
||||
kDataTypeFloat64
|
||||
} LiteDataType;
|
||||
|
||||
typedef enum DataOrder {
|
||||
|
|
|
@ -46,26 +46,26 @@ void ReduceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
|
||||
if constexpr (std::is_same<T, bool>::value) {
|
||||
if (kernel_name == "ReduceAll") {
|
||||
reduce_type_ = ReduceType::ReduceAll;
|
||||
reduce_type_ = kReduceAll;
|
||||
reduce_func_ = [](const T *input, size_t pos, T *out) { *out &= input[pos]; };
|
||||
} else if (kernel_name == "ReduceAny") {
|
||||
reduce_type_ = ReduceType::ReduceAny;
|
||||
reduce_type_ = kReduceAny;
|
||||
reduce_func_ = [](const T *input, size_t pos, T *out) { *out |= input[pos]; };
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name_ << " for bool.";
|
||||
}
|
||||
} else {
|
||||
if (kernel_name == "ReduceMax") {
|
||||
reduce_type_ = ReduceType::ReduceMax;
|
||||
reduce_type_ = kReduceMax;
|
||||
reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::max(input[pos], *out); };
|
||||
} else if (kernel_name == "ReduceMin") {
|
||||
reduce_type_ = ReduceType::ReduceMin;
|
||||
reduce_type_ = kReduceMin;
|
||||
reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::min(input[pos], *out); };
|
||||
} else if (kernel_name == "ReduceSum") {
|
||||
reduce_type_ = ReduceType::ReduceSum;
|
||||
reduce_type_ = kReduceSum;
|
||||
reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; };
|
||||
} else if (kernel_name == "ReduceMean") {
|
||||
reduce_type_ = ReduceType::ReduceMean;
|
||||
reduce_type_ = kReduceMean;
|
||||
reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; };
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name;
|
||||
|
@ -86,7 +86,7 @@ bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
for (size_t i = 1; i < input_size; ++i) {
|
||||
reduce_func_(input_addr, i, output_addr);
|
||||
}
|
||||
if (reduce_type_ == ReduceType::ReduceMean) {
|
||||
if (reduce_type_ == kReduceMean) {
|
||||
*output_addr /= input_size;
|
||||
}
|
||||
} else {
|
||||
|
@ -126,7 +126,7 @@ bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|||
reduce_func_(input_addr, iter.GetPos(), &output_addr[i]);
|
||||
iter.GenNextPos();
|
||||
}
|
||||
if (reduce_type_ == ReduceType::ReduceMean) {
|
||||
if (reduce_type_ == kReduceMean) {
|
||||
output_addr[i] /= stride;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,8 +24,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
enum class ReduceType { ReduceAll, ReduceAny, ReduceMax, ReduceMin, ReduceSum, ReduceMean };
|
||||
|
||||
template <typename T>
|
||||
class ReduceCPUKernel : public CPUKernel {
|
||||
public:
|
||||
|
@ -36,6 +34,7 @@ class ReduceCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
enum ReduceType { kReduceAll, kReduceAny, kReduceMax, kReduceMin, kReduceSum, kReduceMean };
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<int64_t> axis_;
|
||||
ReduceType reduce_type_;
|
||||
|
|
|
@ -13,231 +13,107 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <algorithm>
|
||||
|
||||
#include "backend/kernel_compiler/cpu/slice_cpu_kernel.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "common/thread_pool.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int MAX_DIMS = 8;
|
||||
int NormalizeBeginPos(int begin_pos, int dim_len) {
|
||||
if (begin_pos < 0) {
|
||||
int normal_pos = begin_pos + dim_len;
|
||||
return std::max(normal_pos, 0);
|
||||
}
|
||||
return std::min(begin_pos, dim_len - 1);
|
||||
}
|
||||
|
||||
void SliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
std::vector<int64_t> begin_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto strides = prim->GetAttr(STRIDES);
|
||||
if (strides != nullptr) {
|
||||
std::vector<int64_t> strides_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
|
||||
std::vector<int64_t> end_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
|
||||
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
TransArg();
|
||||
ClipBegin();
|
||||
} else {
|
||||
std::vector<int> sizes;
|
||||
std::vector<int64_t> sizes_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE);
|
||||
(void)std::transform(sizes_me.begin(), sizes_me.end(), std::back_inserter(sizes),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
if (sizes.size() != input_shape_.size() || begin_.size() != input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "begin|size|input size must be equal";
|
||||
}
|
||||
ClipBegin();
|
||||
for (size_t i = 0; i < sizes.size(); ++i) {
|
||||
while (sizes[i] < 0) {
|
||||
sizes[i] = sizes[i] + SizeToInt(input_shape_[i]);
|
||||
}
|
||||
strides_.emplace_back(1);
|
||||
end_.emplace_back(begin_[i] + sizes[i]);
|
||||
}
|
||||
static const std::unordered_map<TypeId, int> type_size_map = {{kNumberTypeBool, sizeof(bool)},
|
||||
{kNumberTypeInt32, sizeof(int)},
|
||||
{kNumberTypeFloat32, sizeof(float)},
|
||||
{kNumberTypeFloat64, sizeof(double)}};
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > DIMENSION_8D || input_shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Slice only support 1D to 8D input tensor, but got " << input_shape.size() << "D.";
|
||||
}
|
||||
ExpandAllMemberDims();
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
auto size = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE);
|
||||
auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
if (begin.size() != input_shape.size() || size.size() != input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "Slice requires the length of begin and size must be equal to input dimension.";
|
||||
}
|
||||
InitSliceParam(input_shape, begin, size);
|
||||
|
||||
TypeId dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
auto size_pair = type_size_map.find(dtype);
|
||||
if (size_pair == type_size_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Slice supports bool, int32, float32 and float64 input tensor, but got "
|
||||
<< TypeIdToType(dtype)->ToString();
|
||||
}
|
||||
data_size_ = size_pair->second;
|
||||
}
|
||||
void SliceCPUKernel::ClipBegin() {
|
||||
for (size_t i = 0; i < begin_.size(); i++) {
|
||||
if (begin_[i] < 0) {
|
||||
auto k = begin_[i] + SizeToInt(input_shape_[i]);
|
||||
begin_[i] = k < 0 ? 0 : k;
|
||||
}
|
||||
if (begin_[i] > SizeToInt(input_shape_[i])) {
|
||||
begin_[i] = SizeToInt(input_shape_[i]);
|
||||
}
|
||||
|
||||
void SliceCPUKernel::ParallelRun(void *input_addr, void *output_addr, int thread_num) {
|
||||
std::vector<common::Task> tasks;
|
||||
int thread_index = 0;
|
||||
while (thread_index < thread_num) {
|
||||
auto block = [&, thread_index]() {
|
||||
DoSlice(input_addr, output_addr, &slice_param_, thread_index, data_size_);
|
||||
return common::SUCCESS;
|
||||
};
|
||||
tasks.emplace_back(block);
|
||||
thread_index++;
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
}
|
||||
void SliceCPUKernel::ExpandAllMemberDims() {
|
||||
auto input_len = input_shape_.size();
|
||||
if (input_len < 4) {
|
||||
for (size_t i = 0; i < 4 - input_len; ++i) {
|
||||
input_shape_.insert(input_shape_.begin(), 1);
|
||||
begin_.insert(begin_.begin(), 0);
|
||||
strides_.insert(strides_.begin(), 1);
|
||||
end_.insert(end_.begin(), 1);
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < 4; ++i) {
|
||||
if (SignOfStride(i)) {
|
||||
int ax = (end_[i] - begin_[i]) * SignOfStride(i);
|
||||
if (ax < 0) {
|
||||
ax = 0;
|
||||
|
||||
void SliceCPUKernel::InitSliceParam(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
|
||||
const std::vector<int64_t> &size) {
|
||||
for (size_t i = 0; i < DIMENSION_8D; i++) {
|
||||
if (i < input_shape.size()) {
|
||||
int dim_len = SizeToInt(input_shape[i]);
|
||||
int begin_pos = LongToInt(begin[i]);
|
||||
int slice_size = LongToInt(size[i]);
|
||||
if (slice_size <= 0) {
|
||||
MS_LOG(EXCEPTION) << "Slice requires the each dimension slice size must be greater than 0.";
|
||||
}
|
||||
output_shape_.push_back(IntToSize(ax));
|
||||
slice_param_.shape_[i] = dim_len;
|
||||
slice_param_.size_[i] = slice_size;
|
||||
slice_param_.begin_[i] = NormalizeBeginPos(begin_pos, dim_len);
|
||||
int end = slice_param_.begin_[i] + slice_param_.size_[i];
|
||||
slice_param_.end_[i] = std::min(end, dim_len);
|
||||
} else {
|
||||
slice_param_.shape_[i] = 1;
|
||||
slice_param_.begin_[i] = 0;
|
||||
slice_param_.size_[i] = 1;
|
||||
slice_param_.end_[i] = 1;
|
||||
}
|
||||
}
|
||||
slice_param_.param_length_ = DIMENSION_8D;
|
||||
|
||||
size_t max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
slice_param_.op_parameter_.thread_num_ = std::min(slice_param_.size_[1], SizeToInt(max_thread_num));
|
||||
}
|
||||
|
||||
bool SliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
bool ret{true};
|
||||
if (dtype_ == kNumberTypeInt32) {
|
||||
ret = LaunchKernel<int>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat32) {
|
||||
ret = LaunchKernel<float>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeBool) {
|
||||
ret = LaunchKernel<bool>(inputs, outputs);
|
||||
} else if (dtype_ == kNumberTypeFloat64) {
|
||||
ret = LaunchKernel<double>(inputs, outputs);
|
||||
if (outputs[0]->size == 0) {
|
||||
return true;
|
||||
}
|
||||
auto input_addr = inputs[0]->addr;
|
||||
auto output_addr = outputs[0]->addr;
|
||||
int thread_num = slice_param_.op_parameter_.thread_num_;
|
||||
if (parallel_ && thread_num >= 2) {
|
||||
ParallelRun(input_addr, output_addr, thread_num);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Slice op only support input_x bool,int32,float32 and float64";
|
||||
return false;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SliceCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
bool can_copy_memory[3] = {CanCopyMemoryOnAxis(0), CanCopyMemoryOnAxis(1), CanCopyMemoryOnAxis(2)};
|
||||
int signstride[4] = {SignOfStride(0), SignOfStride(1), SignOfStride(2), SignOfStride(3)};
|
||||
size_t in_start_offset[3] = {begin_[0] * input_element_num_[0], begin_[1] * input_element_num_[1],
|
||||
begin_[2] * input_element_num_[2]};
|
||||
size_t in_step_size[3] = {strides_[0] * input_element_num_[0], strides_[1] * input_element_num_[1],
|
||||
strides_[2] * input_element_num_[2]};
|
||||
|
||||
auto in_n_offset = in_start_offset[0];
|
||||
auto out_n_offset = 0;
|
||||
for (int i = begin_[0]; signstride[0] * i < signstride[0] * end_[0];
|
||||
i += strides_[0], in_n_offset += in_step_size[0], out_n_offset += output_element_num_[0]) {
|
||||
if (can_copy_memory[0]) {
|
||||
CopyDataToOutput<T>(inputs, in_n_offset, outputs, out_n_offset, input_element_num_[0], 0);
|
||||
continue;
|
||||
}
|
||||
auto in_c_offset = in_start_offset[1];
|
||||
auto out_c_offset = 0;
|
||||
for (int j = begin_[1]; signstride[1] * j < signstride[1] * end_[1];
|
||||
j += strides_[1], in_c_offset += in_step_size[1], out_c_offset += output_element_num_[1]) {
|
||||
if (can_copy_memory[1]) {
|
||||
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset, outputs, out_n_offset + out_c_offset,
|
||||
input_element_num_[1], 1);
|
||||
continue;
|
||||
}
|
||||
auto in_h_offset = in_start_offset[2];
|
||||
auto out_h_offset = 0;
|
||||
for (int k = begin_[2]; signstride[2] * k < signstride[2] * end_[2];
|
||||
k += strides_[2], in_h_offset += in_step_size[2], out_h_offset += output_element_num_[2]) {
|
||||
if (can_copy_memory[2]) {
|
||||
CopyDataToOutput<T>(inputs, in_n_offset + in_c_offset + in_h_offset, outputs,
|
||||
out_n_offset + out_c_offset + out_h_offset, input_element_num_[2], 2);
|
||||
continue;
|
||||
}
|
||||
for (int m = begin_[3]; signstride[3] * m < signstride[3] * end_[3]; m += strides_[3]) {
|
||||
*output_addr++ = input_addr[in_n_offset + in_c_offset + in_h_offset + m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SliceCPUKernel::CanCopyMemoryOnAxis(size_t dim) const {
|
||||
for (size_t i = dim + 1; i < 4; ++i) {
|
||||
if (begin_[i] != 0 || end_[i] != SizeToInt(input_shape_[i]) || strides_[i] != 1) {
|
||||
return false;
|
||||
}
|
||||
DoSliceNoParallel(input_addr, output_addr, &slice_param_, data_size_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int SliceCPUKernel::SignOfStride(size_t axis) const {
|
||||
if (strides_[axis] > 0) {
|
||||
return 1;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SliceCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset,
|
||||
size_t copy_num, int id) const {
|
||||
T *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto in_buff_size = inputs[0]->size;
|
||||
T *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto out_buff_size = outputs[0]->size;
|
||||
|
||||
if ((in_offset + copy_num) * sizeof(T) > in_buff_size) {
|
||||
MS_LOG(EXCEPTION) << "input memory out of bounds.";
|
||||
}
|
||||
if ((out_offset + copy_num) * sizeof(T) > out_buff_size) {
|
||||
MS_LOG(EXCEPTION) << id << " output memory out of bounds.";
|
||||
}
|
||||
|
||||
size_t buff_size = out_buff_size - out_offset * sizeof(T);
|
||||
size_t copy_size = copy_num * sizeof(T);
|
||||
if (buff_size < copy_size) {
|
||||
MS_LOG(EXCEPTION) << "output buffer is not enough. memcpy failed!";
|
||||
}
|
||||
auto ret = memcpy_s(output_addr + out_offset, copy_size, input_addr + in_offset, copy_size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed. ret:" << ret;
|
||||
}
|
||||
}
|
||||
|
||||
void SliceCPUKernel::TransArg() {
|
||||
if (strides_.size() != end_.size() || strides_.size() != input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "stride|end|input size must be equal";
|
||||
}
|
||||
for (size_t i = 0; i < strides_.size(); ++i) {
|
||||
if (strides_[i] == 0) {
|
||||
MS_LOG(EXCEPTION) << "slice stride cannot be zero";
|
||||
}
|
||||
if (end_[i] == 0 && begin_[i] < 0) {
|
||||
end_[i] = end_[i] + SizeToInt(input_shape_[i]);
|
||||
}
|
||||
if (end_[i] < 0) {
|
||||
end_[i] = end_[i] + SizeToInt(input_shape_[i]) < 0 ? 0 : end_[i] + SizeToInt(input_shape_[i]);
|
||||
}
|
||||
if (end_[i] > SizeToInt(input_shape_[i])) {
|
||||
end_[i] = SizeToInt(input_shape_[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SliceCPUKernel::CheckParam(const CNodePtr &kernel_node) const {
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but SliceCPUKernel needs 1 inputs.";
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but SliceCPUKernel needs 1 output.";
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > MAX_DIMS) {
|
||||
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but SliceCPUKernel olny support 4d or lower.";
|
||||
}
|
||||
if (input_shape.size() == 0) {
|
||||
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", scalar is not supported.";
|
||||
}
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -13,12 +13,16 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "nnacl/base/slice_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -33,41 +37,20 @@ class SliceCPUKernel : public CPUKernel {
|
|||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);
|
||||
template <typename T>
|
||||
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t in_offset,
|
||||
const std::vector<kernel::AddressPtr> &outputs, size_t out_offset, size_t copy_num,
|
||||
int id) const;
|
||||
void ExpandAllMemberDims();
|
||||
bool CanCopyMemoryOnAxis(size_t dim) const;
|
||||
int SignOfStride(size_t axis) const;
|
||||
void CheckParam(const CNodePtr &kernel_node) const;
|
||||
void TransArg();
|
||||
void ClipBegin();
|
||||
std::vector<int> begin_;
|
||||
std::vector<int> end_;
|
||||
std::vector<int> strides_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> input_element_num_;
|
||||
std::vector<size_t> output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
void InitSliceParam(const std::vector<size_t> &input_shape, const std::vector<int64_t> &begin,
|
||||
const std::vector<int64_t> &size);
|
||||
void ParallelRun(void *input_addr, void *output_addr, int thread_num);
|
||||
|
||||
bool parallel_{true};
|
||||
int data_size_{4};
|
||||
SliceParameter slice_param_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
MS_REG_CPU_KERNEL(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,226 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/stridedslice_cpu_kernel.h"
|
||||
|
||||
#include <utility>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "common/thread_pool.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
enum PosType { kBegin, kEnd };
|
||||
|
||||
int NormalizePos(int pos, int dim_len, PosType pos_type) {
|
||||
if (pos < 0) {
|
||||
int normal_pos = pos + dim_len;
|
||||
normal_pos = std::max(normal_pos, 0);
|
||||
return normal_pos;
|
||||
}
|
||||
int max_pos = pos_type == kBegin ? dim_len - 1 : dim_len;
|
||||
return std::min(pos, max_pos);
|
||||
}
|
||||
|
||||
void StridedSliceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
if (input_shape_.size() > DIMENSION_8D || input_shape_.empty()) {
|
||||
MS_LOG(EXCEPTION) << "StridedSlice only support 1D to 8D input tensor, but got " << input_shape_.size() << "D.";
|
||||
}
|
||||
|
||||
auto begin = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
auto end = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
|
||||
auto stride = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
|
||||
if (begin.size() != end.size() || begin.size() != stride.size() || begin.size() > input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< "StridedSLice requires the length of begin, stride and end must be equal and less than input dimension.";
|
||||
}
|
||||
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
||||
InitSliceParam(begin, end, stride);
|
||||
|
||||
parallel_ = MatchParallelPattern();
|
||||
if (parallel_) {
|
||||
InitParallelParam();
|
||||
}
|
||||
}
|
||||
|
||||
bool StridedSliceCPUKernel::MatchParallelPattern() {
|
||||
// This function is seeking if that the number of only one dimension
|
||||
// is different between input and output. If so, we can do some trick.
|
||||
// Example 1:
|
||||
// input shape info: [1, 80, 46, 40]
|
||||
// output shape info: [1, 80, 20, 40]
|
||||
// Example 2:
|
||||
// input shape info: [1, 46, 40]
|
||||
// output shape info: [1, 20, 40]
|
||||
if (input_shape_.size() != output_shape_.size()) {
|
||||
return false;
|
||||
}
|
||||
std::vector<int> axis_list;
|
||||
for (size_t i = 0; i < input_shape_.size(); ++i) {
|
||||
if (input_shape_[i] != output_shape_[i]) {
|
||||
axis_list.emplace_back(i);
|
||||
}
|
||||
}
|
||||
if (axis_list.size() == 1) {
|
||||
split_axis_ = axis_list.front();
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void StridedSliceCPUKernel::InitParallelParam() {
|
||||
outer_ = SizeToInt(
|
||||
std::accumulate(input_shape_.begin(), input_shape_.begin() + split_axis_, size_t(1), std::multiplies<size_t>()));
|
||||
inner_ = SizeToInt(
|
||||
std::accumulate(input_shape_.begin() + split_axis_ + 1, input_shape_.end(), size_t(1), std::multiplies<size_t>()));
|
||||
|
||||
int max_thread_num = SizeToInt(common::ThreadPool::GetInstance().GetSyncRunThreadNum());
|
||||
int thread_num = 1;
|
||||
if (outer_ == 1) {
|
||||
parallel_strategy_ = kOnSplitAxis;
|
||||
thread_num = std::min(SizeToInt(output_shape_[split_axis_]), max_thread_num);
|
||||
cal_num_per_thread_ = UP_DIV(output_shape_[split_axis_], thread_num);
|
||||
} else {
|
||||
parallel_strategy_ = kOnOuter;
|
||||
thread_num = std::min(outer_, max_thread_num);
|
||||
cal_num_per_thread_ = UP_DIV(outer_, thread_num);
|
||||
}
|
||||
slice_param_.op_parameter_.thread_num_ = thread_num;
|
||||
}
|
||||
|
||||
void StridedSliceCPUKernel::InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end,
|
||||
const std::vector<int64_t> &stride) {
|
||||
static const std::unordered_map<TypeId, std::pair<LiteDataType, int>> type_convert_map = {
|
||||
{kNumberTypeBool, {kDataTypeBool, sizeof(bool)}},
|
||||
{kNumberTypeInt32, {kDataTypeInt, sizeof(int)}},
|
||||
{kNumberTypeFloat32, {kDataTypeFloat, sizeof(float)}},
|
||||
{kNumberTypeFloat64, {kDataTypeFloat64, sizeof(double)}}};
|
||||
|
||||
auto type_pair = type_convert_map.find(dtype_);
|
||||
if (type_pair == type_convert_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "StridedSlice supports bool, int32, float32 and float64 input tensor, but got "
|
||||
<< TypeIdToType(dtype_)->ToString();
|
||||
}
|
||||
data_size_ = type_pair->second.second;
|
||||
slice_param_.data_type = type_pair->second.first;
|
||||
|
||||
for (size_t i = 0; i < DIMENSION_8D; i++) {
|
||||
if (i < begin.size()) {
|
||||
int dim_len = SizeToInt(input_shape_[i]);
|
||||
int begin_pos = LongToInt(begin[i]);
|
||||
int end_pos = LongToInt(end[i]);
|
||||
int stride_size = LongToInt(stride[i]);
|
||||
if (stride_size == 0) {
|
||||
MS_LOG(EXCEPTION) << "StridedSlice requires the each dimension slice stride can't be 0.";
|
||||
}
|
||||
slice_param_.in_shape_[i] = dim_len;
|
||||
slice_param_.strides_[i] = stride_size;
|
||||
slice_param_.begins_[i] = NormalizePos(begin_pos, dim_len, kBegin);
|
||||
slice_param_.ends_[i] = NormalizePos(end_pos, dim_len, kEnd);
|
||||
if (slice_param_.ends_[i] <= slice_param_.begins_[i] && slice_param_.strides_[i] > 0) {
|
||||
slice_param_.ends_[i] = slice_param_.begins_[i] + 1;
|
||||
}
|
||||
if (slice_param_.ends_[i] >= slice_param_.begins_[i] && slice_param_.strides_[i] < 0) {
|
||||
slice_param_.ends_[i] = slice_param_.begins_[i] - 1;
|
||||
}
|
||||
} else if (i < input_shape_.size()) {
|
||||
int dim_len = SizeToInt(input_shape_[i]);
|
||||
slice_param_.in_shape_[i] = dim_len;
|
||||
slice_param_.begins_[i] = 0;
|
||||
slice_param_.ends_[i] = dim_len;
|
||||
slice_param_.strides_[i] = 1;
|
||||
} else {
|
||||
slice_param_.in_shape_[i] = 1;
|
||||
slice_param_.begins_[i] = 0;
|
||||
slice_param_.ends_[i] = 1;
|
||||
slice_param_.strides_[i] = 1;
|
||||
}
|
||||
}
|
||||
slice_param_.in_shape_length_ = DIMENSION_8D;
|
||||
slice_param_.num_axes_ = DIMENSION_8D;
|
||||
}
|
||||
|
||||
int StridedSliceCPUKernel::RunTaskOnOuter(uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
|
||||
int begin_index = slice_param_.begins_[split_axis_];
|
||||
int inner_size = inner_ * data_size_;
|
||||
uint8_t *cur_in_ptr = input_addr + (start_pos * input_shape_[split_axis_] + begin_index) * inner_size;
|
||||
uint8_t *cur_out_ptr = output_addr + start_pos * output_shape_[split_axis_] * inner_size;
|
||||
int cur_outer = outer_ - start_pos;
|
||||
if (cur_outer <= 0) {
|
||||
return common::SUCCESS;
|
||||
}
|
||||
cur_outer = cur_outer > cal_num_per_thread_ ? cal_num_per_thread_ : cur_outer;
|
||||
FastStride(cur_in_ptr, cur_out_ptr, output_shape_[split_axis_], slice_param_.strides_[split_axis_], cur_outer,
|
||||
inner_size, input_shape_[split_axis_] * inner_size);
|
||||
return common::SUCCESS;
|
||||
}
|
||||
|
||||
int StridedSliceCPUKernel::RunTaskOnSplitAxis(uint8_t *input_addr, uint8_t *output_addr, int start_pos) {
|
||||
int begin_index = slice_param_.begins_[split_axis_];
|
||||
int inner_size = inner_ * data_size_;
|
||||
uint8_t *cur_in_ptr = input_addr + (start_pos * slice_param_.strides_[split_axis_] + begin_index) * inner_size;
|
||||
uint8_t *cur_out_ptr = output_addr + start_pos * inner_size;
|
||||
int cal_axis_num = output_shape_[split_axis_] - start_pos;
|
||||
if (cal_axis_num <= 0) {
|
||||
return common::SUCCESS;
|
||||
}
|
||||
cal_axis_num = cal_axis_num > cal_num_per_thread_ ? cal_num_per_thread_ : cal_axis_num;
|
||||
FastStride(cur_in_ptr, cur_out_ptr, cal_axis_num, slice_param_.strides_[split_axis_], 1, inner_size, 0);
|
||||
return common::SUCCESS;
|
||||
}
|
||||
|
||||
void StridedSliceCPUKernel::ParallelRun(uint8_t *input_addr, uint8_t *output_addr, int thread_num) {
|
||||
int thread_index = 0;
|
||||
std::vector<common::Task> tasks;
|
||||
std::function<int(StridedSliceCPUKernel *, uint8_t *, uint8_t *, int)> execute_func;
|
||||
if (parallel_strategy_ == kOnOuter) {
|
||||
execute_func = &StridedSliceCPUKernel::RunTaskOnOuter;
|
||||
} else if (parallel_strategy_ == kOnSplitAxis) {
|
||||
execute_func = &StridedSliceCPUKernel::RunTaskOnSplitAxis;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not supported parallel execute strategy for StridedSlice.";
|
||||
}
|
||||
|
||||
while (thread_index < thread_num) {
|
||||
tasks.emplace_back(std::bind(execute_func, this, input_addr, output_addr, thread_index * cal_num_per_thread_));
|
||||
thread_index++;
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
}
|
||||
|
||||
bool StridedSliceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (outputs[0]->size == 0) {
|
||||
return true;
|
||||
}
|
||||
auto input_addr = reinterpret_cast<uint8_t *>(inputs[0]->addr);
|
||||
auto output_addr = reinterpret_cast<uint8_t *>(outputs[0]->addr);
|
||||
int thread_num = slice_param_.op_parameter_.thread_num_;
|
||||
if (parallel_ && thread_num >= 2) {
|
||||
ParallelRun(input_addr, output_addr, thread_num);
|
||||
} else {
|
||||
DoStridedSlice(input_addr, output_addr, &slice_param_);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,74 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "nnacl/fp32/strided_slice_fp32.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class StridedSliceCPUKernel : public CPUKernel {
|
||||
public:
|
||||
StridedSliceCPUKernel() = default;
|
||||
~StridedSliceCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
enum ParallelStrategy { kOnSplitAxis, kOnOuter };
|
||||
|
||||
void InitSliceParam(const std::vector<int64_t> &begin, const std::vector<int64_t> &end,
|
||||
const std::vector<int64_t> &stride);
|
||||
bool MatchParallelPattern();
|
||||
void InitParallelParam();
|
||||
void ParallelRun(uint8_t *input_addr, uint8_t *output_addr, int thread_num);
|
||||
int RunTaskOnOuter(uint8_t *input_addr, uint8_t *output_addr, int start_pos);
|
||||
int RunTaskOnSplitAxis(uint8_t *input_addr, uint8_t *output_addr, int start_pos);
|
||||
|
||||
TypeId dtype_;
|
||||
int data_size_{4};
|
||||
int split_axis_{-1};
|
||||
int inner_{1};
|
||||
int outer_{1};
|
||||
int cal_num_per_thread_{1};
|
||||
bool parallel_{false};
|
||||
ParallelStrategy parallel_strategy_{kOnSplitAxis};
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
StridedSliceParameter slice_param_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
StridedSliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
StridedSliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
StridedSliceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(StridedSlice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
StridedSliceCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SLICE_CPU_KERNEL_H_
|
|
@ -64,7 +64,7 @@ int ReduceCPUKernel::CallReduceUnit(int task_id) {
|
|||
}
|
||||
reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_),
|
||||
static_cast<float *>(dst_data_), task_id, context_->thread_num_);
|
||||
} else if (data_type_ == KDataTypeBool) {
|
||||
} else if (data_type_ == kDataTypeBool) {
|
||||
if (!bool_reducer_) {
|
||||
MS_LOG(ERROR) << "function bool_reducer_ is null.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -96,7 +96,7 @@ int ReduceCPUKernel::Run() {
|
|||
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
|
||||
data_type_ = kDataTypeFloat;
|
||||
} else if (in_tensors().at(0)->data_type() == kNumberTypeBool) {
|
||||
data_type_ = KDataTypeBool;
|
||||
data_type_ = kDataTypeBool;
|
||||
} else {
|
||||
data_type_ = kDataTypeInt;
|
||||
}
|
||||
|
@ -183,7 +183,7 @@ int ReduceCPUKernel::MallocTmpBuffer() {
|
|||
void *buffer = nullptr;
|
||||
if (data_type_ == kDataTypeFloat) {
|
||||
buffer = context_->allocator->Malloc(size * sizeof(float));
|
||||
} else if (data_type_ == KDataTypeBool) {
|
||||
} else if (data_type_ == kDataTypeBool) {
|
||||
buffer = context_->allocator->Malloc(size * sizeof(bool));
|
||||
} else {
|
||||
buffer = context_->allocator->Malloc(size * sizeof(int));
|
||||
|
|
Loading…
Reference in New Issue