forked from mindspore-Ecosystem/mindspore
add check for ControlDepend
This commit is contained in:
parent
a3b9c238cc
commit
3e2074b113
|
@ -69,6 +69,8 @@ class ControlDepend(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self, depend_mode=0):
|
||||
"""init"""
|
||||
validator.check_int_range(
|
||||
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)
|
||||
|
||||
def __call__(self, src, dst):
|
||||
return src
|
||||
|
@ -128,8 +130,10 @@ class GeSwitch(PrimitiveWithInfer):
|
|||
return (data, data)
|
||||
|
||||
def infer_dtype(self, data_type, pred_type):
|
||||
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
|
||||
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
|
||||
validator.check_subclass(
|
||||
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
{"pred": pred_type}, [mstype.bool_], self.name)
|
||||
return (data_type, data_type)
|
||||
|
||||
|
||||
|
@ -161,5 +165,6 @@ class Merge(PrimitiveWithInfer):
|
|||
for i, item in enumerate(inputs):
|
||||
args['inputs[%d]' % i] = item
|
||||
|
||||
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
validator.check_tensor_type_same(
|
||||
args, (mstype.bool_,) + mstype.number_type, self.name)
|
||||
return (inputs[0], mstype.int32)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test control ops """
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
|
@ -434,3 +435,11 @@ def test_index_to_switch_layer():
|
|||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
|
||||
def test_control_depend_check():
|
||||
with pytest.raises(TypeError) as e:
|
||||
depend = P.ControlDepend(0.0)
|
||||
with pytest.raises(ValueError) as e:
|
||||
depend = P.ControlDepend(2)
|
||||
with pytest.raises(TypeError) as e:
|
||||
depend = P.ControlDepend((2,))
|
Loading…
Reference in New Issue