pad_multithread
This commit is contained in:
parent
b9b64615f2
commit
f58a83356b
|
@ -88,7 +88,7 @@ bool PadCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const s
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
bool PadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const {
|
||||
bool PadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
|
||||
const auto *inputs_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *outputs_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
if (memset_s(outputs_addr, outputs[0]->size, 0, outputs[0]->size) != EOK) {
|
||||
|
@ -96,18 +96,21 @@ bool PadCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std
|
|||
return false;
|
||||
}
|
||||
|
||||
for (size_t gt_id = 0; gt_id < input_size_; ++gt_id) {
|
||||
size_t linear_index = gt_id;
|
||||
size_t padded_linear_index = 0;
|
||||
for (size_t i = input_rank_; i >= 1; i--) {
|
||||
size_t unravel_dimension = input_shape_[i - 1];
|
||||
size_t unraveled_index = linear_index % unravel_dimension;
|
||||
padded_linear_index += ((unraveled_index + flattened_paddings_[kPadElemSize * (i - 1)]) * strides_[i - 1]);
|
||||
linear_index -= unraveled_index;
|
||||
linear_index /= unravel_dimension;
|
||||
auto task = [&inputs_addr, &outputs_addr, this](size_t start, size_t end) {
|
||||
for (size_t gt_id = start; gt_id < end; ++gt_id) {
|
||||
size_t linear_index = gt_id;
|
||||
size_t padded_linear_index = 0;
|
||||
for (size_t i = input_rank_; i >= 1; i--) {
|
||||
size_t unravel_dimension = input_shape_[i - 1];
|
||||
size_t unraveled_index = linear_index % unravel_dimension;
|
||||
padded_linear_index += ((unraveled_index + flattened_paddings_[kPadElemSize * (i - 1)]) * strides_[i - 1]);
|
||||
linear_index -= unraveled_index;
|
||||
linear_index /= unravel_dimension;
|
||||
}
|
||||
outputs_addr[padded_linear_index] = inputs_addr[gt_id];
|
||||
}
|
||||
outputs_addr[padded_linear_index] = inputs_addr[gt_id];
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, input_size_, this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -37,7 +37,7 @@ class PadCPUKernel : public CPUKernel {
|
|||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const;
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
std::vector<std::vector<int64_t>> paddings_;
|
||||
|
|
Loading…
Reference in New Issue