Merge pull request !49747 from panzhihui/bugfix
This commit is contained in:
i-robot 2023-03-07 08:06:43 +00:00 committed by Gitee
commit 6c8e82c675
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 118 additions and 80 deletions

View File

@ -22,6 +22,7 @@
"mindspore/mindspore/python/mindspore/ops/operations" "super-init-not-called"
"mindspore/mindspore/python/mindspore/ops/operations/_quant_ops.py" "unused-import"
"mindspore/mindspore/python/mindspore/ops/operations/nn_ops.py" "redefined-builtin"
"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "missing-docstring"
"mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "dangerous-default-value"
"mindspore/mindspore/python/mindspore/ops/operations/_thor_ops.py" "dangerous-default-value"
"mindspore/mindspore/python/mindspore/ops/operations/_thor_ops.py" "redefined-outer-name"

View File

@ -87,6 +87,7 @@ constexpr auto kNonMaxSuppressionV3 = "NonMaxSuppressionV3";
constexpr auto kMaskedSelect = "MaskedSelect";
constexpr auto kMaskedSelectGrad = "MaskedSelectGrad";
constexpr auto kDynamicStitch = "DynamicStitch";
constexpr auto kSort = "Sort";
constexpr auto kSearchSorted = "SearchSorted";
constexpr auto kLinSpace = "LinSpace";
constexpr auto kResizeBilinear = "ResizeBilinear";
@ -208,12 +209,15 @@ constexpr auto kQuantDTypeCast = "QuantDTypeCast";
constexpr auto kFSEDecode = "FSEDecode";
constexpr auto kSparseSegmentSum = "SparseSegmentSum";
constexpr auto kRealDiv = "RealDiv";
constexpr auto kMaskedFill = "MaskedFill";
constexpr auto kDeformableOffsets = "DeformableOffsets";
constexpr auto kDeformableOffsetsGrad = "DeformableOffsetsGrad";
const std::set<std::string> kCpuKernelOps{kIdentity,
kMaskedFill,
kGather,
kDynamicStitch,
kSort,
kSearchSorted,
kSparseSegmentSum,
kResizeBilinear,

View File

@ -18,6 +18,22 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
ArgMaxAndMinWithValueGpuKernelMod, int8_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
ArgMaxAndMinWithValueGpuKernelMod, int64_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
ArgMaxAndMinWithValueGpuKernelMod, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
ArgMaxAndMinWithValueGpuKernelMod, uint64_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMaxWithValue,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
@ -47,6 +63,22 @@ MS_REG_GPU_KERNEL_TWO(
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
ArgMaxAndMinWithValueGpuKernelMod, half, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
ArgMaxAndMinWithValueGpuKernelMod, int8_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
ArgMaxAndMinWithValueGpuKernelMod, int64_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
ArgMaxAndMinWithValueGpuKernelMod, uint8_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
ArgMaxAndMinWithValueGpuKernelMod, uint64_t, int)
MS_REG_GPU_KERNEL_TWO(
ArgMinWithValue,
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),

View File

@ -20,6 +20,7 @@
#include <vector>
#include <string>
#include <map>
#include <functional>
#include "mindspore/core/ops/argmax_with_value.h"
#include "mindspore/core/ops/argmin_with_value.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
@ -51,6 +52,10 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
@ -102,7 +107,7 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod {
axis_ = kernel_ptr->axis();
}
small_ = (kernel_name_ == "ArgMinWithValue") ? true : false;
return InitSize(base_operator, inputs, outputs);
return true;
}
bool InitSize(const BaseOperatorPtr &, const std::vector<KernelTensorPtr> &inputs,
@ -112,17 +117,26 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod {
MS_EXCEPTION_IF_NULL(outputs[0]);
auto output_shape = Convert2SizeTClipNeg(outputs[0]->GetShapeVector());
int64_t dims = SizeToLong(shape.size());
is_zero_dim_ = (dims == 0);
// If the rank is uncertain, do not update the axis.
ShapeVector dynamic_rank_shape = {-2};
if (inputs[0]->GetShapeVector() != dynamic_rank_shape) {
if (is_zero_dim_) {
if (axis_ != -1 && axis_ != 0) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-1, "
<< "0], but got " << axis_;
}
} else {
if (axis_ < -dims || axis_ >= dims) {
MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << ","
<< dims << "), but got " << axis_;
}
if (axis_ < 0) {
axis_ += dims;
}
}
if (axis_ < 0) {
axis_ += dims;
}
size_t input_element_num = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies<size_t>());
if (input_element_num == 0) {
return KRET_OK;
}
input_size_ = sizeof(T);
@ -133,11 +147,8 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod {
for (auto x : output_shape) {
output_size_ *= x;
}
bound_ = static_cast<S>(shape[axis_]);
if (static_cast<S>(shape[axis_]) != bound_) {
MS_EXCEPTION(ArgumentError) << "For '" << kernel_name_ << "', the value of shape[axis] must be "
<< static_cast<size_t>(bound_) << ", but got " << shape[axis_];
}
bound_ = is_zero_dim_ ? 1 : static_cast<S>(shape[axis_]);
outer_size_ = 1;
for (int64_t i = axis_ - 1; i >= 0; i--) {
outer_size_ *= shape[i];
@ -170,6 +181,7 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod {
private:
bool small_ = false;
bool is_zero_dim_{false};
int64_t axis_;
size_t input_size_;
size_t output_size_;

View File

@ -44,11 +44,7 @@ inline __device__ void ConditionAssign(bool is_assign, T *x, const T &y) {
template <typename T, typename S>
__global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
T *output, S *output_index, bool fp16_flag, T init_K) {
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}
T *output, S *output_index, T init_K) {
const S init_V = static_cast<S>(-1);
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < outer_size * inner_size;
@ -75,10 +71,7 @@ __global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, siz
template <typename T, typename S>
__global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output,
S *output_index, bool fp16_flag, T init_K) {
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}
S *output_index, T init_K) {
const S init_V = static_cast<S>(-1);
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kWarpSize * outer_size * inner_size;
@ -123,12 +116,9 @@ __global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_
template <typename T, typename S>
__global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
T *output, S *output_index, bool fp16_flag, T init_K) {
T *output, S *output_index, T init_K) {
__shared__ T shared_K[kNumWarps];
__shared__ S shared_V[kNumWarps];
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}
const S init_V = static_cast<S>(-1);
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kGroupSize * outer_size * inner_size;
@ -212,12 +202,9 @@ __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size
template <typename T, typename S>
__global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input,
T *output, S *output_index, bool fp16_flag, T init_K) {
T *output, S *output_index, T init_K) {
__shared__ T shared_K[kNumWarps];
__shared__ S shared_V[kNumWarps];
if (fp16_flag) {
init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504);
}
const S init_V = static_cast<S>(-1);
for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kBlockSize * outer_size * inner_size;
@ -295,37 +282,52 @@ __global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size
}
template <typename T, typename S>
void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output,
S *output_index, cudaStream_t stream) {
void GeneralReductionImpl(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output,
S *output_index, T init_K, cudaStream_t stream) {
int block_num_limit = outer_size * inner_size;
bool fp16_flag = false;
if (std::is_same<T, half>::value) {
fp16_flag = true;
}
T init_K = small ? std::numeric_limits<T>::max() : std::numeric_limits<T>::lowest();
if (bound <= kMaxThreadLoop) {
ThreadReduction<T, S><<<GET_BLOCKS(block_num_limit), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
ThreadReduction<T, S><<<GET_BLOCKS(block_num_limit * kBlockSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, init_K);
} else if (bound <= kMaxWarpLoop) {
WarpReduction<T, S><<<GET_BLOCKS(block_num_limit * kWarpSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
WarpReduction<T, S><<<GET_BLOCKS(block_num_limit * kBlockSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, init_K);
} else if (bound <= kMaxGroupLoop) {
Warp4Reduction<T, S><<<GET_BLOCKS(block_num_limit * kGroupSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
Warp4Reduction<T, S><<<GET_BLOCKS(block_num_limit * kBlockSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, init_K);
} else {
BlockReduction<T, S><<<GET_BLOCKS(block_num_limit * kBlockSize), kBlockSize, 0, stream>>>(
small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K);
small, outer_size, bound, inner_size, input, output, output_index, init_K);
}
}
template <typename T, typename S>
void CalGeneralReduction(bool small, const T *input, const size_t bound, const size_t outerSize, const size_t innerSize,
S *index, T *output, cudaStream_t cuda_stream) {
GeneralReduction(small, outerSize, bound, innerSize, input, output, index, cuda_stream);
void CalGeneralReduction(bool small, const T *input, const size_t bound, const size_t outerSize,
const size_t innerSize, S *output_index, T *output, cudaStream_t stream) {
T init_K = small ? std::numeric_limits<T>::max() : std::numeric_limits<T>::lowest();
GeneralReductionImpl(small, outerSize, bound, innerSize, input, output, output_index, init_K, stream);
return;
}
template <typename S>
void CalGeneralReduction(bool small, const half *input, const size_t bound, const size_t outerSize,
const size_t innerSize, S *output_index, half *output, cudaStream_t stream) {
half init_K = small ? static_cast<half>(65504) : static_cast<half>(-65504);
GeneralReductionImpl(small, outerSize, bound, innerSize, input, output, output_index, init_K, stream);
return;
}
template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const int8_t *input, const size_t bound_,
const size_t outerSize_, const size_t innerSize_, int *index,
int8_t *output, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const int64_t *input, const size_t bound_,
const size_t outerSize_, const size_t innerSize_, int *index,
int64_t *output, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const uint8_t *input, const size_t bound_,
const size_t outerSize_, const size_t innerSize_, int *index,
uint8_t *output, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const uint64_t *input, const size_t bound_,
const size_t outerSize_, const size_t innerSize_, int *index,
uint64_t *output, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const int16_t *input, const size_t bound_,
const size_t outerSize_, const size_t innerSize_, int *index,
int16_t *output, cudaStream_t cuda_stream);

View File

@ -20,4 +20,7 @@
template <typename T, typename S>
CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const T *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, S *index, T *output, cudaStream_t cuda_stream);
template <typename S>
CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const half *input, const size_t bound_, const size_t outerSize_,
const size_t innerSize_, S *index, half *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GENERAL_REDUCTION_IMPL_CUH_

View File

@ -37,6 +37,7 @@
#include "mindapi/ir/value.h"
#include "ops/core_ops.h"
#include "ops/op_name.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/convert_utils_base.h"
#include "utils/log_adapter.h"
@ -107,10 +108,8 @@ abstract::TupleShapePtr ArgMaxWithValueInferShape(const PrimitivePtr &primitive,
TuplePtr ArgMaxWithValueInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(input_args[0]);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32,
kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
TypePtr input_x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, common_valid_types, prim->name());
auto index_type = std::make_shared<TensorType>(kInt32);
return std::make_shared<Tuple>(std::vector<TypePtr>{index_type, input_x_type});
}

View File

@ -39,6 +39,7 @@
#include "mindapi/ir/value.h"
#include "ops/core_ops.h"
#include "ops/op_name.h"
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "utils/convert_utils_base.h"
#include "utils/log_adapter.h"
@ -116,16 +117,8 @@ abstract::TupleShapePtr ArgMinWithValueInferShape(const PrimitivePtr &primitive,
TuplePtr ArgMinWithValueInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
MS_EXCEPTION_IF_NULL(input_args[0]);
std::set<TypePtr> valid_types;
TypePtr input_x_type = input_args[0]->BuildType();
auto context = MsContext::GetInstance();
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
if (is_gpu) {
valid_types = {kInt16, kInt32, kUInt16, kUInt32, kFloat16, kFloat32, kFloat64};
} else {
valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
}
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, common_valid_types, prim->name());
auto index_type = std::make_shared<TensorType>(kInt32);
return std::make_shared<Tuple>(std::vector<TypePtr>{index_type, input_x_type});
}

View File

@ -15,27 +15,12 @@
*/
#include "ops/bernoulli.h"
#include <algorithm>
#include <memory>
#include <vector>
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "abstract/abstract_value.h"
#include "abstract/dshape.h"
#include "abstract/ops/op_infer.h"
#include "abstract/utils.h"
#include "base/base.h"
#include "ir/anf.h"
#include "ir/dtype/number.h"
#include "ir/primitive.h"
#include "mindapi/base/shared_ptr.h"
#include "mindapi/ir/value.h"
#include "ops/core_ops.h"
#include "ops/op_name.h"
#include "ops/primitive_c.h"
#include "utils/log_adapter.h"
#include "mindapi/src/helper.h"
namespace mindspore {
@ -43,8 +28,15 @@ namespace ops {
namespace {
abstract::ShapePtr BernoulliInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(out_shape);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto p_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
if (!IsDynamic(x_shape) && !IsDynamic(p_shape)) {
if (SizeOf(p_shape) != 1 && p_shape != x_shape) {
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', "
<< "'x' and 'p' should have same shape or 'p' have a size of 1.";
}
}
return std::make_shared<abstract::Shape>(x_shape);
}
TypePtr BernoulliInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -100,6 +100,8 @@ from .random_choice_with_mask import _random_choice_with_mask_aicpu
from .rsqrt import _rsqrt_aicpu
from .sqrt import _sqrt_aicpu
from .sqrt_grad import _sqrt_grad_aicpu
from .masked_fill import _masked_fill_aicpu
from .sort import _sort_aicpu
from .search_sorted import _search_sorted_aicpu
from .stack import _stack_aicpu
from .unstack import _unstack_aicpu

View File

@ -334,7 +334,6 @@ class Cast(PrimitiveWithCheck):
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
def check_elim(self, x, dtype):
"""Cast Infer Value in Pynative Mode"""
if isinstance(x, (Tensor, numbers.Number, Parameter)):
if isinstance(x, Parameter):
data = x.data
@ -349,7 +348,6 @@ class Cast(PrimitiveWithCheck):
return (False, None)
def infer_value(self, x, dst_type):
"""Cast Infer Value"""
if x is None:
return None
src_type = mstype.get_py_obj_dtype(x)