forked from mindspore-Ecosystem/mindspore
!39504 maxpoolwithargmax gpu & ascend fix
Merge pull request !39504 from panfengfeng/maxpoolwithargmax_fix
This commit is contained in:
commit
49e9e55d26
|
@ -24,6 +24,7 @@
|
|||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -87,6 +88,8 @@ void MaxPoolWithArgmax::Init(const std::vector<int64_t> &kernel_size, const std:
|
|||
namespace {
|
||||
abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kDim0]->BuildShape())[kShape];
|
||||
Format format = Format(CheckAndConvertUtils::GetAndCheckFormat(primitive->GetAttr(kFormat)));
|
||||
|
@ -136,6 +139,18 @@ abstract::TupleShapePtr MaxPoolWithArgmaxInferShape(const PrimitivePtr &primitiv
|
|||
out_w = static_cast<int64_t>(std::ceil(in_w / static_cast<float>(stride_w)));
|
||||
}
|
||||
std::vector<int64_t> out_shape = {batch, channel, out_h, out_w};
|
||||
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||
bool is_gpu = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice);
|
||||
if (is_ascend || is_gpu) {
|
||||
for (size_t i = 0; i < out_shape.size(); i++) {
|
||||
if (out_shape[i] <= 0 && out_shape[i] != -1) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "',"
|
||||
<< " the each element of the output shape must be larger than 0, but got: "
|
||||
<< "output shape: [" << batch << ", " << channel << ", " << out_h << ", " << out_w
|
||||
<< "].";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process attr mapping problems from mindspore to tbe
|
||||
// kernel_size -> ksize
|
||||
|
|
|
@ -1819,7 +1819,7 @@ def grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zero
|
|||
Examples:
|
||||
>>> input_x = Tensor(np.arange(16).reshape((2, 2, 2, 2)).astype(np.float32))
|
||||
>>> grid = Tensor(np.arange(0.2, 1, 0.1).reshape((2, 2, 1, 2)).astype(np.float32))
|
||||
>>> output = grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zeros',
|
||||
>>> output = ops.grid_sample(input_x, grid, interpolation_mode='bilinear', padding_mode='zeros',
|
||||
align_corners=True)
|
||||
>>> print(output)
|
||||
[[[[ 1.9 ]
|
||||
|
@ -1943,7 +1943,8 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
|||
>>> inputs = Tensor(np.array([[[0.6, 0.4, 0.2], [0.8, 0.6, 0.3]],
|
||||
... [[0.0, 0.6, 0.0], [0.5, 0.4, 0.5]]]), mindspore.float32)
|
||||
>>> sequence_length = Tensor(np.array([2, 2]), mindspore.int32)
|
||||
>>> decoded_indices, decoded_values, decoded_shape, log_probability = ctc_greedy_decode(inputs, sequence_length)
|
||||
>>> decoded_indices, decoded_values, decoded_shape, log_probability = ops.ctc_greedy_decoder(inputs,
|
||||
sequence_length)
|
||||
>>> print(decoded_indices)
|
||||
[[0 0]
|
||||
[0 1]
|
||||
|
|
|
@ -1810,7 +1810,7 @@ class MaxPoolV1(Primitive):
|
|||
self.add_prim_attr("strides", strides_adapted)
|
||||
|
||||
|
||||
class MaxPoolWithArgmax(_Pool):
|
||||
class MaxPoolWithArgmax(Primitive):
|
||||
r"""
|
||||
Performs max pooling on the input Tensor and returns both max values and indices.
|
||||
|
||||
|
@ -1878,16 +1878,25 @@ class MaxPoolWithArgmax(_Pool):
|
|||
@prim_attr_register
|
||||
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
|
||||
"""Initialize MaxPoolWithArgmax."""
|
||||
super(MaxPoolWithArgmax, self).__init__(kernel_size, strides, pad_mode, data_format)
|
||||
self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask'])
|
||||
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
|
||||
validator.check_value_type('strides', strides, [int, tuple], self.name)
|
||||
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
|
||||
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
||||
self.add_prim_attr("pad_mode", self.pad_mode)
|
||||
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
||||
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
||||
raise ValueError(f"For '{self.name}', the 'NHWC' format is only supported in GPU target, "
|
||||
f"but got the 'data_format' is {self.format} and "
|
||||
f"the platform is {context.get_context('device_target')}.")
|
||||
self.kernel_size = _check_positive_int_or_tuple(
|
||||
"kernel_size", kernel_size, self.name, allow_four=False, ret_four=True)
|
||||
self.kernel_size = (1, self.kernel_size[-2], self.kernel_size[-1], 1)
|
||||
self.add_prim_attr("kernel_size", self.kernel_size)
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
out_shape = _Pool.infer_shape(self, x_shape)
|
||||
return out_shape, out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid("x", x_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
argmax_dtype = mstype.int32
|
||||
return x_dtype, argmax_dtype
|
||||
self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True)
|
||||
self.strides = (1, self.strides[-2], self.strides[-1], 1)
|
||||
self.add_prim_attr("strides", self.strides)
|
||||
|
||||
|
||||
class MaxPool3D(PrimitiveWithInfer):
|
||||
|
|
Loading…
Reference in New Issue