forked from mindspore-Ecosystem/mindspore
fix example of ControlDepend to ensure pass in gpu and cpu
This commit is contained in:
parent
80978cf3cc
commit
3684ec4967
|
@ -51,16 +51,18 @@ class ControlDepend(Primitive):
|
|||
>>> class Net(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(Net, self).__init__()
|
||||
>>> self.global_step = mindspore.Parameter(initializer(0, [1]), name="global_step")
|
||||
>>> self.rate = 0.2
|
||||
>>> self.control_depend = P.ControlDepend()
|
||||
>>> self.softmax = P.Softmax()
|
||||
>>>
|
||||
>>> def construct(self, x):
|
||||
>>> data = self.rate * self.global_step + x
|
||||
>>> added_global_step = self.global_step + 1
|
||||
>>> self.global_step = added_global_step
|
||||
>>> self.control_depend(data, added_global_step)
|
||||
>>> return data
|
||||
>>> def construct(self, x, y):
|
||||
>>> mul = x * y
|
||||
>>> softmax = self.softmax(x)
|
||||
>>> ret = self.control_depend(mul, softmax)
|
||||
>>> return ret
|
||||
>>> x = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
|
||||
>>> y = Tensor(np.ones([4, 5]), dtype=mindspore.float32)
|
||||
>>> net = Net()
|
||||
>>> output = net(x, y)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
|
Loading…
Reference in New Issue