fixed some codecheck
This commit is contained in:
parent
d24f4e65e9
commit
059a48ffdd
|
@ -32,7 +32,9 @@ namespace {
|
|||
int64_t BNTrainingReduceGetAndCheckFormat(const PrimitivePtr &primitive, const ValuePtr &value) {
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
|
||||
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
|
||||
if (!result ||
|
||||
(data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) &&
|
||||
data_format != static_cast<int64_t>(Format::NCDHW))) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', data format must be NCHW, NHWC or NCDHW, but got "
|
||||
<< data_format << ".";
|
||||
}
|
||||
|
@ -51,7 +53,7 @@ abstract::TupleShapePtr BNTrainingReduceInferShape(const PrimitivePtr &primitive
|
|||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format = BNTrainingReduceGetAndCheckFormat(primitive, data_format_ptr);
|
||||
size_t c_axis = kInputIndex1;
|
||||
if (data_format == Format::NHWC) {
|
||||
if (data_format == static_cast<int64_t>(Format::NHWC)) {
|
||||
c_axis = kInputIndex3;
|
||||
}
|
||||
ShapeVector batch = {shape[c_axis]};
|
||||
|
|
|
@ -33,7 +33,9 @@ constexpr auto kBNTrainingUpdateInputNum = 7;
|
|||
int64_t BNTrainingUpdateGetAndCheckFormat(const PrimitivePtr &primitive, const ValuePtr &value) {
|
||||
int64_t data_format;
|
||||
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
|
||||
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
|
||||
if (!result ||
|
||||
(data_format != static_cast<int64_t>(Format::NHWC) && data_format != static_cast<int64_t>(Format::NCHW) &&
|
||||
data_format != static_cast<int64_t>(Format::NCDHW))) {
|
||||
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', data format must be NCHW, NHWC and NCDHW, but got "
|
||||
<< data_format << ".";
|
||||
}
|
||||
|
@ -58,7 +60,7 @@ abstract::TupleShapePtr BNTrainingUpdateInferShape(const PrimitivePtr &primitive
|
|||
MS_EXCEPTION_IF_NULL(data_format_ptr);
|
||||
int64_t data_format = BNTrainingUpdateGetAndCheckFormat(primitive, data_format_ptr);
|
||||
size_t c_axis = kInputIndex1;
|
||||
if (data_format == Format::NHWC) {
|
||||
if (data_format == static_cast<int64_t>(Format::NHWC)) {
|
||||
c_axis = kInputIndex3;
|
||||
}
|
||||
// input_x rank must be equal to 4
|
||||
|
@ -107,19 +109,19 @@ TuplePtr BNTrainingUpdateInferType(const PrimitivePtr &primitive, const std::vec
|
|||
auto variance_type = input_args[kInputIndex6]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
// input_x type must be valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x type", input_x_type, valid_types, prim_name);
|
||||
// sum type must be valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("sum type", sum_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("sum type", sum_type, valid_types, prim_name);
|
||||
// square_sum type must be valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("square_sum type", square_sum_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("square_sum type", square_sum_type, valid_types, prim_name);
|
||||
// scale type must be valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("scale_type", scale_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("scale_type", scale_type, valid_types, prim_name);
|
||||
// offset type must be valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("offset_type", offset_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("offset_type", offset_type, valid_types, prim_name);
|
||||
// mean type must be valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("mean_type", mean_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("mean_type", mean_type, valid_types, prim_name);
|
||||
// variance type must be valid
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("variance_type", variance_type, valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("variance_type", variance_type, valid_types, prim_name);
|
||||
return std::make_shared<Tuple>(
|
||||
std::vector<TypePtr>{input_x_type, variance_type, variance_type, variance_type, variance_type});
|
||||
}
|
||||
|
|
|
@ -16,14 +16,13 @@
|
|||
|
||||
#include "ops/grad/bn_training_update_grad.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
|
|
@ -324,25 +324,23 @@ constexpr auto kSearchStep = "search_step";
|
|||
constexpr auto kWithOffset = "with_offset";
|
||||
constexpr auto kLinearSumAssignment = "linear_sum_assignment";
|
||||
|
||||
enum Index : size_t {
|
||||
kInputIndex0 = 0,
|
||||
kInputIndex1,
|
||||
kInputIndex2,
|
||||
kInputIndex3,
|
||||
kInputIndex4,
|
||||
kInputIndex5,
|
||||
kInputIndex6,
|
||||
kInputIndex7,
|
||||
kInputIndex8,
|
||||
kInputIndex9,
|
||||
kInputIndex10,
|
||||
kInputIndex11,
|
||||
kInputIndex12,
|
||||
kInputIndex13,
|
||||
kInputIndex14,
|
||||
kInputIndex15,
|
||||
kInputIndex16,
|
||||
};
|
||||
constexpr size_t kInputIndex0 = 0;
|
||||
constexpr size_t kInputIndex1 = 1;
|
||||
constexpr size_t kInputIndex2 = 2;
|
||||
constexpr size_t kInputIndex3 = 3;
|
||||
constexpr size_t kInputIndex4 = 4;
|
||||
constexpr size_t kInputIndex5 = 5;
|
||||
constexpr size_t kInputIndex6 = 6;
|
||||
constexpr size_t kInputIndex7 = 7;
|
||||
constexpr size_t kInputIndex8 = 8;
|
||||
constexpr size_t kInputIndex9 = 9;
|
||||
constexpr size_t kInputIndex10 = 10;
|
||||
constexpr size_t kInputIndex11 = 11;
|
||||
constexpr size_t kInputIndex12 = 12;
|
||||
constexpr size_t kInputIndex13 = 13;
|
||||
constexpr size_t kInputIndex14 = 14;
|
||||
constexpr size_t kInputIndex15 = 15;
|
||||
constexpr size_t kInputIndex16 = 16;
|
||||
enum Dims : size_t { kDim0 = 0, kDim1, kDim2, kDim3, kDim4, kDim5, kDim6, kDim7, kDim8 };
|
||||
} // namespace mindspore::ops
|
||||
#endif // MINDSPORE_CORE_OPS_OP_NAME_H
|
||||
|
|
|
@ -14,6 +14,6 @@
|
|||
# ============================================================================
|
||||
|
||||
"""ops utils."""
|
||||
from .utils import get_broadcast_shape, get_concat_offset, is_shape_unknown, is_shape_known
|
||||
from .utils import get_broadcast_shape, get_concat_offset, is_shape_unknown
|
||||
|
||||
__all__ = ['get_broadcast_shape', 'get_concat_offset']
|
||||
|
|
|
@ -146,14 +146,6 @@ def is_shape_unknown(shape):
|
|||
return False
|
||||
|
||||
|
||||
@constexpr
|
||||
def is_shape_known(shape):
|
||||
for i in shape:
|
||||
if i < 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def is_dim_unknown(shape):
|
||||
for i in shape:
|
||||
|
|
|
@ -351,6 +351,7 @@ def _tensor_index_by_bool(data, bool_value):
|
|||
|
||||
|
||||
def check_range(x, dim_size):
|
||||
"""Check whether x is within the range of dim_size"""
|
||||
tensor_x = const_utils.scalar_to_tensor(x)
|
||||
if tensor_x >= dim_size or tensor_x < -dim_size:
|
||||
return tensor_x
|
||||
|
@ -359,6 +360,7 @@ def check_range(x, dim_size):
|
|||
|
||||
|
||||
def get_stride_info_from_integer(tensor_int):
|
||||
"""Convert integer to slice"""
|
||||
begin_strides = [tensor_int]
|
||||
end_strides = [tensor_int + const_utils.scalar_to_tensor(1)]
|
||||
step_strides = [const_utils.scalar_to_tensor(1)]
|
||||
|
|
|
@ -26,7 +26,7 @@ from ... import context
|
|||
from ..._checkparam import Validator as validator, Rel
|
||||
from ...common import dtype as mstype
|
||||
from ...communication.management import GlobalComm
|
||||
from .._utils import is_shape_known
|
||||
from .._utils import is_shape_unknown
|
||||
|
||||
|
||||
class AbsGrad(PrimitiveWithInfer):
|
||||
|
@ -2123,7 +2123,7 @@ class SliceGrad(PrimitiveWithInfer):
|
|||
def __infer__(self, dy, x, begin, size):
|
||||
dy_shape, x_shape, size_value, begin_v = dy['shape'], x['shape'], size['value'], begin['value']
|
||||
dy_shape_len = len(dy_shape)
|
||||
if (size_value is not None) and is_shape_known(x_shape):
|
||||
if (size_value is not None) and not is_shape_unknown(x_shape):
|
||||
size_value = list(size_value)
|
||||
for i in range(dy_shape_len):
|
||||
if size_value[i] == -1:
|
||||
|
|
|
@ -26,7 +26,7 @@ from mindspore import log as logger
|
|||
from mindspore import context
|
||||
from mindspore.common.initializer import Zero
|
||||
from .. import signature as sig
|
||||
from .._utils import get_broadcast_shape, is_shape_unknown, is_shape_known
|
||||
from .._utils import get_broadcast_shape, is_shape_unknown
|
||||
from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
||||
from ..._checkparam import Rel
|
||||
from ..._checkparam import Validator as validator
|
||||
|
@ -1570,7 +1570,7 @@ class Fill(PrimitiveWithInfer):
|
|||
mstype.complex128]
|
||||
validator.check_types_same_and_valid({"value": dtype['value']}, valid_dtypes, self.name)
|
||||
x_nptype = mstype.dtype_to_nptype(dtype['value'])
|
||||
if is_shape_known(dims['value']):
|
||||
if not is_shape_unknown(dims['value']):
|
||||
for i, item in enumerate(dims['value']):
|
||||
validator.check_positive_int(item, f'dims[{i}]', self.name)
|
||||
ret = np.full(dims['value'], x['value'], x_nptype)
|
||||
|
@ -2298,7 +2298,6 @@ class Tile(PrimitiveWithInfer):
|
|||
for a, b in zip(multiples_v_min, multiples_v_max):
|
||||
if isinstance(a, (Tensor_, Tensor)):
|
||||
a = a.asnumpy()
|
||||
if isinstance(b, (Tensor_, Tensor)):
|
||||
b = b.asnumpy()
|
||||
if x_shp[i] >= 0:
|
||||
x_shp[i] *= a
|
||||
|
@ -2411,7 +2410,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
|
|||
validator.check_positive_int(segment_ids_shp_len, "rank of segment_ids", self.name)
|
||||
validator.check(f'rank of input_x', len(x_shp),
|
||||
'rank of segments_id', len(segment_ids_shp), Rel.GE, self.name)
|
||||
if is_shape_known(x_shp) and is_shape_known(segment_ids_shp):
|
||||
if not is_shape_unknown(x_shp) and not is_shape_unknown(segment_ids_shp):
|
||||
# only validate when both shapes fully known
|
||||
for i, value in enumerate(segment_ids_shp):
|
||||
validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name)
|
||||
|
@ -2495,7 +2494,7 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
|
|||
|
||||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
|
||||
if is_shape_known(x_shape) and is_shape_known(segment_ids_shape):
|
||||
if not is_shape_unknown(x_shape) and not is_shape_unknown(segment_ids_shape):
|
||||
# only validate when both shapes fully known
|
||||
validator.check(f'first shape of input_x', x_shape[0],
|
||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||
|
@ -2603,7 +2602,7 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
|
|||
|
||||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.number], self.name)
|
||||
if is_shape_known(x_shape) and is_shape_known(segment_ids_shape):
|
||||
if not is_shape_unknown(x_shape) and not is_shape_unknown(segment_ids_shape):
|
||||
# only validate when both shapes fully known
|
||||
validator.check(f'first shape of input_x', x_shape[0],
|
||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||
|
@ -3457,63 +3456,6 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
validator.check_non_negative_int(new_axis_mask, 'new_axis_mask', self.name)
|
||||
validator.check_non_negative_int(shrink_axis_mask, 'shrink_axis_mask', self.name)
|
||||
|
||||
def _check_and_get_value(self, slice_input, name):
|
||||
"""Check begin, end, strides. Get its length and value."""
|
||||
slice_value = slice_input['value']
|
||||
has_special_value = False
|
||||
if "min_value" in slice_input and "max_value" in slice_input:
|
||||
slice_min = slice_input["min_value"]
|
||||
slice_max = slice_input["max_value"]
|
||||
has_special_value = True
|
||||
else:
|
||||
slice_min = slice_value
|
||||
slice_max = slice_value
|
||||
if slice_value is None:
|
||||
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
|
||||
slice_shape = slice_input['shape']
|
||||
if len(slice_shape) != 1:
|
||||
raise ValueError(f"For '{self.name}', both the 'begins', 'ends', and 'strides' must be 1-D, "
|
||||
f"but got '{name}' shape: {slice_shape}.")
|
||||
# not support scalar
|
||||
slices = {
|
||||
'value': slice_value,
|
||||
'min_value': slice_min,
|
||||
'max_value': slice_max
|
||||
}
|
||||
return slices, slice_shape[0], has_special_value
|
||||
|
||||
if isinstance(slice_value, Tensor_):
|
||||
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
|
||||
slice_value = slice_value.asnumpy().tolist()
|
||||
elif not isinstance(slice_value, tuple):
|
||||
raise TypeError(f"For '{self.name}', both the 'begin', 'end', and 'strides' must be a tuple or Tensor, "
|
||||
f"but got '{name}': {slice_value}.")
|
||||
|
||||
if tuple(filter(lambda x: not isinstance(x, int), slice_value)):
|
||||
raise TypeError(f"For '{self.name}', the elements of 'begin', 'end', and 'strides' must be int, "
|
||||
f"but got {name}: {slice_value}.")
|
||||
|
||||
if name == 'strides':
|
||||
if slice_value is not None and tuple(filter(lambda x: x == 0, slice_value)):
|
||||
raise ValueError(f"For '{self.name}', 'strides' cannot contain 0, but got 'strides': {slice_value}.")
|
||||
|
||||
slices = {
|
||||
'value': slice_value,
|
||||
'min_value': slice_min,
|
||||
'max_value': slice_max
|
||||
}
|
||||
return slices, len(slice_value), has_special_value
|
||||
|
||||
def _check_and_get_shape(self, x):
|
||||
"""Check the shape of x. Get its shape and min/max_shape."""
|
||||
x_shape = x['shape']
|
||||
min_shape = None
|
||||
max_shape = None
|
||||
if "min_shape" in x and "max_shape" in x:
|
||||
min_shape = x["min_shape"]
|
||||
max_shape = x["max_shape"]
|
||||
return x_shape, min_shape, max_shape
|
||||
|
||||
def __infer__(self, x, begin, end, strides):
|
||||
x_shape, min_shape, max_shape = self._check_and_get_shape(x)
|
||||
begin_v, begin_len, begin_specical_value = self._check_and_get_value(begin, 'begin')
|
||||
|
@ -3528,7 +3470,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
if begin_specical_value or end_specical_value:
|
||||
bd_has_min_max_value = True
|
||||
|
||||
if bd_has_min_max_value and is_shape_known(x_shape):
|
||||
if bd_has_min_max_value and not is_shape_unknown(x_shape):
|
||||
ret_shape = [-1] * len(x_shape)
|
||||
ret_min_shape = list(x_shape)
|
||||
ret_max_shape = list(x_shape)
|
||||
|
@ -3557,9 +3499,6 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
rets['min_shape'] = ret_min_shape
|
||||
rets['max_shape'] = ret_max_shape
|
||||
|
||||
if is_shape_known(x_shape):
|
||||
return self._compute_max_min_shape(rets, x_shape, ret_shape)
|
||||
|
||||
return rets
|
||||
|
||||
ret_shape = self._compute_slicing_shape(x_shape, begin_v['value'], end_v['value'], strides_v['value'])
|
||||
|
@ -3591,17 +3530,16 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
'dtype': x['dtype'],
|
||||
'value': value}
|
||||
|
||||
def _compute_max_min_shape(self, rets, x_shape, ret_shape):
|
||||
"""compute max/min shape"""
|
||||
ret_min_shape = [1] * len(x_shape)
|
||||
ret_max_shape = x_shape
|
||||
for i, val in enumerate(ret_shape):
|
||||
if val > 0:
|
||||
ret_min_shape[i] = val
|
||||
ret_max_shape[i] = val
|
||||
rets['max_shape'] = ret_max_shape
|
||||
rets['min_shape'] = ret_min_shape
|
||||
return rets
|
||||
@staticmethod
|
||||
def _check_and_get_shape(x):
|
||||
"""Check the shape of x. Get its shape and min/max_shape."""
|
||||
x_shape = x['shape']
|
||||
min_shape = None
|
||||
max_shape = None
|
||||
if "min_shape" in x and "max_shape" in x:
|
||||
min_shape = x["min_shape"]
|
||||
max_shape = x["max_shape"]
|
||||
return x_shape, min_shape, max_shape
|
||||
|
||||
def _compute_slicing_shape(self, x_shape, begin_v, end_v, strides_v):
|
||||
"""Computes the shape of the slicing."""
|
||||
|
@ -3738,6 +3676,53 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
j += 1
|
||||
return ret_shape, ret_min_shape, ret_max_shape
|
||||
|
||||
def _check_and_get_value(self, slice_input, name):
|
||||
"""Check begin, end, strides. Get its length and value."""
|
||||
slice_value = slice_input['value']
|
||||
has_special_value = False
|
||||
if "min_value" in slice_input and "max_value" in slice_input:
|
||||
slice_min = slice_input["min_value"]
|
||||
slice_max = slice_input["max_value"]
|
||||
has_special_value = True
|
||||
else:
|
||||
slice_min = slice_value
|
||||
slice_max = slice_value
|
||||
if slice_value is None:
|
||||
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
|
||||
slice_shape = slice_input['shape']
|
||||
if len(slice_shape) != 1:
|
||||
raise ValueError(f"For '{self.name}', both the 'begins', 'ends', and 'strides' must be 1-D, "
|
||||
f"but got '{name}' shape: {slice_shape}.")
|
||||
# not support scalar
|
||||
slices = {
|
||||
'value': slice_value,
|
||||
'min_value': slice_min,
|
||||
'max_value': slice_max
|
||||
}
|
||||
return slices, slice_shape[0], has_special_value
|
||||
|
||||
if isinstance(slice_value, Tensor_):
|
||||
validator.check_tensor_dtype_valid(name, slice_input['dtype'], [mstype.int64], self.name)
|
||||
slice_value = slice_value.asnumpy().tolist()
|
||||
elif not isinstance(slice_value, tuple):
|
||||
raise TypeError(f"For '{self.name}', both the 'begin', 'end', and 'strides' must be a tuple or Tensor, "
|
||||
f"but got '{name}': {slice_value}.")
|
||||
|
||||
if tuple(filter(lambda x: not isinstance(x, int), slice_value)):
|
||||
raise TypeError(f"For '{self.name}', the elements of 'begin', 'end', and 'strides' must be int, "
|
||||
f"but got {name}: {slice_value}.")
|
||||
|
||||
if name == 'strides':
|
||||
if slice_value is not None and tuple(filter(lambda x: x == 0, slice_value)):
|
||||
raise ValueError(f"For '{self.name}', 'strides' cannot contain 0, but got 'strides': {slice_value}.")
|
||||
|
||||
slices = {
|
||||
'value': slice_value,
|
||||
'min_value': slice_min,
|
||||
'max_value': slice_max
|
||||
}
|
||||
return slices, len(slice_value), has_special_value
|
||||
|
||||
|
||||
class Diag(PrimitiveWithCheck):
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue