!44313 [assistant][ops]New operator implementation, include Quantile,NanQuantile
Merge pull request !44313 from jakcmanftr/pr-quantile
This commit is contained in:
commit
e88d49cc53
|
@ -0,0 +1,328 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <iostream>
|
||||
#include <cmath>
|
||||
#include "plugin/device/cpu/kernel/quantile_cpu_kernel.h"
|
||||
#include "mindspore/core/ops/quantile.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kQuantileInputsNum = 2;
|
||||
constexpr size_t kQuantileOutputsNum = 1;
|
||||
constexpr int kQuantileDefaultDim = 10000;
|
||||
} // namespace
|
||||
|
||||
bool QuantileCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Quantile>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast Quantile ops failed!";
|
||||
return false;
|
||||
}
|
||||
input_shape_ = inputs.at(0)->GetShapeVector();
|
||||
q_shape_ = inputs.at(1)->GetShapeVector();
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
dim_ = GetValue<int64_t>(base_operator->GetAttr("dim"));
|
||||
keep_dims_ = GetValue<bool>(base_operator->GetAttr("keep_dims"));
|
||||
ignore_nan_ = GetValue<bool>(base_operator->GetAttr("ignore_nan"));
|
||||
return MatchKernelFunc(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
bool QuantileCpuKernelMod::CheckInputAndAttr() { return true; }
|
||||
|
||||
int QuantileCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||
int ret = 0;
|
||||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
|
||||
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
|
||||
return ret;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t QuantileCpuKernelMod::MaybeWrapDim(int dim, int dim_post_expr) {
|
||||
if (dim == kQuantileDefaultDim) {
|
||||
return dim;
|
||||
}
|
||||
if (dim_post_expr <= 0) {
|
||||
dim_post_expr = 1;
|
||||
}
|
||||
int min = -dim_post_expr;
|
||||
int max = dim_post_expr - 1;
|
||||
if (dim < min || dim > max) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Quantile', dimension out of range (expected to be in range of " << min
|
||||
<< " and [ " << max << "]).";
|
||||
}
|
||||
if (dim < 0) {
|
||||
dim += dim_post_expr;
|
||||
}
|
||||
return dim;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> transpose(const std::vector<T> &f, const std::vector<int64_t> &shape, int index) {
|
||||
size_t element_count = f.size();
|
||||
size_t m = shape.size();
|
||||
std::vector<int> pos(m);
|
||||
std::vector<int> indexA(m);
|
||||
std::vector<int> indexB(m);
|
||||
std::vector<T> target(element_count);
|
||||
std::vector<int64_t> shapeTarget(shape);
|
||||
|
||||
for (size_t j = 0; j < m; j++) {
|
||||
pos[j] = j;
|
||||
}
|
||||
if (m != 0) {
|
||||
std::swap(pos[m - 1], pos[((index + m) % m)]);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < m; j++) {
|
||||
shapeTarget[j] = shape[pos[j]];
|
||||
}
|
||||
|
||||
for (size_t src = 0; src < element_count; src++) {
|
||||
int temp = src;
|
||||
for (int i = m - 1; i >= 0; i--) {
|
||||
indexA[i] = temp % shape[i];
|
||||
temp = temp / shape[i];
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < m; i++) {
|
||||
indexB[i] = indexA[pos[i]];
|
||||
}
|
||||
|
||||
int dst = 0;
|
||||
temp = 1;
|
||||
for (int i = m - 1; i >= 0; i--) {
|
||||
dst = dst + indexB[i] * temp;
|
||||
temp = temp * shapeTarget[i];
|
||||
}
|
||||
target[dst] = f[src];
|
||||
}
|
||||
|
||||
return target;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void QuantileComputeDefaultFunc(uint64_t n, uint64_t q_size, const std::vector<T> &sorted, T *output_addr, T *q_addrs,
|
||||
bool has_nan_, bool ignore_nan_) {
|
||||
std::vector<T> tmp(sorted);
|
||||
|
||||
std::sort(tmp.begin(), tmp.end());
|
||||
bool all_nan = true;
|
||||
tmp.clear();
|
||||
for (auto &x : sorted) {
|
||||
if (!std::isnan(x)) {
|
||||
tmp.push_back(x);
|
||||
all_nan = false;
|
||||
}
|
||||
}
|
||||
std::sort(tmp.begin(), tmp.end());
|
||||
for (uint64_t i = 0; i < q_size; ++i) {
|
||||
if ((has_nan_ && !ignore_nan_) || all_nan) {
|
||||
output_addr[i] = NAN;
|
||||
continue;
|
||||
}
|
||||
|
||||
T index = (tmp.size() - 1) * q_addrs[i];
|
||||
int32_t idx = index;
|
||||
|
||||
if (idx == (int32_t)tmp.size() - 1) {
|
||||
output_addr[i] = tmp[idx];
|
||||
continue;
|
||||
}
|
||||
|
||||
output_addr[i] = tmp[idx] + (tmp[idx + 1] - tmp[idx]) * (index - idx);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> SetQuantileOutputShape(int64_t dim, int64_t input_dim, bool keep_dims, const int64_t q_size,
|
||||
const std::vector<int64_t> &input_shapesize, int64_t q_dim) {
|
||||
std::vector<int64_t> out_shape;
|
||||
if (dim != kQuantileDefaultDim && input_dim > 0) {
|
||||
out_shape = input_shapesize;
|
||||
if (keep_dims) {
|
||||
out_shape[dim] = 1;
|
||||
} else {
|
||||
out_shape.erase(out_shape.begin() + dim);
|
||||
}
|
||||
} else if (keep_dims) {
|
||||
out_shape = std::vector<int64_t>(input_dim, 1);
|
||||
}
|
||||
if (q_dim > 0) {
|
||||
out_shape.insert(out_shape.begin(), q_size);
|
||||
}
|
||||
return out_shape;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool QuantileCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kQuantileInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kQuantileOutputsNum, kernel_name_);
|
||||
|
||||
auto input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto q = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
size_t input_dim = input_shape_.size();
|
||||
size_t q_dim = q_shape_.size();
|
||||
size_t q_size = inputs[1]->size / sizeof(T);
|
||||
total_ = inputs[0]->size / sizeof(T);
|
||||
auto output_size = outputs[0]->size / sizeof(T);
|
||||
|
||||
auto input_shape = input_shape_;
|
||||
dim_ = MaybeWrapDim(dim_, input_shape.size());
|
||||
if (total_ <= 0) {
|
||||
MS_EXCEPTION(ValueError) << "For Quantile, input tensor must be non-empty";
|
||||
}
|
||||
|
||||
if (q_dim > 1) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Quantile', the input q must be a scalar or 1D tensor,but got dimension = "
|
||||
<< q_dim << ".";
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < q_size; ++j) {
|
||||
if (q[j] < 0 || q[j] > 1) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Quantile', q values must be in the range [0, 1].";
|
||||
}
|
||||
}
|
||||
|
||||
auto input_shapesize = input_shape_;
|
||||
std::vector<T> sorted;
|
||||
std::vector<int64_t> out_shape =
|
||||
SetQuantileOutputShape(dim_, input_shape_.size(), keep_dims_, q_size, input_shapesize, q_shape_.size());
|
||||
for (uint64_t i = 0; i < total_; i++) {
|
||||
sorted.push_back(input[i]);
|
||||
if (std::isnan(input[i])) {
|
||||
has_nan_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (dim_ == kQuantileDefaultDim) {
|
||||
QuantileComputeDefaultFunc<T>(total_, q_size, sorted, output, q, has_nan_, ignore_nan_);
|
||||
} else if (dim_ == static_cast<int>(input_dim) - 1) {
|
||||
int64_t last_shape_size = input_shape[input_shape.size() - 1];
|
||||
ParallelRun(last_shape_size, q_size, sorted, output, q);
|
||||
} else {
|
||||
input_shapesize.push_back(1);
|
||||
sorted = transpose<T>(sorted, input_shapesize, dim_);
|
||||
|
||||
int32_t m = input_shapesize.size();
|
||||
if (m != 0) {
|
||||
std::swap(input_shapesize[m - 1], input_shapesize[((dim_ + m) % m)]);
|
||||
}
|
||||
int64_t last_shape_size = input_shapesize[input_shapesize.size() - 1];
|
||||
ParallelRun(last_shape_size, q_size, sorted, output, q);
|
||||
}
|
||||
|
||||
std::vector<T> out;
|
||||
if (q_dim > 0) {
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
out.push_back(*(output + i));
|
||||
}
|
||||
int64_t out_end_shape = out_shape[out_shape.size() - 1];
|
||||
out_shape.push_back(out_end_shape);
|
||||
std::swap(out_shape[0], out_shape[out_shape.size() - 1]);
|
||||
out_shape.erase(out_shape.begin());
|
||||
out_shape.insert(out_shape.begin(), 1);
|
||||
out = transpose<T>(out, out_shape, 0);
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
output[i] = out[i];
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void QuantileCpuKernelMod::ParallelRun(int64_t last_shape_size, uint64_t q_size, const std::vector<T> &sorted,
|
||||
T *output_addr, T *q_addrs) {
|
||||
const int64_t thread_num = SizeToLong(FloatToSize(std::ceil(total_ / last_shape_size)));
|
||||
std::vector<common::Task> tasks;
|
||||
for (uint64_t task_id = 0; task_id < (uint64_t)thread_num && task_id * last_shape_size < total_; task_id++) {
|
||||
uint64_t start = task_id * last_shape_size;
|
||||
uint64_t end = (task_id + 1) * last_shape_size;
|
||||
auto task = [this, &last_shape_size, &q_size, &sorted, &output_addr, &q_addrs, start, end]() {
|
||||
DoQuantile(last_shape_size, q_size, sorted, output_addr, q_addrs, start, end);
|
||||
return common::SUCCESS;
|
||||
};
|
||||
|
||||
(void)tasks.emplace_back(task);
|
||||
}
|
||||
ParallelLaunch(tasks);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void QuantileCpuKernelMod::DoQuantile(int64_t last_shape_size, uint64_t q_size, const std::vector<T> &sorted,
|
||||
T *output_addr, T *q_addrs, uint64_t start, uint64_t end) {
|
||||
std::vector<T> tmp;
|
||||
bool has_nan = false;
|
||||
bool all_nan = true;
|
||||
for (auto i = start; i < end; i++) {
|
||||
tmp.push_back(sorted[i]);
|
||||
if (std::isnan(sorted[i])) {
|
||||
has_nan = true;
|
||||
} else {
|
||||
all_nan = false;
|
||||
}
|
||||
}
|
||||
std::sort(tmp.begin(), tmp.end());
|
||||
|
||||
bool flag = (has_nan && !ignore_nan_) || all_nan;
|
||||
if (flag) {
|
||||
for (uint64_t j = 0; j < q_size; ++j) {
|
||||
output_addr[start / last_shape_size * q_size + j] = NAN;
|
||||
}
|
||||
} else {
|
||||
tmp.clear();
|
||||
for (auto i = start; i < end; i++) {
|
||||
auto x = sorted[i];
|
||||
if (!std::isnan(x)) {
|
||||
tmp.push_back(x);
|
||||
}
|
||||
}
|
||||
std::sort(tmp.begin(), tmp.end());
|
||||
|
||||
for (uint64_t j = 0; j < q_size; ++j) {
|
||||
T index = (tmp.size() - 1) * q_addrs[j];
|
||||
|
||||
int32_t idx = index;
|
||||
if (idx == (int32_t)tmp.size() - 1) {
|
||||
output_addr[start / last_shape_size * q_size + j] = tmp[idx];
|
||||
continue;
|
||||
}
|
||||
output_addr[start / last_shape_size * q_size + j] = tmp[idx] + (index - idx) * (tmp[idx + 1] - tmp[idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, QuantileCpuKernelMod::KernelRunFunc>> &QuantileCpuKernelMod::GetFuncList()
|
||||
const {
|
||||
static const std::vector<std::pair<KernelAttr, QuantileCpuKernelMod::KernelRunFunc>> func_list = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&QuantileCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&QuantileCpuKernelMod::LaunchKernel<double>}};
|
||||
return func_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Quantile, QuantileCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_QUANTILE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_QUANTILE_CPU_KERNEL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class QuantileCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<QuantileCpuKernelMod> {
|
||||
public:
|
||||
QuantileCpuKernelMod() = default;
|
||||
~QuantileCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
uint32_t MaybeWrapDim(int dim, int dim_post_expr);
|
||||
template <typename T>
|
||||
void ParallelRun(int64_t last_shape_size, uint64_t q_size, const std::vector<T> &sorted, T *output_addr, T *q_addrs);
|
||||
template <typename T>
|
||||
void DoQuantile(int64_t last_shape_size, uint64_t q_size, const std::vector<T> &sorted, T *output_addr, T *q_addrs,
|
||||
uint64_t start, uint64_t end);
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
bool CheckInputAndAttr();
|
||||
int dim_ = 0;
|
||||
bool keep_dims_ = false;
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> q_shape_;
|
||||
size_t total_ = 0;
|
||||
bool ignore_nan_ = false;
|
||||
bool has_nan_ = false;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_QUANTILE_CPU_KERNEL_H_
|
|
@ -124,8 +124,8 @@ TypePtr QuantileInferType(const PrimitivePtr &primitive, const std::vector<Abstr
|
|||
}
|
||||
}
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name
|
||||
<< "', the type of 'q' must be float or tensor, but got: " << q_type->ToString() << ".";
|
||||
MS_EXCEPTION(TypeError) << "For '" << prim_name
|
||||
<< "', the type of 'q' must be float or tensor, but got: " << q_type->ToString() << ".";
|
||||
}
|
||||
return input_type;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Quantile op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||
|
||||
quantile_op_info = AiCPURegOp("Quantile") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.input(0, "input", "required") \
|
||||
.input(1, "q", "required") \
|
||||
.attr("dim", "int") \
|
||||
.attr("keep_dims", "bool") \
|
||||
.attr("ignore_nan", "bool") \
|
||||
.output(0, "out", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(quantile_op_info)
|
||||
def _quantile_aicpu():
|
||||
"""Sinc AiCPU register"""
|
||||
return
|
|
@ -56,6 +56,7 @@ from mindspore.ops.operations.math_ops import (
|
|||
Lcm,
|
||||
Gcd,
|
||||
Sinc,
|
||||
Quantile,
|
||||
NanToNum,
|
||||
SparseSegmentMean,
|
||||
TrilIndices,
|
||||
|
@ -6227,6 +6228,117 @@ def bmm(input_x, mat2):
|
|||
return bmm_op(input_x, mat2)
|
||||
|
||||
|
||||
def quantile(x, q, axis=None, keepdims=False):
|
||||
r"""
|
||||
Computes the q-th quantiles of all elements in the input tensor, doing a linear interpolation when the
|
||||
q-th quantile lies between two data points.
|
||||
|
||||
Args:
|
||||
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
Supported dtypes: float32, float64.
|
||||
q (float or Tensor): A scalar or 1D tensor of quantile values in the range [0, 1].
|
||||
Supported dtypes: float32, float64.
|
||||
axis (int): The dimension to reduce. Default: None. By default,
|
||||
axis is None resulting in the input tensor being flattened before computation.
|
||||
keepdims (bool): Whether the output tensor has dim retained or not. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dtype as the `input`.
|
||||
|
||||
Assume the input shape is :math:`(m, x_0, x_1, ..., x_i, ..., X_R)`, axis = :math:`i` and m is
|
||||
the element count of input `q`.
|
||||
|
||||
- If `q` is scalar and `keepdims` is True, the shape of output is :math:`(x_0, x_1, ..., 1, ..., X_R)`.
|
||||
- If `q` is scalar and `keepdims` is False, the shape of output is :math:`(x_0, x_1, ..., X_R)`.
|
||||
- If `q` is 1D Tensor and `keepdims` is True, the shape of output is :math:`(m, x_0, x_1, ..., 1, ..., X_R)`.
|
||||
- If `q` is 1D Tensor and `keepdims` is False, the shape of output is :math:`(m, x_0, x_1, ..., X_R)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If `q` is not a Tensor or float.
|
||||
TypeError: If dtype of `input` is not float32 or float64.
|
||||
TypeError: If dtype of `q` is not float32 or float64.
|
||||
TypeError: If dtype of `input` and the dtype of `q` is different.
|
||||
ValueError: If the `q` values not in the range [0, 1].
|
||||
ValueError: If the `axis` values out of range.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([0.0700, -0.5446, 0.9214]), mindspore.float32)
|
||||
>>> q = Tensor(np.array([0, 0.5, 1]), mindspore.float32)
|
||||
>>> output = ops.quantile(x, q)
|
||||
>>> print(output.asnumpy())
|
||||
[-0.5446 0.07 0.9214]
|
||||
"""
|
||||
|
||||
if axis is not None:
|
||||
_check_attr_dtype("axis", axis, [int], "quantile")
|
||||
if keepdims is not None:
|
||||
_check_attr_dtype("keepdims", keepdims, [bool], "quantile")
|
||||
|
||||
quantile_ = _get_cache_prim(Quantile)(dim=axis, keep_dims=keepdims)
|
||||
return quantile_(x, q)
|
||||
|
||||
|
||||
def nanquantile(x, q, axis=None, keepdims=False):
|
||||
r"""
|
||||
This is a variant of mindspore.ops.quantile() that 'ignores' NaN values,
|
||||
computing the quantiles q as if NaN values in input did not exist.
|
||||
If all values in a reduced row are NaN then the quantiles for that reduction will be NaN.
|
||||
|
||||
Refer to :func:`mindspore.ops.quantile` for more detail.
|
||||
|
||||
Args:
|
||||
x (Tensor): The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
Supported dtypes: float32, float64.
|
||||
q (float or Tensor): A scalar or 1D tensor of quantile values in the range [0, 1].
|
||||
Supported dtypes: float32, float64.
|
||||
axis (int): The dimension to reduce. Default: None. By default,
|
||||
axis is None resulting in the input tensor being flattened before computation.
|
||||
keepdims (bool): Whether the output tensor has dim retained or not. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor, has the same dtype as the `input`.
|
||||
|
||||
Assume the input shape is :math:`(m, x_0, x_1, ..., x_i, ..., X_R)`, axis = :math:`i` and m is
|
||||
the element count of input `q`.
|
||||
|
||||
- If `q` is scalar and `keepdims` is True, the shape of output is :math:`(x_0, x_1, ..., 1, ..., X_R)`.
|
||||
- If `q` is scalar and `keepdims` is False, the shape of output is :math:`(x_0, x_1, ..., X_R)`.
|
||||
- If `q` is 1D Tensor and `keepdims` is True, the shape of output is :math:`(m, x_0, x_1, ..., 1, ..., X_R)`.
|
||||
- If `q` is 1D Tensor and `keepdims` is False, the shape of output is :math:`(m, x_0, x_1, ..., X_R)`.
|
||||
|
||||
Raises:
|
||||
TypeError: If `x` is not a Tensor.
|
||||
TypeError: If `q` is not a Tensor or float.
|
||||
TypeError: If dtype of `input` is not float32 or float64.
|
||||
TypeError: If dtype of `q` is not float32 or float64.
|
||||
TypeError: If dtype of `input` and the dtype of `q` is different.
|
||||
ValueError: If the `q` values not in the range [0, 1]
|
||||
ValueError: If the `axis` values out of range.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([0.0700, -0.5446, 0.9214]), mindspore.float32)
|
||||
>>> q = Tensor(np.array([0, 0.5, 1]), mindspore.float32)
|
||||
>>> output = ops.nanquantile(x, q)
|
||||
>>> print(output.asnumpy())
|
||||
[-0.5446 0.07 0.9214]
|
||||
"""
|
||||
|
||||
if axis is not None:
|
||||
_check_attr_dtype("axis", axis, [int], "nanquantile")
|
||||
if keepdims is not None:
|
||||
_check_attr_dtype("keepdims", keepdims, [bool], "nanquantile")
|
||||
|
||||
quantile_ = _get_cache_prim(Quantile)(dim=axis, keep_dims=keepdims, ignore_nan=True)
|
||||
return quantile_(x, q)
|
||||
|
||||
|
||||
def baddbmm(x, batch1, batch2, beta=1, alpha=1):
|
||||
r"""
|
||||
Performs a batch matrix-matrix product of matrices in batch1 and batch2. input is added to the final result.
|
||||
|
|
|
@ -3743,7 +3743,7 @@ class Quantile(Primitive):
|
|||
Refer to :func:`mindspore.ops.quantile` and :func:`mindspore.ops.nanquantile` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> quantile = ops.Quantile()
|
||||
|
|
|
@ -28,6 +28,8 @@ from mindspore.ops import functional as F
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.function.math_func import matrix_exp
|
||||
from mindspore.ops.function.math_func import sinc
|
||||
from mindspore.ops.function.math_func import quantile
|
||||
from mindspore.ops.function.math_func import nanquantile
|
||||
from mindspore.ops.function.math_func import nan_to_num
|
||||
from mindspore.ops.operations.image_ops import CropAndResizeGradBoxes, AdjustHue, AdjustContrastv2, \
|
||||
AdjustSaturation, CombinedNonMaxSuppression, CropAndResizeGradImage
|
||||
|
@ -100,6 +102,7 @@ from mindspore.ops.operations.math_ops import Betainc
|
|||
from mindspore.ops.operations.math_ops import Diagonal
|
||||
from mindspore.ops.operations.math_ops import Hypot
|
||||
from mindspore.ops.operations.math_ops import Heaviside
|
||||
from mindspore.ops.operations.math_ops import Quantile
|
||||
from mindspore.ops.operations.math_ops import Lcm
|
||||
from mindspore.ops.operations.math_ops import DivNoNan
|
||||
from mindspore.ops.operations.math_ops import Fmin
|
||||
|
@ -1368,6 +1371,26 @@ class MatrixSetDiagV3Net(nn.Cell):
|
|||
return self.matrix_set_diag_v3(x, diagonal, self.k)
|
||||
|
||||
|
||||
class QuantileFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(QuantileFunc, self).__init__()
|
||||
self.quantile = quantile
|
||||
|
||||
def construct(self, x, q):
|
||||
y = self.quantile(x, q)
|
||||
return y
|
||||
|
||||
|
||||
class NanQuantileFunc(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NanQuantileFunc, self).__init__()
|
||||
self.nanquantile = nanquantile
|
||||
|
||||
def construct(self, x, q):
|
||||
y = self.nanquantile(x, q)
|
||||
return y
|
||||
|
||||
|
||||
class SparseApplyCenteredRMSPropNet(nn.Cell):
|
||||
def __init__(self, use_locking=False):
|
||||
super(SparseApplyCenteredRMSPropNet, self).__init__()
|
||||
|
@ -1736,6 +1759,18 @@ test_case_math_ops = [
|
|||
'block': SincFunc(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
'desc_bprop': [[2, 3]]}),
|
||||
('Quantile_1', {
|
||||
'block': Quantile(),
|
||||
'desc_inputs': [[2, 3], [3]],
|
||||
'skip': ['backward']}),
|
||||
('Quantile_2', {
|
||||
'block': QuantileFunc(),
|
||||
'desc_inputs': [[2, 3], [3]],
|
||||
'skip': ['backward']}),
|
||||
('NanQuantile', {
|
||||
'block': NanQuantileFunc(),
|
||||
'desc_inputs': [[2, 3], [3]],
|
||||
'skip': ['backward']}),
|
||||
('Asin', {
|
||||
'block': P.Asin(),
|
||||
'desc_inputs': [[2, 3]],
|
||||
|
|
Loading…
Reference in New Issue