add new GPU operator quantile

This commit is contained in:
arranclo 2022-11-23 19:39:19 +08:00
parent f411ab0e6a
commit dfdab165de
10 changed files with 738 additions and 0 deletions

View File

@ -0,0 +1,155 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <limits>
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/quantile_impl.cuh"
int RoundUpPower2(int v) {
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return v;
}
template <typename T>
__inline__ __device__ void Swap(T *lhs, T *rhs) {
T tmp = lhs[0];
lhs[0] = rhs[0];
rhs[0] = tmp;
}
template <typename T>
__global__ void DoQuantile(const T *input, const T *q, T *out, T *sort, const int dim, const int x, const int y,
const int z, const int each_q_elements, const int output_elements, const int ceil_p_2,
int *nan) {
for (size_t index = blockIdx.x * blockDim.x + threadIdx.x; index < output_elements; index += blockDim.x * gridDim.x) {
size_t q_index = index / each_q_elements;
size_t start = static_cast<size_t>((index % each_q_elements) / z) * ceil_p_2 * z + (index % each_q_elements) % z;
T iq = q[q_index];
int iqy_int = static_cast<int>(iq * static_cast<T>(y - 1));
T iqy_T = static_cast<T>(iq * static_cast<T>(y - 1));
int step = z * iqy_int;
int input_index = start + step;
if (nan[index % each_q_elements] == 2) {
out[index] = NAN;
} else {
out[index] = static_cast<T>(sort[input_index] +
(iqy_T - static_cast<T>(iqy_int)) * (sort[input_index + z] - sort[input_index]));
}
}
}
template <typename T>
__global__ void Copy(const T *input, T *sort, const int x, const int ceil_p_2, const int y, const int z) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < x * ceil_p_2 * z; pos += blockDim.x * gridDim.x) {
size_t input_x = static_cast<size_t>(pos / (ceil_p_2 * z));
size_t input_y = static_cast<size_t>(pos % (ceil_p_2 * z) / z);
size_t input_z = pos % z;
sort[pos] = input_y < y ? input[input_x * y * z + input_y * z + input_z] : std::numeric_limits<T>::max();
}
}
template <typename T>
__global__ void BitonicSort(const int ceil_power2, T *rank_buff, const int clip_num, const int step) {
for (size_t clip_i = blockIdx.x; clip_i < clip_num; clip_i += gridDim.x) {
T *rank_buff_offset = rank_buff + static_cast<size_t>(clip_i / step) * ceil_power2 * step + clip_i % step;
for (size_t i = 2; i <= ceil_power2; i <<= 1) {
for (size_t j = (i >> 1); j > 0; j >>= 1) {
for (size_t tid = threadIdx.x; tid < ceil_power2; tid += blockDim.x) {
size_t tid_comp = tid ^ j;
if (tid_comp > tid) {
if ((tid & i) == 0) {
if (rank_buff_offset[tid * step] > rank_buff_offset[tid_comp * step]) {
Swap(&rank_buff_offset[tid * step], &rank_buff_offset[tid_comp * step]);
}
} else {
if (rank_buff_offset[tid * step] < rank_buff_offset[tid_comp * step]) {
Swap(&rank_buff_offset[tid * step], &rank_buff_offset[tid_comp * step]);
}
}
}
}
__syncthreads();
}
}
}
}
template <typename T>
__global__ void QuantileKernelCheck(int num, const T *q, int *flag_in) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
if (q[i] < 0 || q[i] > 1) {
*flag_in = 1;
return;
}
}
}
template <typename T>
__global__ void QuantileKernelCheckNan(int x, int y, int z, int num, const T *input, int *flag_in) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
if (std::isnan(input[i]) && flag_in[i / (y * z) * z + i % (y * z) % z] != 2) {
flag_in[i / (y * z) * z + i % (y * z) % z] = 2;
}
}
}
template <typename T>
__global__ void QuantileKernelCheckNanInit(int x, int y, int z, int num, const T *input, int *flag_in) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += gridDim.x * blockDim.x) {
flag_in[i] = 0;
}
}
template <typename T>
CUDA_LIB_EXPORT void Quantile(const T *input, const T *q, T *out, T *sort, const int dim, const int x, const int y,
const int z, const int each_q_elements, const int output_elements, int *flag_in,
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
cudaStream_t cuda_stream) {
(void)cudaMemset(ret_flag_device, 0, sizeof(int));
QuantileKernelCheck<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
output_elements / each_q_elements, q, ret_flag_device);
cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(cuda_stream));
(void)cudaMemcpy(flag_in, ret_flag_device, sizeof(int), cudaMemcpyDeviceToHost);
(void)cudaMemset(nan_flags, 0, sizeof(int));
QuantileKernelCheckNanInit<<<CUDA_BLOCKS(device_id, x * z), CUDA_THREADS(device_id), 0, cuda_stream>>>(
x, y, z, x * z, input, nan_flags);
QuantileKernelCheckNan<<<CUDA_BLOCKS(device_id, x * y * z), CUDA_THREADS(device_id), 0, cuda_stream>>>(
x, y, z, x * y * z, input, nan_flags);
int ceil_p_2 = RoundUpPower2(y);
int thread = std::min(ceil_p_2, CUDA_THREADS(device_id));
Copy<<<CUDA_BLOCKS(device_id, x * ceil_p_2 * z), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, sort, x, ceil_p_2,
y, z);
BitonicSort<<<x * z, thread, 0, cuda_stream>>>(ceil_p_2, sort, x * z, z);
DoQuantile<<<CUDA_BLOCKS(device_id, output_elements), CUDA_THREADS(device_id), 0, cuda_stream>>>(
input, q, out, sort, dim, x, y, z, each_q_elements, output_elements, ceil_p_2, nan_flags);
}
template CUDA_LIB_EXPORT void Quantile<float>(const float *input, const float *q, float *out, float *sort,
const int dim, const int x, const int y, const int z,
const int each_q_elements, const int output_elements, int *flag_in,
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Quantile<double>(const double *input, const double *q, double *out, double *sort,
const int dim, const int x, const int y, const int z,
const int each_q_elements, const int output_elements, int *flag_in,
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,27 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_QUANTILE_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_QUANTILE_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void Quantile(const T *input, const T *q, T *out, T *sort, const int dim, const int x, const int y,
const int z, const int each_q_elements, const int output_elements, int *flag_in,
int *ret_flag_device, int *nan_flags, const uint32_t &device_id,
cudaStream_t cuda_stream);
CUDA_LIB_EXPORT int RoundUpPower2(int v);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_QUANTILE_IMPL_CUH_

View File

@ -0,0 +1,162 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/quantile_gpu_kernel.h"
#include <functional>
#include <utility>
#include <string>
#include <algorithm>
#include <memory>
#include "abstract/utils.h"
#include "mindspore/core/ops/quantile.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/quantile_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr int kQuantileDefaultDim = 10000;
bool QuantileGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Quantile>(base_operator);
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Cdist ops failed!";
return false;
}
kernel_name_ = kernel_ptr->name();
dim_ = kernel_ptr->get_dim();
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
input_unit_size_ = abstract::TypeIdSize(inputs[kIndex0]->GetDtype());
q_unit_size_ = abstract::TypeIdSize(inputs[kIndex1]->GetDtype());
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
return true;
}
uint32_t MaybeWrapDim(int dim, int dim_post_expr) {
if (dim == kQuantileDefaultDim) {
return dim;
}
if (dim_post_expr <= 0) {
dim_post_expr = 1;
}
int min = -dim_post_expr;
int max = dim_post_expr - 1;
if (dim < min || dim > max) {
MS_LOG(ERROR) << "For Quantile, dimension out of range (expected to be in range of " << min << " and [ " << max
<< "]).";
}
if (dim < 0) {
dim += dim_post_expr;
}
return dim;
}
int QuantileGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
input_elements_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
auto input_shape = inputs.at(kIndex0)->GetShapeVector();
auto q_shape = inputs.at(kIndex1)->GetShapeVector();
auto output_shape = outputs.at(kIndex0)->GetShapeVector();
input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), size_t(1), std::multiplies<size_t>());
auto q_elements = std::accumulate(q_shape.begin(), q_shape.end(), size_t(1), std::multiplies<size_t>());
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), size_t(1), std::multiplies<size_t>());
if (input_elements_ == 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be greater than zero.";
return KRET_RESIZE_FAILED;
}
dim_ = MaybeWrapDim(dim_, input_shape.size());
if (dim_ == kQuantileDefaultDim) {
x_ = 1;
y_ = 1;
for (size_t i = 0; i < input_shape.size(); i++) y_ *= input_shape.at(i);
z_ = 1;
dim_ = 0;
} else {
x_ = 1;
y_ = input_shape.at(dim_);
z_ = 1;
for (int i = 0; i < dim_; i++) x_ *= input_shape.at(i);
for (size_t i = dim_ + 1; i < input_shape.size(); i++) z_ *= input_shape.at(i);
}
each_q_elements_ = input_elements_ / y_;
size_t input_size = input_elements_ * input_unit_size_;
size_t q_size = q_elements * q_unit_size_;
input_size_list_.push_back(input_size);
input_size_list_.push_back(q_size);
output_size_list_.push_back(output_elements_ * input_unit_size_);
ceil_power2_ = RoundUpPower2(y_);
workspace_size_list_.push_back(input_size / y_ * ceil_power2_);
workspace_size_list_.push_back(input_unit_size_);
workspace_size_list_.push_back(output_elements_);
return KRET_OK;
}
template <typename T>
bool QuantileGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input = GetDeviceAddress<T>(inputs, kIndex0);
T *q = GetDeviceAddress<T>(inputs, kIndex1);
T *out = GetDeviceAddress<T>(outputs, kIndex0);
T *sort = GetDeviceAddress<T>(workspace, kIndex0);
int *ret_flag_device = GetDeviceAddress<int>(workspace, kIndex1);
int *nan_flags = GetDeviceAddress<int>(workspace, kIndex2);
total_ = inputs[0]->size / sizeof(T);
if (total_ <= 0) {
MS_LOG(ERROR) << "For Quantile, input tensor must be non-empty";
}
int flag_in = 0;
Quantile(input, q, out, sort, dim_, x_, y_, z_, each_q_elements_, output_elements_, &flag_in, ret_flag_device,
nan_flags, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
if (flag_in == 1) {
MS_EXCEPTION(ValueError) << "For Quantile, q out of range (expected to be in range of [0, 1]).";
return false;
}
return true;
}
std::vector<std::pair<KernelAttr, QuantileGpuKernelMod::QuantileFunc>> QuantileGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&QuantileGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&QuantileGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> QuantileGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, QuantileFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Quantile, QuantileGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,73 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_QUANTILE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_QUANTILE_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <utility>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class QuantileGpuKernelMod : public NativeGpuKernelMod {
public:
QuantileGpuKernelMod() {}
~QuantileGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using QuantileFunc =
std::function<bool(QuantileGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
void *cuda_stream_{nullptr};
QuantileFunc kernel_func_{};
static std::vector<std::pair<KernelAttr, QuantileFunc>> func_list_;
size_t input_unit_size_{1};
size_t q_unit_size_{1};
size_t input_elements_{};
size_t output_elements_{};
size_t each_q_elements_{};
int ceil_power2_{0};
int dim_;
int x_;
int y_;
int z_;
size_t total_ = 0;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_QUANTILE_REORDER_GPU_KERNEL_H_

View File

@ -928,6 +928,7 @@ GVAR_DEF(PrimitivePtr, kPrimSeLU, std::make_shared<Primitive>("SeLU"));
GVAR_DEF(PrimitivePtr, kPrimGLU, std::make_shared<Primitive>(kGLU));
GVAR_DEF(PrimitivePtr, kPrimGluGrad, std::make_shared<Primitive>(kGluGrad));
GVAR_DEF(PrimitivePtr, kPrimSoftplus, std::make_shared<Primitive>("Softplus"));
GVAR_DEF(PrimitivePtr, kPrimQuantile, std::make_shared<Primitive>("Quantile"));
GVAR_DEF(PrimitivePtr, kPrimSoftplusGrad, std::make_shared<Primitive>("SoftplusGrad"));
GVAR_DEF(PrimitivePtr, kPrimZeros, std::make_shared<Primitive>("Zeros"));
GVAR_DEF(PrimitivePtr, kPrimZerosLike, std::make_shared<Primitive>(kZerosLike));

View File

@ -177,6 +177,7 @@ constexpr auto kOutQuantized = "out_quantized";
constexpr auto kMvlgammaP = "mvlgamma_p";
constexpr auto kP = "p";
constexpr auto kMargin = "margin";
constexpr auto kKeepdim = "keepdim";
constexpr auto kPad = "pad";
constexpr auto kPadding = "padding";
constexpr auto kPaddingsElementSize = "paddings_element_size";

View File

@ -0,0 +1,160 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/quantile.h"
#include <set>
#include <map>
#include <string>
#include <vector>
#include <utility>
#include <iostream>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
constexpr int kQuantileDefaultDim = 10000;
abstract::ShapePtr QuantileInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto input = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(input);
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto q_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto q_dim = q_shape.size();
if (IsDynamicRank(input_shape) || IsDynamicRank(q_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
}
std::vector<int64_t> out_shape;
auto dim_ptr = primitive->GetAttr("dim");
MS_EXCEPTION_IF_NULL(dim_ptr);
auto dim = GetValue<int64_t>(dim_ptr);
int64_t input_dim = SizeToLong(input_shape.size());
int64_t wrapped_input_dim = input_dim;
if (wrapped_input_dim == 0) {
wrapped_input_dim = 1;
}
if (dim != kQuantileDefaultDim && (dim < -wrapped_input_dim || dim >= wrapped_input_dim)) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the attr dim must be range of [" << -wrapped_input_dim
<< "," << (wrapped_input_dim - 1) << "]";
}
if (q_dim > 1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the input q must be a scalar or 1D tensor,but got dimension = " << q_dim << ".";
}
if (dim < 0) {
dim = dim + wrapped_input_dim;
}
auto keep_dims_ptr = primitive->GetAttr("keep_dims");
MS_EXCEPTION_IF_NULL(keep_dims_ptr);
auto keep_dims = GetValue<bool>(keep_dims_ptr);
int q_size = 1;
for (uint64_t i = 0; i < q_shape.size(); i++) {
q_size *= q_shape[i];
}
if (dim != kQuantileDefaultDim && input_dim > 0) {
out_shape = input_shape;
if (keep_dims) {
out_shape[dim] = 1;
} else {
out_shape.erase(out_shape.begin() + dim);
}
} else if (keep_dims) {
out_shape = std::vector<int64_t>(input_dim, 1);
}
if (q_dim > 0) {
out_shape.insert(out_shape.begin(), q_size);
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr QuantileInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
auto q = input_args[1];
MS_EXCEPTION_IF_NULL(q);
auto q_type = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(q_type);
auto prim_name = primitive->name();
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
std::map<std::string, TypePtr> dict_type;
(void)dict_type.insert(std::make_pair("q", q_type));
(void)dict_type.insert(std::make_pair("input", input_type));
(void)CheckAndConvertUtils::CheckTensorTypeValid("input", input_type, valid_types, prim_name);
auto q_value = q->BuildValue();
MS_EXCEPTION_IF_NULL(q_value);
if (q->isa<abstract::AbstractTensor>()) {
CheckAndConvertUtils::CheckTensorTypeSame(dict_type, valid_types, prim_name);
} else if (q->isa<abstract::AbstractScalar>()) {
if (q_value != nullptr) {
if (!q_value->isa<FloatImm>()) {
MS_EXCEPTION(TypeError) << "For '" << prim_name
<< "', the type of 'q' must be float or tensor, but got: " << q_type->ToString() << ".";
}
auto value = GetValue<float>(q_value);
if (value < 0 || value > 1) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the 'q' must in the range [0, 1], but got: " << value
<< ".";
}
}
} else {
MS_EXCEPTION(ValueError) << "For '" << prim_name
<< "', the type of 'q' must be float or tensor, but got: " << q_type->ToString() << ".";
}
return input_type;
}
} // namespace
void Quantile::set_dim(int64_t dim) { (void)AddAttr(kDim, api::MakeValue(dim)); }
void Quantile::set_keepdim(bool keepdim) { (void)AddAttr(kKeepdim, api::MakeValue(keepdim)); }
int64_t Quantile::get_dim() const {
auto value_ptr = GetAttr(kDim);
return GetValue<int64_t>(value_ptr);
}
bool Quantile::get_keepdim() const {
auto value_ptr = GetAttr(kKeepdim);
return GetValue<bool>(value_ptr);
}
MIND_API_OPERATOR_IMPL(Quantile, BaseOperator);
AbstractBasePtr QuantileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = QuantileInferType(primitive, input_args);
auto infer_shape = QuantileInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Quantile, prim::kPrimQuantile, QuantileInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_QUANTILE_H_
#define MINDSPORE_CORE_OPS_QUANTILE_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameQuantile = "Quantile";
class MIND_API Quantile : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Quantile);
Quantile() : BaseOperator(kNameQuantile) { InitIOName({"input", "q"}, {"out"}); }
void Init() {}
void set_dim(int64_t dim);
void set_keepdim(bool keepdim);
int64_t get_dim() const;
bool get_keepdim() const;
};
abstract::AbstractBasePtr QuantileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimQuantilePtr = std::shared_ptr<Quantile>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_QUANTILE_H_

View File

@ -3700,6 +3700,42 @@ class _LogicBinaryOp(_BinaryOp):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name)
class Quantile(Primitive):
r"""
Computes the q-th quantiles of all elements in the input tensor, doing a linear interpolation when the
q-th quantile lies between two data points.
Refer to :func:`mindspore.ops.quantile` and :func:`mindspore.ops.nanquantile` for more detail.
Supported Platforms:
``GPU``
Examples:
>>> quantile = ops.Quantile()
>>> input = Tensor(np.array([0.0700, -0.5446, 0.9214]), mindspore.float32)
>>> q = Tensor(np.array([0, 0.5, 1]), mindspore.float32)
>>> output = quantile(input, q)
>>> print(output)
[-0.5446 0.07 0.9214]
"""
@prim_attr_register
def __init__(self, dim=None, keep_dims=False, ignore_nan=False):
"""Initialize Quantile"""
if dim is not None:
validator.check_value_type("dim", dim, [int], self.name)
else:
self.add_prim_attr("dim", 10000)
if keep_dims is not None:
validator.check_value_type("keep_dims", keep_dims, [bool], self.name)
else:
self.add_prim_attr("keep_dims", False)
if ignore_nan is not None:
validator.check_value_type("ignore_nan", ignore_nan, [bool], self.name)
else:
self.add_prim_attr("ignore_nan", False)
class Equal(Primitive):
r"""
Computes the equivalence between two tensors element-wise.

View File

@ -0,0 +1,79 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
import mindspore.ops.operations.math_ops as op
from mindspore.nn import Cell
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
class Quantile(Cell):
def __init__(self, dim=0, keep_dims=False):
super().__init__()
self.quantile = op.Quantile(dim=dim, keep_dims=keep_dims)
def construct(self, x, q):
return self.quantile(x, q)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_quantile_fp32():
"""
Feature: Quantile
Description: Test of input
Expectation: The results are as expected
"""
type_i = np.float32
ertol_loss = 1e-04
x = np.array([[1.0, 5.0, 9.0, 13], [2, 6, 10, 14],
[3, 7, 11, 15], [4, 8, 12, 16]]).astype(type_i)
q = np.array([0.25, 0.5, 0.75]).astype(type_i)
dim = 0
keep_dims = True
net = Quantile(dim=dim, keep_dims=keep_dims)
output = net(Tensor(x), Tensor(q))
output = output.asnumpy()
expect_output = np.array([[[1.75, 5.75, 9.75, 13.75]], [[2.5, 6.5, 10.5, 14.5]],
[[3.25, 7.25, 11.25, 15.25]]]).astype(type_i)
assert np.allclose(output, expect_output, ertol_loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_quantile_fp64():
"""
Feature: Quantile
Description: Test of input
Expectation: The results are as expected
"""
type_i = np.float64
ertol_loss = 1e-05
x = np.array([[1.0, 5.0, 9.0, 13], [2, 6, 10, 14],
[3, 7, 11, 15], [4, 8, 12, 16]]).astype(type_i)
q = np.array([0.25, 0.5, 0.75]).astype(type_i)
dim = 0
keep_dims = True
net = Quantile(dim=dim, keep_dims=keep_dims)
output = net(Tensor(x), Tensor(q))
output = output.asnumpy()
expect_output = np.array([[[1.75, 5.75, 9.75, 13.75]], [[2.5, 6.5, 10.5, 14.5]],
[[3.25, 7.25, 11.25, 15.25]]]).astype(type_i)
assert np.allclose(output, expect_output, ertol_loss)