pad_multithread

This commit is contained in:
wanyiming 2021-12-02 16:45:31 +08:00
parent b9b64615f2
commit f58a83356b
2 changed files with 16 additions and 13 deletions

View File

@ -88,7 +88,7 @@ bool PadCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const s
}
template <typename T>
bool PadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const {
bool PadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
const auto *inputs_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto *outputs_addr = reinterpret_cast<T *>(outputs[0]->addr);
if (memset_s(outputs_addr, outputs[0]->size, 0, outputs[0]->size) != EOK) {
@ -96,18 +96,21 @@ bool PadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std
return false;
}
for (size_t gt_id = 0; gt_id < input_size_; ++gt_id) {
size_t linear_index = gt_id;
size_t padded_linear_index = 0;
for (size_t i = input_rank_; i >= 1; i--) {
size_t unravel_dimension = input_shape_[i - 1];
size_t unraveled_index = linear_index % unravel_dimension;
padded_linear_index += ((unraveled_index + flattened_paddings_[kPadElemSize * (i - 1)]) * strides_[i - 1]);
linear_index -= unraveled_index;
linear_index /= unravel_dimension;
auto task = [&inputs_addr, &outputs_addr, this](size_t start, size_t end) {
for (size_t gt_id = start; gt_id < end; ++gt_id) {
size_t linear_index = gt_id;
size_t padded_linear_index = 0;
for (size_t i = input_rank_; i >= 1; i--) {
size_t unravel_dimension = input_shape_[i - 1];
size_t unraveled_index = linear_index % unravel_dimension;
padded_linear_index += ((unraveled_index + flattened_paddings_[kPadElemSize * (i - 1)]) * strides_[i - 1]);
linear_index -= unraveled_index;
linear_index /= unravel_dimension;
}
outputs_addr[padded_linear_index] = inputs_addr[gt_id];
}
outputs_addr[padded_linear_index] = inputs_addr[gt_id];
}
};
ParallelLaunchAutoSearch(task, input_size_, this, &parallel_search_info_);
return true;
}
} // namespace kernel

View File

@ -37,7 +37,7 @@ class PadCPUKernel : public CPUKernel {
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown};
std::vector<std::vector<int64_t>> paddings_;