Add input check
This commit is contained in:
parent
2cc4cade73
commit
dd0da1542d
|
@ -34,6 +34,16 @@ from ..cell import Cell
|
|||
|
||||
__all__ = ['Embedding', 'EmbeddingLookup', 'MultiFieldEmbeddingLookup']
|
||||
|
||||
@constexpr
|
||||
def _check_input_2d(input_shape, param_name, func_name):
|
||||
if len(input_shape) != 2:
|
||||
raise ValueError(f"{func_name} {param_name} should be 2d, but got shape {input_shape}")
|
||||
return True
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
||||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
class Embedding(Cell):
|
||||
r"""
|
||||
|
@ -428,6 +438,13 @@ class MultiFieldEmbeddingLookup(EmbeddingLookup):
|
|||
|
||||
def construct(self, input_indices, input_values, field_ids):
|
||||
|
||||
_check_input_2d(F.shape(input_indices), "input_indices", self.cls_name)
|
||||
_check_input_2d(F.shape(input_values), "input_values", self.cls_name)
|
||||
_check_input_2d(F.shape(field_ids), "field_ids", self.cls_name)
|
||||
_check_input_dtype(F.dtype(input_indices), "input_indices", [mstype.int32, mstype.int64], self.cls_name)
|
||||
_check_input_dtype(F.dtype(input_values), "input_values", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(F.dtype(field_ids), "field_ids", [mstype.int32], self.cls_name)
|
||||
|
||||
batch_size = self.shape(input_indices)[0]
|
||||
num_segments = batch_size * self.field_size
|
||||
bias = Range(0, num_segments, self.field_size)()
|
||||
|
|
|
@ -14,11 +14,12 @@
|
|||
# ============================================================================
|
||||
""" test nn embedding """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype
|
||||
from mindspore.common.api import _executor
|
||||
from mindspore.nn import Embedding
|
||||
from mindspore.nn import Embedding, MultiFieldEmbeddingLookup
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
||||
|
||||
|
@ -43,6 +44,55 @@ def test_check_embedding_3():
|
|||
_executor.compile(net, input_data)
|
||||
|
||||
|
||||
def compile_multi_field_embedding(shape_id, shape_value, shape_field,
|
||||
type_id, type_value, type_field):
|
||||
net = MultiFieldEmbeddingLookup(20000, 768, 3)
|
||||
input_data = Tensor(np.ones(shape_id), type_id)
|
||||
input_value = Tensor(np.ones(shape_value), type_value)
|
||||
input_field = Tensor(np.ones(shape_field), type_field)
|
||||
_executor.compile(net, input_data, input_value, input_field)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_check_multifield_embedding_right_type():
|
||||
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
|
||||
dtype.int64, dtype.float32, dtype.int32)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_check_multifield_embedding_false_type_input():
|
||||
with pytest.raises(TypeError):
|
||||
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
|
||||
dtype.int16, dtype.float32, dtype.int32)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_check_multifield_embedding_false_type_value():
|
||||
with pytest.raises(TypeError):
|
||||
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
|
||||
dtype.int16, dtype.float16, dtype.int32)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_check_multifield_embedding_false_type_field_id():
|
||||
with pytest.raises(TypeError):
|
||||
compile_multi_field_embedding((8, 200), (8, 200), (8, 200),
|
||||
dtype.int16, dtype.float32, dtype.int16)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_check_multifield_embedding_false_input_shape():
|
||||
with pytest.raises(TypeError):
|
||||
compile_multi_field_embedding((8,), (8, 200), (8, 200),
|
||||
dtype.int16, dtype.float32, dtype.int16)
|
||||
|
||||
|
||||
@non_graph_engine
|
||||
def test_check_multifield_embedding_false_value_shape():
|
||||
with pytest.raises(TypeError):
|
||||
compile_multi_field_embedding((8, 200), (8,), (8, 200),
|
||||
dtype.int16, dtype.float32, dtype.int16)
|
||||
|
||||
@non_graph_engine
|
||||
def test_print_embedding():
|
||||
net = Embedding(20000, 768, False)
|
||||
|
|
Loading…
Reference in New Issue