fix softmax

This commit is contained in:
baihuawei 2020-09-27 09:53:27 +08:00
parent f5128faba5
commit ab427caf53
2 changed files with 33 additions and 2 deletions

View File

@ -28,9 +28,12 @@ void SoftmaxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "cpu softmax only support input axis size 1";
}
int axis = axis_list[0];
if (axis == -1 || axis >= SizeToInt(src_shape.size())) {
if (axis >= SizeToInt(src_shape.size())) {
axis = SizeToInt(src_shape.size()) - 1;
}
while (axis < 0) {
axis += SizeToInt(src_shape.size());
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, src_desc, axis);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());

View File

@ -29,7 +29,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class NetSoftmax(nn.Cell):
def __init__(self):
super(NetSoftmax, self).__init__()
self.softmax = P.Softmax()
self.softmax = P.Softmax(axis=-1)
x = Tensor(np.array([[0.1, 0.3, 0.6],
[0.2, -0.6, 0.8],
[0.6, 1, 0.4]]).astype(np.float32))
@ -52,3 +52,31 @@ def test_softmax():
diff = np.abs(outputSum - expect)
print(diff)
assert np.all(diff < error)
class NetSoftmax1(nn.Cell):
def __init__(self):
super(NetSoftmax1, self).__init__()
self.softmax = P.Softmax(axis=-2)
x = Tensor(np.array([[0.1, 0.3, 0.6],
[0.2, -0.6, 0.8],
[0.6, 1, 0.4]]).astype(np.float32))
self.x = Parameter(initializer(x, x.shape), name='x')
def construct(self):
return self.softmax(self.x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_softmax1():
Softmax = NetSoftmax1()
output = Softmax()
output = output.asnumpy()
outputSum = output.sum(axis=0)
expect = np.ones(3)
error = expect * 1.0e-6
diff = np.abs(outputSum - expect)
print(diff)
assert np.all(diff < error)