forked from OSSInnovation/mindspore
!2369 add cpu reduce op and cpu softmax_cross_entropy_with_logits op
Merge pull request !2369 from baihuawei/reduce
This commit is contained in:
commit
c9b8a8da0a
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "kernel/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h"
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include <cmath>
|
||||
#include "kernel/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||
CPUKernel::InitInputOutputSize(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
size_t type_size = sizeof(float);
|
||||
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
|
||||
workspace_size_list_.emplace_back(tensor_size);
|
||||
}
|
||||
|
||||
void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
dnnl::memory::dims mem_dims;
|
||||
mem_dims.insert(mem_dims.end(), shape.begin(), shape.end());
|
||||
if (mem_dims.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size();
|
||||
}
|
||||
batch_size_ = shape[0];
|
||||
class_num_ = shape[1];
|
||||
if (batch_size_ == 0 || class_num_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "invalid batch size or class num input!";
|
||||
}
|
||||
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
|
||||
|
||||
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
|
||||
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
|
||||
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
|
||||
|
||||
AddArgument(DNNL_ARG_SRC, mem_desc);
|
||||
AddArgument(DNNL_ARG_DST, mem_desc);
|
||||
}
|
||||
|
||||
void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels,
|
||||
float *output1, float *output2) const {
|
||||
float epsilon = 1e-6;
|
||||
for (size_t i = 0; i < batch_size_; ++i) {
|
||||
output1[i] = 0;
|
||||
float loss = 0.0;
|
||||
for (size_t j = 0; j < class_num_; ++j) {
|
||||
float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]);
|
||||
output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j];
|
||||
loss += labels[i * class_num_ + j] * logit;
|
||||
}
|
||||
output1[i] = -loss;
|
||||
}
|
||||
}
|
||||
|
||||
bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.empty() || workspace.empty() || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "error input output size!";
|
||||
}
|
||||
size_t batch_float_size = batch_size_ * sizeof(float);
|
||||
size_t batch_class_float_size = class_num_ * batch_float_size;
|
||||
if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size ||
|
||||
inputs[1]->size != batch_class_float_size) {
|
||||
MS_LOG(EXCEPTION) << "error input data size!";
|
||||
}
|
||||
if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) {
|
||||
MS_LOG(EXCEPTION) << "error output data size!";
|
||||
}
|
||||
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
|
||||
SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr);
|
||||
ExecutePrimitive();
|
||||
auto labels = reinterpret_cast<float *>(inputs[1]->addr);
|
||||
auto logits = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
auto output1 = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
auto output2 = reinterpret_cast<float *>(outputs[1]->addr);
|
||||
ForwardPostExecute(logits, labels, output1, output2);
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "kernel/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel {
|
||||
public:
|
||||
SoftmaxCrossEntropyWithLogitsCPUKernel() = default;
|
||||
~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
void InitInputOutputSize(const CNodePtr &kernel_node) override;
|
||||
|
||||
private:
|
||||
void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const;
|
||||
size_t class_num_{0};
|
||||
size_t batch_size_{0};
|
||||
};
|
||||
MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
SoftmaxCrossEntropyWithLogitsCPUKernel);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_
|
|
@ -0,0 +1,161 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "kernel/cpu/reduce_cpu_kernel.h"
|
||||
#include "device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const size_t kReduceTypeMax = 0;
|
||||
const size_t kReduceTypeMean = 1;
|
||||
const size_t kReduceTypeSum = 2;
|
||||
const size_t kMaxDim = 100;
|
||||
void ReduceCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (kernel_name == "ReduceMax") {
|
||||
reduce_type_ = kReduceTypeMax;
|
||||
} else if (kernel_name == "ReduceMean") {
|
||||
reduce_type_ = kReduceTypeMean;
|
||||
} else if (kernel_name == "ReduceSum") {
|
||||
reduce_type_ = kReduceTypeSum;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Array reduce kernel type " << kernel_name << " is not supported.";
|
||||
}
|
||||
shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
|
||||
if (axis_addr->isa<ValueTuple>()) {
|
||||
auto attr_axis = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, AXIS);
|
||||
if (attr_axis.size() > shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "invalid axis size: " << axis_.size();
|
||||
} else if (attr_axis.empty()) {
|
||||
axis_.push_back(shape_.size() - 1);
|
||||
} else {
|
||||
for (auto axis : attr_axis) {
|
||||
if (IntToSize(axis) >= (shape_.size())) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
|
||||
}
|
||||
}
|
||||
} else if (axis_addr->isa<Int32Imm>()) {
|
||||
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS);
|
||||
|
||||
if (axis >= 0 && IntToSize(axis) >= shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "axis value is oversize.";
|
||||
}
|
||||
axis < 0 ? axis_.push_back(axis + shape_.size()) : axis_.push_back(axis);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Attribute axis type is invalid.";
|
||||
}
|
||||
for (size_t i = 0; i < shape_.size(); ++i) {
|
||||
if (shape_[i] <= 0) {
|
||||
MS_LOG(EXCEPTION) << "shape value is invalid.";
|
||||
}
|
||||
left_dims_ *= shape_[i];
|
||||
}
|
||||
for (size_t i = 0; i < axis_.size(); ++i) {
|
||||
stride_ *= shape_[axis_[i]];
|
||||
}
|
||||
if (stride_ <= 0) {
|
||||
MS_LOG(EXCEPTION) << "stride_ must greater than zero.";
|
||||
}
|
||||
left_dims_ = left_dims_ / stride_;
|
||||
}
|
||||
bool ReduceCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspaces*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "input or output empty!";
|
||||
}
|
||||
size_t out_float_size = left_dims_ * sizeof(float);
|
||||
size_t in_float_size = stride_ * out_float_size;
|
||||
if (inputs[0]->size != in_float_size || outputs[0]->size != out_float_size) {
|
||||
MS_LOG(EXCEPTION) << "invalid input or output data size!";
|
||||
}
|
||||
auto input = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
auto output = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
int size = inputs[0]->size / sizeof(float);
|
||||
std::vector<float> new_input(IntToSize(size), 0.0);
|
||||
std::vector<size_t> transpose_axis;
|
||||
for (size_t i = 0; i < shape_.size(); ++i) {
|
||||
bool insert = true;
|
||||
for (size_t j = 0; j < axis_.size(); ++j) {
|
||||
if (axis_[j] == i) {
|
||||
insert = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (insert) {
|
||||
transpose_axis.push_back(i);
|
||||
}
|
||||
}
|
||||
(void)transpose_axis.insert(transpose_axis.end(), axis_.begin(), axis_.end());
|
||||
Transpose(size, input, shape_, transpose_axis, SizeToInt(shape_.size()), &new_input[0]);
|
||||
if (reduce_type_ == kReduceTypeMax) {
|
||||
for (size_t i = 0; i < left_dims_; ++i) {
|
||||
float value = new_input[i * stride_];
|
||||
for (size_t k = 0; k < stride_; ++k) {
|
||||
if (value < new_input[i * stride_ + k]) {
|
||||
value = new_input[i * stride_ + k];
|
||||
}
|
||||
}
|
||||
output[i] = value;
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < left_dims_; ++i) {
|
||||
float value = 0.0;
|
||||
for (size_t k = 0; k < stride_; ++k) {
|
||||
value += new_input[i * stride_ + k];
|
||||
}
|
||||
if (reduce_type_ == kReduceTypeMean) {
|
||||
output[i] = value / stride_;
|
||||
} else {
|
||||
output[i] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void ReduceCPUKernel::Transpose(const int size, const float *input, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &input_axis, const int shape_size, float *output) {
|
||||
int pos_array[kMaxDim];
|
||||
int size_offset[kMaxDim];
|
||||
size_offset[0] = size / SizeToInt(input_shape[0]);
|
||||
for (int i = 1; i < shape_size; i++) {
|
||||
size_offset[i] = size_offset[i - 1] / SizeToInt(input_shape[i]);
|
||||
}
|
||||
for (int position = 0; position < size; position += 1) {
|
||||
int temp_position = position;
|
||||
pos_array[0] = temp_position / size_offset[0];
|
||||
for (int i = 1; i < shape_size; i++) {
|
||||
temp_position -= pos_array[i - 1] * size_offset[i - 1];
|
||||
pos_array[i] = temp_position / size_offset[i];
|
||||
}
|
||||
int new_position = pos_array[SizeToInt(input_axis[shape_size - 1])];
|
||||
int new_position_size = 1;
|
||||
for (int j = shape_size - 2; j >= 0; j--) {
|
||||
new_position_size *= SizeToInt(input_shape[SizeToInt(input_axis[j + 1])]);
|
||||
new_position += pos_array[SizeToInt(input_axis[j])] * new_position_size;
|
||||
}
|
||||
output[new_position] = input[position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "kernel/cpu/cpu_kernel.h"
|
||||
#include "kernel/cpu/cpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ReduceCPUKernel : public CPUKernel {
|
||||
public:
|
||||
ReduceCPUKernel() = default;
|
||||
~ReduceCPUKernel() override = default;
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
void Transpose(const int size, const float *input, const std::vector<size_t> &input_shape,
|
||||
const std::vector<size_t> &input_axis, const int shape_size, float *output);
|
||||
size_t reduce_type_;
|
||||
std::vector<size_t> axis_;
|
||||
std::vector<size_t> shape_;
|
||||
size_t left_dims_ = 1;
|
||||
size_t stride_ = 1;
|
||||
};
|
||||
MS_REG_CPU_KERNEL(ReduceMean, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReduceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ReduceMax, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReduceCPUKernel);
|
||||
MS_REG_CPU_KERNEL(ReduceSum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ReduceCPUKernel);
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_CPU_REDUCE_CPU_KERNEL_H_
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
context.set_context(device_target="CPU")
|
||||
|
||||
|
||||
class NetReduce(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetReduce, self).__init__()
|
||||
self.axis0 = 0
|
||||
self.axis1 = 1
|
||||
self.axis2 = -1
|
||||
self.axis3 = (0, 1)
|
||||
self.axis4 = (0, 1, 2)
|
||||
self.reduce_mean = P.ReduceMean(False)
|
||||
self.reduce_sum = P.ReduceSum(False)
|
||||
self.reduce_max = P.ReduceMax(False)
|
||||
|
||||
@ms_function
|
||||
def construct(self, indice):
|
||||
return (self.reduce_mean(indice, self.axis0),
|
||||
self.reduce_mean(indice, self.axis1),
|
||||
self.reduce_mean(indice, self.axis2),
|
||||
self.reduce_mean(indice, self.axis3),
|
||||
self.reduce_mean(indice, self.axis4),
|
||||
self.reduce_sum(indice, self.axis0),
|
||||
self.reduce_sum(indice, self.axis2),
|
||||
self.reduce_max(indice, self.axis0),
|
||||
self.reduce_max(indice, self.axis2))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_reduce():
|
||||
reduce = NetReduce()
|
||||
indice = Tensor(np.array([
|
||||
[[0., 2., 1., 4., 0., 2.], [3., 1., 2., 2., 4., 0.]],
|
||||
[[2., 0., 1., 5., 0., 1.], [1., 0., 0., 4., 4., 3.]],
|
||||
[[4., 1., 4., 0., 0., 0.], [2., 5., 1., 0., 1., 3.]]
|
||||
]).astype(np.float32))
|
||||
output = reduce(indice)
|
||||
print(output[0])
|
||||
print(output[1])
|
||||
print(output[2])
|
||||
print(output[3])
|
||||
print(output[4])
|
||||
print(output[5])
|
||||
print(output[6])
|
||||
print(output[7])
|
||||
print(output[8])
|
||||
expect_0 = np.array([[2., 1., 2., 3., 0., 1], [2., 2., 1., 2., 3., 2.]]).astype(np.float32)
|
||||
expect_1 = np.array([[1.5, 1.5, 1.5, 3., 2., 1.], [1.5, 0., 0.5, 4.5, 2., 2.], [3., 3., 2.5, 0., 0.5, 1.5]]).astype(
|
||||
np.float32)
|
||||
expect_2 = np.array([[1.5, 2.], [1.5, 2.], [1.5, 2.]]).astype(np.float32)
|
||||
expect_3 = np.array([2, 1.5, 1.5, 2.5, 1.5, 1.5]).astype(np.float32)
|
||||
expect_4 = np.array([1.75]).astype(np.float32)
|
||||
expect_5 = np.array([[6., 3., 6., 9., 0., 3.], [6., 6., 3., 6., 9., 6.]]).astype(np.float32)
|
||||
expect_6 = np.array([[9., 12.], [9., 12.], [9., 12.]]).astype(np.float32)
|
||||
expect_7 = np.array([[4., 2., 4., 5., 0., 2.], [3., 5., 2., 4., 4., 3.]]).astype(np.float32)
|
||||
expect_8 = np.array([[4., 4.], [5., 4.], [4., 5.]]).astype(np.float32)
|
||||
assert (output[0].asnumpy() == expect_0).all()
|
||||
assert (output[1].asnumpy() == expect_1).all()
|
||||
assert (output[2].asnumpy() == expect_2).all()
|
||||
assert (output[3].asnumpy() == expect_3).all()
|
||||
assert (output[4].asnumpy() == expect_4).all()
|
||||
assert (output[5].asnumpy() == expect_5).all()
|
||||
assert (output[6].asnumpy() == expect_6).all()
|
||||
assert (output[7].asnumpy() == expect_7).all()
|
||||
assert (output[8].asnumpy() == expect_8).all()
|
||||
|
||||
|
||||
test_reduce()
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
|
||||
|
||||
class NetSoftmaxCrossEntropyWithLogits(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetSoftmaxCrossEntropyWithLogits, self).__init__()
|
||||
self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
|
||||
def construct(self, logits, labels):
|
||||
return self.loss(logits, labels)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_softmax_cross_entropy_with_logits():
|
||||
logits = Tensor(np.array([[1, 1, 10],
|
||||
[1, 10, 1],
|
||||
[10, 1, 1]]).astype(np.float32))
|
||||
labels = Tensor(np.array([[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[1, 0, 0]]).astype(np.float32))
|
||||
expect_loss = [0.00024673, 0.00024673, 0.00024673]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
softmax_cross_entropy_with_logits = NetSoftmaxCrossEntropyWithLogits()
|
||||
output = softmax_cross_entropy_with_logits(logits, labels)
|
||||
error0 = 1.0e-6
|
||||
diff0 = output.asnumpy() - expect_loss
|
||||
assert np.all(abs(diff0) < error0)
|
||||
|
||||
test_softmax_cross_entropy_with_logits()
|
Loading…
Reference in New Issue