diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.cc index cbc79f35662..e02874ef429 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.cc @@ -22,64 +22,41 @@ namespace kernel { namespace { constexpr size_t kPdistInputsNum = 1; constexpr size_t kPdistOutputsNum = 1; -constexpr size_t kPdistInputDimsMin = 2; +constexpr float P_ZERO = 0.0; +constexpr float P_ONE = 1.0; +constexpr float P_TWO = 2.0; constexpr int64_t GRAIN_SIZE = 2048; } // namespace -template -void PdistZeroNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) { - double res = 0; - for (size_t i = 0; i < col; i++) { - res += (in1[i] != in2[1]); - } - *output = static_cast(res); -} +struct zdist_calc { + static inline double map(const double &diff, const float &p) { return std::min(ceil(diff), 1.0); } + static inline double red(const double &agg, const double &up) { return agg + up; } + static inline double finish(const double &agg, const float &p) { return agg; } +}; -template -void PdistInfNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) { - double res = 0; - for (size_t i = 0; i < col; i++) { - double x = static_cast(in1[i]); - double y = static_cast(in2[i]); - res = std::max(std::abs(x - y), res); - } - *output = static_cast(res); -} +struct odist_calc { + static inline double map(const double &diff, const float &p) { return diff; } + static inline double red(const double &agg, const double &up) { return agg + up; } + static inline double finish(const double &agg, const float &p) { return agg; } +}; -template -void PdistOneNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) { - double res = 0; - for (size_t i = 0; i < col; i++) { - double x = static_cast(in1[i]); - double y = static_cast(in2[i]); - res += std::abs(x - y); - } - *output = static_cast(res); -} +struct tdist_calc { + static inline double map(const double &diff, const float &p) { return diff * diff; } + static inline double red(const double &agg, const double &up) { return agg + up; } + static inline double finish(const double &agg, const float &p) { return std::sqrt(agg); } +}; -template -void PdistTwoNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) { - double res = 0; - for (size_t i = 0; i < col; i++) { - double x = static_cast(in1[i]); - double y = static_cast(in2[i]); - auto temp = x - y; - res += temp * temp; - } - *output = static_cast(std::sqrt(res)); -} +struct idist_calc { + static inline double map(const double &diff, const float &p) { return diff; } + static inline double red(const double &agg, const double &up) { return std::max(agg, up); } + static inline double finish(const double &agg, const float &p) { return agg; } +}; -template -void PdistPNormalcompute(const T *in1, const T *in2, T *output, size_t col, float p) { - double res = 0; - for (size_t i = 0; i < col; i++) { - double x = static_cast(in1[i]); - double y = static_cast(in2[i]); - res += std::pow(std::abs(x - y), p); - } - res = std::pow(res, 1.0 / p); - *output = static_cast(res); -} +struct pdist_calc { + static inline double map(const double &diff, const float &p) { return std::pow(diff, p); } + static inline double red(const double &agg, const double &up) { return agg + up; } + static inline double finish(const double &agg, const float &p) { return std::pow(agg, 1.0 / p); } +}; bool PdistCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, const std::vector &outputs) { @@ -90,31 +67,12 @@ bool PdistCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::ve } kernel_name_ = kernel_ptr->name(); p_ = kernel_ptr->get_p(); - if (inputs.size() != kPdistInputsNum || outputs.size() != kPdistOutputsNum) { - MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kPdistInputsNum << " and " - << kPdistOutputsNum << ", but get " << inputs.size() << " and " << outputs.size(); - return false; - } + auto input_shape = inputs[0]->GetShapeVector(); auto input_dim_ = input_shape.size(); h_ = input_shape[input_dim_ - kIndex2]; w_ = input_shape[input_dim_ - kIndex1]; - - auto input_dtype_ = inputs[0]->GetDtype(); - switch (input_dtype_) { - case kNumberTypeFloat64: - kernel_func_ = &PdistCpuKernelMod::LaunchKernel; - break; - case kNumberTypeFloat32: - kernel_func_ = &PdistCpuKernelMod::LaunchKernel; - break; - case kNumberTypeFloat16: - kernel_func_ = &PdistCpuKernelMod::LaunchKernel; - break; - default: - MS_LOG(ERROR) << "Pdist kernel does not support " << TypeIdToString(input_dtype_); - return false; - } + dtype_ = inputs[0]->GetDtype(); return true; } @@ -128,22 +86,22 @@ int PdistCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::v return 0; } -template -bool PdistCpuKernelMod::LaunchKernel(const std::vector &inputs, - const std::vector &outputs) { - auto input_size = inputs[0]->size / sizeof(T); +template +bool PdistCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + if (h_ == 1) { + return true; + } auto output_size = outputs[0]->size / sizeof(T); - const auto *input_start = GetDeviceAddress(inputs, kIndex0); - const auto *input_end = input_start + input_size; + const auto *input = GetDeviceAddress(inputs, kIndex0); auto *output = GetDeviceAddress(outputs, kIndex0); int64_t combs = h_ * (h_ - 1) / 2; int64_t one_size = h_ * w_; int64_t temp = one_size - w_; - auto task = [this, input_start, input_end, output, combs, one_size, temp](size_t start, size_t end) { + auto task = [this, input, output, combs, one_size, temp](size_t start, size_t end) { int64_t l = start / combs; int64_t k = start % combs; double h2 = h_ - .5; - int64_t i = static_cast((h2 - sqrtf(h2 * h2 - 2 * k - 1))); + int64_t i = static_cast((h2 - std::sqrt(h2 * h2 - 2 * k - 1))); int64_t j = k - h_ * i + i * (i + 1) / 2 + i + 1; i = i * w_; j = j * w_; @@ -151,19 +109,15 @@ bool PdistCpuKernelMod::LaunchKernel(const std::vector &inpu const T *const res_end = output + end; while (res != res_end) { - const T *input_i = input_start + l * one_size + i; - const T *input_j = input_start + l * one_size + j; - if (p_ == 0.0) { - PdistZeroNormalcompute(input_i, input_j, res, w_, p_); - } else if (p_ == 1.0) { - PdistOneNormalcompute(input_i, input_j, res, w_, p_); - } else if (p_ == 2.0) { - PdistTwoNormalcompute(input_i, input_j, res, w_, p_); - } else if (std::isinf(p_)) { - PdistInfNormalcompute(input_i, input_j, res, w_, p_); - } else { - PdistPNormalcompute(input_i, input_j, res, w_, p_); + const T *input_i = input + l * one_size + i; + const T *input_j = input + l * one_size + j; + double agg = 0; + for (size_t x = 0; x < w_; x++) { + double a = static_cast(*(input_i + x)); + double b = static_cast(*(input_j + x)); + agg = F::red(agg, F::map(std::abs(a - b), p_)); } + *res = static_cast(F::finish(agg, p_)); res += 1; j += w_; if (j == one_size) { @@ -181,6 +135,38 @@ bool PdistCpuKernelMod::LaunchKernel(const std::vector &inpu return true; } +template +void PdistCpuKernelMod::Apply_pdist(const std::vector &inputs, const std::vector &outputs) { + if (p_ == P_ZERO) { + LaunchKernel(inputs, outputs); + } else if (p_ == P_ONE) { + LaunchKernel(inputs, outputs); + } else if (p_ == P_TWO) { + LaunchKernel(inputs, outputs); + } else if (std::isinf(p_)) { + LaunchKernel(inputs, outputs); + } else { + LaunchKernel(inputs, outputs); + } +} + +bool PdistCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kPdistInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kPdistOutputsNum, kernel_name_); + if (dtype_ == kNumberTypeFloat64) { + Apply_pdist(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + Apply_pdist(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat16) { + Apply_pdist(inputs, outputs); + } else { + MS_LOG(ERROR) << "Pdist kernel does not support" << TypeIdToString(dtype_); + return false; + } + return true; +} + std::vector PdistCpuKernelMod::GetOpSupport() { std::vector support_list = { KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.h index 494c8382409..52a0ff1b579 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/pdist_cpu_kernel.h @@ -38,22 +38,21 @@ class PdistCpuKernelMod : public NativeCpuKernelMod { const std::vector &outputs, const std::map &others = std::map()) override; bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs) override { - return kernel_func_(this, inputs, outputs); - } + const std::vector &outputs) override; std::vector GetOpSupport() override; private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); + template - bool LaunchKernel(const std::vector &inputs, const std::vector &outputs); - using PdistKernel = std::function &, - const std::vector &)>; - PdistKernel kernel_func_; + void Apply_pdist(const std::vector &inputs, const std::vector &outputs); size_t h_; size_t w_; float p_; + TypeId dtype_{kTypeUnknown}; }; } // namespace kernel } // namespace mindspore