!43211 Add Multinomial, Adam, LogSoftmaxGrad and MaxPoolGrad Vmap test case
Merge pull request !43211 from YijieChen/vmap
This commit is contained in:
commit
8388e614f8
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,16 +13,19 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.nn import Dense
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
|
||||
|
@ -184,3 +187,58 @@ def test_adam_offload_acc():
|
|||
losses1.append(loss.asnumpy())
|
||||
|
||||
assert np.array_equal(losses1[-1], np.array(2.2237475, np.float32))
|
||||
|
||||
|
||||
def numpy_apply_adam(var, m, v, grad, beta1=0.9, beta2=0.999, eps=1e-8, lr=0.01):
|
||||
new_lr = lr * math.sqrt(1 - beta2) / (1 - beta1)
|
||||
m = m * beta1 + grad * (1 - beta1)
|
||||
v = v * beta2 + grad * grad * (1 - beta2)
|
||||
var = var - new_lr * m / (np.sqrt(v) + eps)
|
||||
return var
|
||||
|
||||
|
||||
class AdamNetVmap(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(AdamNetVmap, self).__init__()
|
||||
shape = (8, 9, 6, 10, 5)
|
||||
self.net = net
|
||||
self.var_np = np.random.randn(*shape).astype(np.float32)
|
||||
self.m_np = np.random.randn(*shape).astype(np.float32)
|
||||
self.v_np = np.random.randn(*shape).astype(np.float32)
|
||||
self.var = Parameter(Tensor(self.var_np), name="var")
|
||||
self.m = Parameter(Tensor(self.m_np), name="m")
|
||||
self.v = Parameter(Tensor(self.v_np), name="v")
|
||||
self.vmap_adam = vmap(self.net, in_axes=(
|
||||
0, 0, 0, None, None, None, None, None, None, 0), out_axes=0)
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
return self.vmap_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_arm_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_adam_witm_adam_op_vmap():
|
||||
"""
|
||||
Feature: Adam cpu kernel
|
||||
Description: test the Adam vmap.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
shape = (8, 9, 6, 10, 5)
|
||||
|
||||
def cal_amsgrad(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
return P.Adam()(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
|
||||
error = 1e-4
|
||||
grad_np = np.random.randn(*shape).astype(np.float32)
|
||||
grad = Tensor(grad_np)
|
||||
|
||||
vmap_adam = AdamNetVmap(cal_amsgrad)
|
||||
_ = vmap_adam(Tensor(0.9, ms.float32), Tensor(0.999, ms.float32), Tensor(0.01, ms.float32), Tensor(
|
||||
0.9, ms.float32), Tensor(0.999, ms.float32), Tensor(1e-8, ms.float32), grad)
|
||||
ms_var = vmap_adam.var.asnumpy()
|
||||
np_var = numpy_apply_adam(vmap_adam.var_np, vmap_adam.m_np,
|
||||
vmap_adam.v_np, grad_np)
|
||||
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
|
|
|
@ -237,3 +237,43 @@ def test_logsoftmaxgrad1():
|
|||
diff = dx[0].asnumpy() - expect
|
||||
err = np.ones(shape=expect.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
|
||||
|
||||
class LogSoftmaxForForward(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super().__init__()
|
||||
self.axis = axis
|
||||
self.logsoftmax = P.LogSoftmax(axis=axis)
|
||||
self.stack = P.Stack(axis=axis)
|
||||
|
||||
def construct(self, x):
|
||||
out = []
|
||||
for i in range(x.shape[self.axis]):
|
||||
out.append(self.logsoftmax(x[i]))
|
||||
out = self.stack(out)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad_vmap():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for LogSoftmax Grad vmap
|
||||
Expectation: the result match result
|
||||
"""
|
||||
seed = np.random.RandomState()
|
||||
x = Tensor(seed.random((3, 5, 1)).astype(np.float32))
|
||||
sens = Tensor(seed.random((3, 5, 1)).astype(np.float32))
|
||||
|
||||
forward = LogSoftmax(axis=0)
|
||||
for_forward = LogSoftmaxForForward(axis=0)
|
||||
backward = Grad(forward)
|
||||
for_backward = Grad(for_forward)
|
||||
|
||||
forward_result = forward(x)
|
||||
backward_vmap = ops.vmap(backward, in_axes=0, out_axes=0)(forward_result, sens)
|
||||
backward_for = for_backward(forward_result, sens)
|
||||
|
||||
np.testing.assert_allclose(backward_for[0].asnumpy(), backward_vmap[0].asnumpy(), rtol=1e-5)
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class MaxPool(nn.Cell):
|
||||
def __init__(self, kernel_size, strides, pad_mode):
|
||||
super().__init__()
|
||||
self.maxpool = P.MaxPool(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)
|
||||
|
||||
def construct(self, x):
|
||||
return self.maxpool(x)
|
||||
|
||||
|
||||
class MaxPoolGrad(nn.Cell):
|
||||
def __init__(self, forward):
|
||||
super().__init__()
|
||||
self.forward = forward
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, sens):
|
||||
return self.grad(self.forward)(x, sens)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_maxpool_grad_vmap():
|
||||
"""
|
||||
Feature: test MaxPoolGrad vmap feature.
|
||||
Description: test MaxPoolGrad vmap feature.
|
||||
Expectation: success.
|
||||
"""
|
||||
in_axes = -1
|
||||
seed = np.random.RandomState()
|
||||
x = Tensor(seed.random((1, 1, 6, 6, 3, 6)).astype(np.float32))
|
||||
sens = Tensor(seed.random((1, 1, 3, 3, 3, 6)).astype(np.float32))
|
||||
maxpool = MaxPool(kernel_size=2, strides=2, pad_mode="VALID")
|
||||
bp = MaxPoolGrad(maxpool)
|
||||
maxpoolgrad_vmap = vmap(vmap(bp, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
|
||||
out = maxpoolgrad_vmap(x, sens)
|
||||
|
||||
assert out[0].shape == (6, 3, 1, 1, 6, 6)
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -17,9 +17,11 @@ import pytest
|
|||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -97,3 +99,35 @@ def test_multinomial_dynamic_shape():
|
|||
outputs = dynamic_shape_net(x, indices_ms)
|
||||
expect_shape = (len(np.unique(indices_np)), 2)
|
||||
assert outputs.asnumpy().shape == expect_shape
|
||||
|
||||
|
||||
class BatchedMultinomial(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.multinomial = P.Multinomial(seed=5, seed2=6)
|
||||
|
||||
def construct(self, prob, num_sample):
|
||||
return self.multinomial(prob, num_sample)
|
||||
|
||||
|
||||
def multinomial(prob, num_sample):
|
||||
return P.Multinomial(seed=5, seed2=6)(prob, num_sample)
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_multinomial_vmap():
|
||||
"""
|
||||
Feature: test Multinomial vmap feature.
|
||||
Description: test Multinomial vmap feature.
|
||||
Expectation: success.
|
||||
"""
|
||||
prob = Tensor([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], ms.float32)
|
||||
num_sample = 3
|
||||
|
||||
batched_multinomial = BatchedMultinomial()
|
||||
batched_out = batched_multinomial(prob, num_sample)
|
||||
vmap_out = vmap(multinomial, in_axes=(0, None), out_axes=0)(prob, num_sample)
|
||||
|
||||
assert (batched_out.asnumpy() == vmap_out.asnumpy()).all()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -13,16 +13,19 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.nn import Dense
|
||||
from mindspore.nn import TrainOneStepCell, WithLossCell
|
||||
from mindspore.nn.optim import Adam
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
@ -142,3 +145,57 @@ def test_adam_offload_acc():
|
|||
losses1.append(loss.asnumpy())
|
||||
|
||||
assert np.array_equal(losses1[-1], np.array(2.2237475, np.float32))
|
||||
|
||||
|
||||
def numpy_apply_adam(var, m, v, grad, beta1=0.9, beta2=0.999, eps=1e-8, lr=0.01):
|
||||
new_lr = lr * math.sqrt(1 - beta2) / (1 - beta1)
|
||||
m = m * beta1 + grad * (1 - beta1)
|
||||
v = v * beta2 + grad * grad * (1 - beta2)
|
||||
var = var - new_lr * m / (np.sqrt(v) + eps)
|
||||
return var
|
||||
|
||||
|
||||
class AdamNetVmap(nn.Cell):
|
||||
def __init__(self, net):
|
||||
super(AdamNetVmap, self).__init__()
|
||||
shape = (8, 9, 6, 10, 5)
|
||||
self.net = net
|
||||
self.var_np = np.random.randn(*shape).astype(np.float32)
|
||||
self.m_np = np.random.randn(*shape).astype(np.float32)
|
||||
self.v_np = np.random.randn(*shape).astype(np.float32)
|
||||
self.var = Parameter(Tensor(self.var_np), name="var")
|
||||
self.m = Parameter(Tensor(self.m_np), name="m")
|
||||
self.v = Parameter(Tensor(self.v_np), name="v")
|
||||
self.vmap_adam = vmap(self.net, in_axes=(
|
||||
0, 0, 0, None, None, None, None, None, None, 0), out_axes=0)
|
||||
|
||||
def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
return self.vmap_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_apply_adam_witm_adam_op_vmap():
|
||||
"""
|
||||
Feature: Adam gpu kernel
|
||||
Description: test the Adam vmap.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
shape = (8, 9, 6, 10, 5)
|
||||
|
||||
def cal_amsgrad(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad):
|
||||
return P.Adam()(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad)
|
||||
|
||||
error = 1e-4
|
||||
grad_np = np.random.randn(*shape).astype(np.float32)
|
||||
grad = Tensor(grad_np)
|
||||
|
||||
vmap_adam = AdamNetVmap(cal_amsgrad)
|
||||
_ = vmap_adam(Tensor(0.9, ms.float32), Tensor(0.999, ms.float32), Tensor(0.01, ms.float32), Tensor(
|
||||
0.9, ms.float32), Tensor(0.999, ms.float32), Tensor(1e-8, ms.float32), grad)
|
||||
ms_var = vmap_adam.var.asnumpy()
|
||||
np_var = numpy_apply_adam(vmap_adam.var_np, vmap_adam.m_np,
|
||||
vmap_adam.v_np, grad_np)
|
||||
|
||||
np.testing.assert_allclose(ms_var, np_var, rtol=error, atol=error)
|
||||
|
|
|
@ -21,6 +21,7 @@ import mindspore.nn as nn
|
|||
from mindspore import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
|
@ -35,8 +36,8 @@ def test_logsoftmax():
|
|||
[-3.452001, -1.2546989, -1.4618242, -0.79552734]]).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
logSoftmax = P.LogSoftmax()
|
||||
output = logSoftmax(Tensor(x))
|
||||
log_softmax = P.LogSoftmax()
|
||||
output = log_softmax(Tensor(x))
|
||||
assert np.allclose(output.asnumpy(), expect)
|
||||
|
||||
|
||||
|
@ -140,3 +141,43 @@ def test_logsoftmaxgrad1():
|
|||
net = LogSoftmax(0)
|
||||
dx = Grad(net)(Tensor(x), Tensor(dy))
|
||||
assert np.allclose(dx[0].asnumpy(), expect)
|
||||
|
||||
|
||||
class LogSoftmaxForForward(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super().__init__()
|
||||
self.axis = axis
|
||||
self.logsoftmax = P.LogSoftmax(axis=axis)
|
||||
self.stack = P.Stack(axis=axis)
|
||||
|
||||
def construct(self, x):
|
||||
out = []
|
||||
for i in range(x.shape[self.axis]):
|
||||
out.append(self.logsoftmax(x[i]))
|
||||
out = self.stack(out)
|
||||
return out
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_logsoftmaxgrad_vmap():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for LogSoftmax Grad vmap
|
||||
Expectation: the result match result
|
||||
"""
|
||||
seed = np.random.RandomState()
|
||||
x = Tensor(seed.random((3, 5, 1)).astype(np.float32))
|
||||
sens = Tensor(seed.random((3, 5, 1)).astype(np.float32))
|
||||
|
||||
forward = LogSoftmax(axis=0)
|
||||
for_forward = LogSoftmaxForForward(axis=0)
|
||||
backward = Grad(forward)
|
||||
for_backward = Grad(for_forward)
|
||||
|
||||
forward_result = forward(x)
|
||||
backward_vmap = vmap(backward, in_axes=0, out_axes=0)(forward_result, sens)
|
||||
backward_for = for_backward(forward_result, sens)
|
||||
|
||||
np.testing.assert_allclose(backward_for[0].asnumpy(), backward_vmap[0].asnumpy(), rtol=1e-5)
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.functional import vmap
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
|
||||
class MaxPool(nn.Cell):
|
||||
def __init__(self, kernel_size, strides, pad_mode):
|
||||
super().__init__()
|
||||
self.maxpool = P.MaxPool(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)
|
||||
|
||||
def construct(self, x):
|
||||
return self.maxpool(x)
|
||||
|
||||
|
||||
class MaxPoolGrad(nn.Cell):
|
||||
def __init__(self, forward):
|
||||
super().__init__()
|
||||
self.forward = forward
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
|
||||
def construct(self, x, sens):
|
||||
return self.grad(self.forward)(x, sens)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_maxpool_grad_vmap():
|
||||
"""
|
||||
Feature: test MaxPoolGrad vmap feature.
|
||||
Description: test MaxPoolGrad vmap feature.
|
||||
Expectation: success.
|
||||
"""
|
||||
in_axes = -1
|
||||
seed = np.random.RandomState()
|
||||
x = Tensor(seed.random((1, 1, 6, 6, 3, 6)).astype(np.float32))
|
||||
sens = Tensor(seed.random((1, 1, 3, 3, 3, 6)).astype(np.float32))
|
||||
maxpool = MaxPool(kernel_size=2, strides=2, pad_mode="VALID")
|
||||
bp = MaxPoolGrad(maxpool)
|
||||
maxpoolgrad_vmap = vmap(vmap(bp, in_axes=in_axes, out_axes=0), in_axes=in_axes, out_axes=0)
|
||||
out = maxpoolgrad_vmap(x, sens)
|
||||
|
||||
assert out[0].shape == (6, 3, 1, 1, 6, 6)
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -16,12 +16,16 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.functional import vmap
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sample, replacement, seed=0):
|
||||
super(Net, self).__init__()
|
||||
|
@ -32,6 +36,7 @@ class Net(nn.Cell):
|
|||
def construct(self, x):
|
||||
return C.multinomial(x, self.sample, self.replacement, self.seed)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -47,3 +52,35 @@ def test_multinomial():
|
|||
assert out0.asnumpy().shape == (1,)
|
||||
assert out1.asnumpy().shape == (2,)
|
||||
assert out2.asnumpy().shape == (2, 6)
|
||||
|
||||
|
||||
class BatchedMultinomial(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.multinomial = P.Multinomial(seed=5, seed2=6)
|
||||
|
||||
def construct(self, prob, num_sample):
|
||||
return self.multinomial(prob, num_sample)
|
||||
|
||||
|
||||
def multinomial(prob, num_sample):
|
||||
return P.Multinomial(seed=5, seed2=6)(prob, num_sample)
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multinomial_vmap():
|
||||
"""
|
||||
Feature: test Multinomial vmap feature.
|
||||
Description: test Multinomial vmap feature.
|
||||
Expectation: success.
|
||||
"""
|
||||
prob = Tensor([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]], ms.float32)
|
||||
num_sample = 3
|
||||
|
||||
batched_multinomial = BatchedMultinomial()
|
||||
batched_out = batched_multinomial(prob, num_sample)
|
||||
vmap_out = vmap(multinomial, in_axes=(0, None), out_axes=0)(prob, num_sample)
|
||||
|
||||
assert (batched_out.asnumpy() == vmap_out.asnumpy()).all()
|
||||
|
|
Loading…
Reference in New Issue