add median op for gpu
This commit is contained in:
parent
33f12ac66b
commit
17e37b8546
|
@ -0,0 +1,23 @@
|
|||
mindspore.Tensor.median
|
||||
=======================
|
||||
|
||||
.. py:method:: mindspore.Tensor.median(global_median=False, axis=0, keep_dims=False)
|
||||
|
||||
返回指定维度上的中值。
|
||||
|
||||
参数:
|
||||
- **global_median** (bool) - 表示是否对当前Tensor的全部元素取中值。默认值:False。
|
||||
- **axis** (int) - 计算中值的维度。默认值:(0), 取值范围为[-ndim, ndim - 1],'ndim' 表示当前Tensor的维度长度。
|
||||
- **keep_dims** (bool) - 表示是否减少维度,如果为True,输出将与输入保持相同的维度;如果为False,输出将减少维度。默认值:False。
|
||||
|
||||
返回:
|
||||
- **y** (Tensor) - 返回指定维度上的中值,数据类型与当前Tensor相同。
|
||||
- **indices** (bool) - 指定中值索引。数据类型为int64。如果 `global_median` 为True,则结果无意义。
|
||||
|
||||
异常:
|
||||
- **TypeError** - 当前Tensor的类型不是: int16, int32, int64, float32, float64。
|
||||
- **TypeError** - `global_median` 不是bool。
|
||||
- **TypeError** - `axis` 不是int。
|
||||
- **TypeError** - `keep_dims` 不是bool。
|
||||
- **ValueError** - `axis` 的范围不在[-ndim, ndim - 1]。
|
||||
|
|
@ -83,6 +83,7 @@ mindspore.Tensor
|
|||
mindspore.Tensor.matrix_determinant
|
||||
mindspore.Tensor.max
|
||||
mindspore.Tensor.mean
|
||||
mindspore.Tensor.median
|
||||
mindspore.Tensor.min
|
||||
mindspore.Tensor.narrow
|
||||
mindspore.Tensor.nbytes
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
mindspore.ops.median
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.median(x, global_median=False, axis=0, keep_dims=Fasle)
|
||||
|
||||
输出张量指定维度上的中值。
|
||||
|
||||
|
||||
参数:
|
||||
|
||||
- **x** (Tensor) - median的输入,任意维度的Tensor。数据类型支持int16、int32、int64、float32或float64。
|
||||
- **global_median** (bool) - 表示是否对x的全部元素取中值。默认值:False。
|
||||
- **axis** (int) - 指定计算维度。默认值:(0), 取值范围为[-dims, dims - 1],`dims` 表示 `x` 的维度长度
|
||||
- **keep_dims** (bool) - 表示是否减少维度,如果为True,输出将与输入保持相同的维度;如果为False,输出将减少维度。
|
||||
默认值:False。
|
||||
|
||||
返回:
|
||||
|
||||
- **y** (Tensor) - 返回指定维度上的中值,数据类型与 `x` 相同。
|
||||
- **indices** (bool) - 指定中值索引。数据类型为int64。如果 `global_median` 为True,则结果无意义。
|
||||
|
||||
异常:
|
||||
|
||||
- **TypeError** - `x` 的数据类型不是: int16, int32, int64, float32, float64。
|
||||
- **TypeError** - `x` 不是Tensor。
|
||||
- **TypeError** - `global_median` 不是bool。
|
||||
- **TypeError** - `axis` 不是int。
|
||||
- **TypeError** - `keep_dims` 不是bool。
|
||||
- **ValueError** - `axis` 的范围不在[-dims, dims - 1],`dims` 表示 `x` 的维度长度。
|
|
@ -206,6 +206,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"expand_dims", std::string("expand_dims")}, // P.expand_dims()
|
||||
{"squeeze", std::string("squeeze")}, // P.squeeze()
|
||||
{"astype", std::string("astype")}, // P.cast()
|
||||
{"median", std::string("median")}, // P.median()
|
||||
{"cumsum", std::string("cumsum")}, // P.cumsum()
|
||||
{"cummin", std::string("cummin")}, // cummin()
|
||||
{"cummax", std::string("cummax")}, // cummax()
|
||||
|
|
|
@ -0,0 +1,398 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "median_impl.cuh"
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
constexpr int WARP_SIZE = 32;
|
||||
constexpr int MAX_THREAD = 1024;
|
||||
constexpr int RADIX_BITS = 2;
|
||||
constexpr int RADIX_SIZE = 4;
|
||||
constexpr int RADIX_MASK = (RADIX_SIZE - 1);
|
||||
|
||||
__device__ __forceinline__ unsigned int warp_ballot(int predicate) {
|
||||
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 700
|
||||
return __ballot(predicate);
|
||||
#else
|
||||
return __ballot_sync(__activemask(), predicate);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __device__ __host__ T round_up(T a, T b) {
|
||||
return (a / b) * b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct Bitfield {};
|
||||
|
||||
template <>
|
||||
struct Bitfield<unsigned int> {
|
||||
static __device__ unsigned int GetBitfield(unsigned int val, int pos, int len) {
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
pos &= 0xff;
|
||||
len &= 0xff;
|
||||
|
||||
unsigned int m = (1u << len) - 1u;
|
||||
return (val >> pos) & m;
|
||||
#else
|
||||
unsigned int ret;
|
||||
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ unsigned int SetBitfield(unsigned int val, unsigned int to_insert, int pos, int len) {
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
pos &= 0xff;
|
||||
len &= 0xff;
|
||||
|
||||
unsigned int m = (1u << len) - 1u;
|
||||
to_insert &= m;
|
||||
to_insert <<= pos;
|
||||
m <<= pos;
|
||||
|
||||
return (val & ~m) | to_insert;
|
||||
#else
|
||||
unsigned int ret;
|
||||
asm("bfi.b32 %0, %1, %2, %3, %4;" : "=r"(ret) : "r"(to_insert), "r"(val), "r"(pos), "r"(len));
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Bitfield<uint64_t> {
|
||||
static __device__ uint64_t GetBitfield(uint64_t val, int pos, int len) {
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
pos &= 0xff;
|
||||
len &= 0xff;
|
||||
|
||||
uint64_t m = (1u << len) - 1u;
|
||||
return (val >> pos) & m;
|
||||
#else
|
||||
uint64_t ret;
|
||||
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ uint64_t SetBitfield(uint64_t val, uint64_t to_insert, int pos, int len) {
|
||||
#if !defined(__CUDA_ARCH__)
|
||||
pos &= 0xff;
|
||||
len &= 0xff;
|
||||
|
||||
uint64_t m = (1u << len) - 1u;
|
||||
to_insert &= m;
|
||||
to_insert <<= pos;
|
||||
m <<= pos;
|
||||
|
||||
return (val & ~m) | to_insert;
|
||||
#else
|
||||
uint64_t ret;
|
||||
asm("bfi.b64 %0, %1, %2, %3, %4;" : "=l"(ret) : "l"(to_insert), "l"(val), "r"(pos), "r"(len));
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MedianTypeConfig {};
|
||||
|
||||
template <>
|
||||
struct MedianTypeConfig<float> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
// Converts a float to an integer representation with the same sorting
|
||||
static inline __device__ RadixType Convert(float v) {
|
||||
RadixType x = __float_as_int(v);
|
||||
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
|
||||
return (v == v) ? (x ^ mask) : 0xffffffff;
|
||||
}
|
||||
|
||||
static inline __device__ float Deconvert(RadixType v) {
|
||||
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
|
||||
return __int_as_float(v ^ mask);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MedianTypeConfig<uint8_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType Convert(uint8_t v) { return v; }
|
||||
|
||||
static inline __device__ uint8_t Deconvert(RadixType v) { return v; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MedianTypeConfig<int8_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType Convert(int8_t v) { return 128u + v; }
|
||||
|
||||
static inline __device__ int8_t Deconvert(RadixType v) { return v - 128; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MedianTypeConfig<int16_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType Convert(int16_t v) { return 32768u + v; }
|
||||
|
||||
static inline __device__ int16_t Deconvert(RadixType v) { return v - 32768; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MedianTypeConfig<int32_t> {
|
||||
typedef uint32_t RadixType;
|
||||
|
||||
static inline __device__ RadixType Convert(int32_t v) { return 2147483648u + v; }
|
||||
|
||||
static inline __device__ int32_t Deconvert(RadixType v) { return v - 2147483648u; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MedianTypeConfig<int64_t> {
|
||||
typedef uint64_t RadixType;
|
||||
|
||||
static inline __device__ RadixType Convert(int64_t v) { return 9223372036854775808ull + v; }
|
||||
|
||||
static inline __device__ int64_t Deconvert(RadixType v) { return v - 9223372036854775808ull; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MedianTypeConfig<double> {
|
||||
typedef uint64_t RadixType;
|
||||
|
||||
static inline __device__ RadixType Convert(double v) {
|
||||
RadixType x = __double_as_longlong(v);
|
||||
RadixType mask = -((x >> 63)) | 0x8000000000000000;
|
||||
return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
|
||||
}
|
||||
|
||||
static inline __device__ double Deconvert(RadixType v) {
|
||||
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
|
||||
return __longlong_as_double(v ^ mask);
|
||||
}
|
||||
};
|
||||
|
||||
// This function counts the distribution of all input values in a slice
|
||||
template <typename T, typename R_T, typename S, int RadixSize, int RadixBits>
|
||||
__device__ void CountRadixUsingMask(int counts[RadixSize], int *smem, R_T desired, R_T desired_mask,
|
||||
int radix_digit_pos, S size, S stride, const T *data) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RadixSize; ++i) {
|
||||
counts[i] = 0;
|
||||
}
|
||||
|
||||
if (threadIdx.x < RadixSize) {
|
||||
smem[threadIdx.x] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (S i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
R_T val = MedianTypeConfig<T>::Convert(data[i * stride]);
|
||||
|
||||
bool hasVal = ((val & desired_mask) == desired);
|
||||
R_T digit_in_radix = Bitfield<R_T>::GetBitfield(val, radix_digit_pos, RadixBits);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < RadixSize; ++j) {
|
||||
bool vote = hasVal && (digit_in_radix == j);
|
||||
counts[j] += __popc(warp_ballot(vote));
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x % 32 == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < RadixSize; ++i) {
|
||||
atomicAdd(&smem[i], counts[i]);
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < RadixSize; ++i) {
|
||||
counts[i] = smem[i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// This finds the unique value that matches the pattern
|
||||
template <typename T, typename R_T, typename S>
|
||||
__device__ T FindPattern(T *smem, const T *data, S size, S stride, R_T desired, R_T desired_mask) {
|
||||
if (threadIdx.x < 2) {
|
||||
smem[threadIdx.x] = static_cast<T>(0);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
S size_round = round_up(size, static_cast<S>(blockDim.x));
|
||||
for (S i = threadIdx.x; i < size_round; i += blockDim.x) {
|
||||
bool in_range = (i < size);
|
||||
T v = in_range ? data[i * stride] : static_cast<T>(0);
|
||||
|
||||
if (in_range && ((MedianTypeConfig<T>::Convert(v) & desired_mask) == desired)) {
|
||||
smem[0] = static_cast<T>(1);
|
||||
smem[1] = v;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
T found = smem[0];
|
||||
T val = smem[1];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (found != static_cast<T>(0)) {
|
||||
return val;
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<T>(0);
|
||||
}
|
||||
|
||||
// Returns the top-Kth element found in the data using radix selection
|
||||
template <typename T, typename R_T, typename S>
|
||||
__device__ void RadixSelect(const T *data, S kth, bool largest, S size, S stride, int *smem, T *top_k) {
|
||||
int counts[RADIX_SIZE];
|
||||
R_T desired = 0;
|
||||
R_T desired_mask = 0;
|
||||
int k = kth;
|
||||
|
||||
for (int digit_pos = sizeof(T) * 8 - RADIX_BITS; digit_pos >= 0; digit_pos -= RADIX_BITS) {
|
||||
CountRadixUsingMask<T, R_T, S, RADIX_SIZE, RADIX_BITS>(counts, smem, desired, desired_mask, digit_pos, size, stride,
|
||||
data);
|
||||
|
||||
auto found_unique = [&](int i, int count) -> bool {
|
||||
if (count == 1 && k == 1) {
|
||||
desired = Bitfield<R_T>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
|
||||
desired_mask = Bitfield<R_T>::SetBitfield(desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);
|
||||
*top_k = FindPattern<T, R_T, S>(reinterpret_cast<T *>(smem), data, size, stride, desired, desired_mask);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto found_non_unique = [&](int i, int count) -> bool {
|
||||
if (count >= k) {
|
||||
desired = Bitfield<R_T>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
|
||||
desired_mask = Bitfield<R_T>::SetBitfield(desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);
|
||||
return true;
|
||||
}
|
||||
k -= count;
|
||||
return false;
|
||||
};
|
||||
|
||||
if (largest) {
|
||||
#pragma unroll
|
||||
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
|
||||
int count = counts[i];
|
||||
if (found_unique(i, count)) {
|
||||
return;
|
||||
}
|
||||
if (found_non_unique(i, count)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < RADIX_SIZE; ++i) {
|
||||
int count = counts[i];
|
||||
if (found_unique(i, count)) {
|
||||
return;
|
||||
}
|
||||
if (found_non_unique(i, count)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*top_k = MedianTypeConfig<T>::Deconvert(desired);
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void MedianKernel(const T *input, T *output, S *indices, S size, S num, S stride, bool global_median) {
|
||||
__shared__ int smem[WARP_SIZE];
|
||||
|
||||
S slice = blockIdx.y * gridDim.x + blockIdx.x;
|
||||
if (slice >= num) {
|
||||
return;
|
||||
}
|
||||
|
||||
S offset_y = size * gridDim.x;
|
||||
|
||||
S k = (size - 1) / 2;
|
||||
|
||||
// Find the median value
|
||||
T median = static_cast<T>(0);
|
||||
RadixSelect<T, typename MedianTypeConfig<T>::RadixType, S>(input + blockIdx.y * offset_y + blockIdx.x, k + 1, false,
|
||||
size, stride, smem, &median);
|
||||
output[slice] = median;
|
||||
|
||||
// Find the index of the median value in the slice
|
||||
if (!global_median) {
|
||||
for (S i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
T val = input[blockIdx.y * offset_y + blockIdx.x + i * stride];
|
||||
if (val == median) {
|
||||
indices[slice] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Median(const T *input_value, T *output, S *indices, const std::vector<int64_t> input_shape, const int64_t axis,
|
||||
bool global_median, cudaStream_t cuda_stream) {
|
||||
dim3 threads, grids;
|
||||
size_t i = 0;
|
||||
for (; i < static_cast<size_t>(axis); i++) {
|
||||
grids.y *= input_shape[i];
|
||||
}
|
||||
size_t size = input_shape[axis];
|
||||
for (i = axis + 1; i < input_shape.size(); i++) {
|
||||
grids.x *= input_shape[i];
|
||||
}
|
||||
threads.x = std::min(round_up(static_cast<int>(size), WARP_SIZE), MAX_THREAD);
|
||||
|
||||
S num = grids.y * grids.x;
|
||||
S stride = grids.x;
|
||||
|
||||
MedianKernel<T, S>
|
||||
<<<grids, threads, 0, cuda_stream>>>(input_value, output, indices, size, num, stride, global_median);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void Median<int16_t, int64_t>(const int16_t *input_value, int16_t *output, int64_t *indices,
|
||||
const std::vector<int64_t> input_shape, const int64_t axis,
|
||||
bool global_median, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Median<int32_t, int64_t>(const int32_t *input_value, int32_t *output, int64_t *indices,
|
||||
const std::vector<int64_t> input_shape, const int64_t axis,
|
||||
bool global_median, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Median<int64_t, int64_t>(const int64_t *input_value, int64_t *output, int64_t *indices,
|
||||
const std::vector<int64_t> input_shape, const int64_t axis,
|
||||
bool global_median, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Median<float, int64_t>(const float *input_value, float *output, int64_t *indices,
|
||||
const std::vector<int64_t> input_shape, const int64_t axis,
|
||||
bool global_median, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void Median<double, int64_t>(const double *input_value, double *output, int64_t *indices,
|
||||
const std::vector<int64_t> input_shape, const int64_t axis,
|
||||
bool global_median, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T, typename S>
|
||||
void Median(const T *input_value, T *output, S *indices, const std::vector<int64_t> input_shape, const int64_t axis,
|
||||
bool global_median, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MEDIAN_IMPL_CUH_
|
|
@ -0,0 +1,39 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/median_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Median, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
MedianGpuKernelMod, int16_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Median, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
MedianGpuKernelMod, int32_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Median, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
MedianGpuKernelMod, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Median,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
MedianGpuKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
Median,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
|
||||
MedianGpuKernelMod, double, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,123 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MEDIAN_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MEDIAN_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/median.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/median_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kMedianInputsNum = 1;
|
||||
constexpr size_t kMedianOutputsNum = 2;
|
||||
template <typename T, typename S>
|
||||
class MedianGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
MedianGpuKernelMod() : global_median_(false), keep_dims_(false), axis_(0) {}
|
||||
~MedianGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output0_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
S *output1_addr = nullptr;
|
||||
if (!global_median_) {
|
||||
output1_addr = GetDeviceAddress<S>(outputs, 1);
|
||||
}
|
||||
Median(input_addr, output0_addr, output1_addr, input_shape_, axis_, global_median_,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override {
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Median>(base_operator);
|
||||
if (kernel_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' cast Median ops failed!";
|
||||
return false;
|
||||
}
|
||||
if (inputs.size() != kMedianInputsNum || outputs.size() > kMedianOutputsNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size should be " << kMedianInputsNum << " and "
|
||||
<< kMedianOutputsNum << ", but got " << inputs.size() << " and " << outputs.size();
|
||||
return false;
|
||||
}
|
||||
global_median_ = kernel_ptr->get_global_median();
|
||||
keep_dims_ = kernel_ptr->get_keep_dims();
|
||||
axis_ = kernel_ptr->get_axis();
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
if (global_median_) {
|
||||
int input_size = 1;
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
input_size *= input_shape_[i];
|
||||
}
|
||||
input_shape_.clear();
|
||||
input_shape_.push_back(input_size);
|
||||
}
|
||||
int64_t dims = static_cast<int64_t>(input_shape_.size());
|
||||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims
|
||||
<< "), but got " << axis_;
|
||||
}
|
||||
if (axis_ < 0) {
|
||||
axis_ += dims;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
if (global_median_) {
|
||||
int input_size = 1;
|
||||
for (size_t i = 0; i < input_shape_.size(); i++) {
|
||||
input_size *= input_shape_[i];
|
||||
}
|
||||
input_shape_.clear();
|
||||
input_shape_.push_back(input_size);
|
||||
}
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
private:
|
||||
bool global_median_;
|
||||
bool keep_dims_;
|
||||
int64_t axis_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MEDIAN_GPU_KERNEL_H_
|
|
@ -27,6 +27,38 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
void Median::Init(const bool global_median, const int64_t axis, const bool keep_dims) {
|
||||
this->set_global_median(global_median);
|
||||
this->set_axis(axis);
|
||||
this->set_keep_dims(keep_dims);
|
||||
}
|
||||
|
||||
void Median::set_global_median(const bool global_median) {
|
||||
(void)this->AddAttr(kGlobalMedian, api::MakeValue(global_median));
|
||||
}
|
||||
|
||||
void Median::set_keep_dims(const bool keep_dims) { (void)this->AddAttr(kKeepDims, api::MakeValue(keep_dims)); }
|
||||
|
||||
void Median::set_axis(const int64_t &axis) {
|
||||
int64_t f = axis;
|
||||
(void)this->AddAttr(kAxis, api::MakeValue(f));
|
||||
}
|
||||
|
||||
bool Median::get_global_median() const {
|
||||
auto value_ptr = GetAttr(kGlobalMedian);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
bool Median::get_keep_dims() const {
|
||||
auto value_ptr = GetAttr(kKeepDims);
|
||||
return GetValue<bool>(value_ptr);
|
||||
}
|
||||
|
||||
int64_t Median::get_axis() const {
|
||||
auto value_ptr = GetAttr(kAxis);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
|
||||
namespace {
|
||||
abstract::TupleShapePtr MedianInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -34,18 +66,18 @@ abstract::TupleShapePtr MedianInferShape(const PrimitivePtr &primitive,
|
|||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
int64_t x_size = x_shape.size();
|
||||
std::vector<int64_t> out;
|
||||
auto check_global_median = primitive->GetAttr("global_median");
|
||||
auto check_global_median = primitive->GetAttr(kGlobalMedian);
|
||||
MS_EXCEPTION_IF_NULL(check_global_median);
|
||||
bool global_median = GetValue<bool>(check_global_median);
|
||||
if (!global_median) {
|
||||
auto check_axis = primitive->GetAttr("axis");
|
||||
auto check_axis = primitive->GetAttr(kAxis);
|
||||
auto axis = GetValue<int64_t>(check_axis);
|
||||
auto check_keepdim = primitive->GetAttr("keep_dims");
|
||||
auto check_keepdim = primitive->GetAttr(kKeepDims);
|
||||
bool keepdim = GetValue<bool>(check_keepdim);
|
||||
if (x_size == 0) {
|
||||
CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-1, 1}, "Median");
|
||||
CheckAndConvertUtils::CheckInRange(kAxis, axis, kIncludeLeft, {-1, 1}, "Median");
|
||||
} else {
|
||||
CheckAndConvertUtils::CheckInRange("axis", axis, kIncludeLeft, {-x_size, x_size}, "Median");
|
||||
CheckAndConvertUtils::CheckInRange(kAxis, axis, kIncludeLeft, {-x_size, x_size}, "Median");
|
||||
}
|
||||
if (axis < 0) {
|
||||
axis += x_size;
|
||||
|
|
|
@ -34,6 +34,26 @@ class MIND_API Median : public BaseOperator {
|
|||
MIND_API_BASE_MEMBER(Median);
|
||||
/// \brief Constructor.
|
||||
Median() : BaseOperator(kNameMedian) { InitIOName({"x"}, {"y", "indices"}); }
|
||||
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Median for the inputs.
|
||||
void Init(const bool global_median = false, const int64_t axis = 0, const bool keep_dims = false);
|
||||
/// \brief Set global_median.
|
||||
void set_global_median(const bool global_median);
|
||||
/// \brief Set keep_dims.
|
||||
void set_keep_dims(const bool keep_dims);
|
||||
/// \brief Set axis.
|
||||
void set_axis(const int64_t &axis);
|
||||
/// \brief Get global_median.
|
||||
///
|
||||
/// \return global_median.
|
||||
bool get_global_median() const;
|
||||
/// \brief Get keep_dims.
|
||||
///
|
||||
/// \return keep_dims.
|
||||
bool get_keep_dims() const;
|
||||
/// \brief Get axis.
|
||||
///
|
||||
/// \return axis.
|
||||
int64_t get_axis() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr MedianInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -87,6 +87,7 @@ constexpr auto kFreqUpperLimit = "freq_upper_limit";
|
|||
constexpr auto kFreezeBn = "freeze_bn";
|
||||
constexpr auto kGateOrder = "gate_order";
|
||||
constexpr auto kGlobal = "global";
|
||||
constexpr auto kGlobalMedian = "global_median";
|
||||
constexpr auto kGrad = "grad";
|
||||
constexpr auto kIsGrad = "is_grad";
|
||||
constexpr auto kGradientScale = "gradient_scale";
|
||||
|
|
|
@ -28,6 +28,7 @@ from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add,
|
|||
from ...ops.composite.base import _append, _insert, _pop
|
||||
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from ...ops.operations.math_ops import Median
|
||||
from ...ops.operations._inner_ops import Format
|
||||
from ...ops.operations import _csr_ops
|
||||
from ...ops.primitive import constexpr
|
||||
|
@ -632,6 +633,19 @@ def argmin(x, axis=None):
|
|||
return P.Argmax(axis)(F.neg_tensor(x))
|
||||
|
||||
|
||||
def median(x, global_median, axis=0, keep_dims=False):
|
||||
r"""
|
||||
Computes the median of input tensor.
|
||||
|
||||
.. warning::
|
||||
When attr `global_median` is True, the second output Tensor value is meaningless.
|
||||
|
||||
"""
|
||||
check_axis_in_range_const(axis, x.ndim)
|
||||
median_ = Median(global_median, axis, keep_dims)
|
||||
return median_(x)
|
||||
|
||||
|
||||
def cumsum(x, axis=None, dtype=None):
|
||||
"""
|
||||
Returns the cumulative sum of the elements along a given axis.
|
||||
|
|
|
@ -4724,6 +4724,56 @@ class Tensor(Tensor_):
|
|||
validator.check_bool(sorted, 'sorted')
|
||||
return tensor_operator_registry.get("top_k")(sorted)(self, k)
|
||||
|
||||
def median(self, global_median=False, axis=0, keep_dims=False):
|
||||
r"""
|
||||
Computes the median of input tensor.
|
||||
|
||||
.. warning::
|
||||
When attr `global_median` is True, the second output Tensor value is meaningless.
|
||||
|
||||
Args:
|
||||
global_median (bool): Whether the output tensor is the global median of all input tensor elements or not.
|
||||
Default: False.
|
||||
axis (int): The dimension need to reduce. Default: 0.
|
||||
keepdim (bool): Whether the output tensor need to retain `axis` dimension or not. Default: False.
|
||||
|
||||
Outputs:
|
||||
Not exist when global_median is True.
|
||||
|
||||
Returns:
|
||||
y (Tensor) - Has the same dtype as the self Tensor.
|
||||
If `keep_dims` is true, the output tensors have the same dimension as the self Tensor except in
|
||||
dimension `axis` which are of size 1. Otherwise, the outputs tensor have 1 fewer dimension than input.
|
||||
indices (Tensor) - Has the same shape as the `y`, but dtype is int64.
|
||||
Not exist when global_median is True.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of self Tensor is not one of the following: int16, int32, int64, float32, double.
|
||||
TypeError: If `global_median` is not a bool.
|
||||
TypeError: If `axis` is not a int.
|
||||
TypeError: If `keep_dims` is not a bool.
|
||||
ValueError: If `axis` is not in range of [-len(`self.shape`), len(`self.shape`) - 1).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> # case 1 : common median compute
|
||||
>>> x = Tensor(np.array([[0.57, 0.11, 0.21],[0.38, 0.50, 0.57], [0.36, 0.16, 0.44]]).astype(np.float32))
|
||||
>>> y = x.median(global_median=False, axis=0, keep_dims=False)
|
||||
>>> print(y)
|
||||
(Tensor(shape=[3], dtype=Float32, value=[0.38, 0.16, 0.44]),
|
||||
Tensor(shape=[3], dtype=Int64, value=[1, 2, 2]))
|
||||
>>> # case 2 : global median compute
|
||||
>>> x = Tensor(np.array([1, 7, 6],[5, 1, 3],[9, 17, 1]), mindspore.int32)
|
||||
>>> y = x.median(global_median=True)
|
||||
>>> print(y)
|
||||
(Tensor(shape=[1], dtype=Int32, value=[5]), Tensor(shape=[1], dtype=Int64, value=[1]))
|
||||
"""
|
||||
self._init_check()
|
||||
validator.check_axis_in_range(axis, self.ndim)
|
||||
return tensor_operator_registry.get('median')(global_median, axis, keep_dims)(self)
|
||||
|
||||
|
||||
class RowTensor(RowTensor_):
|
||||
"""
|
||||
|
|
|
@ -30,7 +30,7 @@ from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, get_a
|
|||
get_unop_vmap_rule, _raise_value_error, _bdim_at_front, _broadcast_by_axis, _handle_broadcasting, \
|
||||
get_unary_grad_vmap_rule, _vmap_clone_prim, _bdim_at_any
|
||||
from ..operations.math_ops import (Bernoulli, BesselJ0, BesselJ1, BesselK0, BesselK0e, BesselY0, BesselY1, BesselK1,
|
||||
BesselK1e)
|
||||
BesselK1e, Median)
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -469,6 +469,34 @@ def get_reducer_vmap_rule(prim, axis_size):
|
|||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(Median)
|
||||
def get_median_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for median operations."""
|
||||
global_median = prim.global_median
|
||||
axis = prim.axis
|
||||
keep_dims = prim.keep_dims
|
||||
|
||||
@constexpr
|
||||
def trans_axis(axis, rank, dim, keep_dims):
|
||||
if axis < 0:
|
||||
axis += rank - 1
|
||||
axis_new = axis + 1 if dim <= axis else axis
|
||||
if keep_dims:
|
||||
dim_new = axis_new
|
||||
else:
|
||||
dim_new = dim - 1 if dim > axis_new else dim
|
||||
return axis_new, dim_new
|
||||
|
||||
def vmap_rule(x_bdim):
|
||||
x, x_dim = x_bdim
|
||||
rank = len(x.shape)
|
||||
axis_new, dim_new = trans_axis(axis, rank, x_dim, keep_dims)
|
||||
y, indices = Median(global_median, axis_new, keep_dims)(x)
|
||||
return (y, dim_new), (indices, dim_new)
|
||||
|
||||
return vmap_rule
|
||||
|
||||
|
||||
@vmap_rules_getters.register(P.IndexAdd)
|
||||
def get_index_add_vmap_rule(prim, axis_size):
|
||||
"""VmapRule for IndexAdd."""
|
||||
|
|
|
@ -176,6 +176,7 @@ from .math_func import (
|
|||
linspace,
|
||||
matrix_solve,
|
||||
maximum,
|
||||
median,
|
||||
logaddexp,
|
||||
logaddexp2,
|
||||
logit,
|
||||
|
|
|
@ -42,6 +42,7 @@ from ..operations.math_ops import (
|
|||
BesselK1,
|
||||
BesselK1e,
|
||||
MatrixSolve,
|
||||
Median,
|
||||
Orgqr,
|
||||
Renorm,
|
||||
Hypot,
|
||||
|
@ -3053,6 +3054,53 @@ def minimum(x, y):
|
|||
return minimum_(x, y)
|
||||
|
||||
|
||||
def median(x, global_median=False, axis=0, keep_dims=False):
|
||||
r"""
|
||||
Computes the median of input tensor.
|
||||
|
||||
.. warning::
|
||||
When attr `global_median` is True, the second output Tensor value is meaningless.
|
||||
|
||||
Args:
|
||||
x (Tensor) - The first input is a tensor whose data type is number.
|
||||
global_median (bool) - Whether the output tensor is the global median of all input tensor elements or not.
|
||||
Default: False.
|
||||
axis (int) - The dimension need to reduce. Default: 0.
|
||||
keep_dims (bool) - Whether the output tensor need to retain `axis` dimension or not. Default: False.
|
||||
|
||||
Returns:
|
||||
y (Tensor) - Has the same dtype as the `x`. If `global_median` is true, the `y` has only one
|
||||
element. If `keep_dims` is true, the `y` has the same shape as the `x` except the shape of `y` in dimension
|
||||
`axis` is size 1. Otherwise, the `y` lacks `axis` dimension than input.
|
||||
indices (Tensor) - Has the same shape as the `y`, but dtype is int64.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype of `x` is not one of the following: int16, int32, int64, float32, double.
|
||||
TypeError: If input `x` is not a Tensor.
|
||||
TypeError: If `global_median` is not a bool.
|
||||
TypeError: If `axis` is not a int.
|
||||
TypeError: If `keep_dims` is not a bool.
|
||||
ValueError: If `axis` is not in range of [-x.dim, x.dim-1].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> # case 1 : common median compute
|
||||
>>> x = Tensor(np.array([[0.57, 0.11, 0.21],[0.38, 0.50, 0.57], [0.36, 0.16, 0.44]]).astype(np.float32))
|
||||
>>> y = ops.median(x, global_median=False, axis=0, keep_dims=False)
|
||||
>>> print(y)
|
||||
(Tensor(shape=[3], dtype=Float32, value=[0.38, 0.16, 0.44]), Tensor(shape=[3], dtype=Int64, value=[1, 2, 2]))
|
||||
>>> # case 2 : global median compute
|
||||
>>> x = Tensor(np.array([1, 7, 6],[5, 1, 3],[9, 17, 1]), mindspore.int32)
|
||||
>>> y = ops.median(x, global_median=True)
|
||||
>>> print(y)
|
||||
(Tensor(shape=[1], dtype=Int32, value=[5]), Tensor(shape=[1], dtype=Int64, value=[1]))
|
||||
"""
|
||||
median_ = Median(global_median, axis, keep_dims)
|
||||
return median_(x)
|
||||
|
||||
|
||||
def orgqr(x, tau):
|
||||
r"""
|
||||
Computes the first :math:`N` columns of a product of Householder matrices. Take the case of input without batch
|
||||
|
@ -5379,6 +5427,7 @@ __all__ = [
|
|||
'same_type_shape',
|
||||
'maximum',
|
||||
'minimum',
|
||||
'median',
|
||||
'floor',
|
||||
'logical_not',
|
||||
'logical_or',
|
||||
|
|
|
@ -32,6 +32,7 @@ from . import operations as P
|
|||
from .operations import _grad_ops
|
||||
from .operations import _csr_ops
|
||||
from .operations import linalg_ops
|
||||
from .operations.math_ops import Median
|
||||
from .operations.array_ops import UniqueConsecutive
|
||||
from .operations.nn_ops import AdaptiveMaxPool2D
|
||||
from .function.sparse_func import sparse_add
|
||||
|
@ -404,6 +405,7 @@ tensor_operator_registry.register('split', P.Split)
|
|||
tensor_operator_registry.register('erf', P.Erf)
|
||||
tensor_operator_registry.register('erfc', P.Erfc)
|
||||
tensor_operator_registry.register('standard_normal', P.StandardNormal)
|
||||
tensor_operator_registry.register('median', Median)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -6328,7 +6328,7 @@ class Median(Primitive):
|
|||
ValueError: If `axis` is not in range of [-x.dim, x.dim-1].
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> # case 1 : common median compute
|
||||
|
|
Loading…
Reference in New Issue