commit
064582a845
|
@ -14,63 +14,71 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/cpu/kernel/pdist_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
#include "mindspore/core/ops/pdist.h"
|
||||
#include "abstract/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kPdistInputsNum = 1;
|
||||
constexpr size_t kPdistOutputsNum = 1;
|
||||
constexpr size_t kPdistInputDimsMin = 2;
|
||||
constexpr int64_t GRAIN_SIZE = 2048;
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void PdistZeroNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col,
|
||||
size_t idx) {
|
||||
void PdistZeroNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) {
|
||||
double res = 0;
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
res += (input[start_x + i] == input[start_y + i]) ? 0 : 1;
|
||||
res += (in1[i] != in2[1]);
|
||||
}
|
||||
output[idx] = static_cast<T>(res);
|
||||
*output = static_cast<T>(res);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PdistInfNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col, size_t idx) {
|
||||
void PdistInfNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) {
|
||||
double res = 0;
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
double x = static_cast<double>(input[start_x + i]);
|
||||
double y = static_cast<double>(input[start_y + i]);
|
||||
double x = static_cast<double>(in1[i]);
|
||||
double y = static_cast<double>(in2[i]);
|
||||
res = std::max(std::abs(x - y), res);
|
||||
}
|
||||
output[idx] = static_cast<T>(res);
|
||||
*output = static_cast<T>(res);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PdistOneNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col, size_t idx) {
|
||||
void PdistOneNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) {
|
||||
double res = 0;
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
double x = static_cast<double>(input[start_x + i]);
|
||||
double y = static_cast<double>(input[start_y + i]);
|
||||
double x = static_cast<double>(in1[i]);
|
||||
double y = static_cast<double>(in2[i]);
|
||||
res += std::abs(x - y);
|
||||
}
|
||||
output[idx] = static_cast<T>(res);
|
||||
*output = static_cast<T>(res);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PdistNormalcompute(const T *input, T *output, size_t start_x, size_t start_y, float p, size_t col, size_t idx) {
|
||||
void PdistTwoNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) {
|
||||
double res = 0;
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
double x = static_cast<double>(input[start_x + i]);
|
||||
double y = static_cast<double>(input[start_y + i]);
|
||||
double x = static_cast<double>(in1[i]);
|
||||
double y = static_cast<double>(in2[i]);
|
||||
auto temp = x - y;
|
||||
res += temp * temp;
|
||||
}
|
||||
*output = static_cast<T>(std::sqrt(res));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PdistPNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) {
|
||||
double res = 0;
|
||||
for (size_t i = 0; i < col; i++) {
|
||||
double x = static_cast<double>(in1[i]);
|
||||
double y = static_cast<double>(in2[i]);
|
||||
res += std::pow(std::abs(x - y), p);
|
||||
}
|
||||
res = std::pow(res, 1.0 / p);
|
||||
output[idx] = static_cast<T>(res);
|
||||
*output = static_cast<T>(res);
|
||||
}
|
||||
|
||||
bool PdistCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
|
@ -88,9 +96,10 @@ bool PdistCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve
|
|||
return false;
|
||||
}
|
||||
auto input_shape = inputs[0]->GetShapeVector();
|
||||
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_), LongToSize);
|
||||
input_dim_ = input_shape_.size();
|
||||
input_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies<size_t>());
|
||||
auto input_dim_ = input_shape.size();
|
||||
h_ = input_shape[input_dim_ - kIndex2];
|
||||
w_ = input_shape[input_dim_ - kIndex1];
|
||||
|
||||
auto input_dtype_ = inputs[0]->GetDtype();
|
||||
switch (input_dtype_) {
|
||||
case kNumberTypeFloat64:
|
||||
|
@ -122,30 +131,53 @@ int PdistCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::v
|
|||
template <typename T>
|
||||
bool PdistCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto col = input_shape_[input_dim_ - 1];
|
||||
auto temp = input_shape_[input_dim_ - 1] * input_shape_[input_dim_ - 2];
|
||||
auto task = [this, &input, &output, col, temp](size_t start, size_t end) {
|
||||
size_t idx = 0;
|
||||
for (size_t i = start; i < end; i = i + temp) {
|
||||
for (size_t j = i; j < i + temp; j = j + col) {
|
||||
for (size_t k = j + col; k < i + temp; k = k + col) {
|
||||
if (p_ == 0.0) {
|
||||
PdistZeroNormalcompute(input, output, j, k, p_, col, idx);
|
||||
} else if (std::isinf(p_)) {
|
||||
PdistInfNormalcompute(input, output, j, k, p_, col, idx);
|
||||
} else if (p_ == 1.0) {
|
||||
PdistOneNormalcompute(input, output, j, k, p_, col, idx);
|
||||
} else {
|
||||
PdistNormalcompute(input, output, j, k, p_, col, idx);
|
||||
}
|
||||
idx++;
|
||||
auto input_size = inputs[0]->size / sizeof(T);
|
||||
auto output_size = outputs[0]->size / sizeof(T);
|
||||
const auto *input_start = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
const auto *input_end = input_start + input_size;
|
||||
auto *output = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
int64_t combs = h_ * (h_ - 1) / 2;
|
||||
int64_t one_size = h_ * w_;
|
||||
int64_t temp = one_size - w_;
|
||||
auto task = [this, input_start, input_end, output, combs, one_size, temp](size_t start, size_t end) {
|
||||
int64_t l = start / combs;
|
||||
int64_t k = start % combs;
|
||||
double h2 = h_ - .5;
|
||||
int64_t i = static_cast<int64_t>((h2 - sqrtf(h2 * h2 - 2 * k - 1)));
|
||||
int64_t j = k - h_ * i + i * (i + 1) / 2 + i + 1;
|
||||
i = i * w_;
|
||||
j = j * w_;
|
||||
T *res = output + start;
|
||||
const T *const res_end = output + end;
|
||||
|
||||
while (res != res_end) {
|
||||
const T *input_i = input_start + l * one_size + i;
|
||||
const T *input_j = input_start + l * one_size + j;
|
||||
if (p_ == 0.0) {
|
||||
PdistZeroNormalcompute(input_i, input_j, res, w_, p_);
|
||||
} else if (p_ == 1.0) {
|
||||
PdistOneNormalcompute(input_i, input_j, res, w_, p_);
|
||||
} else if (p_ == 2.0) {
|
||||
PdistTwoNormalcompute(input_i, input_j, res, w_, p_);
|
||||
} else if (std::isinf(p_)) {
|
||||
PdistInfNormalcompute(input_i, input_j, res, w_, p_);
|
||||
} else {
|
||||
PdistPNormalcompute(input_i, input_j, res, w_, p_);
|
||||
}
|
||||
res += 1;
|
||||
j += w_;
|
||||
if (j == one_size) {
|
||||
i += w_;
|
||||
j = i + w_;
|
||||
if (i == temp) {
|
||||
i = 0;
|
||||
j = w_;
|
||||
l += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, input_size_, this, ¶llel_search_info_, pool_);
|
||||
ParallelLaunch(task, output_size, GRAIN_SIZE / w_, this);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -20,8 +20,10 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <functional>
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "mindspore/core/ops/pdist.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
@ -49,9 +51,8 @@ class PdistCpuKernelMod : public NativeCpuKernelMod {
|
|||
const std::vector<kernel::AddressPtr> &)>;
|
||||
PdistKernel kernel_func_;
|
||||
|
||||
std::vector<size_t> input_shape_;
|
||||
size_t input_size_;
|
||||
size_t input_dim_;
|
||||
size_t h_;
|
||||
size_t w_;
|
||||
float p_;
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
Loading…
Reference in New Issue