diff --git a/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py index 1f87f4dc904..cd7c464e909 100644 --- a/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/python/mindspore/nn/probability/distribution/_utils/utils.py @@ -236,12 +236,6 @@ def raise_none_error(name): f" It can not be None since it is not specified during initialization.") -@constexpr -def raise_probs_logits_error(): - raise TypeError( - "Either 'probs' or 'logits' must be specified, but not both.") - - @constexpr def raise_broadcast_error(shape_a, shape_b): raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.") diff --git a/mindspore/python/mindspore/nn/transformer/layers.py b/mindspore/python/mindspore/nn/transformer/layers.py index 316409d4d61..6e73da129aa 100644 --- a/mindspore/python/mindspore/nn/transformer/layers.py +++ b/mindspore/python/mindspore/nn/transformer/layers.py @@ -200,28 +200,6 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name): Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name) -@constexpr -def _check_input_shape(input_shape, param_name, func_name, target_len): - # check the input length - _LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len) - - -@constexpr -def _check_shape_equal(input_shape, param_name, func_name, target_shape): - # check the input length - _LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape) - - -@constexpr -def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value): - _LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value) - - -@constexpr -def _check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape): - _LayerInputCheck.check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape) - - class _Dropout(nn.Cell): r""" A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training. @@ -707,17 +685,9 @@ class FixedSparseAttention(nn.Cell): self.slice1 = P.StridedSlice().shard(((dp, 1, 1),)) def construct(self, q, k, v, attention_mask): - _check_shape_equal(F.shape(q), "q", self.cls_name, - [self.batch_size, self.seq_length, self.hidden_size]) _check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name) - _check_shape_equal(F.shape(k), "k", self.cls_name, - [self.batch_size, self.seq_length, self.hidden_size]) _check_input_dtype(F.dtype(k), "k", [mstype.float16], self.cls_name) - _check_shape_equal(F.shape(v), "v", self.cls_name, - [self.batch_size, self.seq_length, self.hidden_size]) _check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name) - _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, - [self.batch_size, self.seq_length, self.seq_length]) _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) q, k, v = self._transpose_inputs(q, k, v) diff --git a/mindspore/python/mindspore/nn/transformer/loss.py b/mindspore/python/mindspore/nn/transformer/loss.py index aa4b968c3bb..cc97dc6e030 100644 --- a/mindspore/python/mindspore/nn/transformer/loss.py +++ b/mindspore/python/mindspore/nn/transformer/loss.py @@ -30,7 +30,7 @@ from mindspore.context import ParallelMode from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages from mindspore.log import _LogActionOnce from mindspore import log as logger -from mindspore.nn.transformer.layers import _check_input_dtype, _check_input_shape +from mindspore.nn.transformer.layers import _check_input_dtype from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig __all__ = ["CrossEntropyLoss"] @@ -247,7 +247,4 @@ class CrossEntropyLoss(Cell): _check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name) _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name) - _check_input_shape(F.shape(logits), "logits", self.cls_name, 2) - _check_input_shape(F.shape(label), "label", self.cls_name, 1) - _check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 1) return True diff --git a/mindspore/python/mindspore/nn/transformer/moe.py b/mindspore/python/mindspore/nn/transformer/moe.py index e74878efc01..27440676a64 100644 --- a/mindspore/python/mindspore/nn/transformer/moe.py +++ b/mindspore/python/mindspore/nn/transformer/moe.py @@ -18,7 +18,6 @@ Note: Mixture of Expert (MoE) structure. This is an experimental interface that from __future__ import absolute_import from __future__ import division -import math import numpy as np from mindspore.common.tensor import Tensor @@ -134,7 +133,9 @@ def _check_moe_config(moe_config=None, parallel_config=None): @constexpr def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim): - return math.ceil(k * tokens_per_group * capacity_factor / expert_dim) + res = k * tokens_per_group * capacity_factor / expert_dim + res_int = int(res) + return res_int if res < 0 or res == res_int else res_int + 1 class MoE(Cell): diff --git a/mindspore/python/mindspore/nn/transformer/transformer.py b/mindspore/python/mindspore/nn/transformer/transformer.py index dc46791e681..a07a5ac4921 100644 --- a/mindspore/python/mindspore/nn/transformer/transformer.py +++ b/mindspore/python/mindspore/nn/transformer/transformer.py @@ -35,10 +35,9 @@ from mindspore import log as logger from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore.context import ParallelMode from mindspore.log import _LogActionOnce -from mindspore.nn.transformer.layers import _LayerNorm, _Linear, _check_input_shape, \ +from mindspore.nn.transformer.layers import _LayerNorm, _Linear, \ _args_type_validator_check, _valid_type_checks, _valid_value_checks, \ - _check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \ - _check_shape_equal_without_batch + _check_past_none_input_none, _check_input_dtype from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \ _Config, _check_config, MoEParallelConfig from mindspore.nn.transformer.moe import default_moe_config, MoE, _check_moe_config @@ -566,7 +565,6 @@ class FeedForward(Cell): self.cast = P.Cast() def construct(self, x): - _check_input_shape(F.shape(x), "x", self.cls_name, [2, 3]) _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) x = self.cast(x, mstype.float16) # returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size] @@ -639,9 +637,7 @@ class AttentionMask(Cell): self.multiply = P.Mul().shard(((parallel_config.data_parallel, 1, 1), (1, 1, 1))) def construct(self, input_mask): - _check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 2) _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name) - _check_input_shape_value(F.shape(input_mask), 1, "input_mask", self.cls_name, self.seq_length) input_mask = P.Cast()(self.not_equal(input_mask, 0), mstype.float16) input_shape = P.Shape()(input_mask) shape_right = (input_shape[0], 1, input_shape[1]) @@ -736,7 +732,6 @@ class VocabEmbedding(Cell): f"model parallel for the embedding lookup.") def construct(self, input_ids): - _check_input_shape(F.shape(input_ids), "input_ids", self.cls_name, 2) _check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32], self.cls_name) output = self.gather(self.embedding_table, input_ids, 0) return output, self.embedding_table.value() @@ -1223,27 +1218,6 @@ class MultiHeadAttention(Cell): def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None, value_past=None, batch_valid_length=None): r"""Check inputs""" - if not self.use_past or (self.use_past and self.is_first_iteration): - _check_shape_equal_without_batch(F.shape(query_tensor), "query_tensor", self.cls_name, - [self.src_seq_length, self.hidden_size]) - _check_shape_equal_without_batch(F.shape(key_tensor), "key_tensor", self.cls_name, - [self.tgt_seq_length, self.hidden_size]) - _check_shape_equal_without_batch(F.shape(value_tensor), "value_tensor", self.cls_name, - [self.tgt_seq_length, self.hidden_size]) - if attention_mask is not None: - _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, - [F.shape(attention_mask)[0], self.src_seq_length, self.tgt_seq_length]) - else: - _check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name, - [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]]) - _check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name, - [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]]) - _check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name, - [[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]]) - if attention_mask is not None: - _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, - [[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]]) - _check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name) _check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name) @@ -1264,13 +1238,8 @@ class MultiHeadAttention(Cell): _check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None, batch_valid_length_is_tensor, batch_is_default) if self.use_past: - _check_shape_equal(F.shape(key_past), "key_past", self.cls_name, - [self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length]) _check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name) - _check_shape_equal(F.shape(value_past), "value_past", self.cls_name, - [self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head]) _check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name) - _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size]) _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name) return True @@ -1770,17 +1739,6 @@ class TransformerEncoderLayer(Cell): def _check_input(self, x, input_mask, init_reset, batch_valid_length): r"""Check inputs""" - if not self.use_past or (self.use_past and self.is_first_iteration): - _check_shape_equal_without_batch(F.shape(x), "x", self.cls_name, - [self.seq_length, self.hidden_size]) - if input_mask is not None: - _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name, - [F.shape(input_mask)[0], self.seq_length, self.seq_length]) - else: - _check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size]) - if input_mask is not None: - _check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name, - [F.shape(input_mask)[0], 1, self.seq_length]) _check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name) if input_mask is not None: _check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name) @@ -1795,9 +1753,7 @@ class TransformerEncoderLayer(Cell): batch_valid_length_is_tensor, batch_is_default) if self.use_past: - _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1]) _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name) - _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size]) _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name) return True @@ -2226,31 +2182,14 @@ class TransformerDecoderLayer(Cell): def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length): r"""Check inputs""" - if not self.use_past or (self.use_past and self.is_first_iteration): - _check_shape_equal_without_batch(F.shape(hidden_states), "hidden_states", self.cls_name, - [self.tgt_seq_length, self.hidden_size]) - if attention_mask is not None: - _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, - [F.shape(attention_mask)[0], self.tgt_seq_length, self.tgt_seq_length]) - - else: - _check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name, - [self.batch_size, 1, self.hidden_size]) - if attention_mask is not None: - _check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name, - [self.batch_size, 1, self.tgt_seq_length]) _check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name) if attention_mask is not None: _check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name) if encoder_output is not None: - _check_shape_equal_without_batch(F.shape(encoder_output), "encoder_output", self.cls_name, - [self.src_seq_length, self.hidden_size]) _check_input_dtype(F.dtype(encoder_output), "encoder_output", [mstype.float32, mstype.float16], self.cls_name) if memory_mask is not None: - _check_shape_equal_without_batch(F.shape(memory_mask), "memory_mask", self.cls_name, - [self.tgt_seq_length, self.src_seq_length]) _check_input_dtype(F.dtype(memory_mask), "memory_mask", [mstype.float32, mstype.float16], self.cls_name) @@ -2264,9 +2203,7 @@ class TransformerDecoderLayer(Cell): batch_valid_length_is_tensor, batch_is_default) if self.use_past: - _check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1]) _check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name) - _check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size]) _check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name) return True diff --git a/mindspore/python/mindspore/nn/wrap/cell_wrapper.py b/mindspore/python/mindspore/nn/wrap/cell_wrapper.py index dbc5b0f72cc..1af0544fad6 100644 --- a/mindspore/python/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/python/mindspore/nn/wrap/cell_wrapper.py @@ -27,7 +27,6 @@ from mindspore._checkparam import Validator as validator from mindspore import ops, nn from mindspore.common import dtype as mstype from mindspore.common.parameter import Parameter, ParameterTuple -from mindspore.ops.primitive import constexpr from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.ops import operations as P @@ -467,14 +466,6 @@ class _VirtualDatasetCell(Cell): return self._backbone(*output) -@constexpr -def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size): - if input_shape[0] % micro_size != 0: - raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be " - f"divided by micro size({micro_size})") - return True - - class _MicroBatch(Cell): """ transform mini-batch to micro-batch in pipeline parallel. @@ -493,7 +484,6 @@ class _MicroBatch(Cell): micro_inputs = () for each_input in inputs: input_shape = self.shape(each_input) - _check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size) micro_batch_begin = i * input_shape[0] // self.micro_size micro_batch_end = (i + 1) * input_shape[0] // self.micro_size strided_slice_begin = (micro_batch_begin,) diff --git a/mindspore/python/mindspore/numpy/math_ops.py b/mindspore/python/mindspore/numpy/math_ops.py index 83738a8999e..b4c92f92c7a 100644 --- a/mindspore/python/mindspore/numpy/math_ops.py +++ b/mindspore/python/mindspore/numpy/math_ops.py @@ -39,7 +39,7 @@ from mindspore.numpy.array_ops import ravel, expand_dims, moveaxis, concatenate, split from mindspore.numpy.utils_const import _infer_out_shape, _check_axis_valid, _get_device, \ - _check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \ + _raise_type_error, _check_same_type, _check_is_float, \ _raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \ _is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range, \ _check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \ @@ -682,7 +682,6 @@ def inner(a, b): if F.rank(a) == 0 or F.rank(b) == 0: return F.tensor_mul(a, b) - _check_shape_aligned(F.shape(a), F.shape(b)) aligned_shape_a = (F.shape_mul(F.shape(a)[:-1]), F.shape(a)[-1]) aligned_shape_b = (F.shape_mul(F.shape(b)[:-1]), F.shape(a)[-1]) a_aligned = F.reshape(a, aligned_shape_a) diff --git a/mindspore/python/mindspore/numpy/utils_const.py b/mindspore/python/mindspore/numpy/utils_const.py index 1ccf2caf157..dc0da4f70a6 100644 --- a/mindspore/python/mindspore/numpy/utils_const.py +++ b/mindspore/python/mindspore/numpy/utils_const.py @@ -30,15 +30,12 @@ from mindspore._checkparam import Validator as validator from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric - _check_axis_type = constexpr(validator.check_axis_type) @constexpr def _check_shape(shape): """check the shape param to match the numpy style""" - if not isinstance(shape, (int, tuple, list, Tensor, typing.Tuple, typing.List)): - raise TypeError(f"only int, tuple, list and tensor are allowed for shape, but got {type(shape)}") # convert tensor to int/list, use followed if statements to do further conversions if isinstance(shape, Tensor): shape = shape.asnumpy().tolist() @@ -47,11 +44,6 @@ def _check_shape(shape): shape = (shape,) elif isinstance(shape, (list, typing.List)): shape = tuple(shape) - for s in shape: - if not isinstance(s, int): - raise TypeError("each entry in shape should be int.") - if s < 0: - raise ValueError("each entry in shape should no less than 0.") return shape @@ -89,8 +81,6 @@ def _is_shape_empty(shp): @constexpr def _check_start_normalize(start, ndim): """check and normalize start argument for rollaxis.""" - if start < -ndim or start > ndim: - raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.") if start < 0: start = start + ndim return start @@ -112,9 +102,6 @@ def _check_axes_range(axes, ndim): TypeError: If the axes are not integer, tuple(int) or list(int). ValueError: If duplicate axes exists or some axis is out of bounds. """ - _check_axis_type(axes, True, True, True) - if isinstance(axes, (list, tuple)): - _check_element_int(axes) axes = _canonicalize_axis(axes, ndim) return axes @@ -125,14 +112,12 @@ def _get_device(): return context.get_context('device_target') -#remove constexpr def _infer_out_shape(*shapes): """ Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast. """ shape_out = list() max_len = max([len(it) for it in shapes]) - for i in range(max_len): items = [it[i-max_len+len(it)] if i-max_len + len(it) >= 0 else 1 for it in shapes] @@ -146,23 +131,14 @@ def _can_broadcast(*shapes): """ Returns Ture if shapes can broadcast, False if they cannot. """ - try: - _infer_out_shape(*shapes) - except ValueError: - return False - finally: - pass + _infer_out_shape(*shapes) return True @constexpr def _check_axis_in_range(axis, ndim): """Checks axes are with the bounds of ndim""" - if not isinstance(axis, int): - raise TypeError(f'axes should be integers, not {type(axis)}') - if not -ndim <= axis < ndim: - raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}') - return axis % ndim + return axis - axis // ndim * ndim @constexpr @@ -176,19 +152,10 @@ def _check_axis_valid(axes, ndim): return axes if isinstance(axes, (tuple, list)): axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes)) - if any(axes.count(el) > 1 for el in axes): - raise ValueError('duplicate value in "axis"') return axes return (_check_axis_in_range(axes, ndim),) -@constexpr -def _check_shape_aligned(shape1, shape2): - """Checks shape1 and shape2 are valid shapes to perform inner product""" - if shape1[-1] != shape2[-1]: - raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)') - - @constexpr def _tile_size(shape, out_shape, ndim): """Returns tile_size such that shape*tile_size = out_shape""" @@ -341,10 +308,31 @@ def _canonicalize_axis(axis, ndim): def canonicalizer(ax): return ax + ndim if ax < 0 else ax + def _sort_axis(ax): + def merge(left, right): + result = [] + i = j = 0 + while i < len(left) and j < len(right): + if left[i] <= right[j]: + result.append(left[i]) + i += 1 + else: + result.append(right[j]) + j += 1 + + return result + left[i:] + right[j:] + + if len(ax) <= 1: + return ax + + middle = len(ax) // 2 + left = _sort_axis(ax[:middle]) + right = _sort_axis(ax[middle:]) + + return merge(left, right) + axis = tuple([canonicalizer(axis) for axis in axis]) - if all(axis.count(el) <= 1 for el in axis): - return tuple(sorted(axis)) if len(axis) > 1 else axis[0] - raise ValueError(f"duplicate axes in {axis}.") + return tuple(_sort_axis(axis)) if len(axis) > 1 else axis[0] @constexpr diff --git a/tests/st/ops/test_constexpr_modfied.py b/tests/st/ops/test_constexpr_modfied.py index 3c059feb9a3..93e4882141c 100644 --- a/tests/st/ops/test_constexpr_modfied.py +++ b/tests/st/ops/test_constexpr_modfied.py @@ -209,3 +209,75 @@ def test_flip(): net = Net() output = net(x) assert np.allclose(output.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_calculate_expert_capacity(): + """ + Feature: calculate_expert_capacity func + Description: Verify the result of calculate_expert_capacity + Expectation: success + """ + from mindspore.nn.transformer.moe import calculate_expert_capacity + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.func = calculate_expert_capacity + + def construct(self, k, tokens_per_group, capacity_factor, expert_dim): + return self.func(k, tokens_per_group, capacity_factor, expert_dim) + net = Net() + assert net(10.1, 2.0, 3.3, 4) == 17 + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_infer_out_shape(): + """ + Feature: _infer_out_shape func + Description: Verify the result of _infer_out_shape + Expectation: success + """ + from mindspore.numpy.utils_const import _infer_out_shape + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.func = _infer_out_shape + + def construct(self, *shape): + return self.func(*shape) + net = Net() + assert net((5,), (6, 1), (7, 1, 5), (8, 1, 6, 1)) == (8, 7, 6, 5) + + +@pytest.mark.level1 +@pytest.mark.platform_x86_cpu +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_canonicalize_axis(): + """ + Feature: _canonicalize_axis func + Description: Verify the result of _canonicalize_axis + Expectation: success + """ + from mindspore.numpy.utils_const import _canonicalize_axis + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.func = _canonicalize_axis + + def construct(self, axis, ndim): + return self.func(axis, ndim) + net = Net() + assert net(0, 2) == 0