Add input check

This commit is contained in:
huangxinjing 2020-12-18 19:17:46 +08:00
parent 2cc4cade73
commit dd0da1542d
2 changed files with 68 additions and 1 deletions

View File

@ -34,6 +34,16 @@ from ..cell import Cell
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
@constexpr
def _check_input_2d(input_shape, param_name, func_name):
if len(input_shape) != 2:
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}")
return True
@constexpr
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
class Embedding(Cell):
r"""
@ -428,6 +438,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
def construct(self, input_indices, input_values, field_ids):
_check_input_2d(F.shape(input_indices), "input_indices", self.cls_name)
_check_input_2d(F.shape(input_values), "input_values", self.cls_name)
_check_input_2d(F.shape(field_ids), "field_ids", self.cls_name)
_check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
_check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
_check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)
batch_size = self.shape(input_indices)[0]
num_segments = batch_size * self.field_size
bias = Range(0, num_segments, self.field_size)()

View File

@ -14,11 +14,12 @@
# ============================================================================
""" test nn embedding """
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.common import dtype
from mindspore.common.api import _executor
from mindspore.nn import Embedding
from mindspore.nn import Embedding, MultiFieldEmbeddingLookup
from ..ut_filter import non_graph_engine
@ -43,6 +44,55 @@ def test_check_embedding_3():
_executor.compile(net, input_data)
def compile_multi_field_embedding(shape_id, shape_value, shape_field,
type_id, type_value, type_field):
net = MultiFieldEmbeddingLookup(20000, 768, 3)
input_data = Tensor(np.ones(shape_id), type_id)
input_value = Tensor(np.ones(shape_value), type_value)
input_field = Tensor(np.ones(shape_field), type_field)
_executor.compile(net, input_data, input_value, input_field)
@non_graph_engine
def test_check_multifield_embedding_right_type():
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int64, dtype.float32, dtype.int32)
@non_graph_engine
def test_check_multifield_embedding_false_type_input():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int32)
@non_graph_engine
def test_check_multifield_embedding_false_type_value():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int16, dtype.float16, dtype.int32)
@non_graph_engine
def test_check_multifield_embedding_false_type_field_id():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int16)
@non_graph_engine
def test_check_multifield_embedding_false_input_shape():
with pytest.raises(TypeError):
compile_multi_field_embedding((8,), (8, 200), (8, 200),
dtype.int16, dtype.float32, dtype.int16)
@non_graph_engine
def test_check_multifield_embedding_false_value_shape():
with pytest.raises(TypeError):
compile_multi_field_embedding((8, 200), (8,), (8, 200),
dtype.int16, dtype.float32, dtype.int16)
@non_graph_engine
def test_print_embedding():
net = Embedding(20000, 768, False)