remove code warning

This commit is contained in:
zhujingxuan 2022-07-25 21:21:18 +08:00
parent efc85e972a
commit 8aae722162
2 changed files with 43 additions and 9 deletions

View File

@ -50,6 +50,8 @@ int LinSpaceCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
batch_num_ = std::accumulate(input_shape.begin(), input_shape.end(), int64_t(1), std::multiplies{});
batch_num_ = (batch_num_ == 0) ? 1 : batch_num_;
multi_dims_ = (batch_num_ != 1);
const auto dtype_size = abstract::TypeIdSize(inputs.at(kIndex0)->GetDtype());
// Deal with workspace_size_list_
@ -74,27 +76,56 @@ bool LinSpaceCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &i
return true;
}
if (multi_dims_) {
return LaunchVmapKernel<T>(inputs, workspace, outputs);
}
auto start = *reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto stop = *reinterpret_cast<T *>(inputs[kIndex1]->addr);
auto output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
const auto step = ((stop - start) / (num - 1));
auto task = [output, start, step](size_t start_index, size_t end_index) {
for (size_t i = start_index; i < end_index; i++) {
output[i] = start + step * i;
}
};
ParallelLaunchAutoSearch(task, num, this, &parallel_search_info_);
return true;
}
template <typename T>
bool LinSpaceCpuKernelMod::LaunchVmapKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
auto starts = reinterpret_cast<T *>(inputs[kIndex0]->addr);
auto stops = reinterpret_cast<T *>(inputs[kIndex1]->addr);
const int64_t num = *reinterpret_cast<int64_t *>(inputs[kIndex2]->addr);
auto add_values = reinterpret_cast<T *>(workspace[kIndex0]->addr);
auto steps = reinterpret_cast<T *>(workspace[kIndex0]->addr);
auto output = reinterpret_cast<T *>(outputs[kIndex0]->addr);
for (int64_t i = 0; i < batch_num_; ++i) {
add_values[i] = ((stops[i] - starts[i]) / (num - 1));
steps[i] = ((stops[i] - starts[i]) / (num - 1));
}
size_t num_t = LongToSize(num);
// Run parallel both on batch and also the calculated axis
auto task = [output, num, starts, add_values](size_t start, size_t end) {
for (size_t index = start; index < end; index++) {
const size_t batch = index / num;
const size_t i = index % num;
output[index] = starts[batch] + add_values[batch] * i;
auto task = [output, num_t, starts, steps](size_t start, size_t end) {
while (start < end) {
const size_t batch = start / num_t;
const size_t offset = batch * num_t;
for (size_t i = start; i < (batch + 1) * num_t; ++i) {
output[i] = starts[batch] + steps[batch] * (i - offset);
}
start = (batch + 1) * num_t;
}
};
ParallelLaunchAutoSearch(task, batch_num_ * num, this, &parallel_search_info_);
return true;
}

View File

@ -51,8 +51,11 @@ class LinSpaceCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs);
template <typename T>
bool LaunchVmapKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs);
int64_t batch_num_{0};
bool multi_dims_{false};
};
} // namespace kernel
} // namespace mindspore