add check for ControlDepend

This commit is contained in:
huangdongrun 2020-05-20 20:39:59 +08:00
parent a3b9c238cc
commit 3e2074b113
2 changed files with 17 additions and 3 deletions

View File

@ -69,6 +69,8 @@ class ControlDepend(Primitive):
@prim_attr_register
def __init__(self, depend_mode=0):
"""init"""
validator.check_int_range(
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)
def __call__(self, src, dst):
return src
@ -128,8 +130,10 @@ class GeSwitch(PrimitiveWithInfer):
return (data, data)
def infer_dtype(self, data_type, pred_type):
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
validator.check_subclass(
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
validator.check_tensor_type_same(
{"pred": pred_type}, [mstype.bool_], self.name)
return (data_type, data_type)
@ -161,5 +165,6 @@ class Merge(PrimitiveWithInfer):
for i, item in enumerate(inputs):
args['inputs[%d]' % i] = item
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
validator.check_tensor_type_same(
args, (mstype.bool_,) + mstype.number_type, self.name)
return (inputs[0], mstype.int32)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
""" test control ops """
import pytest
import numpy as np
import mindspore as ms
@ -434,3 +435,11 @@ def test_index_to_switch_layer():
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_control_depend_check():
with pytest.raises(TypeError) as e:
depend = P.ControlDepend(0.0)
with pytest.raises(ValueError) as e:
depend = P.ControlDepend(2)
with pytest.raises(TypeError) as e:
depend = P.ControlDepend((2,))