!47595 remove numpy constexpr

Merge pull request !47595 from hujiahui8/constexpr
This commit is contained in:
i-robot 2023-01-09 04:02:21 +00:00 committed by Gitee
commit 92c63200f1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 105 additions and 157 deletions

View File

@ -236,12 +236,6 @@ def raise_none_error(name):
f" It can not be None since it is not specified during initialization.")
@constexpr
def raise_probs_logits_error():
raise TypeError(
"Either 'probs' or 'logits' must be specified, but not both.")
@constexpr
def raise_broadcast_error(shape_a, shape_b):
raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")

View File

@ -200,28 +200,6 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
@constexpr
def _check_input_shape(input_shape, param_name, func_name, target_len):
# check the input length
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len)
@constexpr
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
# check the input length
_LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape)
@constexpr
def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
@constexpr
def _check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
_LayerInputCheck.check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape)
class _Dropout(nn.Cell):
r"""
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
@ -707,17 +685,9 @@ class FixedSparseAttention(nn.Cell):
self.slice1 = P.StridedSlice().shard(((dp, 1, 1),))
def construct(self, q, k, v, attention_mask):
_check_shape_equal(F.shape(q), "q", self.cls_name,
[self.batch_size, self.seq_length, self.hidden_size])
_check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(k), "k", self.cls_name,
[self.batch_size, self.seq_length, self.hidden_size])
_check_input_dtype(F.dtype(k), "k", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(v), "v", self.cls_name,
[self.batch_size, self.seq_length, self.hidden_size])
_check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, self.seq_length, self.seq_length])
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
q, k, v = self._transpose_inputs(q, k, v)

View File

@ -30,7 +30,7 @@ from mindspore.context import ParallelMode
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
from mindspore.log import _LogActionOnce
from mindspore import log as logger
from mindspore.nn.transformer.layers import _check_input_dtype, _check_input_shape
from mindspore.nn.transformer.layers import _check_input_dtype
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig
__all__ = ["CrossEntropyLoss"]
@ -247,7 +247,4 @@ class CrossEntropyLoss(Cell):
_check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
_check_input_shape(F.shape(logits), "logits", self.cls_name, 2)
_check_input_shape(F.shape(label), "label", self.cls_name, 1)
_check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 1)
return True

View File

@ -18,7 +18,6 @@ Note: Mixture of Expert (MoE) structure. This is an experimental interface that
from __future__ import absolute_import
from __future__ import division
import math
import numpy as np
from mindspore.common.tensor import Tensor
@ -134,7 +133,9 @@ def _check_moe_config(moe_config=None, parallel_config=None):
@constexpr
def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim):
return math.ceil(k * tokens_per_group * capacity_factor / expert_dim)
res = k * tokens_per_group * capacity_factor / expert_dim
res_int = int(res)
return res_int if res < 0 or res == res_int else res_int + 1
class MoE(Cell):

View File

@ -35,10 +35,9 @@ from mindspore import log as logger
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
from mindspore.context import ParallelMode
from mindspore.log import _LogActionOnce
from mindspore.nn.transformer.layers import _LayerNorm, _Linear, _check_input_shape, \
from mindspore.nn.transformer.layers import _LayerNorm, _Linear, \
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \
_check_shape_equal_without_batch
_check_past_none_input_none, _check_input_dtype
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
_Config, _check_config, MoEParallelConfig
from mindspore.nn.transformer.moe import default_moe_config, MoE, _check_moe_config
@ -566,7 +565,6 @@ class FeedForward(Cell):
self.cast = P.Cast()
def construct(self, x):
_check_input_shape(F.shape(x), "x", self.cls_name, [2, 3])
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
x = self.cast(x, mstype.float16)
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
@ -639,9 +637,7 @@ class AttentionMask(Cell):
self.multiply = P.Mul().shard(((parallel_config.data_parallel, 1, 1), (1, 1, 1)))
def construct(self, input_mask):
_check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 2)
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
_check_input_shape_value(F.shape(input_mask), 1, "input_mask", self.cls_name, self.seq_length)
input_mask = P.Cast()(self.not_equal(input_mask, 0), mstype.float16)
input_shape = P.Shape()(input_mask)
shape_right = (input_shape[0], 1, input_shape[1])
@ -736,7 +732,6 @@ class VocabEmbedding(Cell):
f"model parallel for the embedding lookup.")
def construct(self, input_ids):
_check_input_shape(F.shape(input_ids), "input_ids", self.cls_name, 2)
_check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32], self.cls_name)
output = self.gather(self.embedding_table, input_ids, 0)
return output, self.embedding_table.value()
@ -1223,27 +1218,6 @@ class MultiHeadAttention(Cell):
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
value_past=None, batch_valid_length=None):
r"""Check inputs"""
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal_without_batch(F.shape(query_tensor), "query_tensor", self.cls_name,
[self.src_seq_length, self.hidden_size])
_check_shape_equal_without_batch(F.shape(key_tensor), "key_tensor", self.cls_name,
[self.tgt_seq_length, self.hidden_size])
_check_shape_equal_without_batch(F.shape(value_tensor), "value_tensor", self.cls_name,
[self.tgt_seq_length, self.hidden_size])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[F.shape(attention_mask)[0], self.src_seq_length, self.tgt_seq_length])
else:
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
@ -1264,13 +1238,8 @@ class MultiHeadAttention(Cell):
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
batch_valid_length_is_tensor, batch_is_default)
if self.use_past:
_check_shape_equal(F.shape(key_past), "key_past", self.cls_name,
[self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length])
_check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(value_past), "value_past", self.cls_name,
[self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head])
_check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name)
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
return True
@ -1770,17 +1739,6 @@ class TransformerEncoderLayer(Cell):
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
r"""Check inputs"""
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal_without_batch(F.shape(x), "x", self.cls_name,
[self.seq_length, self.hidden_size])
if input_mask is not None:
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[F.shape(input_mask)[0], self.seq_length, self.seq_length])
else:
_check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
if input_mask is not None:
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
[F.shape(input_mask)[0], 1, self.seq_length])
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
if input_mask is not None:
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
@ -1795,9 +1753,7 @@ class TransformerEncoderLayer(Cell):
batch_valid_length_is_tensor, batch_is_default)
if self.use_past:
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
return True
@ -2226,31 +2182,14 @@ class TransformerDecoderLayer(Cell):
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
r"""Check inputs"""
if not self.use_past or (self.use_past and self.is_first_iteration):
_check_shape_equal_without_batch(F.shape(hidden_states), "hidden_states", self.cls_name,
[self.tgt_seq_length, self.hidden_size])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[F.shape(attention_mask)[0], self.tgt_seq_length, self.tgt_seq_length])
else:
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
[self.batch_size, 1, self.hidden_size])
if attention_mask is not None:
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
[self.batch_size, 1, self.tgt_seq_length])
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
if attention_mask is not None:
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
self.cls_name)
if encoder_output is not None:
_check_shape_equal_without_batch(F.shape(encoder_output), "encoder_output", self.cls_name,
[self.src_seq_length, self.hidden_size])
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
[mstype.float32, mstype.float16], self.cls_name)
if memory_mask is not None:
_check_shape_equal_without_batch(F.shape(memory_mask), "memory_mask", self.cls_name,
[self.tgt_seq_length, self.src_seq_length])
_check_input_dtype(F.dtype(memory_mask), "memory_mask",
[mstype.float32, mstype.float16], self.cls_name)
@ -2264,9 +2203,7 @@ class TransformerDecoderLayer(Cell):
batch_valid_length_is_tensor, batch_is_default)
if self.use_past:
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
return True

View File

@ -27,7 +27,6 @@ from mindspore._checkparam import Validator as validator
from mindspore import ops, nn
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.ops.primitive import constexpr
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
@ -467,14 +466,6 @@ class _VirtualDatasetCell(Cell):
return self._backbone(*output)
@constexpr
def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
if input_shape[0] % micro_size != 0:
raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
f"divided by micro size({micro_size})")
return True
class _MicroBatch(Cell):
"""
transform mini-batch to micro-batch in pipeline parallel.
@ -493,7 +484,6 @@ class _MicroBatch(Cell):
micro_inputs = ()
for each_input in inputs:
input_shape = self.shape(each_input)
_check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size)
micro_batch_begin = i * input_shape[0] // self.micro_size
micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
strided_slice_begin = (micro_batch_begin,)

View File

@ -39,7 +39,7 @@ from mindspore.numpy.array_ops import ravel, expand_dims, moveaxis, concatenate,
split
from mindspore.numpy.utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
_raise_type_error, _check_same_type, _check_is_float, \
_raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \
_is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range, \
_check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \
@ -682,7 +682,6 @@ def inner(a, b):
if F.rank(a) == 0 or F.rank(b) == 0:
return F.tensor_mul(a, b)
_check_shape_aligned(F.shape(a), F.shape(b))
aligned_shape_a = (F.shape_mul(F.shape(a)[:-1]), F.shape(a)[-1])
aligned_shape_b = (F.shape_mul(F.shape(b)[:-1]), F.shape(a)[-1])
a_aligned = F.reshape(a, aligned_shape_a)

View File

@ -30,15 +30,12 @@ from mindspore._checkparam import Validator as validator
from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
_check_axis_type = constexpr(validator.check_axis_type)
@constexpr
def _check_shape(shape):
"""check the shape param to match the numpy style"""
if not isinstance(shape, (int, tuple, list, Tensor, typing.Tuple, typing.List)):
raise TypeError(f"only int, tuple, list and tensor are allowed for shape, but got {type(shape)}")
# convert tensor to int/list, use followed if statements to do further conversions
if isinstance(shape, Tensor):
shape = shape.asnumpy().tolist()
@ -47,11 +44,6 @@ def _check_shape(shape):
shape = (shape,)
elif isinstance(shape, (list, typing.List)):
shape = tuple(shape)
for s in shape:
if not isinstance(s, int):
raise TypeError("each entry in shape should be int.")
if s < 0:
raise ValueError("each entry in shape should no less than 0.")
return shape
@ -89,8 +81,6 @@ def _is_shape_empty(shp):
@constexpr
def _check_start_normalize(start, ndim):
"""check and normalize start argument for rollaxis."""
if start < -ndim or start > ndim:
raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.")
if start < 0:
start = start + ndim
return start
@ -112,9 +102,6 @@ def _check_axes_range(axes, ndim):
TypeError: If the axes are not integer, tuple(int) or list(int).
ValueError: If duplicate axes exists or some axis is out of bounds.
"""
_check_axis_type(axes, True, True, True)
if isinstance(axes, (list, tuple)):
_check_element_int(axes)
axes = _canonicalize_axis(axes, ndim)
return axes
@ -125,14 +112,12 @@ def _get_device():
return context.get_context('device_target')
#remove constexpr
def _infer_out_shape(*shapes):
"""
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
"""
shape_out = list()
max_len = max([len(it) for it in shapes])
for i in range(max_len):
items = [it[i-max_len+len(it)] if i-max_len +
len(it) >= 0 else 1 for it in shapes]
@ -146,23 +131,14 @@ def _can_broadcast(*shapes):
"""
Returns Ture if shapes can broadcast, False if they cannot.
"""
try:
_infer_out_shape(*shapes)
except ValueError:
return False
finally:
pass
_infer_out_shape(*shapes)
return True
@constexpr
def _check_axis_in_range(axis, ndim):
"""Checks axes are with the bounds of ndim"""
if not isinstance(axis, int):
raise TypeError(f'axes should be integers, not {type(axis)}')
if not -ndim <= axis < ndim:
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
return axis % ndim
return axis - axis // ndim * ndim
@constexpr
@ -176,19 +152,10 @@ def _check_axis_valid(axes, ndim):
return axes
if isinstance(axes, (tuple, list)):
axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
if any(axes.count(el) > 1 for el in axes):
raise ValueError('duplicate value in "axis"')
return axes
return (_check_axis_in_range(axes, ndim),)
@constexpr
def _check_shape_aligned(shape1, shape2):
"""Checks shape1 and shape2 are valid shapes to perform inner product"""
if shape1[-1] != shape2[-1]:
raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
@constexpr
def _tile_size(shape, out_shape, ndim):
"""Returns tile_size such that shape*tile_size = out_shape"""
@ -341,10 +308,31 @@ def _canonicalize_axis(axis, ndim):
def canonicalizer(ax):
return ax + ndim if ax < 0 else ax
def _sort_axis(ax):
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
return result + left[i:] + right[j:]
if len(ax) <= 1:
return ax
middle = len(ax) // 2
left = _sort_axis(ax[:middle])
right = _sort_axis(ax[middle:])
return merge(left, right)
axis = tuple([canonicalizer(axis) for axis in axis])
if all(axis.count(el) <= 1 for el in axis):
return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
raise ValueError(f"duplicate axes in {axis}.")
return tuple(_sort_axis(axis)) if len(axis) > 1 else axis[0]
@constexpr

View File

@ -209,3 +209,75 @@ def test_flip():
net = Net()
output = net(x)
assert np.allclose(output.asnumpy(), expect.asnumpy())
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_calculate_expert_capacity():
"""
Feature: calculate_expert_capacity func
Description: Verify the result of calculate_expert_capacity
Expectation: success
"""
from mindspore.nn.transformer.moe import calculate_expert_capacity
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = calculate_expert_capacity
def construct(self, k, tokens_per_group, capacity_factor, expert_dim):
return self.func(k, tokens_per_group, capacity_factor, expert_dim)
net = Net()
assert net(10.1, 2.0, 3.3, 4) == 17
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_infer_out_shape():
"""
Feature: _infer_out_shape func
Description: Verify the result of _infer_out_shape
Expectation: success
"""
from mindspore.numpy.utils_const import _infer_out_shape
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _infer_out_shape
def construct(self, *shape):
return self.func(*shape)
net = Net()
assert net((5,), (6, 1), (7, 1, 5), (8, 1, 6, 1)) == (8, 7, 6, 5)
@pytest.mark.level1
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_canonicalize_axis():
"""
Feature: _canonicalize_axis func
Description: Verify the result of _canonicalize_axis
Expectation: success
"""
from mindspore.numpy.utils_const import _canonicalize_axis
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.func = _canonicalize_axis
def construct(self, axis, ndim):
return self.func(axis, ndim)
net = Net()
assert net(0, 2) == 0