From 99d6308340228375c021fc85ee64e0250927b0f1 Mon Sep 17 00:00:00 2001 From: panzhihui Date: Fri, 3 Mar 2023 15:40:16 +0800 Subject: [PATCH] Bug fix --- .jenkins/check/config/filter_pylint.txt | 1 + .../device/ascend/kernel/aicpu/aicpu_util.h | 4 ++ .../argmaxandminwithvalue_gpu_kernel.cc | 32 +++++++++ .../arrays/argmaxandminwithvalue_gpu_kernel.h | 36 ++++++---- .../cuda_ops/general_reduction_impl.cu | 72 ++++++++++--------- .../cuda_ops/general_reduction_impl.cuh | 3 + mindspore/core/ops/argmax_with_value.cc | 5 +- mindspore/core/ops/argmin_with_value.cc | 11 +-- mindspore/core/ops/bernoulli.cc | 30 +++----- .../mindspore/ops/_op_impl/aicpu/__init__.py | 2 + .../mindspore/ops/operations/array_ops.py | 2 - 11 files changed, 118 insertions(+), 80 deletions(-) diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index 1376e41fd7d..c31d2a3b874 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -22,6 +22,7 @@ "mindspore/mindspore/python/mindspore/ops/operations" "super-init-not-called" "mindspore/mindspore/python/mindspore/ops/operations/_quant_ops.py" "unused-import" "mindspore/mindspore/python/mindspore/ops/operations/nn_ops.py" "redefined-builtin" +"mindspore/mindspore/python/mindspore/ops/operations/array_ops.py" "missing-docstring" "mindspore/mindspore/python/mindspore/ops/operations/_inner_ops.py" "dangerous-default-value" "mindspore/mindspore/python/mindspore/ops/operations/_thor_ops.py" "dangerous-default-value" "mindspore/mindspore/python/mindspore/ops/operations/_thor_ops.py" "redefined-outer-name" diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h index 3b4a26b3cb9..ed793876127 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h @@ -87,6 +87,7 @@ constexpr auto kNonMaxSuppressionV3 = "NonMaxSuppressionV3"; constexpr auto kMaskedSelect = "MaskedSelect"; constexpr auto kMaskedSelectGrad = "MaskedSelectGrad"; constexpr auto kDynamicStitch = "DynamicStitch"; +constexpr auto kSort = "Sort"; constexpr auto kSearchSorted = "SearchSorted"; constexpr auto kLinSpace = "LinSpace"; constexpr auto kResizeBilinear = "ResizeBilinear"; @@ -208,12 +209,15 @@ constexpr auto kQuantDTypeCast = "QuantDTypeCast"; constexpr auto kFSEDecode = "FSEDecode"; constexpr auto kSparseSegmentSum = "SparseSegmentSum"; constexpr auto kRealDiv = "RealDiv"; +constexpr auto kMaskedFill = "MaskedFill"; constexpr auto kDeformableOffsets = "DeformableOffsets"; constexpr auto kDeformableOffsetsGrad = "DeformableOffsetsGrad"; const std::set kCpuKernelOps{kIdentity, + kMaskedFill, kGather, kDynamicStitch, + kSort, kSearchSorted, kSparseSegmentSum, kResizeBilinear, diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.cc index 35c2ff99559..bcb2b32a3c0 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.cc @@ -18,6 +18,22 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + ArgMaxAndMinWithValueGpuKernelMod, int8_t, int) +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + ArgMaxAndMinWithValueGpuKernelMod, int64_t, int) +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + ArgMaxAndMinWithValueGpuKernelMod, uint8_t, int) +MS_REG_GPU_KERNEL_TWO( + ArgMaxWithValue, + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + ArgMaxAndMinWithValueGpuKernelMod, uint64_t, int) MS_REG_GPU_KERNEL_TWO( ArgMaxWithValue, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), @@ -47,6 +63,22 @@ MS_REG_GPU_KERNEL_TWO( KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), ArgMaxAndMinWithValueGpuKernelMod, half, int) +MS_REG_GPU_KERNEL_TWO( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + ArgMaxAndMinWithValueGpuKernelMod, int8_t, int) +MS_REG_GPU_KERNEL_TWO( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + ArgMaxAndMinWithValueGpuKernelMod, int64_t, int) +MS_REG_GPU_KERNEL_TWO( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + ArgMaxAndMinWithValueGpuKernelMod, uint8_t, int) +MS_REG_GPU_KERNEL_TWO( + ArgMinWithValue, + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + ArgMaxAndMinWithValueGpuKernelMod, uint64_t, int) MS_REG_GPU_KERNEL_TWO( ArgMinWithValue, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.h index 3eb73afabca..3b2bf348881 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/argmaxandminwithvalue_gpu_kernel.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "mindspore/core/ops/argmax_with_value.h" #include "mindspore/core/ops/argmin_with_value.h" #include "plugin/device/gpu/kernel/gpu_kernel.h" @@ -51,6 +52,10 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod { std::vector GetOpSupport() override { static std::vector support_list = { + KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), @@ -102,7 +107,7 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod { axis_ = kernel_ptr->axis(); } small_ = (kernel_name_ == "ArgMinWithValue") ? true : false; - return InitSize(base_operator, inputs, outputs); + return true; } bool InitSize(const BaseOperatorPtr &, const std::vector &inputs, @@ -112,17 +117,26 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod { MS_EXCEPTION_IF_NULL(outputs[0]); auto output_shape = Convert2SizeTClipNeg(outputs[0]->GetShapeVector()); int64_t dims = SizeToLong(shape.size()); + is_zero_dim_ = (dims == 0); - // If the rank is uncertain, do not update the axis. - ShapeVector dynamic_rank_shape = {-2}; - if (inputs[0]->GetShapeVector() != dynamic_rank_shape) { + if (is_zero_dim_) { + if (axis_ != -1 && axis_ != 0) { + MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-1, " + << "0], but got " << axis_; + } + } else { if (axis_ < -dims || axis_ >= dims) { MS_EXCEPTION(ValueError) << "For '" << kernel_name_ << "', the 'axis' must be in the range [-" << dims << "," << dims << "), but got " << axis_; } - if (axis_ < 0) { - axis_ += dims; - } + } + + if (axis_ < 0) { + axis_ += dims; + } + size_t input_element_num = std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); + if (input_element_num == 0) { + return KRET_OK; } input_size_ = sizeof(T); @@ -133,11 +147,8 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod { for (auto x : output_shape) { output_size_ *= x; } - bound_ = static_cast(shape[axis_]); - if (static_cast(shape[axis_]) != bound_) { - MS_EXCEPTION(ArgumentError) << "For '" << kernel_name_ << "', the value of shape[axis] must be " - << static_cast(bound_) << ", but got " << shape[axis_]; - } + + bound_ = is_zero_dim_ ? 1 : static_cast(shape[axis_]); outer_size_ = 1; for (int64_t i = axis_ - 1; i >= 0; i--) { outer_size_ *= shape[i]; @@ -170,6 +181,7 @@ class ArgMaxAndMinWithValueGpuKernelMod : public NativeGpuKernelMod { private: bool small_ = false; + bool is_zero_dim_{false}; int64_t axis_; size_t input_size_; size_t output_size_; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cu index 1a86e44f991..a2709c089bd 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cu +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cu @@ -44,11 +44,7 @@ inline __device__ void ConditionAssign(bool is_assign, T *x, const T &y) { template __global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, - T *output, S *output_index, bool fp16_flag, T init_K) { - if (fp16_flag) { - init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); - } - + T *output, S *output_index, T init_K) { const S init_V = static_cast(-1); for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < outer_size * inner_size; @@ -75,10 +71,7 @@ __global__ void ThreadReduction(bool small, size_t outer_size, size_t bound, siz template __global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output, - S *output_index, bool fp16_flag, T init_K) { - if (fp16_flag) { - init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); - } + S *output_index, T init_K) { const S init_V = static_cast(-1); for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kWarpSize * outer_size * inner_size; @@ -123,12 +116,9 @@ __global__ void WarpReduction(bool small, size_t outer_size, size_t bound, size_ template __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, - T *output, S *output_index, bool fp16_flag, T init_K) { + T *output, S *output_index, T init_K) { __shared__ T shared_K[kNumWarps]; __shared__ S shared_V[kNumWarps]; - if (fp16_flag) { - init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); - } const S init_V = static_cast(-1); for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kGroupSize * outer_size * inner_size; @@ -212,12 +202,9 @@ __global__ void Warp4Reduction(bool small, size_t outer_size, size_t bound, size template __global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, - T *output, S *output_index, bool fp16_flag, T init_K) { + T *output, S *output_index, T init_K) { __shared__ T shared_K[kNumWarps]; __shared__ S shared_V[kNumWarps]; - if (fp16_flag) { - init_K = small ? __int2half_rd(65504) : __int2half_rd(-65504); - } const S init_V = static_cast(-1); for (int t_idx = blockIdx.x * blockDim.x + threadIdx.x; t_idx < kBlockSize * outer_size * inner_size; @@ -295,37 +282,52 @@ __global__ void BlockReduction(bool small, size_t outer_size, size_t bound, size } template -void GeneralReduction(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output, - S *output_index, cudaStream_t stream) { +void GeneralReductionImpl(bool small, size_t outer_size, size_t bound, size_t inner_size, const T *input, T *output, + S *output_index, T init_K, cudaStream_t stream) { int block_num_limit = outer_size * inner_size; - bool fp16_flag = false; - if (std::is_same::value) { - fp16_flag = true; - } - T init_K = small ? std::numeric_limits::max() : std::numeric_limits::lowest(); - if (bound <= kMaxThreadLoop) { - ThreadReduction<<>>( - small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + ThreadReduction<<>>( + small, outer_size, bound, inner_size, input, output, output_index, init_K); } else if (bound <= kMaxWarpLoop) { - WarpReduction<<>>( - small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + WarpReduction<<>>( + small, outer_size, bound, inner_size, input, output, output_index, init_K); } else if (bound <= kMaxGroupLoop) { - Warp4Reduction<<>>( - small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + Warp4Reduction<<>>( + small, outer_size, bound, inner_size, input, output, output_index, init_K); } else { BlockReduction<<>>( - small, outer_size, bound, inner_size, input, output, output_index, fp16_flag, init_K); + small, outer_size, bound, inner_size, input, output, output_index, init_K); } } template -void CalGeneralReduction(bool small, const T *input, const size_t bound, const size_t outerSize, const size_t innerSize, - S *index, T *output, cudaStream_t cuda_stream) { - GeneralReduction(small, outerSize, bound, innerSize, input, output, index, cuda_stream); +void CalGeneralReduction(bool small, const T *input, const size_t bound, const size_t outerSize, + const size_t innerSize, S *output_index, T *output, cudaStream_t stream) { + T init_K = small ? std::numeric_limits::max() : std::numeric_limits::lowest(); + GeneralReductionImpl(small, outerSize, bound, innerSize, input, output, output_index, init_K, stream); return; } +template +void CalGeneralReduction(bool small, const half *input, const size_t bound, const size_t outerSize, + const size_t innerSize, S *output_index, half *output, cudaStream_t stream) { + half init_K = small ? static_cast(65504) : static_cast(-65504); + GeneralReductionImpl(small, outerSize, bound, innerSize, input, output, output_index, init_K, stream); + return; +} + +template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const int8_t *input, const size_t bound_, + const size_t outerSize_, const size_t innerSize_, int *index, + int8_t *output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const int64_t *input, const size_t bound_, + const size_t outerSize_, const size_t innerSize_, int *index, + int64_t *output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const uint8_t *input, const size_t bound_, + const size_t outerSize_, const size_t innerSize_, int *index, + uint8_t *output, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const uint64_t *input, const size_t bound_, + const size_t outerSize_, const size_t innerSize_, int *index, + uint64_t *output, cudaStream_t cuda_stream); template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const int16_t *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, int *index, int16_t *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cuh index fc6b7237eaf..cd96043d6cc 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cuh +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/general_reduction_impl.cuh @@ -20,4 +20,7 @@ template CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const T *input, const size_t bound_, const size_t outerSize_, const size_t innerSize_, S *index, T *output, cudaStream_t cuda_stream); +template +CUDA_LIB_EXPORT void CalGeneralReduction(bool small, const half *input, const size_t bound_, const size_t outerSize_, + const size_t innerSize_, S *index, half *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_GENERAL_REDUCTION_IMPL_CUH_ diff --git a/mindspore/core/ops/argmax_with_value.cc b/mindspore/core/ops/argmax_with_value.cc index b2130e0c749..0104a97c995 100644 --- a/mindspore/core/ops/argmax_with_value.cc +++ b/mindspore/core/ops/argmax_with_value.cc @@ -37,6 +37,7 @@ #include "mindapi/ir/value.h" #include "ops/core_ops.h" #include "ops/op_name.h" +#include "ops/op_utils.h" #include "ops/primitive_c.h" #include "utils/convert_utils_base.h" #include "utils/log_adapter.h" @@ -107,10 +108,8 @@ abstract::TupleShapePtr ArgMaxWithValueInferShape(const PrimitivePtr &primitive, TuplePtr ArgMaxWithValueInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(input_args[0]); - const std::set valid_types = {kFloat16, kFloat32, kFloat64, kInt8, kInt16, kInt32, - kInt64, kUInt8, kUInt16, kUInt32, kUInt64}; TypePtr input_x_type = input_args[0]->BuildType(); - (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, common_valid_types, prim->name()); auto index_type = std::make_shared(kInt32); return std::make_shared(std::vector{index_type, input_x_type}); } diff --git a/mindspore/core/ops/argmin_with_value.cc b/mindspore/core/ops/argmin_with_value.cc index 42699ba15c0..250a17954a9 100644 --- a/mindspore/core/ops/argmin_with_value.cc +++ b/mindspore/core/ops/argmin_with_value.cc @@ -39,6 +39,7 @@ #include "mindapi/ir/value.h" #include "ops/core_ops.h" #include "ops/op_name.h" +#include "ops/op_utils.h" #include "ops/primitive_c.h" #include "utils/convert_utils_base.h" #include "utils/log_adapter.h" @@ -116,16 +117,8 @@ abstract::TupleShapePtr ArgMinWithValueInferShape(const PrimitivePtr &primitive, TuplePtr ArgMinWithValueInferType(const PrimitivePtr &prim, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(input_args[0]); - std::set valid_types; TypePtr input_x_type = input_args[0]->BuildType(); - auto context = MsContext::GetInstance(); - bool is_gpu = (context->get_param(MS_CTX_DEVICE_TARGET) == kGPUDevice); - if (is_gpu) { - valid_types = {kInt16, kInt32, kUInt16, kUInt32, kFloat16, kFloat32, kFloat64}; - } else { - valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64}; - } - (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, valid_types, prim->name()); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_x_type, common_valid_types, prim->name()); auto index_type = std::make_shared(kInt32); return std::make_shared(std::vector{index_type, input_x_type}); } diff --git a/mindspore/core/ops/bernoulli.cc b/mindspore/core/ops/bernoulli.cc index e1cdc08bb92..dad21a6bb1f 100644 --- a/mindspore/core/ops/bernoulli.cc +++ b/mindspore/core/ops/bernoulli.cc @@ -15,27 +15,12 @@ */ #include "ops/bernoulli.h" - +#include #include #include -#include - +#include "ops/op_utils.h" #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" -#include "abstract/abstract_value.h" -#include "abstract/dshape.h" -#include "abstract/ops/op_infer.h" -#include "abstract/utils.h" -#include "base/base.h" -#include "ir/anf.h" -#include "ir/dtype/number.h" -#include "ir/primitive.h" -#include "mindapi/base/shared_ptr.h" -#include "mindapi/ir/value.h" -#include "ops/core_ops.h" -#include "ops/op_name.h" -#include "ops/primitive_c.h" -#include "utils/log_adapter.h" #include "mindapi/src/helper.h" namespace mindspore { @@ -43,8 +28,15 @@ namespace ops { namespace { abstract::ShapePtr BernoulliInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - return std::make_shared(out_shape); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto p_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; + if (!IsDynamic(x_shape) && !IsDynamic(p_shape)) { + if (SizeOf(p_shape) != 1 && p_shape != x_shape) { + MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', " + << "'x' and 'p' should have same shape or 'p' have a size of 1."; + } + } + return std::make_shared(x_shape); } TypePtr BernoulliInferType(const PrimitivePtr &primitive, const std::vector &input_args) { diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py index ebe8df0987f..0d30fe9e0a3 100644 --- a/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py @@ -100,6 +100,8 @@ from .random_choice_with_mask import _random_choice_with_mask_aicpu from .rsqrt import _rsqrt_aicpu from .sqrt import _sqrt_aicpu from .sqrt_grad import _sqrt_grad_aicpu +from .masked_fill import _masked_fill_aicpu +from .sort import _sort_aicpu from .search_sorted import _search_sorted_aicpu from .stack import _stack_aicpu from .unstack import _unstack_aicpu diff --git a/mindspore/python/mindspore/ops/operations/array_ops.py b/mindspore/python/mindspore/ops/operations/array_ops.py index 608d3691a0d..2535058d896 100755 --- a/mindspore/python/mindspore/ops/operations/array_ops.py +++ b/mindspore/python/mindspore/ops/operations/array_ops.py @@ -334,7 +334,6 @@ class Cast(PrimitiveWithCheck): self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output']) def check_elim(self, x, dtype): - """Cast Infer Value in Pynative Mode""" if isinstance(x, (Tensor, numbers.Number, Parameter)): if isinstance(x, Parameter): data = x.data @@ -349,7 +348,6 @@ class Cast(PrimitiveWithCheck): return (False, None) def infer_value(self, x, dst_type): - """Cast Infer Value""" if x is None: return None src_type = mstype.get_py_obj_dtype(x)