fixed some codecheck

This commit is contained in:
huoxinyou 2022-07-28 10:02:43 +08:00
parent d24f4e65e9
commit 059a48ffdd
9 changed files with 102 additions and 122 deletions

View File

@ -32,7 +32,9 @@ namespace {
int64_t BNTrainingReduceGetAndCheckFormat(const PrimitivePtr &primitive, const ValuePtr &value) {
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
if (!result ||
(data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) &&
data_format != static_cast<int64_t>(Format::NCDHW))) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', data format must be NCHW, NHWC or NCDHW, but got "
<< data_format << ".";
}
@ -51,7 +53,7 @@ abstract::TupleShapePtr BNTrainingReduceInferShape(const PrimitivePtr &primitive
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format = BNTrainingReduceGetAndCheckFormat(primitive, data_format_ptr);
size_t c_axis = kInputIndex1;
if (data_format == Format::NHWC) {
if (data_format == static_cast<int64_t>(Format::NHWC)) {
c_axis = kInputIndex3;
}
ShapeVector batch = {shape[c_axis]};

View File

@ -33,7 +33,9 @@ constexpr auto kBNTrainingUpdateInputNum = 7;
int64_t BNTrainingUpdateGetAndCheckFormat(const PrimitivePtr &primitive, const ValuePtr &value) {
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
if (!result ||
(data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) &&
data_format != static_cast<int64_t>(Format::NCDHW))) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', data format must be NCHW, NHWC and NCDHW, but got "
<< data_format << ".";
}
@ -58,7 +60,7 @@ abstract::TupleShapePtr BNTrainingUpdateInferShape(const PrimitivePtr &primitive
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format = BNTrainingUpdateGetAndCheckFormat(primitive, data_format_ptr);
size_t c_axis = kInputIndex1;
if (data_format == Format::NHWC) {
if (data_format == static_cast<int64_t>(Format::NHWC)) {
c_axis = kInputIndex3;
}
// input_x rank must be equal to 4
@ -107,19 +109,19 @@ TuplePtr BNTrainingUpdateInferType(const PrimitivePtr &primitive, const std::vec
auto variance_type = input_args[kInputIndex6]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
// input_x type must be valid
CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type, valid_types, prim_name);
// sum type must be valid
CheckAndConvertUtils::CheckTensorTypeValid("sum type", sum_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("sum type", sum_type, valid_types, prim_name);
// square_sum type must be valid
CheckAndConvertUtils::CheckTensorTypeValid("square_sum type", square_sum_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("square_sum type", square_sum_type, valid_types, prim_name);
// scale type must be valid
CheckAndConvertUtils::CheckTensorTypeValid("scale_type", scale_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("scale_type", scale_type, valid_types, prim_name);
// offset type must be valid
CheckAndConvertUtils::CheckTensorTypeValid("offset_type", offset_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("offset_type", offset_type, valid_types, prim_name);
// mean type must be valid
CheckAndConvertUtils::CheckTensorTypeValid("mean_type", mean_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("mean_type", mean_type, valid_types, prim_name);
// variance type must be valid
CheckAndConvertUtils::CheckTensorTypeValid("variance_type", variance_type, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckTensorTypeValid("variance_type", variance_type, valid_types, prim_name);
return std::make_shared<Tuple>(
std::vector<TypePtr>{input_x_type, variance_type, variance_type, variance_type, variance_type});
}

View File

@ -16,14 +16,13 @@
#include "ops/grad/bn_training_update_grad.h"
#include <algorithm>
#include <set>
#include "ops/op_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
#include "utils/check_convert_utils.h"
#include "mindapi/src/helper.h"
#include "abstract/ops/primitive_infer_map.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {

View File

@ -324,25 +324,23 @@ constexpr auto kSearchStep = "search_step";
constexpr auto kWithOffset = "with_offset";
constexpr auto kLinearSumAssignment = "linear_sum_assignment";
enum Index : size_t {
kInputIndex0 = 0,
kInputIndex1,
kInputIndex2,
kInputIndex3,
kInputIndex4,
kInputIndex5,
kInputIndex6,
kInputIndex7,
kInputIndex8,
kInputIndex9,
kInputIndex10,
kInputIndex11,
kInputIndex12,
kInputIndex13,
kInputIndex14,
kInputIndex15,
kInputIndex16,
};
constexpr size_t kInputIndex0 = 0;
constexpr size_t kInputIndex1 = 1;
constexpr size_t kInputIndex2 = 2;
constexpr size_t kInputIndex3 = 3;
constexpr size_t kInputIndex4 = 4;
constexpr size_t kInputIndex5 = 5;
constexpr size_t kInputIndex6 = 6;
constexpr size_t kInputIndex7 = 7;
constexpr size_t kInputIndex8 = 8;
constexpr size_t kInputIndex9 = 9;
constexpr size_t kInputIndex10 = 10;
constexpr size_t kInputIndex11 = 11;
constexpr size_t kInputIndex12 = 12;
constexpr size_t kInputIndex13 = 13;
constexpr size_t kInputIndex14 = 14;
constexpr size_t kInputIndex15 = 15;
constexpr size_t kInputIndex16 = 16;
enum Dims : size_t { kDim0 = 0, kDim1, kDim2, kDim3, kDim4, kDim5, kDim6, kDim7, kDim8 };
} // namespace mindspore::ops
#endif // MINDSPORE_CORE_OPS_OP_NAME_H

View File

@ -14,6 +14,6 @@
# ============================================================================
"""ops utils."""
from .utils import get_broadcast_shape, get_concat_offset, is_shape_unknown, is_shape_known
from .utils import get_broadcast_shape, get_concat_offset, is_shape_unknown
__all__ = ['get_broadcast_shape', 'get_concat_offset']

View File

@ -146,14 +146,6 @@ def is_shape_unknown(shape):
return False
@constexpr
def is_shape_known(shape):
for i in shape:
if i < 0:
return False
return True
@constexpr
def is_dim_unknown(shape):
for i in shape:

View File

@ -351,6 +351,7 @@ def _tensor_index_by_bool(data, bool_value):
def check_range(x, dim_size):
"""Check whether x is within the range of dim_size"""
tensor_x = const_utils.scalar_to_tensor(x)
if tensor_x >= dim_size or tensor_x < -dim_size:
return tensor_x
@ -359,6 +360,7 @@ def check_range(x, dim_size):
def get_stride_info_from_integer(tensor_int):
"""Convert integer to slice"""
begin_strides = [tensor_int]
end_strides = [tensor_int + const_utils.scalar_to_tensor(1)]
step_strides = [const_utils.scalar_to_tensor(1)]

View File

@ -26,7 +26,7 @@ from ... import context
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
from ...communication.management import GlobalComm
from .._utils import is_shape_known
from .._utils import is_shape_unknown
class AbsGrad(PrimitiveWithInfer):
@ -2123,7 +2123,7 @@ class SliceGrad(PrimitiveWithInfer):
def __infer__(self, dy, x, begin, size):
dy_shape, x_shape, size_value, begin_v = dy['shape'], x['shape'], size['value'], begin['value']
dy_shape_len = len(dy_shape)
if (size_value is not None) and is_shape_known(x_shape):
if (size_value is not None) and not is_shape_unknown(x_shape):
size_value = list(size_value)
for i in range(dy_shape_len):
if size_value[i] == -1:

View File

@ -26,7 +26,7 @@ from mindspore import log as logger
from mindspore import context
from mindspore.common.initializer import Zero
from .. import signature as sig
from .._utils import get_broadcast_shape, is_shape_unknown, is_shape_known
from .._utils import get_broadcast_shape, is_shape_unknown
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
from ..._checkparam import Rel
from ..._checkparam import Validator as validator
@ -1570,7 +1570,7 @@ class Fill(PrimitiveWithInfer):
mstype.complex128]
validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
x_nptype = mstype.dtype_to_nptype(dtype['value'])
if is_shape_known(dims['value']):
if not is_shape_unknown(dims['value']):
for i, item in enumerate(dims['value']):
validator.check_positive_int(item, f'dims[{i}]', self.name)
ret = np.full(dims['value'], x['value'], x_nptype)
@ -2298,7 +2298,6 @@ class Tile(PrimitiveWithInfer):
for a, b in zip(multiples_v_min, multiples_v_max):
if isinstance(a, (Tensor_, Tensor)):
a = a.asnumpy()
if isinstance(b, (Tensor_, Tensor)):
b = b.asnumpy()
if x_shp[i] >= 0:
x_shp[i] *= a
@ -2411,7 +2410,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
validator.check(f'rank of input_x', len(x_shp),
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
if is_shape_known(x_shp) and is_shape_known(segment_ids_shp):
if not is_shape_unknown(x_shp) and not is_shape_unknown(segment_ids_shp):
# only validate when both shapes fully known
for i, value in enumerate(segment_ids_shp):
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
@ -2495,7 +2494,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if is_shape_known(x_shape) and is_shape_known(segment_ids_shape):
if not is_shape_unknown(x_shape) and not is_shape_unknown(segment_ids_shape):
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -2603,7 +2602,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
num_segments_type = num_segments['dtype']
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
if is_shape_known(x_shape) and is_shape_known(segment_ids_shape):
if not is_shape_unknown(x_shape) and not is_shape_unknown(segment_ids_shape):
# only validate when both shapes fully known
validator.check(f'first shape of input_x', x_shape[0],
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
@ -3457,63 +3456,6 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_non_negative_int(new_axis_mask, 'new_axis_mask', self.name)
validator.check_non_negative_int(shrink_axis_mask, 'shrink_axis_mask', self.name)
def _check_and_get_value(self, slice_input, name):
"""Check begin, end, strides. Get its length and value."""
slice_value = slice_input['value']
has_special_value = False
if "min_value" in slice_input and "max_value" in slice_input:
slice_min = slice_input["min_value"]
slice_max = slice_input["max_value"]
has_special_value = True
else:
slice_min = slice_value
slice_max = slice_value
if slice_value is None:
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
slice_shape = slice_input['shape']
if len(slice_shape) != 1:
raise ValueError(f"For '{self.name}', both the 'begins', 'ends', and 'strides' must be 1-D, "
f"but got '{name}' shape: {slice_shape}.")
# not support scalar
slices = {
'value': slice_value,
'min_value': slice_min,
'max_value': slice_max
}
return slices, slice_shape[0], has_special_value
if isinstance(slice_value, Tensor_):
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
slice_value = slice_value.asnumpy().tolist()
elif not isinstance(slice_value, tuple):
raise TypeError(f"For '{self.name}', both the 'begin', 'end', and 'strides' must be a tuple or Tensor, "
f"but got '{name}': {slice_value}.")
if tuple(filter(lambda x: not isinstance(x, int), slice_value)):
raise TypeError(f"For '{self.name}', the elements of 'begin', 'end', and 'strides' must be int, "
f"but got {name}: {slice_value}.")
if name == 'strides':
if slice_value is not None and tuple(filter(lambda x: x == 0, slice_value)):
raise ValueError(f"For '{self.name}', 'strides' cannot contain 0, but got 'strides': {slice_value}.")
slices = {
'value': slice_value,
'min_value': slice_min,
'max_value': slice_max
}
return slices, len(slice_value), has_special_value
def _check_and_get_shape(self, x):
"""Check the shape of x. Get its shape and min/max_shape."""
x_shape = x['shape']
min_shape = None
max_shape = None
if "min_shape" in x and "max_shape" in x:
min_shape = x["min_shape"]
max_shape = x["max_shape"]
return x_shape, min_shape, max_shape
def __infer__(self, x, begin, end, strides):
x_shape, min_shape, max_shape = self._check_and_get_shape(x)
begin_v, begin_len, begin_specical_value = self._check_and_get_value(begin, 'begin')
@ -3528,7 +3470,7 @@ class StridedSlice(PrimitiveWithInfer):
if begin_specical_value or end_specical_value:
bd_has_min_max_value = True
if bd_has_min_max_value and is_shape_known(x_shape):
if bd_has_min_max_value and not is_shape_unknown(x_shape):
ret_shape = [-1] * len(x_shape)
ret_min_shape = list(x_shape)
ret_max_shape = list(x_shape)
@ -3557,9 +3499,6 @@ class StridedSlice(PrimitiveWithInfer):
rets['min_shape'] = ret_min_shape
rets['max_shape'] = ret_max_shape
if is_shape_known(x_shape):
return self._compute_max_min_shape(rets, x_shape, ret_shape)
return rets
ret_shape = self._compute_slicing_shape(x_shape, begin_v['value'], end_v['value'], strides_v['value'])
@ -3591,17 +3530,16 @@ class StridedSlice(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': value}
def _compute_max_min_shape(self, rets, x_shape, ret_shape):
"""compute max/min shape"""
ret_min_shape = [1] * len(x_shape)
ret_max_shape = x_shape
for i, val in enumerate(ret_shape):
if val > 0:
ret_min_shape[i] = val
ret_max_shape[i] = val
rets['max_shape'] = ret_max_shape
rets['min_shape'] = ret_min_shape
return rets
@staticmethod
def _check_and_get_shape(x):
"""Check the shape of x. Get its shape and min/max_shape."""
x_shape = x['shape']
min_shape = None
max_shape = None
if "min_shape" in x and "max_shape" in x:
min_shape = x["min_shape"]
max_shape = x["max_shape"]
return x_shape, min_shape, max_shape
def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
"""Computes the shape of the slicing."""
@ -3738,6 +3676,53 @@ class StridedSlice(PrimitiveWithInfer):
j += 1
return ret_shape, ret_min_shape, ret_max_shape
def _check_and_get_value(self, slice_input, name):
"""Check begin, end, strides. Get its length and value."""
slice_value = slice_input['value']
has_special_value = False
if "min_value" in slice_input and "max_value" in slice_input:
slice_min = slice_input["min_value"]
slice_max = slice_input["max_value"]
has_special_value = True
else:
slice_min = slice_value
slice_max = slice_value
if slice_value is None:
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
slice_shape = slice_input['shape']
if len(slice_shape) != 1:
raise ValueError(f"For '{self.name}', both the 'begins', 'ends', and 'strides' must be 1-D, "
f"but got '{name}' shape: {slice_shape}.")
# not support scalar
slices = {
'value': slice_value,
'min_value': slice_min,
'max_value': slice_max
}
return slices, slice_shape[0], has_special_value
if isinstance(slice_value, Tensor_):
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
slice_value = slice_value.asnumpy().tolist()
elif not isinstance(slice_value, tuple):
raise TypeError(f"For '{self.name}', both the 'begin', 'end', and 'strides' must be a tuple or Tensor, "
f"but got '{name}': {slice_value}.")
if tuple(filter(lambda x: not isinstance(x, int), slice_value)):
raise TypeError(f"For '{self.name}', the elements of 'begin', 'end', and 'strides' must be int, "
f"but got {name}: {slice_value}.")
if name == 'strides':
if slice_value is not None and tuple(filter(lambda x: x == 0, slice_value)):
raise ValueError(f"For '{self.name}', 'strides' cannot contain 0, but got 'strides': {slice_value}.")
slices = {
'value': slice_value,
'min_value': slice_min,
'max_value': slice_max
}
return slices, len(slice_value), has_special_value
class Diag(PrimitiveWithCheck):
r"""