forked from mindspore-Ecosystem/mindspore
!47595 remove numpy constexpr
Merge pull request !47595 from hujiahui8/constexpr
This commit is contained in:
commit
92c63200f1
|
@ -236,12 +236,6 @@ def raise_none_error(name):
|
|||
f" It can not be None since it is not specified during initialization.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_probs_logits_error():
|
||||
raise TypeError(
|
||||
"Either 'probs' or 'logits' must be specified, but not both.")
|
||||
|
||||
|
||||
@constexpr
|
||||
def raise_broadcast_error(shape_a, shape_b):
|
||||
raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.")
|
||||
|
|
|
@ -200,28 +200,6 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
|||
Validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_shape(input_shape, param_name, func_name, target_len):
|
||||
# check the input length
|
||||
_LayerInputCheck.check_shape_length(input_shape, param_name, func_name, target_len)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_equal(input_shape, param_name, func_name, target_shape):
|
||||
# check the input length
|
||||
_LayerInputCheck.check_shape_equal(input_shape, param_name, func_name, target_shape)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_shape_value(input_shape, dim, param_name, cls_name, target_value):
|
||||
_LayerInputCheck.check_shape_value_on_axis(input_shape, dim, param_name, cls_name, target_value)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape):
|
||||
_LayerInputCheck.check_shape_equal_without_batch(input_shape, param_name, func_name, target_shape)
|
||||
|
||||
|
||||
class _Dropout(nn.Cell):
|
||||
r"""
|
||||
A Dropout Implements with P.DropoutGenMask and P.DropoutDoMask for parallel training.
|
||||
|
@ -707,17 +685,9 @@ class FixedSparseAttention(nn.Cell):
|
|||
self.slice1 = P.StridedSlice().shard(((dp, 1, 1),))
|
||||
|
||||
def construct(self, q, k, v, attention_mask):
|
||||
_check_shape_equal(F.shape(q), "q", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(q), "q", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(k), "k", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(k), "k", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(v), "v", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(v), "v", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, self.seq_length, self.seq_length])
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
q, k, v = self._transpose_inputs(q, k, v)
|
||||
|
|
|
@ -30,7 +30,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.parallel._utils import _get_device_num, _get_pipeline_stages
|
||||
from mindspore.log import _LogActionOnce
|
||||
from mindspore import log as logger
|
||||
from mindspore.nn.transformer.layers import _check_input_dtype, _check_input_shape
|
||||
from mindspore.nn.transformer.layers import _check_input_dtype
|
||||
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, OpParallelConfig
|
||||
|
||||
__all__ = ["CrossEntropyLoss"]
|
||||
|
@ -247,7 +247,4 @@ class CrossEntropyLoss(Cell):
|
|||
_check_input_dtype(F.dtype(logits), "logits", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(label), "label", [mstype.int32], self.cls_name)
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32], self.cls_name)
|
||||
_check_input_shape(F.shape(logits), "logits", self.cls_name, 2)
|
||||
_check_input_shape(F.shape(label), "label", self.cls_name, 1)
|
||||
_check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 1)
|
||||
return True
|
||||
|
|
|
@ -18,7 +18,6 @@ Note: Mixture of Expert (MoE) structure. This is an experimental interface that
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -134,7 +133,9 @@ def _check_moe_config(moe_config=None, parallel_config=None):
|
|||
|
||||
@constexpr
|
||||
def calculate_expert_capacity(k, tokens_per_group, capacity_factor, expert_dim):
|
||||
return math.ceil(k * tokens_per_group * capacity_factor / expert_dim)
|
||||
res = k * tokens_per_group * capacity_factor / expert_dim
|
||||
res_int = int(res)
|
||||
return res_int if res < 0 or res == res_int else res_int + 1
|
||||
|
||||
|
||||
class MoE(Cell):
|
||||
|
|
|
@ -35,10 +35,9 @@ from mindspore import log as logger
|
|||
from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.log import _LogActionOnce
|
||||
from mindspore.nn.transformer.layers import _LayerNorm, _Linear, _check_input_shape, \
|
||||
from mindspore.nn.transformer.layers import _LayerNorm, _Linear, \
|
||||
_args_type_validator_check, _valid_type_checks, _valid_value_checks, \
|
||||
_check_shape_equal, _check_past_none_input_none, _check_input_dtype, _check_input_shape_value, \
|
||||
_check_shape_equal_without_batch
|
||||
_check_past_none_input_none, _check_input_dtype
|
||||
from mindspore.nn.transformer.op_parallel_config import default_dpmp_config, _PipeLineConfig, OpParallelConfig, \
|
||||
_Config, _check_config, MoEParallelConfig
|
||||
from mindspore.nn.transformer.moe import default_moe_config, MoE, _check_moe_config
|
||||
|
@ -566,7 +565,6 @@ class FeedForward(Cell):
|
|||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
_check_input_shape(F.shape(x), "x", self.cls_name, [2, 3])
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
x = self.cast(x, mstype.float16)
|
||||
# returned shape is [bs, seq_length, ffn_hidden_size] or [bs * seq_length, ffn_hidden_size]
|
||||
|
@ -639,9 +637,7 @@ class AttentionMask(Cell):
|
|||
self.multiply = P.Mul().shard(((parallel_config.data_parallel, 1, 1), (1, 1, 1)))
|
||||
|
||||
def construct(self, input_mask):
|
||||
_check_input_shape(F.shape(input_mask), "input_mask", self.cls_name, 2)
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_shape_value(F.shape(input_mask), 1, "input_mask", self.cls_name, self.seq_length)
|
||||
input_mask = P.Cast()(self.not_equal(input_mask, 0), mstype.float16)
|
||||
input_shape = P.Shape()(input_mask)
|
||||
shape_right = (input_shape[0], 1, input_shape[1])
|
||||
|
@ -736,7 +732,6 @@ class VocabEmbedding(Cell):
|
|||
f"model parallel for the embedding lookup.")
|
||||
|
||||
def construct(self, input_ids):
|
||||
_check_input_shape(F.shape(input_ids), "input_ids", self.cls_name, 2)
|
||||
_check_input_dtype(F.dtype(input_ids), "input_ids", [mstype.int32], self.cls_name)
|
||||
output = self.gather(self.embedding_table, input_ids, 0)
|
||||
return output, self.embedding_table.value()
|
||||
|
@ -1223,27 +1218,6 @@ class MultiHeadAttention(Cell):
|
|||
def _check_inputs(self, query_tensor, key_tensor, value_tensor, attention_mask, key_past=None,
|
||||
value_past=None, batch_valid_length=None):
|
||||
r"""Check inputs"""
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal_without_batch(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[self.src_seq_length, self.hidden_size])
|
||||
_check_shape_equal_without_batch(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[self.tgt_seq_length, self.hidden_size])
|
||||
_check_shape_equal_without_batch(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[self.tgt_seq_length, self.hidden_size])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[F.shape(attention_mask)[0], self.src_seq_length, self.tgt_seq_length])
|
||||
else:
|
||||
_check_shape_equal(F.shape(query_tensor), "query_tensor", self.cls_name,
|
||||
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(key_tensor), "key_tensor", self.cls_name,
|
||||
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
|
||||
_check_shape_equal(F.shape(value_tensor), "value_tensor", self.cls_name,
|
||||
[[self.batch_size, 1, self.hidden_size], [self.batch_size, self.hidden_size]])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[[self.batch_size, 1, self.tgt_seq_length], [self.batch_size, self.hidden_size]])
|
||||
|
||||
_check_input_dtype(F.dtype(query_tensor), "query_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(key_tensor), "key_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
_check_input_dtype(F.dtype(value_tensor), "value_tensor", [mstype.float32, mstype.float16], self.cls_name)
|
||||
|
@ -1264,13 +1238,8 @@ class MultiHeadAttention(Cell):
|
|||
_check_past_none_input_none(self.use_past, "batch_valid_length", self.cls_name, None,
|
||||
batch_valid_length_is_tensor, batch_is_default)
|
||||
if self.use_past:
|
||||
_check_shape_equal(F.shape(key_past), "key_past", self.cls_name,
|
||||
[self.batch_size, self.n_head, self.size_per_head, self.tgt_seq_length])
|
||||
_check_input_dtype(F.dtype(key_past), "key_past", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(value_past), "value_past", self.cls_name,
|
||||
[self.batch_size, self.n_head, self.tgt_seq_length, self.size_per_head])
|
||||
_check_input_dtype(F.dtype(value_past), "value_past", [mstype.float16], self.cls_name)
|
||||
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
|
||||
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
|
||||
return True
|
||||
|
||||
|
@ -1770,17 +1739,6 @@ class TransformerEncoderLayer(Cell):
|
|||
|
||||
def _check_input(self, x, input_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal_without_batch(F.shape(x), "x", self.cls_name,
|
||||
[self.seq_length, self.hidden_size])
|
||||
if input_mask is not None:
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[F.shape(input_mask)[0], self.seq_length, self.seq_length])
|
||||
else:
|
||||
_check_shape_equal(F.shape(x), "x", self.cls_name, [self.batch_size, 1, self.hidden_size])
|
||||
if input_mask is not None:
|
||||
_check_shape_equal(F.shape(input_mask), "input_mask", self.cls_name,
|
||||
[F.shape(input_mask)[0], 1, self.seq_length])
|
||||
_check_input_dtype(F.dtype(x), "x", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if input_mask is not None:
|
||||
_check_input_dtype(F.dtype(input_mask), "input_mask", [mstype.float32, mstype.float16], self.cls_name)
|
||||
|
@ -1795,9 +1753,7 @@ class TransformerEncoderLayer(Cell):
|
|||
batch_valid_length_is_tensor, batch_is_default)
|
||||
|
||||
if self.use_past:
|
||||
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
|
||||
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
|
||||
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
|
||||
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
|
||||
return True
|
||||
|
||||
|
@ -2226,31 +2182,14 @@ class TransformerDecoderLayer(Cell):
|
|||
|
||||
def _check_input(self, hidden_states, attention_mask, encoder_output, memory_mask, init_reset, batch_valid_length):
|
||||
r"""Check inputs"""
|
||||
if not self.use_past or (self.use_past and self.is_first_iteration):
|
||||
_check_shape_equal_without_batch(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[self.tgt_seq_length, self.hidden_size])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[F.shape(attention_mask)[0], self.tgt_seq_length, self.tgt_seq_length])
|
||||
|
||||
else:
|
||||
_check_shape_equal(F.shape(hidden_states), "hidden_states", self.cls_name,
|
||||
[self.batch_size, 1, self.hidden_size])
|
||||
if attention_mask is not None:
|
||||
_check_shape_equal(F.shape(attention_mask), "attention_mask", self.cls_name,
|
||||
[self.batch_size, 1, self.tgt_seq_length])
|
||||
_check_input_dtype(F.dtype(hidden_states), "hidden_states", [mstype.float32, mstype.float16], self.cls_name)
|
||||
if attention_mask is not None:
|
||||
_check_input_dtype(F.dtype(attention_mask), "attention_mask", [mstype.float32, mstype.float16],
|
||||
self.cls_name)
|
||||
if encoder_output is not None:
|
||||
_check_shape_equal_without_batch(F.shape(encoder_output), "encoder_output", self.cls_name,
|
||||
[self.src_seq_length, self.hidden_size])
|
||||
_check_input_dtype(F.dtype(encoder_output), "encoder_output",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
if memory_mask is not None:
|
||||
_check_shape_equal_without_batch(F.shape(memory_mask), "memory_mask", self.cls_name,
|
||||
[self.tgt_seq_length, self.src_seq_length])
|
||||
_check_input_dtype(F.dtype(memory_mask), "memory_mask",
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
|
||||
|
@ -2264,9 +2203,7 @@ class TransformerDecoderLayer(Cell):
|
|||
batch_valid_length_is_tensor, batch_is_default)
|
||||
|
||||
if self.use_past:
|
||||
_check_shape_equal(F.shape(init_reset), "init_reset", self.cls_name, [1])
|
||||
_check_input_dtype(F.dtype(init_reset), "init_reset", [mstype.bool_], self.cls_name)
|
||||
_check_shape_equal(F.shape(batch_valid_length), "batch_valid_length", self.cls_name, [self.batch_size])
|
||||
_check_input_dtype(F.dtype(batch_valid_length), "batch_valid_length", [mstype.int32], self.cls_name)
|
||||
return True
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@ from mindspore._checkparam import Validator as validator
|
|||
from mindspore import ops, nn
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
@ -467,14 +466,6 @@ class _VirtualDatasetCell(Cell):
|
|||
return self._backbone(*output)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_value_on_axis_divided_by_target_value(input_shape, micro_size):
|
||||
if input_shape[0] % micro_size != 0:
|
||||
raise ValueError(f"For micro batch initialization, the 0th dimension shape of input({input_shape[0]}) must be "
|
||||
f"divided by micro size({micro_size})")
|
||||
return True
|
||||
|
||||
|
||||
class _MicroBatch(Cell):
|
||||
"""
|
||||
transform mini-batch to micro-batch in pipeline parallel.
|
||||
|
@ -493,7 +484,6 @@ class _MicroBatch(Cell):
|
|||
micro_inputs = ()
|
||||
for each_input in inputs:
|
||||
input_shape = self.shape(each_input)
|
||||
_check_shape_value_on_axis_divided_by_target_value(input_shape, self.micro_size)
|
||||
micro_batch_begin = i * input_shape[0] // self.micro_size
|
||||
micro_batch_end = (i + 1) * input_shape[0] // self.micro_size
|
||||
strided_slice_begin = (micro_batch_begin,)
|
||||
|
|
|
@ -39,7 +39,7 @@ from mindspore.numpy.array_ops import ravel, expand_dims, moveaxis, concatenate,
|
|||
split
|
||||
|
||||
from mindspore.numpy.utils_const import _infer_out_shape, _check_axis_valid, _get_device, \
|
||||
_check_shape_aligned, _raise_type_error, _check_same_type, _check_is_float, \
|
||||
_raise_type_error, _check_same_type, _check_is_float, \
|
||||
_raise_value_error, _promote, _check_axis_type, _canonicalize_axis, \
|
||||
_is_shape_empty, _check_is_int, _expanded_shape, _check_axis_in_range, \
|
||||
_check_dtype, _list_comprehensions, _tuple_setitem, _add_unit_axes, _seq_prod, \
|
||||
|
@ -682,7 +682,6 @@ def inner(a, b):
|
|||
if F.rank(a) == 0 or F.rank(b) == 0:
|
||||
return F.tensor_mul(a, b)
|
||||
|
||||
_check_shape_aligned(F.shape(a), F.shape(b))
|
||||
aligned_shape_a = (F.shape_mul(F.shape(a)[:-1]), F.shape(a)[-1])
|
||||
aligned_shape_b = (F.shape_mul(F.shape(b)[:-1]), F.shape(a)[-1])
|
||||
a_aligned = F.reshape(a, aligned_shape_a)
|
||||
|
|
|
@ -30,15 +30,12 @@ from mindspore._checkparam import Validator as validator
|
|||
|
||||
from mindspore.numpy.dtypes import promotion_rule, dtype_tuple, all_types, dtype_map, rule_for_trigonometric
|
||||
|
||||
|
||||
_check_axis_type = constexpr(validator.check_axis_type)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape(shape):
|
||||
"""check the shape param to match the numpy style"""
|
||||
if not isinstance(shape, (int, tuple, list, Tensor, typing.Tuple, typing.List)):
|
||||
raise TypeError(f"only int, tuple, list and tensor are allowed for shape, but got {type(shape)}")
|
||||
# convert tensor to int/list, use followed if statements to do further conversions
|
||||
if isinstance(shape, Tensor):
|
||||
shape = shape.asnumpy().tolist()
|
||||
|
@ -47,11 +44,6 @@ def _check_shape(shape):
|
|||
shape = (shape,)
|
||||
elif isinstance(shape, (list, typing.List)):
|
||||
shape = tuple(shape)
|
||||
for s in shape:
|
||||
if not isinstance(s, int):
|
||||
raise TypeError("each entry in shape should be int.")
|
||||
if s < 0:
|
||||
raise ValueError("each entry in shape should no less than 0.")
|
||||
return shape
|
||||
|
||||
|
||||
|
@ -89,8 +81,6 @@ def _is_shape_empty(shp):
|
|||
@constexpr
|
||||
def _check_start_normalize(start, ndim):
|
||||
"""check and normalize start argument for rollaxis."""
|
||||
if start < -ndim or start > ndim:
|
||||
raise ValueError(f"For rollaxis, start {start} is out of bounds. Ranging from {-ndim} to {ndim} is allowed.")
|
||||
if start < 0:
|
||||
start = start + ndim
|
||||
return start
|
||||
|
@ -112,9 +102,6 @@ def _check_axes_range(axes, ndim):
|
|||
TypeError: If the axes are not integer, tuple(int) or list(int).
|
||||
ValueError: If duplicate axes exists or some axis is out of bounds.
|
||||
"""
|
||||
_check_axis_type(axes, True, True, True)
|
||||
if isinstance(axes, (list, tuple)):
|
||||
_check_element_int(axes)
|
||||
axes = _canonicalize_axis(axes, ndim)
|
||||
return axes
|
||||
|
||||
|
@ -125,14 +112,12 @@ def _get_device():
|
|||
return context.get_context('device_target')
|
||||
|
||||
|
||||
#remove constexpr
|
||||
def _infer_out_shape(*shapes):
|
||||
"""
|
||||
Returns shape of output after broadcasting. Raises ValueError if shapes cannot be broadcast.
|
||||
"""
|
||||
shape_out = list()
|
||||
max_len = max([len(it) for it in shapes])
|
||||
|
||||
for i in range(max_len):
|
||||
items = [it[i-max_len+len(it)] if i-max_len +
|
||||
len(it) >= 0 else 1 for it in shapes]
|
||||
|
@ -146,23 +131,14 @@ def _can_broadcast(*shapes):
|
|||
"""
|
||||
Returns Ture if shapes can broadcast, False if they cannot.
|
||||
"""
|
||||
try:
|
||||
_infer_out_shape(*shapes)
|
||||
except ValueError:
|
||||
return False
|
||||
finally:
|
||||
pass
|
||||
_infer_out_shape(*shapes)
|
||||
return True
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_axis_in_range(axis, ndim):
|
||||
"""Checks axes are with the bounds of ndim"""
|
||||
if not isinstance(axis, int):
|
||||
raise TypeError(f'axes should be integers, not {type(axis)}')
|
||||
if not -ndim <= axis < ndim:
|
||||
raise ValueError(f'axis {axis} is out of bounds for array of dimension {ndim}')
|
||||
return axis % ndim
|
||||
return axis - axis // ndim * ndim
|
||||
|
||||
|
||||
@constexpr
|
||||
|
@ -176,19 +152,10 @@ def _check_axis_valid(axes, ndim):
|
|||
return axes
|
||||
if isinstance(axes, (tuple, list)):
|
||||
axes = tuple(map(lambda x: _check_axis_in_range(x, ndim), axes))
|
||||
if any(axes.count(el) > 1 for el in axes):
|
||||
raise ValueError('duplicate value in "axis"')
|
||||
return axes
|
||||
return (_check_axis_in_range(axes, ndim),)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_shape_aligned(shape1, shape2):
|
||||
"""Checks shape1 and shape2 are valid shapes to perform inner product"""
|
||||
if shape1[-1] != shape2[-1]:
|
||||
raise ValueError(f'shapes {shape1} {shape2} not aligned: {shape1[-1]} (dim 0) != {shape2[-1]} (dim 0)')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _tile_size(shape, out_shape, ndim):
|
||||
"""Returns tile_size such that shape*tile_size = out_shape"""
|
||||
|
@ -341,10 +308,31 @@ def _canonicalize_axis(axis, ndim):
|
|||
def canonicalizer(ax):
|
||||
return ax + ndim if ax < 0 else ax
|
||||
|
||||
def _sort_axis(ax):
|
||||
def merge(left, right):
|
||||
result = []
|
||||
i = j = 0
|
||||
while i < len(left) and j < len(right):
|
||||
if left[i] <= right[j]:
|
||||
result.append(left[i])
|
||||
i += 1
|
||||
else:
|
||||
result.append(right[j])
|
||||
j += 1
|
||||
|
||||
return result + left[i:] + right[j:]
|
||||
|
||||
if len(ax) <= 1:
|
||||
return ax
|
||||
|
||||
middle = len(ax) // 2
|
||||
left = _sort_axis(ax[:middle])
|
||||
right = _sort_axis(ax[middle:])
|
||||
|
||||
return merge(left, right)
|
||||
|
||||
axis = tuple([canonicalizer(axis) for axis in axis])
|
||||
if all(axis.count(el) <= 1 for el in axis):
|
||||
return tuple(sorted(axis)) if len(axis) > 1 else axis[0]
|
||||
raise ValueError(f"duplicate axes in {axis}.")
|
||||
return tuple(_sort_axis(axis)) if len(axis) > 1 else axis[0]
|
||||
|
||||
|
||||
@constexpr
|
||||
|
|
|
@ -209,3 +209,75 @@ def test_flip():
|
|||
net = Net()
|
||||
output = net(x)
|
||||
assert np.allclose(output.asnumpy(), expect.asnumpy())
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_calculate_expert_capacity():
|
||||
"""
|
||||
Feature: calculate_expert_capacity func
|
||||
Description: Verify the result of calculate_expert_capacity
|
||||
Expectation: success
|
||||
"""
|
||||
from mindspore.nn.transformer.moe import calculate_expert_capacity
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.func = calculate_expert_capacity
|
||||
|
||||
def construct(self, k, tokens_per_group, capacity_factor, expert_dim):
|
||||
return self.func(k, tokens_per_group, capacity_factor, expert_dim)
|
||||
net = Net()
|
||||
assert net(10.1, 2.0, 3.3, 4) == 17
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_infer_out_shape():
|
||||
"""
|
||||
Feature: _infer_out_shape func
|
||||
Description: Verify the result of _infer_out_shape
|
||||
Expectation: success
|
||||
"""
|
||||
from mindspore.numpy.utils_const import _infer_out_shape
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.func = _infer_out_shape
|
||||
|
||||
def construct(self, *shape):
|
||||
return self.func(*shape)
|
||||
net = Net()
|
||||
assert net((5,), (6, 1), (7, 1, 5), (8, 1, 6, 1)) == (8, 7, 6, 5)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_canonicalize_axis():
|
||||
"""
|
||||
Feature: _canonicalize_axis func
|
||||
Description: Verify the result of _canonicalize_axis
|
||||
Expectation: success
|
||||
"""
|
||||
from mindspore.numpy.utils_const import _canonicalize_axis
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.func = _canonicalize_axis
|
||||
|
||||
def construct(self, axis, ndim):
|
||||
return self.func(axis, ndim)
|
||||
net = Net()
|
||||
assert net(0, 2) == 0
|
||||
|
|
Loading…
Reference in New Issue