add new GPU operator quantile
This commit is contained in:
parent
f411ab0e6a
commit
dfdab165de
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/quantile_impl.cuh"
|
||||
|
||||
int RoundUpPower2(int v) {
|
||||
v--;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
v++;
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ void Swap(T *lhs, T *rhs) {
|
||||
T tmp = lhs[0];
|
||||
lhs[0] = rhs[0];
|
||||
rhs[0] = tmp;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DoQuantile(const T *input, const T *q, T *out, T *sort, const int dim, const int x, const int y,
|
||||
const int z, const int each_q_elements, const int output_elements, const int ceil_p_2,
|
||||
int *nan) {
|
||||
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < output_elements; index += blockDim.x * gridDim.x) {
|
||||
size_t q_index = index / each_q_elements;
|
||||
size_t start = static_cast<size_t>((index % each_q_elements) / z) * ceil_p_2 * z + (index % each_q_elements) % z;
|
||||
T iq = q[q_index];
|
||||
int iqy_int = static_cast<int>(iq * static_cast<T>(y - 1));
|
||||
T iqy_T = static_cast<T>(iq * static_cast<T>(y - 1));
|
||||
int step = z * iqy_int;
|
||||
int input_index = start + step;
|
||||
if (nan[index % each_q_elements] == 2) {
|
||||
out[index] = NAN;
|
||||
} else {
|
||||
out[index] = static_cast<T>(sort[input_index] +
|
||||
(iqy_T - static_cast<T>(iqy_int)) * (sort[input_index + z] - sort[input_index]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Copy(const T *input, T *sort, const int x, const int ceil_p_2, const int y, const int z) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < x * ceil_p_2 * z; pos += blockDim.x * gridDim.x) {
|
||||
size_t input_x = static_cast<size_t>(pos / (ceil_p_2 * z));
|
||||
size_t input_y = static_cast<size_t>(pos % (ceil_p_2 * z) / z);
|
||||
size_t input_z = pos % z;
|
||||
sort[pos] = input_y < y ? input[input_x * y * z + input_y * z + input_z] : std::numeric_limits<T>::max();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BitonicSort(const int ceil_power2, T *rank_buff, const int clip_num, const int step) {
|
||||
for (size_t clip_i = blockIdx.x; clip_i < clip_num; clip_i += gridDim.x) {
|
||||
T *rank_buff_offset = rank_buff + static_cast<size_t>(clip_i / step) * ceil_power2 * step + clip_i % step;
|
||||
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
|
||||
for (size_t j = (i >> 1); j > 0; j >>= 1) {
|
||||
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
|
||||
size_t tid_comp = tid ^ j;
|
||||
if (tid_comp > tid) {
|
||||
if ((tid & i) == 0) {
|
||||
if (rank_buff_offset[tid * step] > rank_buff_offset[tid_comp * step]) {
|
||||
Swap(&rank_buff_offset[tid * step], &rank_buff_offset[tid_comp * step]);
|
||||
}
|
||||
} else {
|
||||
if (rank_buff_offset[tid * step] < rank_buff_offset[tid_comp * step]) {
|
||||
Swap(&rank_buff_offset[tid * step], &rank_buff_offset[tid_comp * step]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void QuantileKernelCheck(int num, const T *q, int *flag_in) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
|
||||
if (q[i] < 0 || q[i] > 1) {
|
||||
*flag_in = 1;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void QuantileKernelCheckNan(int x, int y, int z, int num, const T *input, int *flag_in) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
|
||||
if (std::isnan(input[i]) && flag_in[i / (y * z) * z + i % (y * z) % z] != 2) {
|
||||
flag_in[i / (y * z) * z + i % (y * z) % z] = 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void QuantileKernelCheckNanInit(int x, int y, int z, int num, const T *input, int *flag_in) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
|
||||
flag_in[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Quantile(const T *input, const T *q, T *out, T *sort, const int dim, const int x, const int y,
|
||||
const int z, const int each_q_elements, const int output_elements, int *flag_in,
|
||||
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
(void)cudaMemset(ret_flag_device, 0, sizeof(int));
|
||||
QuantileKernelCheck<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
output_elements / each_q_elements, q, ret_flag_device);
|
||||
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
(void)cudaMemcpy(flag_in, ret_flag_device, sizeof(int), cudaMemcpyDeviceToHost);
|
||||
(void)cudaMemset(nan_flags, 0, sizeof(int));
|
||||
QuantileKernelCheckNanInit<<<CUDA_BLOCKS(device_id, x * z), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
x, y, z, x * z, input, nan_flags);
|
||||
QuantileKernelCheckNan<<<CUDA_BLOCKS(device_id, x * y * z), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
x, y, z, x * y * z, input, nan_flags);
|
||||
int ceil_p_2 = RoundUpPower2(y);
|
||||
int thread = std::min(ceil_p_2, CUDA_THREADS(device_id));
|
||||
Copy<<<CUDA_BLOCKS(device_id, x * ceil_p_2 * z), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, sort, x, ceil_p_2,
|
||||
y, z);
|
||||
BitonicSort<<<x * z, thread, 0, cuda_stream>>>(ceil_p_2, sort, x * z, z);
|
||||
DoQuantile<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, q, out, sort, dim, x, y, z, each_q_elements, output_elements, ceil_p_2, nan_flags);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void Quantile<float>(const float *input, const float *q, float *out, float *sort,
|
||||
const int dim, const int x, const int y, const int z,
|
||||
const int each_q_elements, const int output_elements, int *flag_in,
|
||||
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Quantile<double>(const double *input, const double *q, double *out, double *sort,
|
||||
const int dim, const int x, const int y, const int z,
|
||||
const int each_q_elements, const int output_elements, int *flag_in,
|
||||
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_QUANTILE_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_QUANTILE_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Quantile(const T *input, const T *q, T *out, T *sort, const int dim, const int x, const int y,
|
||||
const int z, const int each_q_elements, const int output_elements, int *flag_in,
|
||||
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
CUDA_LIB_EXPORT int RoundUpPower2(int v);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_QUANTILE_IMPL_CUH_
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/quantile_gpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "abstract/utils.h"
|
||||
#include "mindspore/core/ops/quantile.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/quantile_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr int kQuantileDefaultDim = 10000;
|
||||
bool QuantileGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Quantile>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Cdist ops failed!";
|
||||
return false;
|
||||
}
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
dim_ = kernel_ptr->get_dim();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
input_unit_size_ = abstract::TypeIdSize(inputs[kIndex0]->GetDtype());
|
||||
q_unit_size_ = abstract::TypeIdSize(inputs[kIndex1]->GetDtype());
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t MaybeWrapDim(int dim, int dim_post_expr) {
|
||||
if (dim == kQuantileDefaultDim) {
|
||||
return dim;
|
||||
}
|
||||
if (dim_post_expr <= 0) {
|
||||
dim_post_expr = 1;
|
||||
}
|
||||
int min = -dim_post_expr;
|
||||
int max = dim_post_expr - 1;
|
||||
if (dim < min || dim > max) {
|
||||
MS_LOG(ERROR) << "For Quantile, dimension out of range (expected to be in range of " << min << " and [ " << max
|
||||
<< "]).";
|
||||
}
|
||||
if (dim < 0) {
|
||||
dim += dim_post_expr;
|
||||
}
|
||||
return dim;
|
||||
}
|
||||
|
||||
int QuantileGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
input_elements_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
auto q_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
auto q_elements = std::accumulate(q_shape.begin(), q_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), size_t(1), std::multiplies<size_t>());
|
||||
if (input_elements_ == 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be greater than zero.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
dim_ = MaybeWrapDim(dim_, input_shape.size());
|
||||
if (dim_ == kQuantileDefaultDim) {
|
||||
x_ = 1;
|
||||
y_ = 1;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) y_ *= input_shape.at(i);
|
||||
z_ = 1;
|
||||
dim_ = 0;
|
||||
} else {
|
||||
x_ = 1;
|
||||
y_ = input_shape.at(dim_);
|
||||
z_ = 1;
|
||||
for (int i = 0; i < dim_; i++) x_ *= input_shape.at(i);
|
||||
for (size_t i = dim_ + 1; i < input_shape.size(); i++) z_ *= input_shape.at(i);
|
||||
}
|
||||
each_q_elements_ = input_elements_ / y_;
|
||||
size_t input_size = input_elements_ * input_unit_size_;
|
||||
size_t q_size = q_elements * q_unit_size_;
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(q_size);
|
||||
output_size_list_.push_back(output_elements_ * input_unit_size_);
|
||||
ceil_power2_ = RoundUpPower2(y_);
|
||||
workspace_size_list_.push_back(input_size / y_ * ceil_power2_);
|
||||
workspace_size_list_.push_back(input_unit_size_);
|
||||
workspace_size_list_.push_back(output_elements_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool QuantileGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *q = GetDeviceAddress<T>(inputs, kIndex1);
|
||||
T *out = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
T *sort = GetDeviceAddress<T>(workspace, kIndex0);
|
||||
int *ret_flag_device = GetDeviceAddress<int>(workspace, kIndex1);
|
||||
int *nan_flags = GetDeviceAddress<int>(workspace, kIndex2);
|
||||
total_ = inputs[0]->size / sizeof(T);
|
||||
if (total_ <= 0) {
|
||||
MS_LOG(ERROR) << "For Quantile, input tensor must be non-empty";
|
||||
}
|
||||
int flag_in = 0;
|
||||
Quantile(input, q, out, sort, dim_, x_, y_, z_, each_q_elements_, output_elements_, &flag_in, ret_flag_device,
|
||||
nan_flags, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
if (flag_in == 1) {
|
||||
MS_EXCEPTION(ValueError) << "For Quantile, q out of range (expected to be in range of [0, 1]).";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, QuantileGpuKernelMod::QuantileFunc>> QuantileGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&QuantileGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&QuantileGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> QuantileGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, QuantileFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Quantile, QuantileGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_QUANTILE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_QUANTILE_GPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class QuantileGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
QuantileGpuKernelMod() {}
|
||||
~QuantileGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using QuantileFunc =
|
||||
std::function<bool(QuantileGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
void *cuda_stream_{nullptr};
|
||||
QuantileFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, QuantileFunc>> func_list_;
|
||||
size_t input_unit_size_{1};
|
||||
size_t q_unit_size_{1};
|
||||
size_t input_elements_{};
|
||||
size_t output_elements_{};
|
||||
size_t each_q_elements_{};
|
||||
int ceil_power2_{0};
|
||||
int dim_;
|
||||
int x_;
|
||||
int y_;
|
||||
int z_;
|
||||
size_t total_ = 0;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_QUANTILE_REORDER_GPU_KERNEL_H_
|
|
@ -928,6 +928,7 @@ GVAR_DEF(PrimitivePtr, kPrimSeLU, std::make_shared<Primitive>("SeLU"));
|
|||
GVAR_DEF(PrimitivePtr, kPrimGLU, std::make_shared<Primitive>(kGLU));
|
||||
GVAR_DEF(PrimitivePtr, kPrimGluGrad, std::make_shared<Primitive>(kGluGrad));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSoftplus, std::make_shared<Primitive>("Softplus"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimQuantile, std::make_shared<Primitive>("Quantile"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimSoftplusGrad, std::make_shared<Primitive>("SoftplusGrad"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimZeros, std::make_shared<Primitive>("Zeros"));
|
||||
GVAR_DEF(PrimitivePtr, kPrimZerosLike, std::make_shared<Primitive>(kZerosLike));
|
||||
|
|
|
@ -177,6 +177,7 @@ constexpr auto kOutQuantized = "out_quantized";
|
|||
constexpr auto kMvlgammaP = "mvlgamma_p";
|
||||
constexpr auto kP = "p";
|
||||
constexpr auto kMargin = "margin";
|
||||
constexpr auto kKeepdim = "keepdim";
|
||||
constexpr auto kPad = "pad";
|
||||
constexpr auto kPadding = "padding";
|
||||
constexpr auto kPaddingsElementSize = "paddings_element_size";
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/quantile.h"
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <iostream>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr int kQuantileDefaultDim = 10000;
|
||||
|
||||
abstract::ShapePtr QuantileInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto input = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto q_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto q_dim = q_shape.size();
|
||||
if (IsDynamicRank(input_shape) || IsDynamicRank(q_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
|
||||
std::vector<int64_t> out_shape;
|
||||
auto dim_ptr = primitive->GetAttr("dim");
|
||||
MS_EXCEPTION_IF_NULL(dim_ptr);
|
||||
|
||||
auto dim = GetValue<int64_t>(dim_ptr);
|
||||
int64_t input_dim = SizeToLong(input_shape.size());
|
||||
int64_t wrapped_input_dim = input_dim;
|
||||
|
||||
if (wrapped_input_dim == 0) {
|
||||
wrapped_input_dim = 1;
|
||||
}
|
||||
|
||||
if (dim != kQuantileDefaultDim && (dim < -wrapped_input_dim || dim >= wrapped_input_dim)) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the attr dim must be range of [" << -wrapped_input_dim
|
||||
<< "," << (wrapped_input_dim - 1) << "]";
|
||||
}
|
||||
|
||||
if (q_dim > 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the input q must be a scalar or 1D tensor,but got dimension = " << q_dim << ".";
|
||||
}
|
||||
|
||||
if (dim < 0) {
|
||||
dim = dim + wrapped_input_dim;
|
||||
}
|
||||
auto keep_dims_ptr = primitive->GetAttr("keep_dims");
|
||||
MS_EXCEPTION_IF_NULL(keep_dims_ptr);
|
||||
auto keep_dims = GetValue<bool>(keep_dims_ptr);
|
||||
int q_size = 1;
|
||||
for (uint64_t i = 0; i < q_shape.size(); i++) {
|
||||
q_size *= q_shape[i];
|
||||
}
|
||||
|
||||
if (dim != kQuantileDefaultDim && input_dim > 0) {
|
||||
out_shape = input_shape;
|
||||
if (keep_dims) {
|
||||
out_shape[dim] = 1;
|
||||
} else {
|
||||
out_shape.erase(out_shape.begin() + dim);
|
||||
}
|
||||
} else if (keep_dims) {
|
||||
out_shape = std::vector<int64_t>(input_dim, 1);
|
||||
}
|
||||
if (q_dim > 0) {
|
||||
out_shape.insert(out_shape.begin(), q_size);
|
||||
}
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr QuantileInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(input_type);
|
||||
auto q = input_args[1];
|
||||
MS_EXCEPTION_IF_NULL(q);
|
||||
auto q_type = input_args[1]->BuildType();
|
||||
MS_EXCEPTION_IF_NULL(q_type);
|
||||
auto prim_name = primitive->name();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
std::map<std::string, TypePtr> dict_type;
|
||||
(void)dict_type.insert(std::make_pair("q", q_type));
|
||||
(void)dict_type.insert(std::make_pair("input", input_type));
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input", input_type, valid_types, prim_name);
|
||||
|
||||
auto q_value = q->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(q_value);
|
||||
if (q->isa<abstract::AbstractTensor>()) {
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(dict_type, valid_types, prim_name);
|
||||
} else if (q->isa<abstract::AbstractScalar>()) {
|
||||
if (q_value != nullptr) {
|
||||
if (!q_value->isa<FloatImm>()) {
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the type of 'q' must be float or tensor, but got: " << q_type->ToString() << ".";
|
||||
}
|
||||
auto value = GetValue<float>(q_value);
|
||||
if (value < 0 || value > 1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the 'q' must in the range [0, 1], but got: " << value
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the type of 'q' must be float or tensor, but got: " << q_type->ToString() << ".";
|
||||
}
|
||||
return input_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void Quantile::set_dim(int64_t dim) { (void)AddAttr(kDim, api::MakeValue(dim)); }
|
||||
|
||||
void Quantile::set_keepdim(bool keepdim) { (void)AddAttr(kKeepdim, api::MakeValue(keepdim)); }
|
||||
|
||||
int64_t Quantile::get_dim() const {
|
||||
auto value_ptr = GetAttr(kDim);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
bool Quantile::get_keepdim() const {
|
||||
auto value_ptr = GetAttr(kKeepdim);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Quantile, BaseOperator);
|
||||
AbstractBasePtr QuantileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int64_t input_num = 2;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
|
||||
auto infer_type = QuantileInferType(primitive, input_args);
|
||||
auto infer_shape = QuantileInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Quantile, prim::kPrimQuantile, QuantileInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_QUANTILE_H_
|
||||
#define MINDSPORE_CORE_OPS_QUANTILE_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameQuantile = "Quantile";
|
||||
class MIND_API Quantile : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(Quantile);
|
||||
Quantile() : BaseOperator(kNameQuantile) { InitIOName({"input", "q"}, {"out"}); }
|
||||
void Init() {}
|
||||
void set_dim(int64_t dim);
|
||||
void set_keepdim(bool keepdim);
|
||||
int64_t get_dim() const;
|
||||
bool get_keepdim() const;
|
||||
};
|
||||
abstract::AbstractBasePtr QuantileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
using kPrimQuantilePtr = std::shared_ptr<Quantile>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_QUANTILE_H_
|
|
@ -3700,6 +3700,42 @@ class _LogicBinaryOp(_BinaryOp):
|
|||
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name)
|
||||
|
||||
|
||||
class Quantile(Primitive):
|
||||
r"""
|
||||
Computes the q-th quantiles of all elements in the input tensor, doing a linear interpolation when the
|
||||
q-th quantile lies between two data points.
|
||||
|
||||
Refer to :func:`mindspore.ops.quantile` and :func:`mindspore.ops.nanquantile` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> quantile = ops.Quantile()
|
||||
>>> input = Tensor(np.array([0.0700, -0.5446, 0.9214]), mindspore.float32)
|
||||
>>> q = Tensor(np.array([0, 0.5, 1]), mindspore.float32)
|
||||
>>> output = quantile(input, q)
|
||||
>>> print(output)
|
||||
[-0.5446 0.07 0.9214]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, dim=None, keep_dims=False, ignore_nan=False):
|
||||
"""Initialize Quantile"""
|
||||
if dim is not None:
|
||||
validator.check_value_type("dim", dim, [int], self.name)
|
||||
else:
|
||||
self.add_prim_attr("dim", 10000)
|
||||
if keep_dims is not None:
|
||||
validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
|
||||
else:
|
||||
self.add_prim_attr("keep_dims", False)
|
||||
if ignore_nan is not None:
|
||||
validator.check_value_type("ignore_nan", ignore_nan, [bool], self.name)
|
||||
else:
|
||||
self.add_prim_attr("ignore_nan", False)
|
||||
|
||||
|
||||
class Equal(Primitive):
|
||||
r"""
|
||||
Computes the equivalence between two tensors element-wise.
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations.math_ops as op
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.context as context
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
|
||||
class Quantile(Cell):
|
||||
def __init__(self, dim=0, keep_dims=False):
|
||||
super().__init__()
|
||||
self.quantile = op.Quantile(dim=dim, keep_dims=keep_dims)
|
||||
|
||||
def construct(self, x, q):
|
||||
return self.quantile(x, q)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_quantile_fp32():
|
||||
"""
|
||||
Feature: Quantile
|
||||
Description: Test of input
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
type_i = np.float32
|
||||
ertol_loss = 1e-04
|
||||
x = np.array([[1.0, 5.0, 9.0, 13], [2, 6, 10, 14],
|
||||
[3, 7, 11, 15], [4, 8, 12, 16]]).astype(type_i)
|
||||
q = np.array([0.25, 0.5, 0.75]).astype(type_i)
|
||||
dim = 0
|
||||
keep_dims = True
|
||||
net = Quantile(dim=dim, keep_dims=keep_dims)
|
||||
output = net(Tensor(x), Tensor(q))
|
||||
output = output.asnumpy()
|
||||
expect_output = np.array([[[1.75, 5.75, 9.75, 13.75]], [[2.5, 6.5, 10.5, 14.5]],
|
||||
[[3.25, 7.25, 11.25, 15.25]]]).astype(type_i)
|
||||
assert np.allclose(output, expect_output, ertol_loss)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_quantile_fp64():
|
||||
"""
|
||||
Feature: Quantile
|
||||
Description: Test of input
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
type_i = np.float64
|
||||
ertol_loss = 1e-05
|
||||
x = np.array([[1.0, 5.0, 9.0, 13], [2, 6, 10, 14],
|
||||
[3, 7, 11, 15], [4, 8, 12, 16]]).astype(type_i)
|
||||
q = np.array([0.25, 0.5, 0.75]).astype(type_i)
|
||||
dim = 0
|
||||
keep_dims = True
|
||||
net = Quantile(dim=dim, keep_dims=keep_dims)
|
||||
output = net(Tensor(x), Tensor(q))
|
||||
output = output.asnumpy()
|
||||
expect_output = np.array([[[1.75, 5.75, 9.75, 13.75]], [[2.5, 6.5, 10.5, 14.5]],
|
||||
[[3.25, 7.25, 11.25, 15.25]]]).astype(type_i)
|
||||
assert np.allclose(output, expect_output, ertol_loss)
|
Loading…
Reference in New Issue