add median op for gpu

This commit is contained in:
Corleone 2022-05-27 20:50:52 +08:00
parent 33f12ac66b
commit 17e37b8546
18 changed files with 845 additions and 7 deletions

View File

@ -0,0 +1,23 @@
mindspore.Tensor.median
=======================
.. py:method:: mindspore.Tensor.median(global_median=False, axis=0, keep_dims=False)
返回指定维度上的中值。
参数:
- **global_median** (bool) - 表示是否对当前Tensor的全部元素取中值。默认值False。
- **axis** (int) - 计算中值的维度。默认值:(0), 取值范围为[-ndim, ndim - 1]'ndim' 表示当前Tensor的维度长度。
- **keep_dims** (bool) - 表示是否减少维度如果为True输出将与输入保持相同的维度如果为False输出将减少维度。默认值False。
返回:
- **y** (Tensor) - 返回指定维度上的中值数据类型与当前Tensor相同。
- **indices** (bool) - 指定中值索引。数据类型为int64。如果 `global_median` 为True则结果无意义。
异常:
- **TypeError** - 当前Tensor的类型不是: int16, int32, int64, float32, float64。
- **TypeError** - `global_median` 不是bool。
- **TypeError** - `axis` 不是int。
- **TypeError** - `keep_dims` 不是bool。
- **ValueError** - `axis` 的范围不在[-ndim, ndim - 1]。

View File

@ -83,6 +83,7 @@ mindspore.Tensor
mindspore.Tensor.matrix_determinant
mindspore.Tensor.max
mindspore.Tensor.mean
mindspore.Tensor.median
mindspore.Tensor.min
mindspore.Tensor.narrow
mindspore.Tensor.nbytes

View File

@ -0,0 +1,29 @@
mindspore.ops.median
====================
.. py:function:: mindspore.ops.median(x, global_median=False, axis=0, keep_dims=Fasle)
输出张量指定维度上的中值。
参数:
- **x** (Tensor) - median的输入任意维度的Tensor。数据类型支持int16、int32、int64、float32或float64。
- **global_median** (bool) - 表示是否对x的全部元素取中值。默认值False。
- **axis** (int) - 指定计算维度。默认值:(0), 取值范围为[-dims, dims - 1]`dims` 表示 `x` 的维度长度
- **keep_dims** (bool) - 表示是否减少维度如果为True输出将与输入保持相同的维度如果为False输出将减少维度。
默认值False。
返回:
- **y** (Tensor) - 返回指定维度上的中值,数据类型与 `x` 相同。
- **indices** (bool) - 指定中值索引。数据类型为int64。如果 `global_median` 为True则结果无意义。
异常:
- **TypeError** - `x` 的数据类型不是: int16, int32, int64, float32, float64。
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `global_median` 不是bool。
- **TypeError** - `axis` 不是int。
- **TypeError** - `keep_dims` 不是bool。
- **ValueError** - `axis` 的范围不在[-dims, dims - 1]`dims` 表示 `x` 的维度长度。

View File

@ -206,6 +206,7 @@ BuiltInTypeMap &GetMethodMap() {
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
{"squeeze", std::string("squeeze")}, // P.squeeze()
{"astype", std::string("astype")}, // P.cast()
{"median", std::string("median")}, // P.median()
{"cumsum", std::string("cumsum")}, // P.cumsum()
{"cummin", std::string("cummin")}, // cummin()
{"cummax", std::string("cummax")}, // cummax()

View File

@ -0,0 +1,398 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "median_impl.cuh"
#include <iostream>
#include <vector>
#include <algorithm>
constexpr int WARP_SIZE = 32;
constexpr int MAX_THREAD = 1024;
constexpr int RADIX_BITS = 2;
constexpr int RADIX_SIZE = 4;
constexpr int RADIX_MASK = (RADIX_SIZE - 1);
__device__ __forceinline__ unsigned int warp_ballot(int predicate) {
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 700
return __ballot(predicate);
#else
return __ballot_sync(__activemask(), predicate);
#endif
}
template <typename T>
static __device__ __host__ T round_up(T a, T b) {
return (a / b) * b;
}
template <typename T>
struct Bitfield {};
template <>
struct Bitfield<unsigned int> {
static __device__ unsigned int GetBitfield(unsigned int val, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
return (val >> pos) & m;
#else
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
static __device__ unsigned int SetBitfield(unsigned int val, unsigned int to_insert, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
unsigned int m = (1u << len) - 1u;
to_insert &= m;
to_insert <<= pos;
m <<= pos;
return (val & ~m) | to_insert;
#else
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;" : "=r"(ret) : "r"(to_insert), "r"(val), "r"(pos), "r"(len));
return ret;
#endif
}
};
template <>
struct Bitfield<uint64_t> {
static __device__ uint64_t GetBitfield(uint64_t val, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
uint64_t m = (1u << len) - 1u;
return (val >> pos) & m;
#else
uint64_t ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
#endif
}
static __device__ uint64_t SetBitfield(uint64_t val, uint64_t to_insert, int pos, int len) {
#if !defined(__CUDA_ARCH__)
pos &= 0xff;
len &= 0xff;
uint64_t m = (1u << len) - 1u;
to_insert &= m;
to_insert <<= pos;
m <<= pos;
return (val & ~m) | to_insert;
#else
uint64_t ret;
asm("bfi.b64 %0, %1, %2, %3, %4;" : "=l"(ret) : "l"(to_insert), "l"(val), "r"(pos), "r"(len));
return ret;
#endif
}
};
template <typename T>
struct MedianTypeConfig {};
template <>
struct MedianTypeConfig<float> {
typedef uint32_t RadixType;
// Converts a float to an integer representation with the same sorting
static inline __device__ RadixType Convert(float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (v == v) ? (x ^ mask) : 0xffffffff;
}
static inline __device__ float Deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
template <>
struct MedianTypeConfig<uint8_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(uint8_t v) { return v; }
static inline __device__ uint8_t Deconvert(RadixType v) { return v; }
};
template <>
struct MedianTypeConfig<int8_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(int8_t v) { return 128u + v; }
static inline __device__ int8_t Deconvert(RadixType v) { return v - 128; }
};
template <>
struct MedianTypeConfig<int16_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(int16_t v) { return 32768u + v; }
static inline __device__ int16_t Deconvert(RadixType v) { return v - 32768; }
};
template <>
struct MedianTypeConfig<int32_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(int32_t v) { return 2147483648u + v; }
static inline __device__ int32_t Deconvert(RadixType v) { return v - 2147483648u; }
};
template <>
struct MedianTypeConfig<int64_t> {
typedef uint64_t RadixType;
static inline __device__ RadixType Convert(int64_t v) { return 9223372036854775808ull + v; }
static inline __device__ int64_t Deconvert(RadixType v) { return v - 9223372036854775808ull; }
};
template <>
struct MedianTypeConfig<double> {
typedef uint64_t RadixType;
static inline __device__ RadixType Convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
}
static inline __device__ double Deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
// This function counts the distribution of all input values in a slice
template <typename T, typename R_T, typename S, int RadixSize, int RadixBits>
__device__ void CountRadixUsingMask(int counts[RadixSize], int *smem, R_T desired, R_T desired_mask,
int radix_digit_pos, S size, S stride, const T *data) {
#pragma unroll
for (int i = 0; i < RadixSize; ++i) {
counts[i] = 0;
}
if (threadIdx.x < RadixSize) {
smem[threadIdx.x] = 0;
}
__syncthreads();
for (S i = threadIdx.x; i < size; i += blockDim.x) {
R_T val = MedianTypeConfig<T>::Convert(data[i * stride]);
bool hasVal = ((val & desired_mask) == desired);
R_T digit_in_radix = Bitfield<R_T>::GetBitfield(val, radix_digit_pos, RadixBits);
#pragma unroll
for (uint32_t j = 0; j < RadixSize; ++j) {
bool vote = hasVal && (digit_in_radix == j);
counts[j] += __popc(warp_ballot(vote));
}
}
if (threadIdx.x % 32 == 0) {
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
atomicAdd(&smem[i], counts[i]);
}
}
__syncthreads();
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
counts[i] = smem[i];
}
__syncthreads();
}
// This finds the unique value that matches the pattern
template <typename T, typename R_T, typename S>
__device__ T FindPattern(T *smem, const T *data, S size, S stride, R_T desired, R_T desired_mask) {
if (threadIdx.x < 2) {
smem[threadIdx.x] = static_cast<T>(0);
}
__syncthreads();
S size_round = round_up(size, static_cast<S>(blockDim.x));
for (S i = threadIdx.x; i < size_round; i += blockDim.x) {
bool in_range = (i < size);
T v = in_range ? data[i * stride] : static_cast<T>(0);
if (in_range && ((MedianTypeConfig<T>::Convert(v) & desired_mask) == desired)) {
smem[0] = static_cast<T>(1);
smem[1] = v;
}
__syncthreads();
T found = smem[0];
T val = smem[1];
__syncthreads();
if (found != static_cast<T>(0)) {
return val;
}
}
return static_cast<T>(0);
}
// Returns the top-Kth element found in the data using radix selection
template <typename T, typename R_T, typename S>
__device__ void RadixSelect(const T *data, S kth, bool largest, S size, S stride, int *smem, T *top_k) {
int counts[RADIX_SIZE];
R_T desired = 0;
R_T desired_mask = 0;
int k = kth;
for (int digit_pos = sizeof(T) * 8 - RADIX_BITS; digit_pos >= 0; digit_pos -= RADIX_BITS) {
CountRadixUsingMask<T, R_T, S, RADIX_SIZE, RADIX_BITS>(counts, smem, desired, desired_mask, digit_pos, size, stride,
data);
auto found_unique = [&](int i, int count) -> bool {
if (count == 1 && k == 1) {
desired = Bitfield<R_T>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
desired_mask = Bitfield<R_T>::SetBitfield(desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);
*top_k = FindPattern<T, R_T, S>(reinterpret_cast<T *>(smem), data, size, stride, desired, desired_mask);
return true;
}
return false;
};
auto found_non_unique = [&](int i, int count) -> bool {
if (count >= k) {
desired = Bitfield<R_T>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
desired_mask = Bitfield<R_T>::SetBitfield(desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);
return true;
}
k -= count;
return false;
};
if (largest) {
#pragma unroll
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
} else {
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
}
}
*top_k = MedianTypeConfig<T>::Deconvert(desired);
}
template <typename T, typename S>
__global__ void MedianKernel(const T *input, T *output, S *indices, S size, S num, S stride, bool global_median) {
__shared__ int smem[WARP_SIZE];
S slice = blockIdx.y * gridDim.x + blockIdx.x;
if (slice >= num) {
return;
}
S offset_y = size * gridDim.x;
S k = (size - 1) / 2;
// Find the median value
T median = static_cast<T>(0);
RadixSelect<T, typename MedianTypeConfig<T>::RadixType, S>(input + blockIdx.y * offset_y + blockIdx.x, k + 1, false,
size, stride, smem, &median);
output[slice] = median;
// Find the index of the median value in the slice
if (!global_median) {
for (S i = threadIdx.x; i < size; i += blockDim.x) {
T val = input[blockIdx.y * offset_y + blockIdx.x + i * stride];
if (val == median) {
indices[slice] = i;
break;
}
}
}
}
template <typename T, typename S>
void Median(const T *input_value, T *output, S *indices, const std::vector<int64_t> input_shape, const int64_t axis,
bool global_median, cudaStream_t cuda_stream) {
dim3 threads, grids;
size_t i = 0;
for (; i < static_cast<size_t>(axis); i++) {
grids.y *= input_shape[i];
}
size_t size = input_shape[axis];
for (i = axis + 1; i < input_shape.size(); i++) {
grids.x *= input_shape[i];
}
threads.x = std::min(round_up(static_cast<int>(size), WARP_SIZE), MAX_THREAD);
S num = grids.y * grids.x;
S stride = grids.x;
MedianKernel<T, S>
<<<grids, threads, 0, cuda_stream>>>(input_value, output, indices, size, num, stride, global_median);
return;
}
template CUDA_LIB_EXPORT void Median<int16_t, int64_t>(const int16_t *input_value, int16_t *output, int64_t *indices,
const std::vector<int64_t> input_shape, const int64_t axis,
bool global_median, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Median<int32_t, int64_t>(const int32_t *input_value, int32_t *output, int64_t *indices,
const std::vector<int64_t> input_shape, const int64_t axis,
bool global_median, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Median<int64_t, int64_t>(const int64_t *input_value, int64_t *output, int64_t *indices,
const std::vector<int64_t> input_shape, const int64_t axis,
bool global_median, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Median<float, int64_t>(const float *input_value, float *output, int64_t *indices,
const std::vector<int64_t> input_shape, const int64_t axis,
bool global_median, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void Median<double, int64_t>(const double *input_value, double *output, int64_t *indices,
const std::vector<int64_t> input_shape, const int64_t axis,
bool global_median, cudaStream_t cuda_stream);

View File

@ -0,0 +1,27 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_IMPL_CUH_
#include <vector>
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S>
void Median(const T *input_value, T *output, S *indices, const std::vector<int64_t> input_shape, const int64_t axis,
bool global_median, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_IMPL_CUH_

View File

@ -0,0 +1,39 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/median_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
Median, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
MedianGpuKernelMod, int16_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
Median, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
MedianGpuKernelMod, int32_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
Median, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MedianGpuKernelMod, int64_t, int64_t)
MS_REG_GPU_KERNEL_TWO(
Median,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
MedianGpuKernelMod, float, int64_t)
MS_REG_GPU_KERNEL_TWO(
Median,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
MedianGpuKernelMod, double, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,123 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MEDIAN_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MEDIAN_GPU_KERNEL_H_
#include <vector>
#include <map>
#include "mindspore/core/ops/median.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t kMedianInputsNum = 1;
constexpr size_t kMedianOutputsNum = 2;
template <typename T, typename S>
class MedianGpuKernelMod : public NativeGpuKernelMod {
public:
MedianGpuKernelMod() : global_median_(false), keep_dims_(false), axis_(0) {}
~MedianGpuKernelMod() = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output0_addr = GetDeviceAddress<T>(outputs, 0);
S *output1_addr = nullptr;
if (!global_median_) {
output1_addr = GetDeviceAddress<S>(outputs, 1);
}
Median(input_addr, output0_addr, output1_addr, input_shape_, axis_, global_median_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
kernel_name_ = base_operator->name();
auto kernel_ptr = std::dynamic_pointer_cast<ops::Median>(base_operator);
if (kernel_ptr == nullptr) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Median ops failed!";
return false;
}
if (inputs.size() != kMedianInputsNum || outputs.size() > kMedianOutputsNum) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kMedianInputsNum << " and "
<< kMedianOutputsNum << ", but got " << inputs.size() << " and " << outputs.size();
return false;
}
global_median_ = kernel_ptr->get_global_median();
keep_dims_ = kernel_ptr->get_keep_dims();
axis_ = kernel_ptr->get_axis();
input_shape_ = inputs[0]->GetShapeVector();
if (global_median_) {
int input_size = 1;
for (size_t i = 0; i < input_shape_.size(); i++) {
input_size *= input_shape_[i];
}
input_shape_.clear();
input_shape_.push_back(input_size);
}
int64_t dims = static_cast<int64_t>(input_shape_.size());
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
<< "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += dims;
}
return true;
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != 0) {
return ret;
}
input_shape_ = inputs[0]->GetShapeVector();
if (global_median_) {
int input_size = 1;
for (size_t i = 0; i < input_shape_.size(); i++) {
input_size *= input_shape_[i];
}
input_shape_.clear();
input_shape_.push_back(input_size);
}
return KRET_OK;
}
std::vector<KernelAttr> GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64)};
return support_list;
}
private:
bool global_median_;
bool keep_dims_;
int64_t axis_;
std::vector<int64_t> input_shape_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MEDIAN_GPU_KERNEL_H_

View File

@ -27,6 +27,38 @@
namespace mindspore {
namespace ops {
void Median::Init(const bool global_median, const int64_t axis, const bool keep_dims) {
this->set_global_median(global_median);
this->set_axis(axis);
this->set_keep_dims(keep_dims);
}
void Median::set_global_median(const bool global_median) {
(void)this->AddAttr(kGlobalMedian, api::MakeValue(global_median));
}
void Median::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
void Median::set_axis(const int64_t &axis) {
int64_t f = axis;
(void)this->AddAttr(kAxis, api::MakeValue(f));
}
bool Median::get_global_median() const {
auto value_ptr = GetAttr(kGlobalMedian);
return GetValue<bool>(value_ptr);
}
bool Median::get_keep_dims() const {
auto value_ptr = GetAttr(kKeepDims);
return GetValue<bool>(value_ptr);
}
int64_t Median::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<int64_t>(value_ptr);
}
namespace {
abstract::TupleShapePtr MedianInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
@ -34,18 +66,18 @@ abstract::TupleShapePtr MedianInferShape(const PrimitivePtr &primitive,
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t x_size = x_shape.size();
std::vector<int64_t> out;
auto check_global_median = primitive->GetAttr("global_median");
auto check_global_median = primitive->GetAttr(kGlobalMedian);
MS_EXCEPTION_IF_NULL(check_global_median);
bool global_median = GetValue<bool>(check_global_median);
if (!global_median) {
auto check_axis = primitive->GetAttr("axis");
auto check_axis = primitive->GetAttr(kAxis);
auto axis = GetValue<int64_t>(check_axis);
auto check_keepdim = primitive->GetAttr("keep_dims");
auto check_keepdim = primitive->GetAttr(kKeepDims);
bool keepdim = GetValue<bool>(check_keepdim);
if (x_size == 0) {
CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-1, 1}, "Median");
CheckAndConvertUtils::CheckInRange(kAxis, axis, kIncludeLeft, {-1, 1}, "Median");
} else {
CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-x_size, x_size}, "Median");
CheckAndConvertUtils::CheckInRange(kAxis, axis, kIncludeLeft, {-x_size, x_size}, "Median");
}
if (axis < 0) {
axis += x_size;

View File

@ -34,6 +34,26 @@ class MIND_API Median : public BaseOperator {
MIND_API_BASE_MEMBER(Median);
/// \brief Constructor.
Median() : BaseOperator(kNameMedian) { InitIOName({"x"}, {"y", "indices"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Median for the inputs.
void Init(const bool global_median = false, const int64_t axis = 0, const bool keep_dims = false);
/// \brief Set global_median.
void set_global_median(const bool global_median);
/// \brief Set keep_dims.
void set_keep_dims(const bool keep_dims);
/// \brief Set axis.
void set_axis(const int64_t &axis);
/// \brief Get global_median.
///
/// \return global_median.
bool get_global_median() const;
/// \brief Get keep_dims.
///
/// \return keep_dims.
bool get_keep_dims() const;
/// \brief Get axis.
///
/// \return axis.
int64_t get_axis() const;
};
abstract::AbstractBasePtr MedianInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -87,6 +87,7 @@ constexpr auto kFreqUpperLimit = "freq_upper_limit";
constexpr auto kFreezeBn = "freeze_bn";
constexpr auto kGateOrder = "gate_order";
constexpr auto kGlobal = "global";
constexpr auto kGlobalMedian = "global_median";
constexpr auto kGrad = "grad";
constexpr auto kIsGrad = "is_grad";
constexpr auto kGradientScale = "gradient_scale";

View File

@ -28,6 +28,7 @@ from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add,
from ...ops.composite.base import _append, _insert, _pop
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
from ...ops.operations.math_ops import Median
from ...ops.operations._inner_ops import Format
from ...ops.operations import _csr_ops
from ...ops.primitive import constexpr
@ -632,6 +633,19 @@ def argmin(x, axis=None):
return P.Argmax(axis)(F.neg_tensor(x))
def median(x, global_median, axis=0, keep_dims=False):
r"""
Computes the median of input tensor.
.. warning::
When attr `global_median` is True, the second output Tensor value is meaningless.
"""
check_axis_in_range_const(axis, x.ndim)
median_ = Median(global_median, axis, keep_dims)
return median_(x)
def cumsum(x, axis=None, dtype=None):
"""
Returns the cumulative sum of the elements along a given axis.

View File

@ -4724,6 +4724,56 @@ class Tensor(Tensor_):
validator.check_bool(sorted, 'sorted')
return tensor_operator_registry.get("top_k")(sorted)(self, k)
def median(self, global_median=False, axis=0, keep_dims=False):
r"""
Computes the median of input tensor.
.. warning::
When attr `global_median` is True, the second output Tensor value is meaningless.
Args:
global_median (bool): Whether the output tensor is the global median of all input tensor elements or not.
Default: False.
axis (int): The dimension need to reduce. Default: 0.
keepdim (bool): Whether the output tensor need to retain `axis` dimension or not. Default: False.
Outputs:
Not exist when global_median is True.
Returns:
y (Tensor) - Has the same dtype as the self Tensor.
If `keep_dims` is true, the output tensors have the same dimension as the self Tensor except in
dimension `axis` which are of size 1. Otherwise, the outputs tensor have 1 fewer dimension than input.
indices (Tensor) - Has the same shape as the `y`, but dtype is int64.
Not exist when global_median is True.
Raises:
TypeError: If dtype of self Tensor is not one of the following: int16, int32, int64, float32, double.
TypeError: If `global_median` is not a bool.
TypeError: If `axis` is not a int.
TypeError: If `keep_dims` is not a bool.
ValueError: If `axis` is not in range of [-len(`self.shape`), len(`self.shape`) - 1).
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # case 1 : common median compute
>>> x = Tensor(np.array([[0.57, 0.11, 0.21],[0.38, 0.50, 0.57], [0.36, 0.16, 0.44]]).astype(np.float32))
>>> y = x.median(global_median=False, axis=0, keep_dims=False)
>>> print(y)
(Tensor(shape=[3], dtype=Float32, value=[0.38, 0.16, 0.44]),
Tensor(shape=[3], dtype=Int64, value=[1, 2, 2]))
>>> # case 2 : global median compute
>>> x = Tensor(np.array([1, 7, 6],[5, 1, 3],[9, 17, 1]), mindspore.int32)
>>> y = x.median(global_median=True)
>>> print(y)
(Tensor(shape=[1], dtype=Int32, value=[5]), Tensor(shape=[1], dtype=Int64, value=[1]))
"""
self._init_check()
validator.check_axis_in_range(axis, self.ndim)
return tensor_operator_registry.get('median')(global_median, axis, keep_dims)(self)
class RowTensor(RowTensor_):
"""

View File

@ -30,7 +30,7 @@ from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_a
get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
BesselK1e)
BesselK1e, Median)
@constexpr
@ -469,6 +469,34 @@ def get_reducer_vmap_rule(prim, axis_size):
return vmap_rule
@vmap_rules_getters.register(Median)
def get_median_vmap_rule(prim, axis_size):
"""VmapRule for median operations."""
global_median = prim.global_median
axis = prim.axis
keep_dims = prim.keep_dims
@constexpr
def trans_axis(axis, rank, dim, keep_dims):
if axis < 0:
axis += rank - 1
axis_new = axis + 1 if dim <= axis else axis
if keep_dims:
dim_new = axis_new
else:
dim_new = dim - 1 if dim > axis_new else dim
return axis_new, dim_new
def vmap_rule(x_bdim):
x, x_dim = x_bdim
rank = len(x.shape)
axis_new, dim_new = trans_axis(axis, rank, x_dim, keep_dims)
y, indices = Median(global_median, axis_new, keep_dims)(x)
return (y, dim_new), (indices, dim_new)
return vmap_rule
@vmap_rules_getters.register(P.IndexAdd)
def get_index_add_vmap_rule(prim, axis_size):
"""VmapRule for IndexAdd."""

View File

@ -176,6 +176,7 @@ from .math_func import (
linspace,
matrix_solve,
maximum,
median,
logaddexp,
logaddexp2,
logit,

View File

@ -42,6 +42,7 @@ from ..operations.math_ops import (
BesselK1,
BesselK1e,
MatrixSolve,
Median,
Orgqr,
Renorm,
Hypot,
@ -3053,6 +3054,53 @@ def minimum(x, y):
return minimum_(x, y)
def median(x, global_median=False, axis=0, keep_dims=False):
r"""
Computes the median of input tensor.
.. warning::
When attr `global_median` is True, the second output Tensor value is meaningless.
Args:
x (Tensor) - The first input is a tensor whose data type is number.
global_median (bool) - Whether the output tensor is the global median of all input tensor elements or not.
Default: False.
axis (int) - The dimension need to reduce. Default: 0.
keep_dims (bool) - Whether the output tensor need to retain `axis` dimension or not. Default: False.
Returns:
y (Tensor) - Has the same dtype as the `x`. If `global_median` is true, the `y` has only one
element. If `keep_dims` is true, the `y` has the same shape as the `x` except the shape of `y` in dimension
`axis` is size 1. Otherwise, the `y` lacks `axis` dimension than input.
indices (Tensor) - Has the same shape as the `y`, but dtype is int64.
Raises:
TypeError: If dtype of `x` is not one of the following: int16, int32, int64, float32, double.
TypeError: If input `x` is not a Tensor.
TypeError: If `global_median` is not a bool.
TypeError: If `axis` is not a int.
TypeError: If `keep_dims` is not a bool.
ValueError: If `axis` is not in range of [-x.dim, x.dim-1].
Supported Platforms:
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> # case 1 : common median compute
>>> x = Tensor(np.array([[0.57, 0.11, 0.21],[0.38, 0.50, 0.57], [0.36, 0.16, 0.44]]).astype(np.float32))
>>> y = ops.median(x, global_median=False, axis=0, keep_dims=False)
>>> print(y)
(Tensor(shape=[3], dtype=Float32, value=[0.38, 0.16, 0.44]), Tensor(shape=[3], dtype=Int64, value=[1, 2, 2]))
>>> # case 2 : global median compute
>>> x = Tensor(np.array([1, 7, 6],[5, 1, 3],[9, 17, 1]), mindspore.int32)
>>> y = ops.median(x, global_median=True)
>>> print(y)
(Tensor(shape=[1], dtype=Int32, value=[5]), Tensor(shape=[1], dtype=Int64, value=[1]))
"""
median_ = Median(global_median, axis, keep_dims)
return median_(x)
def orgqr(x, tau):
r"""
Computes the first :math:`N` columns of a product of Householder matrices. Take the case of input without batch
@ -5379,6 +5427,7 @@ __all__ = [
'same_type_shape',
'maximum',
'minimum',
'median',
'floor',
'logical_not',
'logical_or',

View File

@ -32,6 +32,7 @@ from . import operations as P
from .operations import _grad_ops
from .operations import _csr_ops
from .operations import linalg_ops
from .operations.math_ops import Median
from .operations.array_ops import UniqueConsecutive
from .operations.nn_ops import AdaptiveMaxPool2D
from .function.sparse_func import sparse_add
@ -404,6 +405,7 @@ tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('erf', P.Erf)
tensor_operator_registry.register('erfc', P.Erfc)
tensor_operator_registry.register('standard_normal', P.StandardNormal)
tensor_operator_registry.register('median', Median)
# ms cannot support Tensor(True) compare
tensor_operator_registry.register('__eq__', equal)
tensor_operator_registry.register('__ne__', not_equal)

View File

@ -6328,7 +6328,7 @@ class Median(Primitive):
ValueError: If `axis` is not in range of [-x.dim, x.dim-1].
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> # case 1 : common median compute