change from CRLF to LF

This commit is contained in:
zhujingxuan 2021-07-13 16:50:33 +08:00
parent 85e20508eb
commit b3d4399d32
27 changed files with 3102 additions and 3102 deletions

View File

@ -1,116 +1,116 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.h" #include "backend/kernel_compiler/cpu/binary_cross_entropy_cpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
constexpr size_t kBceInputNumWithWeight = 3; constexpr size_t kBceInputNumWithWeight = 3;
template <typename T> template <typename T>
void BinaryCrossEntropyCpuKernel::LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss) { void BinaryCrossEntropyCpuKernel::LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss) {
if (input_size % 2 == 1) { if (input_size % 2 == 1) {
tmp_loss[0] += tmp_loss[input_size - 1]; tmp_loss[0] += tmp_loss[input_size - 1];
} }
for (int stride = input_size / 2; stride > 0; stride = stride / 2) { for (int stride = input_size / 2; stride > 0; stride = stride / 2) {
for (int i = 0; i < stride; i++) { for (int i = 0; i < stride; i++) {
tmp_loss[i] += tmp_loss[i + stride]; tmp_loss[i] += tmp_loss[i + stride];
} }
if (stride > 2 && stride % 2 == 1) { if (stride > 2 && stride % 2 == 1) {
tmp_loss[0] += tmp_loss[stride - 1]; tmp_loss[0] += tmp_loss[stride - 1];
} }
} }
loss[0] += tmp_loss[0]; loss[0] += tmp_loss[0];
if (reduction == 1) { if (reduction == 1) {
loss[0] /= static_cast<T>(input_size); loss[0] /= static_cast<T>(input_size);
} }
} }
template <typename T> template <typename T>
void BinaryCrossEntropyCpuKernel::Launchkernel(const std::vector<AddressPtr> &inputs, void BinaryCrossEntropyCpuKernel::Launchkernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
T *input_x = reinterpret_cast<T *>(inputs[0]->addr); T *input_x = reinterpret_cast<T *>(inputs[0]->addr);
T *input_y = reinterpret_cast<T *>(inputs[1]->addr); T *input_y = reinterpret_cast<T *>(inputs[1]->addr);
T *weight = nullptr; T *weight = nullptr;
if (weight_defined_) { if (weight_defined_) {
weight = reinterpret_cast<T *>(inputs[2]->addr); weight = reinterpret_cast<T *>(inputs[2]->addr);
} }
T *loss = reinterpret_cast<T *>(outputs[0]->addr); T *loss = reinterpret_cast<T *>(outputs[0]->addr);
std::vector<T> tmp_loss(input_size_); std::vector<T> tmp_loss(input_size_);
T epsilon = static_cast<T>(1e-12); T epsilon = static_cast<T>(1e-12);
T one = static_cast<T>(1); T one = static_cast<T>(1);
if (reduction_ == 0 && weight_defined_) { if (reduction_ == 0 && weight_defined_) {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T value = T value =
-weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon)); -weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
loss[i] = value; loss[i] = value;
} }
} else if (reduction_ == 0 && (!weight_defined_)) { } else if (reduction_ == 0 && (!weight_defined_)) {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T value = -(input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon)); T value = -(input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
loss[i] = value; loss[i] = value;
} }
} else if ((reduction_ != 0) && weight_defined_) { } else if ((reduction_ != 0) && weight_defined_) {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T value = T value =
-weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon)); -weight[i] * (input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
tmp_loss[i] = value; tmp_loss[i] = value;
} }
} else { } else {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T value = -(input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon)); T value = -(input_y[i] * log(input_x[i] + epsilon) + (one - input_y[i]) * log(one - input_x[i] + epsilon));
tmp_loss[i] = value; tmp_loss[i] = value;
} }
} }
if (reduction_ != 0) { if (reduction_ != 0) {
LaunchToScalar<T>(input_size_, reduction_, loss, tmp_loss.data()); LaunchToScalar<T>(input_size_, reduction_, loss, tmp_loss.data());
} }
} }
bool BinaryCrossEntropyCpuKernel::Launch(const std::vector<AddressPtr> &inputs, bool BinaryCrossEntropyCpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
if (input_size_ > 0) { if (input_size_ > 0) {
if (dtype_ == kNumberTypeFloat32) { if (dtype_ == kNumberTypeFloat32) {
Launchkernel<float>(inputs, workspace, outputs); Launchkernel<float>(inputs, workspace, outputs);
} else if (dtype_ == kNumberTypeFloat16) { } else if (dtype_ == kNumberTypeFloat16) {
Launchkernel<float16>(inputs, workspace, outputs); Launchkernel<float16>(inputs, workspace, outputs);
} }
} }
return true; return true;
} }
void BinaryCrossEntropyCpuKernel::InitKernel(const CNodePtr &kernel_node) { void BinaryCrossEntropyCpuKernel::InitKernel(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); i++) { for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i]; input_size_ *= input_shape[i];
} }
string reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction"); string reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction");
if (reduction == "none") { if (reduction == "none") {
reduction_ = 0; reduction_ = 0;
} else if (reduction == "sum") { } else if (reduction == "sum") {
reduction_ = 2; reduction_ = 2;
} }
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
weight_defined_ = (input_num == kBceInputNumWithWeight); weight_defined_ = (input_num == kBceInputNumWithWeight);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,71 +1,71 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H
#include <vector> #include <vector>
#include <string> #include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class BinaryCrossEntropyCpuKernel : public CPUKernel { class BinaryCrossEntropyCpuKernel : public CPUKernel {
public: public:
BinaryCrossEntropyCpuKernel() : input_size_(1), reduction_(1), weight_defined_(false) {} BinaryCrossEntropyCpuKernel() : input_size_(1), reduction_(1), weight_defined_(false) {}
~BinaryCrossEntropyCpuKernel() override = default; ~BinaryCrossEntropyCpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
template <typename T> template <typename T>
void LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss); void LaunchToScalar(const int &input_size, const int &reduction, T *loss, T *tmp_loss);
template <typename T> template <typename T>
void Launchkernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, void Launchkernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs); const std::vector<AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
size_t input_size_; size_t input_size_;
int reduction_; int reduction_;
bool weight_defined_; // true: there are 3 inputs, false: there are 2 inputs(no [weight]) bool weight_defined_; // true: there are 3 inputs, false: there are 2 inputs(no [weight])
}; };
MS_REG_CPU_KERNEL(BinaryCrossEntropy, MS_REG_CPU_KERNEL(BinaryCrossEntropy,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyCpuKernel); BinaryCrossEntropyCpuKernel);
MS_REG_CPU_KERNEL(BinaryCrossEntropy, MS_REG_CPU_KERNEL(BinaryCrossEntropy,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyCpuKernel); BinaryCrossEntropyCpuKernel);
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
BinaryCrossEntropy, BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyCpuKernel); BinaryCrossEntropyCpuKernel);
MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL(
BinaryCrossEntropy, BinaryCrossEntropy,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyCpuKernel); BinaryCrossEntropyCpuKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_KERNEL_H

View File

@ -1,102 +1,102 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.h" #include "backend/kernel_compiler/cpu/binary_cross_entropy_grad_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
constexpr size_t kBceGradInputNumWithWeight = 4; constexpr size_t kBceGradInputNumWithWeight = 4;
template <typename T> template <typename T>
void BinaryCrossEntropyGradCpuKernel::Launchkernel(const std::vector<AddressPtr> &inputs, void BinaryCrossEntropyGradCpuKernel::Launchkernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
T *input_x = reinterpret_cast<T *>(inputs[0]->addr); T *input_x = reinterpret_cast<T *>(inputs[0]->addr);
T *input_y = reinterpret_cast<T *>(inputs[1]->addr); T *input_y = reinterpret_cast<T *>(inputs[1]->addr);
T *dloss = reinterpret_cast<T *>(inputs[2]->addr); T *dloss = reinterpret_cast<T *>(inputs[2]->addr);
T *weight = nullptr; T *weight = nullptr;
if (weight_defined_) { if (weight_defined_) {
weight = reinterpret_cast<T *>(inputs[3]->addr); weight = reinterpret_cast<T *>(inputs[3]->addr);
} }
T *dx = reinterpret_cast<T *>(outputs[0]->addr); T *dx = reinterpret_cast<T *>(outputs[0]->addr);
T epsilon = static_cast<T>(1e-12); T epsilon = static_cast<T>(1e-12);
T one = static_cast<T>(1); T one = static_cast<T>(1);
if (reduction_ == 0) { if (reduction_ == 0) {
if (weight_defined_) { if (weight_defined_) {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon; T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = weight[i] * (input_x[i] - input_y[i]) / denominator; T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i]; dx[i] = value * dloss[i];
} }
} else { } else {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon; T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = (input_x[i] - input_y[i]) / denominator; T value = (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss[i]; dx[i] = value * dloss[i];
} }
} }
} else { } else {
T dloss1 = dloss[0]; T dloss1 = dloss[0];
if (reduction_ == 1) { if (reduction_ == 1) {
dloss1 = dloss[0] / static_cast<T>(input_size_); dloss1 = dloss[0] / static_cast<T>(input_size_);
} }
if (weight_defined_) { if (weight_defined_) {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon; T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = weight[i] * (input_x[i] - input_y[i]) / denominator; T value = weight[i] * (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1; dx[i] = value * dloss1;
} }
} else { } else {
for (size_t i = 0; i < input_size_; i++) { for (size_t i = 0; i < input_size_; i++) {
T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon; T denominator = ((input_x[i] * (one - input_x[i])) > epsilon) ? (input_x[i] * (one - input_x[i])) : epsilon;
T value = (input_x[i] - input_y[i]) / denominator; T value = (input_x[i] - input_y[i]) / denominator;
dx[i] = value * dloss1; dx[i] = value * dloss1;
} }
} }
} }
} }
bool BinaryCrossEntropyGradCpuKernel::Launch(const std::vector<AddressPtr> &inputs, bool BinaryCrossEntropyGradCpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
if (input_size_ > 0) { if (input_size_ > 0) {
if (dtype_ == kNumberTypeFloat32) { if (dtype_ == kNumberTypeFloat32) {
Launchkernel<float>(inputs, outputs); Launchkernel<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat16) { } else if (dtype_ == kNumberTypeFloat16) {
Launchkernel<float16>(inputs, outputs); Launchkernel<float16>(inputs, outputs);
} }
} }
return true; return true;
} }
void BinaryCrossEntropyGradCpuKernel::InitKernel(const CNodePtr &kernel_node) { void BinaryCrossEntropyGradCpuKernel::InitKernel(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); i++) { for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i]; input_size_ *= input_shape[i];
} }
string reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction"); string reduction = AnfAlgo::GetNodeAttr<string>(kernel_node, "reduction");
if (reduction == "none") { if (reduction == "none") {
reduction_ = 0; reduction_ = 0;
} else if (reduction == "sum") { } else if (reduction == "sum") {
reduction_ = 2; reduction_ = 2;
} }
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
weight_defined_ = (input_num == kBceGradInputNumWithWeight); weight_defined_ = (input_num == kBceGradInputNumWithWeight);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,76 +1,76 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H
#include <vector> #include <vector>
#include <string> #include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class BinaryCrossEntropyGradCpuKernel : public CPUKernel { class BinaryCrossEntropyGradCpuKernel : public CPUKernel {
public: public:
BinaryCrossEntropyGradCpuKernel() : input_size_(1), reduction_(1), weight_defined_(false) {} BinaryCrossEntropyGradCpuKernel() : input_size_(1), reduction_(1), weight_defined_(false) {}
~BinaryCrossEntropyGradCpuKernel() override = default; ~BinaryCrossEntropyGradCpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
template <typename T> template <typename T>
void Launchkernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); void Launchkernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
size_t input_size_; size_t input_size_;
int reduction_; int reduction_;
bool weight_defined_; // true: there are 4 inputs, false: there are 3 inputs(no [weight]) bool weight_defined_; // true: there are 4 inputs, false: there are 3 inputs(no [weight])
}; };
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad, MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGradCpuKernel); BinaryCrossEntropyGradCpuKernel);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad, MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGradCpuKernel); BinaryCrossEntropyGradCpuKernel);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad, MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
BinaryCrossEntropyGradCpuKernel); BinaryCrossEntropyGradCpuKernel);
MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad, MS_REG_CPU_KERNEL(BinaryCrossEntropyGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
BinaryCrossEntropyGradCpuKernel); BinaryCrossEntropyGradCpuKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_NN_BINARY_CROSS_ENTROPY_GRAD_KERNEL_H

View File

@ -1,271 +1,271 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "common/thread_pool.h" #include "common/thread_pool.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { void CPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) { for (size_t input_index = 0; input_index < input_num; ++input_index) {
TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index); TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index);
size_t type_size = GetTypeByte(TypeIdToType(type_id)); size_t type_size = GetTypeByte(TypeIdToType(type_id));
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index); std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index);
size_t tensor_size = size_t tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
tensor_size = std::max(tensor_size, type_size); tensor_size = std::max(tensor_size, type_size);
input_size_list_.emplace_back(tensor_size); input_size_list_.emplace_back(tensor_size);
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) { for (size_t output_index = 0; output_index < output_num; ++output_index) {
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, output_index); TypeId type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, output_index);
size_t type_size = GetTypeByte(TypeIdToType(type_id)); size_t type_size = GetTypeByte(TypeIdToType(type_id));
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index); std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index);
size_t tensor_size = size_t tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
tensor_size = std::max(tensor_size, type_size); tensor_size = std::max(tensor_size, type_size);
output_size_list_.emplace_back(tensor_size); output_size_list_.emplace_back(tensor_size);
} }
} }
void CPUKernel::Init(const CNodePtr &kernel_node) { void CPUKernel::Init(const CNodePtr &kernel_node) {
InitKernel(kernel_node); InitKernel(kernel_node);
InitInputOutputSize(kernel_node); InitInputOutputSize(kernel_node);
} }
void CPUKernelUtils::ExpandDimsTo4(std::vector<size_t> *shape) { void CPUKernelUtils::ExpandDimsTo4(std::vector<size_t> *shape) {
auto len = shape->size(); auto len = shape->size();
if (len < 4) { if (len < 4) {
for (size_t i = 0; i < 4 - len; ++i) { for (size_t i = 0; i < 4 - len; ++i) {
shape->insert(shape->begin(), 1); shape->insert(shape->begin(), 1);
} }
} }
} }
size_t CPUKernelUtils::CalcOffset(const std::vector<size_t> &shape, size_t dim0, size_t dim1, size_t dim2, size_t CPUKernelUtils::CalcOffset(const std::vector<size_t> &shape, size_t dim0, size_t dim1, size_t dim2,
size_t dim3) { size_t dim3) {
size_t offset = dim0 * shape[1] * shape[2] * shape[3] + dim1 * shape[2] * shape[3] + dim2 * shape[3] + dim3; size_t offset = dim0 * shape[1] * shape[2] * shape[3] + dim1 * shape[2] * shape[3] + dim2 * shape[3] + dim3;
return offset; return offset;
} }
size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector<size_t> &shape, int axis) { size_t CPUKernelUtils::GetElementNumOnAxis(const std::vector<size_t> &shape, int axis) {
if (axis < 0) { if (axis < 0) {
axis = axis + SizeToInt(shape.size()); axis = axis + SizeToInt(shape.size());
} }
size_t result = 1; size_t result = 1;
for (int j = 3; j > axis; --j) { for (int j = 3; j > axis; --j) {
result *= shape[j]; result *= shape[j];
} }
return result; return result;
} }
void CPUKernelUtils::GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num) { void CPUKernelUtils::GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num) {
size_t accumulation = 1; size_t accumulation = 1;
element_num->emplace_back(1); element_num->emplace_back(1);
for (size_t i = shape.size() - 1; i > 0; --i) { for (size_t i = shape.size() - 1; i > 0; --i) {
accumulation *= shape[i]; accumulation *= shape[i];
element_num->emplace_back(accumulation); element_num->emplace_back(accumulation);
} }
std::reverse(element_num->begin(), element_num->end()); std::reverse(element_num->begin(), element_num->end());
} }
void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) { void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) {
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0; const float block_size = 128.0;
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num; size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
std::vector<common::Task> tasks; std::vector<common::Task> tasks;
size_t start = 0; size_t start = 0;
size_t once_compute_size = (count + thread_num - 1) / thread_num; size_t once_compute_size = (count + thread_num - 1) / thread_num;
while (start < count) { while (start < count) {
size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size); size_t end = (start + once_compute_size) > count ? count : (start + once_compute_size);
auto block = [&, start, end]() { auto block = [&, start, end]() {
task(start, end); task(start, end);
return common::SUCCESS; return common::SUCCESS;
}; };
tasks.emplace_back(block); tasks.emplace_back(block);
start += once_compute_size; start += once_compute_size;
} }
common::ThreadPool::GetInstance().SyncRun(tasks); common::ThreadPool::GetInstance().SyncRun(tasks);
} }
std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) { std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) {
if (axis < 0) { if (axis < 0) {
axis = axis + SizeToInt(shape.size()); axis = axis + SizeToInt(shape.size());
} }
size_t dim_row = 1; size_t dim_row = 1;
size_t dim_col = 1; size_t dim_col = 1;
std::vector<size_t> flat_shape; std::vector<size_t> flat_shape;
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (SizeToInt(i) < axis) { if (SizeToInt(i) < axis) {
dim_row *= shape[i]; dim_row *= shape[i];
} else { } else {
dim_col *= shape[i]; dim_col *= shape[i];
} }
} }
flat_shape.push_back(dim_row); flat_shape.push_back(dim_row);
flat_shape.push_back(dim_col); flat_shape.push_back(dim_col);
return flat_shape; return flat_shape;
} }
BroadcastIterator::BroadcastIterator(std::vector<size_t> input_shape_a, std::vector<size_t> input_shape_b, BroadcastIterator::BroadcastIterator(std::vector<size_t> input_shape_a, std::vector<size_t> input_shape_b,
std::vector<size_t> output_shape) std::vector<size_t> output_shape)
: input_shape_a_(std::move(input_shape_a)), : input_shape_a_(std::move(input_shape_a)),
input_shape_b_(std::move(input_shape_b)), input_shape_b_(std::move(input_shape_b)),
output_shape_(std::move(output_shape)) { output_shape_(std::move(output_shape)) {
output_dimension_ = SizeToInt(output_shape_.size()); // Assign dimension to int for iterator output_dimension_ = SizeToInt(output_shape_.size()); // Assign dimension to int for iterator
BroadcastShape(); BroadcastShape();
// Allocate strides memory // Allocate strides memory
input_strides_a_.resize(output_dimension_); input_strides_a_.resize(output_dimension_);
input_strides_b_.resize(output_dimension_); input_strides_b_.resize(output_dimension_);
input_back_strides_a_.resize(output_dimension_); input_back_strides_a_.resize(output_dimension_);
input_back_strides_b_.resize(output_dimension_); input_back_strides_b_.resize(output_dimension_);
coordinates_.resize(output_dimension_); coordinates_.resize(output_dimension_);
InitStrides(); InitStrides();
} }
void BroadcastIterator::SetPos(size_t pos) { void BroadcastIterator::SetPos(size_t pos) {
for (int i = output_dimension_ - 1; i >= 0 && pos != 0; --i) { for (int i = output_dimension_ - 1; i >= 0 && pos != 0; --i) {
coordinates_[i] = pos % output_shape_[i]; coordinates_[i] = pos % output_shape_[i];
input_pos_[0] += coordinates_[i] * input_strides_a_[i]; input_pos_[0] += coordinates_[i] * input_strides_a_[i];
input_pos_[1] += coordinates_[i] * input_strides_b_[i]; input_pos_[1] += coordinates_[i] * input_strides_b_[i];
pos /= output_shape_[i]; pos /= output_shape_[i];
} }
} }
void BroadcastIterator::GenNextPos() { void BroadcastIterator::GenNextPos() {
// Calculate output next coordinate // Calculate output next coordinate
for (int i = output_dimension_ - 1; i >= 0; --i) { for (int i = output_dimension_ - 1; i >= 0; --i) {
if (coordinates_[i] + 1 == output_shape_[i]) { if (coordinates_[i] + 1 == output_shape_[i]) {
coordinates_[i] = 0; coordinates_[i] = 0;
input_pos_[0] -= input_back_strides_a_[i]; input_pos_[0] -= input_back_strides_a_[i];
input_pos_[1] -= input_back_strides_b_[i]; input_pos_[1] -= input_back_strides_b_[i];
} else { } else {
++coordinates_[i]; ++coordinates_[i];
input_pos_[0] += input_strides_a_[i]; input_pos_[0] += input_strides_a_[i];
input_pos_[1] += input_strides_b_[i]; input_pos_[1] += input_strides_b_[i];
break; break;
} }
} }
} }
void BroadcastIterator::BroadcastShape() { void BroadcastIterator::BroadcastShape() {
int input_dimension_a = input_shape_a_.size(); int input_dimension_a = input_shape_a_.size();
if (input_dimension_a < output_dimension_) { if (input_dimension_a < output_dimension_) {
input_shape_a_.insert(input_shape_a_.begin(), output_dimension_ - input_dimension_a, 1); input_shape_a_.insert(input_shape_a_.begin(), output_dimension_ - input_dimension_a, 1);
} }
int input_dimension_b = input_shape_b_.size(); int input_dimension_b = input_shape_b_.size();
if (input_dimension_b < output_dimension_) { if (input_dimension_b < output_dimension_) {
input_shape_b_.insert(input_shape_b_.begin(), output_dimension_ - input_dimension_b, 1); input_shape_b_.insert(input_shape_b_.begin(), output_dimension_ - input_dimension_b, 1);
} }
} }
void BroadcastIterator::InitStrides() { void BroadcastIterator::InitStrides() {
input_strides_a_[output_dimension_ - 1] = 1; input_strides_a_[output_dimension_ - 1] = 1;
input_strides_b_[output_dimension_ - 1] = 1; input_strides_b_[output_dimension_ - 1] = 1;
for (int i = output_dimension_ - 2; i >= 0; --i) { for (int i = output_dimension_ - 2; i >= 0; --i) {
input_strides_a_[i] = input_shape_a_[i + 1] * input_strides_a_[i + 1]; input_strides_a_[i] = input_shape_a_[i + 1] * input_strides_a_[i + 1];
input_strides_b_[i] = input_shape_b_[i + 1] * input_strides_b_[i + 1]; input_strides_b_[i] = input_shape_b_[i + 1] * input_strides_b_[i + 1];
input_back_strides_a_[i + 1] = (input_shape_a_[i + 1] - 1) * input_strides_a_[i + 1]; input_back_strides_a_[i + 1] = (input_shape_a_[i + 1] - 1) * input_strides_a_[i + 1];
input_back_strides_b_[i + 1] = (input_shape_b_[i + 1] - 1) * input_strides_b_[i + 1]; input_back_strides_b_[i + 1] = (input_shape_b_[i + 1] - 1) * input_strides_b_[i + 1];
} }
// Update strides for broadcast // Update strides for broadcast
// While the axis value is 1, the stride is 0 // While the axis value is 1, the stride is 0
std::transform(input_strides_a_.begin(), input_strides_a_.end(), input_shape_a_.begin(), input_strides_a_.begin(), std::transform(input_strides_a_.begin(), input_strides_a_.end(), input_shape_a_.begin(), input_strides_a_.begin(),
[](const auto &a, const auto &b) { return b == 1 ? 0 : a; }); [](const auto &a, const auto &b) { return b == 1 ? 0 : a; });
std::transform(input_strides_b_.begin(), input_strides_b_.end(), input_shape_b_.begin(), input_strides_b_.begin(), std::transform(input_strides_b_.begin(), input_strides_b_.end(), input_shape_b_.begin(), input_strides_b_.begin(),
[](const auto &a, const auto &b) { return b == 1 ? 0 : a; }); [](const auto &a, const auto &b) { return b == 1 ? 0 : a; });
} }
TransposeIterator::TransposeIterator(std::vector<size_t> output_shape, std::vector<size_t> axes, TransposeIterator::TransposeIterator(std::vector<size_t> output_shape, std::vector<size_t> axes,
const std::vector<size_t> &input_shape) const std::vector<size_t> &input_shape)
: shape_(std::move(output_shape)), axes_(std::move(axes)) { : shape_(std::move(output_shape)), axes_(std::move(axes)) {
// Calculate strides // Calculate strides
dimension_ = shape_.size(); dimension_ = shape_.size();
std::vector<uint32_t> strides(dimension_, 1); std::vector<uint32_t> strides(dimension_, 1);
for (int i = dimension_ - 2; i >= 0; --i) { for (int i = dimension_ - 2; i >= 0; --i) {
strides[i] = input_shape[i + 1] * strides[i + 1]; strides[i] = input_shape[i + 1] * strides[i + 1];
} }
// Swap shape ans strides and calculate back strides // Swap shape ans strides and calculate back strides
strides_.resize(dimension_); strides_.resize(dimension_);
back_strides_.resize(dimension_); back_strides_.resize(dimension_);
for (int i = dimension_ - 1; i >= 0; --i) { for (int i = dimension_ - 1; i >= 0; --i) {
strides_[i] = strides[axes_[i]]; strides_[i] = strides[axes_[i]];
back_strides_[i] = (shape_[i] - 1) * strides_[i]; back_strides_[i] = (shape_[i] - 1) * strides_[i];
} }
// Calculate coordinate by pos // Calculate coordinate by pos
coordinates_.resize(dimension_); coordinates_.resize(dimension_);
} }
void TransposeIterator::SetPos(size_t pos) { void TransposeIterator::SetPos(size_t pos) {
for (int i = dimension_ - 1; i >= 0 && pos != 0; --i) { for (int i = dimension_ - 1; i >= 0 && pos != 0; --i) {
coordinates_[i] = pos % shape_[i]; coordinates_[i] = pos % shape_[i];
pos_ += coordinates_[i] * strides_[i]; pos_ += coordinates_[i] * strides_[i];
pos /= shape_[i]; pos /= shape_[i];
} }
} }
void TransposeIterator::GenNextPos() { void TransposeIterator::GenNextPos() {
for (int i = dimension_ - 1; i >= 0; --i) { for (int i = dimension_ - 1; i >= 0; --i) {
if (coordinates_[i] + 1 == shape_[i]) { if (coordinates_[i] + 1 == shape_[i]) {
coordinates_[i] = 0; coordinates_[i] = 0;
pos_ -= back_strides_[i]; pos_ -= back_strides_[i];
} else { } else {
coordinates_[i]++; coordinates_[i]++;
pos_ += strides_[i]; pos_ += strides_[i];
break; break;
} }
} }
} }
std::vector<size_t> CPUKernelUtils::GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y) { std::vector<size_t> CPUKernelUtils::GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y) {
size_t x_len = x.size(); size_t x_len = x.size();
size_t y_len = y.size(); size_t y_len = y.size();
size_t length = x_len < y_len ? x_len : y_len; size_t length = x_len < y_len ? x_len : y_len;
std::vector<size_t> broadcast_shape; std::vector<size_t> broadcast_shape;
std::vector<size_t> broadcast_shape_back; std::vector<size_t> broadcast_shape_back;
for (int i = -length; i < 0; ++i) { for (int i = -length; i < 0; ++i) {
if (x[x_len + i] == 1) { if (x[x_len + i] == 1) {
broadcast_shape_back.push_back(y[y_len + i]); broadcast_shape_back.push_back(y[y_len + i]);
} else if (y[y_len + i] == 1) { } else if (y[y_len + i] == 1) {
broadcast_shape_back.push_back(x[x_len + i]); broadcast_shape_back.push_back(x[x_len + i]);
} else if (x[x_len + i] == y[y_len + i]) { } else if (x[x_len + i] == y[y_len + i]) {
broadcast_shape_back.push_back(x[x_len + i]); broadcast_shape_back.push_back(x[x_len + i]);
} }
} }
if (length == x_len) { if (length == x_len) {
for (size_t i = 0; i < y_len - length; ++i) { for (size_t i = 0; i < y_len - length; ++i) {
broadcast_shape.push_back(y[i]); broadcast_shape.push_back(y[i]);
} }
} else { } else {
for (size_t i = 0; i < x_len - length; ++i) { for (size_t i = 0; i < x_len - length; ++i) {
broadcast_shape.push_back(x[i]); broadcast_shape.push_back(x[i]);
} }
} }
for (size_t i = 0; i < length; ++i) { for (size_t i = 0; i < length; ++i) {
broadcast_shape.push_back(broadcast_shape_back[i]); broadcast_shape.push_back(broadcast_shape_back[i]);
} }
return broadcast_shape; return broadcast_shape;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,205 +1,205 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h" #include "backend/kernel_compiler/common_utils.h"
#include "ir/anf.h" #include "ir/anf.h"
using mindspore::kernel::Address; using mindspore::kernel::Address;
using mindspore::kernel::AddressPtr; using mindspore::kernel::AddressPtr;
using CTask = std::function<void(size_t, size_t)>; using CTask = std::function<void(size_t, size_t)>;
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
const char KERNEL_SIZE[] = "kernel_size"; const char KERNEL_SIZE[] = "kernel_size";
const char STRIDE[] = "stride"; const char STRIDE[] = "stride";
const char STRIDES[] = "strides"; const char STRIDES[] = "strides";
const char DILATION[] = "dilation"; const char DILATION[] = "dilation";
const char DILATIONS[] = "dilations"; const char DILATIONS[] = "dilations";
const char FORMAT[] = "format"; const char FORMAT[] = "format";
const char PAD[] = "pad"; const char PAD[] = "pad";
const char PAD_LIST[] = "pad_list"; const char PAD_LIST[] = "pad_list";
const char PAD_MODE[] = "pad_mode"; const char PAD_MODE[] = "pad_mode";
const char PAD_MODE_LOWER_SAME[] = "same"; const char PAD_MODE_LOWER_SAME[] = "same";
const char PAD_MODE_LOWER_VALID[] = "valid"; const char PAD_MODE_LOWER_VALID[] = "valid";
const char PAD_MODE_UPPER_SAME[] = "SAME"; const char PAD_MODE_UPPER_SAME[] = "SAME";
const char PAD_MODE_UPPER_VALID[] = "VALID"; const char PAD_MODE_UPPER_VALID[] = "VALID";
const char TRANSPOSE_A[] = "transpose_a"; const char TRANSPOSE_A[] = "transpose_a";
const char TRANSPOSE_B[] = "transpose_b"; const char TRANSPOSE_B[] = "transpose_b";
const char IS_GRAD[] = "is_grad"; const char IS_GRAD[] = "is_grad";
const char TRANSPOSE_NO = 'N'; const char TRANSPOSE_NO = 'N';
const char TRANSPOSE_YES = 'T'; const char TRANSPOSE_YES = 'T';
const char AXIS[] = "axis"; const char AXIS[] = "axis";
const char DIM[] = "dim"; const char DIM[] = "dim";
const char BEGIN[] = "begin"; const char BEGIN[] = "begin";
const char END[] = "end"; const char END[] = "end";
const char SIZE[] = "size"; const char SIZE[] = "size";
const char USE_NESTEROV[] = "use_nesterov"; const char USE_NESTEROV[] = "use_nesterov";
const char GROUP[] = "group"; const char GROUP[] = "group";
const char START[] = "start"; const char START[] = "start";
const char LIMIT[] = "limit"; const char LIMIT[] = "limit";
const char DELTA[] = "delta"; const char DELTA[] = "delta";
const char SORTED[] = "sorted"; const char SORTED[] = "sorted";
const char ADJ_ST[] = "adjoint_st"; const char ADJ_ST[] = "adjoint_st";
const char ADJ_dT[] = "adjoint_dt"; const char ADJ_dT[] = "adjoint_dt";
enum OperateType { enum OperateType {
ADD = 0, ADD = 0,
SUB, SUB,
MUL, MUL,
DIV, DIV,
SQUARE, SQUARE,
SQRT, SQRT,
POW, POW,
REALDIV, REALDIV,
FLOORDIV, FLOORDIV,
MOD, MOD,
FLOORMOD, FLOORMOD,
NEG, NEG,
LESS, LESS,
ASSIGNADD, ASSIGNADD,
RELUGRAD, RELUGRAD,
RELU6GRAD, RELU6GRAD,
ABSGRAD, ABSGRAD,
TANHGRAD, TANHGRAD,
SQRTGRAD, SQRTGRAD,
SIGMOIDGRAD, SIGMOIDGRAD,
ONESLIKE, ONESLIKE,
ZEROSLIKE, ZEROSLIKE,
SIGN, SIGN,
EQUAL, EQUAL,
NOTEQUAL, NOTEQUAL,
LESSEQUAL, LESSEQUAL,
LOGICALAND, LOGICALAND,
LOGICALOR, LOGICALOR,
LOGICALNOT, LOGICALNOT,
FLOOR, FLOOR,
SQUAREDDIFFERENCE, SQUAREDDIFFERENCE,
GREATER, GREATER,
GREATEREQUAL, GREATEREQUAL,
RECIPROCAL, RECIPROCAL,
GELU, GELU,
GELUGRAD, GELUGRAD,
ASIN, ASIN,
ACOS, ACOS,
ATAN, ATAN,
ASINGRAD, ASINGRAD,
ACOSGRAD, ACOSGRAD,
ATANGRAD, ATANGRAD,
SIN, SIN,
COS, COS,
TAN, TAN,
SINH, SINH,
COSH, COSH,
ASINH, ASINH,
ACOSH, ACOSH,
ATANH, ATANH,
ASINHGRAD, ASINHGRAD,
ACOSHGRAD, ACOSHGRAD,
ATAN2, ATAN2,
RINT, RINT,
ROUND, ROUND,
IDENTITY, IDENTITY,
}; };
class CPUKernel : public kernel::KernelMod { class CPUKernel : public kernel::KernelMod {
public: public:
CPUKernel() = default; CPUKernel() = default;
~CPUKernel() override = default; ~CPUKernel() override = default;
virtual void Init(const CNodePtr &kernel_node); virtual void Init(const CNodePtr &kernel_node);
virtual void InitKernel(const CNodePtr &kernel_node) = 0; virtual void InitKernel(const CNodePtr &kernel_node) = 0;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override { const std::vector<AddressPtr> &outputs, void * /*stream_ptr*/) override {
return Launch(inputs, workspace, outputs); return Launch(inputs, workspace, outputs);
}; };
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0; const std::vector<AddressPtr> &outputs) = 0;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
protected: protected:
virtual void InitInputOutputSize(const CNodePtr &kernel_node); virtual void InitInputOutputSize(const CNodePtr &kernel_node);
std::vector<size_t> input_size_list_; std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_; std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_; std::vector<size_t> workspace_size_list_;
}; };
class CPUKernelUtils { class CPUKernelUtils {
public: public:
static void ExpandDimsTo4(std::vector<size_t> *shape); static void ExpandDimsTo4(std::vector<size_t> *shape);
static size_t CalcOffset(const std::vector<size_t> &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3); static size_t CalcOffset(const std::vector<size_t> &shape, size_t dim0, size_t dim1, size_t dim2, size_t dim3);
static size_t GetElementNumOnAxis(const std::vector<size_t> &shape, int axis); static size_t GetElementNumOnAxis(const std::vector<size_t> &shape, int axis);
static void GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num); static void GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num);
static void ParallelFor(const CTask &task, size_t count); static void ParallelFor(const CTask &task, size_t count);
static std::vector<size_t> FlatShapeByAxis(const std::vector<size_t> &shape, int axis); static std::vector<size_t> FlatShapeByAxis(const std::vector<size_t> &shape, int axis);
static std::vector<size_t> GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y); static std::vector<size_t> GetBroadcastShape(const std::vector<size_t> &x, const std::vector<size_t> &y);
}; };
class BroadcastIterator { class BroadcastIterator {
public: public:
BroadcastIterator(std::vector<size_t> input_shape_a, std::vector<size_t> input_shape_b, BroadcastIterator(std::vector<size_t> input_shape_a, std::vector<size_t> input_shape_b,
std::vector<size_t> output_shape); std::vector<size_t> output_shape);
virtual ~BroadcastIterator() = default; virtual ~BroadcastIterator() = default;
inline size_t GetInputPosA() const { return input_pos_[0]; } inline size_t GetInputPosA() const { return input_pos_[0]; }
inline size_t GetInputPosB() const { return input_pos_[1]; } inline size_t GetInputPosB() const { return input_pos_[1]; }
void SetPos(size_t pos); void SetPos(size_t pos);
void GenNextPos(); void GenNextPos();
private: private:
void BroadcastShape(); void BroadcastShape();
void InitStrides(); void InitStrides();
std::vector<size_t> coordinates_; std::vector<size_t> coordinates_;
std::vector<size_t> input_shape_a_; std::vector<size_t> input_shape_a_;
std::vector<size_t> input_shape_b_; std::vector<size_t> input_shape_b_;
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;
std::vector<size_t> input_strides_a_; std::vector<size_t> input_strides_a_;
std::vector<size_t> input_strides_b_; std::vector<size_t> input_strides_b_;
std::vector<size_t> input_back_strides_a_; std::vector<size_t> input_back_strides_a_;
std::vector<size_t> input_back_strides_b_; std::vector<size_t> input_back_strides_b_;
std::array<size_t, 2> input_pos_{0}; std::array<size_t, 2> input_pos_{0};
int output_dimension_{0}; int output_dimension_{0};
}; };
class TransposeIterator { class TransposeIterator {
public: public:
TransposeIterator(std::vector<size_t> output_shape, std::vector<size_t> axes, const std::vector<size_t> &input_shape); TransposeIterator(std::vector<size_t> output_shape, std::vector<size_t> axes, const std::vector<size_t> &input_shape);
virtual ~TransposeIterator() = default; virtual ~TransposeIterator() = default;
inline size_t GetPos() const { return pos_; } inline size_t GetPos() const { return pos_; }
void SetPos(size_t pos); void SetPos(size_t pos);
void GenNextPos(); void GenNextPos();
private: private:
int dimension_{0}; int dimension_{0};
std::vector<size_t> coordinates_; std::vector<size_t> coordinates_;
std::vector<size_t> shape_; std::vector<size_t> shape_;
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> back_strides_; std::vector<size_t> back_strides_;
std::vector<size_t> axes_; std::vector<size_t> axes_;
size_t pos_{0}; size_t pos_{0};
}; };
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CPU_KERNEL_H_

View File

@ -1,340 +1,340 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/ctcloss_cpu_kernel.h" #include "backend/kernel_compiler/cpu/ctcloss_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void CTCLossCPUKernel::InitKernel(const CNodePtr &kernel_node) { void CTCLossCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node); CheckParam(kernel_node);
probs_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); probs_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
indice_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); indice_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
labels_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); labels_dims_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (probs_shape_.size() != 3) { if (probs_shape_.size() != 3) {
MS_LOG(EXCEPTION) << "Probs dims: " << probs_shape_.size() << " not support."; MS_LOG(EXCEPTION) << "Probs dims: " << probs_shape_.size() << " not support.";
} }
if (labels_dims_.size() != 1) { if (labels_dims_.size() != 1) {
MS_LOG(EXCEPTION) << "Labels dims: " << labels_dims_.size() << " not support."; MS_LOG(EXCEPTION) << "Labels dims: " << labels_dims_.size() << " not support.";
} }
if (indice_dims_.size() != 2) { if (indice_dims_.size() != 2) {
MS_LOG(EXCEPTION) << "Labels indice dims: " << indice_dims_.size() << " not support."; MS_LOG(EXCEPTION) << "Labels indice dims: " << indice_dims_.size() << " not support.";
} }
preprocess_collapse_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "preprocess_collapse_repeated"); preprocess_collapse_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "preprocess_collapse_repeated");
ctc_merge_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "ctc_merge_repeated"); ctc_merge_repeated_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "ctc_merge_repeated");
ignore_longer_outputs_than_inputs_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "ignore_longer_outputs_than_inputs"); ignore_longer_outputs_than_inputs_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "ignore_longer_outputs_than_inputs");
max_time_ = probs_shape_[0]; max_time_ = probs_shape_[0];
batch_size_ = probs_shape_[1]; batch_size_ = probs_shape_[1];
num_class_ = probs_shape_[2]; num_class_ = probs_shape_[2];
blank_index_ = num_class_ - 1; blank_index_ = num_class_ - 1;
} }
bool CTCLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, bool CTCLossCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) { if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs); LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) { } else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} }
return true; return true;
} }
template <typename T> template <typename T>
inline T LogSumExp(const T logprob1, const T logprob2) { inline T LogSumExp(const T logprob1, const T logprob2) {
T kLogZero_ = -std::numeric_limits<T>::infinity(); T kLogZero_ = -std::numeric_limits<T>::infinity();
if (logprob1 <= kLogZero_) { if (logprob1 <= kLogZero_) {
return logprob2; return logprob2;
} else if (logprob2 <= kLogZero_) { } else if (logprob2 <= kLogZero_) {
return logprob1; return logprob1;
} else { } else {
return (logprob1 > logprob2) ? logprob1 + static_cast<T>(log1p(exp(logprob2 - logprob1))) return (logprob1 > logprob2) ? logprob1 + static_cast<T>(log1p(exp(logprob2 - logprob1)))
: logprob2 + static_cast<T>(log1p(exp(logprob1 - logprob2))); : logprob2 + static_cast<T>(log1p(exp(logprob1 - logprob2)));
} }
} }
template <typename TT> template <typename TT>
void CTCLossCPUKernel::CalculateFwdVar(const std::vector<uint32_t> &label_with_blank, void CTCLossCPUKernel::CalculateFwdVar(const std::vector<uint32_t> &label_with_blank,
const std::vector<std::vector<TT>> &y, const std::vector<std::vector<TT>> &y,
std::vector<std::vector<TT>> *log_alpha_b) { std::vector<std::vector<TT>> *log_alpha_b) {
int U = label_with_blank.size(); int U = label_with_blank.size();
int T = (*log_alpha_b)[0].size(); int T = (*log_alpha_b)[0].size();
TT kLogZero_ = -std::numeric_limits<TT>::infinity(); TT kLogZero_ = -std::numeric_limits<TT>::infinity();
(*log_alpha_b)[0][0] = static_cast<TT>(log(y[blank_index_][0])); (*log_alpha_b)[0][0] = static_cast<TT>(log(y[blank_index_][0]));
auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_; auto label_0 = (label_with_blank.size() > 1) ? label_with_blank[1] : blank_index_;
if (label_with_blank.size() > 1) { if (label_with_blank.size() > 1) {
(*log_alpha_b)[1][0] = static_cast<TT>(log(y[label_0][0])); (*log_alpha_b)[1][0] = static_cast<TT>(log(y[label_0][0]));
} }
for (int t = 1; t < T; ++t) { for (int t = 1; t < T; ++t) {
int low = std::max(0, U - (2 * (T - t))); int low = std::max(0, U - (2 * (T - t)));
int high = std::min(U, 2 * (t + 1)); int high = std::min(U, 2 * (t + 1));
for (int u = low; u < high; ++u) { for (int u = low; u < high; ++u) {
auto sum_log_alpha_b = kLogZero_; auto sum_log_alpha_b = kLogZero_;
if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) { if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
sum_log_alpha_b = (*log_alpha_b)[u][t - 1]; sum_log_alpha_b = (*log_alpha_b)[u][t - 1];
} }
if (u > 0) { if (u > 0) {
sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 1][t - 1]); sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 1][t - 1]);
} }
if (u > 1) { if (u > 1) {
bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u - 2]); bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u - 2]);
if (label_with_blank[u] != blank_index_ && !matching_labels_merge) { if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 2][t - 1]); sum_log_alpha_b = LogSumExp(sum_log_alpha_b, (*log_alpha_b)[u - 2][t - 1]);
} }
} }
(*log_alpha_b)[u][t] = (*log_alpha_b)[u][t] =
static_cast<TT>(log(static_cast<TT>(y[label_with_blank[IntToSize(u)]][IntToSize(t)]))) + sum_log_alpha_b; static_cast<TT>(log(static_cast<TT>(y[label_with_blank[IntToSize(u)]][IntToSize(t)]))) + sum_log_alpha_b;
} }
} }
} }
template <typename TT> template <typename TT>
void CTCLossCPUKernel::CalculateBwdVar(const std::vector<uint32_t> &label_with_blank, void CTCLossCPUKernel::CalculateBwdVar(const std::vector<uint32_t> &label_with_blank,
const std::vector<std::vector<TT>> &y, const std::vector<std::vector<TT>> &y,
std::vector<std::vector<TT>> *log_beta_b) { std::vector<std::vector<TT>> *log_beta_b) {
int T = (*log_beta_b)[0].size(); int T = (*log_beta_b)[0].size();
int U = label_with_blank.size(); int U = label_with_blank.size();
if (U > 1) { if (U > 1) {
for (int u = U - 2; u < U; ++u) { for (int u = U - 2; u < U; ++u) {
(*log_beta_b)[u][T - 1] = TT(0); (*log_beta_b)[u][T - 1] = TT(0);
} }
} else { } else {
(*log_beta_b)[0][T - 1] = TT(0); (*log_beta_b)[0][T - 1] = TT(0);
(*log_beta_b)[0][T - 2] = TT(0); (*log_beta_b)[0][T - 2] = TT(0);
} }
for (int t = T - 2; t >= 0; --t) { for (int t = T - 2; t >= 0; --t) {
int low = std::max(0, U - (2 * (T - t))); int low = std::max(0, U - (2 * (T - t)));
int high = std::min(U, 2 * (t + 1)); int high = std::min(U, 2 * (t + 1));
for (int u = low; u < high; ++u) { for (int u = low; u < high; ++u) {
if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) { if (ctc_merge_repeated_ || label_with_blank[u] == blank_index_) {
(*log_beta_b)[u][t] = (*log_beta_b)[u][t] =
LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u][t + 1] + TT(log(y[label_with_blank[u]][t + 1]))); LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u][t + 1] + TT(log(y[label_with_blank[u]][t + 1])));
} }
if (u + 1 < U) { if (u + 1 < U) {
(*log_beta_b)[u][t] = (*log_beta_b)[u][t] =
LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 1][t + 1] + TT(log(y[label_with_blank[u + 1]][t + 1]))); LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 1][t + 1] + TT(log(y[label_with_blank[u + 1]][t + 1])));
} }
if (u + 2 < U) { if (u + 2 < U) {
bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u + 2]); bool matching_labels_merge = ctc_merge_repeated_ && (label_with_blank[u] == label_with_blank[u + 2]);
if (label_with_blank[u] != blank_index_ && !matching_labels_merge) { if (label_with_blank[u] != blank_index_ && !matching_labels_merge) {
(*log_beta_b)[u][t] = (*log_beta_b)[u][t] =
LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 2][t + 1] + TT(log(y[label_with_blank[u + 2]][t + 1]))); LogSumExp((*log_beta_b)[u][t], (*log_beta_b)[u + 2][t + 1] + TT(log(y[label_with_blank[u + 2]][t + 1])));
} }
} }
} }
} }
} }
template <typename TT> template <typename TT>
void CTCLossCPUKernel::CalculateGrad(const std::vector<uint32_t> &label_with_blank, void CTCLossCPUKernel::CalculateGrad(const std::vector<uint32_t> &label_with_blank,
const std::vector<std::vector<TT>> &y, const std::vector<std::vector<TT>> &y,
const std::vector<std::vector<TT>> &log_alpha_b, const std::vector<std::vector<TT>> &log_alpha_b,
const std::vector<std::vector<TT>> &log_beta_b, const TT log_pzx, const std::vector<std::vector<TT>> &log_beta_b, const TT log_pzx,
std::vector<std::vector<TT>> *dy) { std::vector<std::vector<TT>> *dy) {
auto dy_b = dy; auto dy_b = dy;
TT kLogZero_ = -std::numeric_limits<TT>::infinity(); TT kLogZero_ = -std::numeric_limits<TT>::infinity();
if (log_pzx <= kLogZero_) { if (log_pzx <= kLogZero_) {
MS_LOG(INFO) << "No valid path found"; MS_LOG(INFO) << "No valid path found";
return; return;
} }
size_t L = y.size(); size_t L = y.size();
size_t T = y[0].size(); size_t T = y[0].size();
size_t U = label_with_blank.size(); size_t U = label_with_blank.size();
for (size_t t = 0; t < T; ++t) { for (size_t t = 0; t < T; ++t) {
std::vector<TT> prob_sum(L, kLogZero_); std::vector<TT> prob_sum(L, kLogZero_);
for (size_t u = 0; u < U; ++u) { for (size_t u = 0; u < U; ++u) {
uint32_t l = label_with_blank[u]; uint32_t l = label_with_blank[u];
prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]); prob_sum[l] = LogSumExp(prob_sum[l], log_alpha_b[u][t] + log_beta_b[u][t]);
} }
for (size_t l = 0; l < L; ++l) { for (size_t l = 0; l < L; ++l) {
(*dy_b)[l][t] = y[l][t] - static_cast<TT>(exp(prob_sum[l] - log_pzx)); (*dy_b)[l][t] = y[l][t] - static_cast<TT>(exp(prob_sum[l] - log_pzx));
} }
} }
} }
void CTCLossCPUKernel::GenLableWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label, void CTCLossCPUKernel::GenLableWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label,
std::vector<std::vector<uint32_t>> *label_with_blank) { std::vector<std::vector<uint32_t>> *label_with_blank) {
for (size_t b = 0; b < batch_size_; ++b) { for (size_t b = 0; b < batch_size_; ++b) {
std::vector<uint32_t> l; std::vector<uint32_t> l;
const std::vector<uint32_t> &label = batch_label[b]; const std::vector<uint32_t> &label = batch_label[b];
bool has_blank = false; bool has_blank = false;
for (size_t i = 0; i < label.size(); ++i) { for (size_t i = 0; i < label.size(); ++i) {
if (i == 0 || !preprocess_collapse_repeated_ || label[i] != label[i - 1]) { if (i == 0 || !preprocess_collapse_repeated_ || label[i] != label[i - 1]) {
if (label[i] >= num_class_ - 1) { if (label[i] >= num_class_ - 1) {
has_blank = true; has_blank = true;
} else { } else {
if (has_blank) { if (has_blank) {
MS_LOG(EXCEPTION) << "Invalid labels(index >= num_class - 1) should not appear between two valid labels"; MS_LOG(EXCEPTION) << "Invalid labels(index >= num_class - 1) should not appear between two valid labels";
} }
l.push_back(label[i]); l.push_back(label[i]);
} }
} }
} }
if (!ignore_longer_outputs_than_inputs_) { if (!ignore_longer_outputs_than_inputs_) {
if (l.size() > seq_len[b]) { if (l.size() > seq_len[b]) {
MS_LOG(EXCEPTION) << "Input time(sequence length) should greater than output size(label length), but gets " MS_LOG(EXCEPTION) << "Input time(sequence length) should greater than output size(label length), but gets "
<< seq_len[b] << "< " << l.size(); << seq_len[b] << "< " << l.size();
} }
} }
(*label_with_blank)[b].reserve(2 * l.size() + 1); (*label_with_blank)[b].reserve(2 * l.size() + 1);
for (auto l_i : l) { for (auto l_i : l) {
(*label_with_blank)[b].push_back(blank_index_); (*label_with_blank)[b].push_back(blank_index_);
(*label_with_blank)[b].push_back(l_i); (*label_with_blank)[b].push_back(l_i);
} }
(*label_with_blank)[b].push_back(blank_index_); (*label_with_blank)[b].push_back(blank_index_);
} }
} }
template <typename T> template <typename T>
void InnerSoftMax(const T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, const uint32_t sequence_length, void InnerSoftMax(const T *inputs_addr, std::vector<std::vector<T>> *softmax_probs, const uint32_t sequence_length,
size_t num_class, size_t batch_size, size_t b) { size_t num_class, size_t batch_size, size_t b) {
for (size_t t = 0; t < sequence_length; ++t) { for (size_t t = 0; t < sequence_length; ++t) {
T maxCoeff(T(0)); T maxCoeff(T(0));
T sumCoeff(T(0)); T sumCoeff(T(0));
for (size_t c = 0; c < num_class; ++c) { for (size_t c = 0; c < num_class; ++c) {
if (inputs_addr[t * batch_size * num_class + b * num_class + c] > maxCoeff) { if (inputs_addr[t * batch_size * num_class + b * num_class + c] > maxCoeff) {
maxCoeff = inputs_addr[t * batch_size * num_class + b * num_class + c]; maxCoeff = inputs_addr[t * batch_size * num_class + b * num_class + c];
} }
} }
for (size_t c = 0; c < num_class; ++c) { for (size_t c = 0; c < num_class; ++c) {
sumCoeff += static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff)); sumCoeff += static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
(*softmax_probs)[c][t] = (*softmax_probs)[c][t] =
static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff)); static_cast<T>(exp(inputs_addr[t * batch_size * num_class + b * num_class + c] - maxCoeff));
} }
for (size_t c = 0; c < num_class; ++c) { for (size_t c = 0; c < num_class; ++c) {
(*softmax_probs)[c][t] /= sumCoeff; (*softmax_probs)[c][t] /= sumCoeff;
} }
} }
} }
template <typename T> template <typename T>
void MatrixfromVector(uint32_t row, uint32_t col, std::vector<std::vector<T>> *array2D, const T init_value) { void MatrixfromVector(uint32_t row, uint32_t col, std::vector<std::vector<T>> *array2D, const T init_value) {
array2D->resize(row); array2D->resize(row);
for (size_t i = 0; i < row; ++i) { for (size_t i = 0; i < row; ++i) {
(*array2D)[i].resize(col, init_value); (*array2D)[i].resize(col, init_value);
} }
} }
template <typename T> template <typename T>
void CTCLossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { void CTCLossCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
auto inputs_addr = reinterpret_cast<T *>(inputs[0]->addr); auto inputs_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto labels_indices_addr = reinterpret_cast<uint64_t *>(inputs[1]->addr); auto labels_indices_addr = reinterpret_cast<uint64_t *>(inputs[1]->addr);
auto labels_values_addr = reinterpret_cast<uint32_t *>(inputs[2]->addr); auto labels_values_addr = reinterpret_cast<uint32_t *>(inputs[2]->addr);
auto sequence_length_addr = reinterpret_cast<uint32_t *>(inputs[3]->addr); auto sequence_length_addr = reinterpret_cast<uint32_t *>(inputs[3]->addr);
auto loss_addr = reinterpret_cast<T *>(outputs[0]->addr); auto loss_addr = reinterpret_cast<T *>(outputs[0]->addr);
auto gradient_addr = reinterpret_cast<T *>(outputs[1]->addr); auto gradient_addr = reinterpret_cast<T *>(outputs[1]->addr);
std::vector<std::vector<uint32_t>> label_batch; std::vector<std::vector<uint32_t>> label_batch;
std::vector<std::vector<uint32_t>> labels_with_blank; std::vector<std::vector<uint32_t>> labels_with_blank;
std::vector<uint64_t> each_label_length; std::vector<uint64_t> each_label_length;
label_batch.resize(batch_size_); label_batch.resize(batch_size_);
labels_with_blank.resize(batch_size_); labels_with_blank.resize(batch_size_);
each_label_length.resize(batch_size_, 0); each_label_length.resize(batch_size_, 0);
T kLogZero_ = -std::numeric_limits<T>::infinity(); T kLogZero_ = -std::numeric_limits<T>::infinity();
// check validation of sequence length // check validation of sequence length
for (size_t b = 0; b < batch_size_; ++b) { for (size_t b = 0; b < batch_size_; ++b) {
if (sequence_length_addr[b] == uint32_t(0)) { if (sequence_length_addr[b] == uint32_t(0)) {
MS_LOG(EXCEPTION) << "Sequence length should > 0, but gets " << sequence_length_addr[b]; MS_LOG(EXCEPTION) << "Sequence length should > 0, but gets " << sequence_length_addr[b];
} }
if (sequence_length_addr[b] > max_time_) { if (sequence_length_addr[b] > max_time_) {
MS_LOG(EXCEPTION) << "Max time should be greater than sequence length, but gets " << max_time_ << " < " MS_LOG(EXCEPTION) << "Max time should be greater than sequence length, but gets " << max_time_ << " < "
<< sequence_length_addr[b]; << sequence_length_addr[b];
} }
} }
for (size_t i = 0; i < indice_dims_[0]; ++i) { for (size_t i = 0; i < indice_dims_[0]; ++i) {
each_label_length[labels_indices_addr[i * 2]]++; each_label_length[labels_indices_addr[i * 2]]++;
} }
// convert label format of label_value and label_indices to batch_label // convert label format of label_value and label_indices to batch_label
uint64_t cum_sum = 0; uint64_t cum_sum = 0;
for (size_t b = 0; b < batch_size_; ++b) { for (size_t b = 0; b < batch_size_; ++b) {
std::vector<uint32_t> *b_value = &label_batch[b]; std::vector<uint32_t> *b_value = &label_batch[b];
for (size_t l = 0; l < each_label_length[b]; ++l) { for (size_t l = 0; l < each_label_length[b]; ++l) {
b_value->push_back(labels_values_addr[cum_sum + l]); b_value->push_back(labels_values_addr[cum_sum + l]);
} }
cum_sum += each_label_length[b]; cum_sum += each_label_length[b];
} }
// convert label to label with blank // convert label to label with blank
GenLableWithBlank(sequence_length_addr, label_batch, &labels_with_blank); GenLableWithBlank(sequence_length_addr, label_batch, &labels_with_blank);
for (size_t b = 0; b < batch_size_; ++b) { for (size_t b = 0; b < batch_size_; ++b) {
std::vector<uint32_t> label_with_blank = labels_with_blank[b]; std::vector<uint32_t> label_with_blank = labels_with_blank[b];
// y_b [num_class, sequence_length] // y_b [num_class, sequence_length]
std::vector<std::vector<T>> y_b; std::vector<std::vector<T>> y_b;
std::vector<std::vector<T>> dy; std::vector<std::vector<T>> dy;
std::vector<std::vector<T>> log_alpha_b; std::vector<std::vector<T>> log_alpha_b;
std::vector<std::vector<T>> log_beta_b; std::vector<std::vector<T>> log_beta_b;
MatrixfromVector(num_class_, sequence_length_addr[b], &y_b, kLogZero_); MatrixfromVector(num_class_, sequence_length_addr[b], &y_b, kLogZero_);
MatrixfromVector(y_b.size(), y_b[0].size(), &dy, T(0)); MatrixfromVector(y_b.size(), y_b[0].size(), &dy, T(0));
MatrixfromVector(label_with_blank.size(), sequence_length_addr[b], &log_alpha_b, kLogZero_); MatrixfromVector(label_with_blank.size(), sequence_length_addr[b], &log_alpha_b, kLogZero_);
MatrixfromVector(label_with_blank.size(), sequence_length_addr[b], &log_beta_b, kLogZero_); MatrixfromVector(label_with_blank.size(), sequence_length_addr[b], &log_beta_b, kLogZero_);
InnerSoftMax(inputs_addr, &y_b, sequence_length_addr[b], num_class_, batch_size_, b); InnerSoftMax(inputs_addr, &y_b, sequence_length_addr[b], num_class_, batch_size_, b);
CalculateFwdVar(label_with_blank, y_b, &log_alpha_b); CalculateFwdVar(label_with_blank, y_b, &log_alpha_b);
CalculateBwdVar(label_with_blank, y_b, &log_beta_b); CalculateBwdVar(label_with_blank, y_b, &log_beta_b);
T log_pzx = kLogZero_; T log_pzx = kLogZero_;
for (size_t u = 0; u < label_with_blank.size(); ++u) { for (size_t u = 0; u < label_with_blank.size(); ++u) {
log_pzx = LogSumExp(log_pzx, log_alpha_b[u][0] + log_beta_b[u][0]); log_pzx = LogSumExp(log_pzx, log_alpha_b[u][0] + log_beta_b[u][0]);
} }
loss_addr[b] = -log_pzx; loss_addr[b] = -log_pzx;
CalculateGrad(label_with_blank, y_b, log_alpha_b, log_beta_b, log_pzx, &dy); CalculateGrad(label_with_blank, y_b, log_alpha_b, log_beta_b, log_pzx, &dy);
for (size_t t = 0; t < sequence_length_addr[b]; ++t) { for (size_t t = 0; t < sequence_length_addr[b]; ++t) {
for (size_t c = 0; c < num_class_; ++c) { for (size_t c = 0; c < num_class_; ++c) {
gradient_addr[t * batch_size_ * num_class_ + b * num_class_ + c] = dy[c][t]; gradient_addr[t * batch_size_ * num_class_ + b * num_class_ + c] = dy[c][t];
} }
} }
} }
} }
void CTCLossCPUKernel::CheckParam(const CNodePtr &kernel_node) { void CTCLossCPUKernel::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) { if (input_num != 4) {
MS_LOG(EXCEPTION) << "CTCLossCPUKernel needs 4 inputs, but gets " << input_num; MS_LOG(EXCEPTION) << "CTCLossCPUKernel needs 4 inputs, but gets " << input_num;
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) { if (output_num != 2) {
MS_LOG(EXCEPTION) << "CTCLossCPUKernel expects 2 outputs, but gets" << output_num; MS_LOG(EXCEPTION) << "CTCLossCPUKernel expects 2 outputs, but gets" << output_num;
} }
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,92 +1,92 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class CTCLossCPUKernel : public CPUKernel { class CTCLossCPUKernel : public CPUKernel {
public: public:
CTCLossCPUKernel() = default; CTCLossCPUKernel() = default;
~CTCLossCPUKernel() override = default; ~CTCLossCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
void GenLableWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label, void GenLableWithBlank(const uint32_t *seq_len, const std::vector<std::vector<uint32_t>> &batch_label,
std::vector<std::vector<uint32_t>> *label_with_blank); std::vector<std::vector<uint32_t>> *label_with_blank);
template <typename T> template <typename T>
void CalculateFwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y, void CalculateFwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y,
std::vector<std::vector<T>> *log_alpha_b); std::vector<std::vector<T>> *log_alpha_b);
template <typename T> template <typename T>
void CalculateBwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y, void CalculateBwdVar(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y,
std::vector<std::vector<T>> *log_beta_b); std::vector<std::vector<T>> *log_beta_b);
template <typename T> template <typename T>
void CalculateGrad(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y, void CalculateGrad(const std::vector<uint32_t> &label_with_blank, const std::vector<std::vector<T>> &y,
const std::vector<std::vector<T>> &log_alpha_b, const std::vector<std::vector<T>> &log_beta_b, const std::vector<std::vector<T>> &log_alpha_b, const std::vector<std::vector<T>> &log_beta_b,
const T log_pzx, std::vector<std::vector<T>> *dy); const T log_pzx, std::vector<std::vector<T>> *dy);
template <typename T> template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private: private:
void CheckParam(const CNodePtr &kernel_node); void CheckParam(const CNodePtr &kernel_node);
std::vector<size_t> probs_shape_; std::vector<size_t> probs_shape_;
std::vector<size_t> indice_dims_; std::vector<size_t> indice_dims_;
std::vector<size_t> labels_dims_; std::vector<size_t> labels_dims_;
size_t num_class_; size_t num_class_;
size_t max_time_; size_t max_time_;
size_t batch_size_; size_t batch_size_;
uint32_t blank_index_; uint32_t blank_index_;
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
bool preprocess_collapse_repeated_; bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_; bool ctc_merge_repeated_;
bool ignore_longer_outputs_than_inputs_; bool ignore_longer_outputs_than_inputs_;
}; };
MS_REG_CPU_KERNEL(CTCLoss, MS_REG_CPU_KERNEL(CTCLoss,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16), .AddOutputAttr(kNumberTypeFloat16),
CTCLossCPUKernel); CTCLossCPUKernel);
MS_REG_CPU_KERNEL(CTCLoss, MS_REG_CPU_KERNEL(CTCLoss,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt64) .AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32) .AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
CTCLossCPUKernel); CTCLossCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CTCLOSS_CPU_KERNEL_H_

View File

@ -1,89 +1,89 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h" #include "backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h"
#include <vector> #include <vector>
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
void DepthToSpaceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { void DepthToSpaceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node); CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
block_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size"); block_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size");
} }
template <typename T> template <typename T>
bool DepthToSpaceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, bool DepthToSpaceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/, const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr); auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t size = IntToSize(inputs[0]->size / sizeof(T)); size_t size = IntToSize(inputs[0]->size / sizeof(T));
std::vector<size_t> input_shape = input_shape_; std::vector<size_t> input_shape = input_shape_;
std::vector<size_t> output_shape = output_shape_; std::vector<size_t> output_shape = output_shape_;
size_t block_size = block_size_; size_t block_size = block_size_;
size_t input_dimension = input_shape.size(); size_t input_dimension = input_shape.size();
size_t output_strides[3] = {1, 1, 1}; size_t output_strides[3] = {1, 1, 1};
for (size_t i = input_dimension - 1; i >= 1; --i) { for (size_t i = input_dimension - 1; i >= 1; --i) {
for (size_t j = 0; j < i; ++j) { for (size_t j = 0; j < i; ++j) {
output_strides[j] *= output_shape[i]; output_strides[j] *= output_shape[i];
} }
} }
auto task = [&, input_addr, output_addr](size_t start, size_t end) { auto task = [&, input_addr, output_addr](size_t start, size_t end) {
std::vector<size_t> output_pos_array(input_dimension, 0); std::vector<size_t> output_pos_array(input_dimension, 0);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
size_t tmp_pos = i; size_t tmp_pos = i;
for (size_t j = 0; j < input_dimension - 1; ++j) { for (size_t j = 0; j < input_dimension - 1; ++j) {
output_pos_array[j] = tmp_pos / output_strides[j]; output_pos_array[j] = tmp_pos / output_strides[j];
tmp_pos %= output_strides[j]; tmp_pos %= output_strides[j];
} }
output_pos_array.back() = tmp_pos; output_pos_array.back() = tmp_pos;
size_t input_pos = output_pos_array[0]; size_t input_pos = output_pos_array[0];
input_pos = input_pos =
(input_pos * input_shape[1]) + (input_pos * input_shape[1]) +
(output_pos_array[1] + (output_pos_array[1] +
(block_size * (output_pos_array[2] % block_size) + output_pos_array[3] % block_size) * output_shape[1]); (block_size * (output_pos_array[2] % block_size) + output_pos_array[3] % block_size) * output_shape[1]);
input_pos = (input_pos * input_shape[2]) + (output_pos_array[2] / block_size); input_pos = (input_pos * input_shape[2]) + (output_pos_array[2] / block_size);
input_pos = (input_pos * input_shape[3]) + (output_pos_array[3] / block_size); input_pos = (input_pos * input_shape[3]) + (output_pos_array[3] / block_size);
output_addr[i] = input_addr[input_pos]; output_addr[i] = input_addr[input_pos];
} }
}; };
CPUKernelUtils::ParallelFor(task, size); CPUKernelUtils::ParallelFor(task, size);
return true; return true;
} }
template <typename T> template <typename T>
void DepthToSpaceCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) { void DepthToSpaceCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) { if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input."; MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input.";
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) { if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output."; MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output.";
} }
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,85 +1,85 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
class DepthToSpaceCPUKernel : public CPUKernel { class DepthToSpaceCPUKernel : public CPUKernel {
public: public:
DepthToSpaceCPUKernel() = default; DepthToSpaceCPUKernel() = default;
~DepthToSpaceCPUKernel() override = default; ~DepthToSpaceCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
void CheckParam(const CNodePtr &kernel_node); void CheckParam(const CNodePtr &kernel_node);
std::vector<size_t> input_shape_; std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;
size_t block_size_; size_t block_size_;
}; };
MS_REG_CPU_KERNEL_T( MS_REG_CPU_KERNEL_T(
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DepthToSpaceCPUKernel, float); DepthToSpaceCPUKernel, float);
MS_REG_CPU_KERNEL_T( MS_REG_CPU_KERNEL_T(
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DepthToSpaceCPUKernel, float16); DepthToSpaceCPUKernel, float16);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
DepthToSpaceCPUKernel, int8_t); DepthToSpaceCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
DepthToSpaceCPUKernel, int16_t); DepthToSpaceCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DepthToSpaceCPUKernel, int); DepthToSpaceCPUKernel, int);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
DepthToSpaceCPUKernel, int64_t); DepthToSpaceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
DepthToSpaceCPUKernel, uint8_t); DepthToSpaceCPUKernel, uint8_t);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
DepthToSpaceCPUKernel, uint16_t); DepthToSpaceCPUKernel, uint16_t);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
DepthToSpaceCPUKernel, uint32_t); DepthToSpaceCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(DepthToSpace, MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
DepthToSpaceCPUKernel, uint64_t); DepthToSpaceCPUKernel, uint64_t);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_

View File

@ -1,102 +1,102 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h" #include "backend/kernel_compiler/cpu/nnacl/fp32/add_fp32.h"
#include "backend/kernel_compiler/cpu/nnacl/errorcode.h" #include "backend/kernel_compiler/cpu/nnacl/errorcode.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "common/thread_pool.h" #include "common/thread_pool.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) { void AddInt(const int *in_0, const int *in_1, int *out, int start, int end) {
int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start); int ret = ElementAddInt(in_0 + start, in_1 + start, out + start, end - start);
if (ret != NNACL_OK) { if (ret != NNACL_OK) {
MS_LOG(EXCEPTION) << "Add failed."; MS_LOG(EXCEPTION) << "Add failed.";
} }
} }
void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) { void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node); CheckParam(kernel_node);
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape); dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape);
dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape); dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape);
dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape); dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc); dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine()); auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::binary>(prim_desc); primitive_ = std::make_shared<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_mem_desc); AddArgument(DNNL_ARG_SRC_0, src0_mem_desc);
AddArgument(DNNL_ARG_SRC_1, src1_mem_desc); AddArgument(DNNL_ARG_SRC_1, src1_mem_desc);
AddArgument(DNNL_ARG_DST, dst_mem_desc); AddArgument(DNNL_ARG_DST, dst_mem_desc);
} }
bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32) { if (dtype_ == kNumberTypeFloat32) {
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr); SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr); SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive(); ExecutePrimitive();
for (size_t index = 2; index < input_num_; ++index) { for (size_t index = 2; index < input_num_; ++index) {
SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr); SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr); SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr); SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive(); ExecutePrimitive();
} }
} else if (dtype_ == kNumberTypeInt32) { } else if (dtype_ == kNumberTypeInt32) {
size_t elements_num = outputs[0]->size / sizeof(int); size_t elements_num = outputs[0]->size / sizeof(int);
const auto input_0 = reinterpret_cast<int *>(inputs[0]->addr); const auto input_0 = reinterpret_cast<int *>(inputs[0]->addr);
const auto input_1 = reinterpret_cast<int *>(inputs[1]->addr); const auto input_1 = reinterpret_cast<int *>(inputs[1]->addr);
auto output = reinterpret_cast<int *>(outputs[0]->addr); auto output = reinterpret_cast<int *>(outputs[0]->addr);
auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2); auto task_0 = std::bind(AddInt, input_0, input_1, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task_0, elements_num); CPUKernelUtils::ParallelFor(task_0, elements_num);
for (size_t index = 2; index < input_num_; ++index) { for (size_t index = 2; index < input_num_; ++index) {
const auto input = reinterpret_cast<int *>(inputs[index]->addr); const auto input = reinterpret_cast<int *>(inputs[index]->addr);
auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2); auto task = std::bind(AddInt, input, output, output, std::placeholders::_1, std::placeholders::_2);
CPUKernelUtils::ParallelFor(task, elements_num); CPUKernelUtils::ParallelFor(task, elements_num);
} }
} else { } else {
MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString(); MS_LOG(EXCEPTION) << "AddN only support float32 and int32, but got " << TypeIdToType(dtype_)->ToString();
} }
return true; return true;
} }
void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) { void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto src0_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto src0_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); auto dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (src0_shape != dst_shape) { if (src0_shape != dst_shape) {
MS_LOG(EXCEPTION) << "AddN output shape must be equal to input shape."; MS_LOG(EXCEPTION) << "AddN output shape must be equal to input shape.";
} }
for (size_t index = 1; index < input_num_; ++index) { for (size_t index = 1; index < input_num_; ++index) {
auto src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index); auto src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (src0_shape != src_shape) { if (src0_shape != src_shape) {
MS_LOG(EXCEPTION) << "AddN input shapes must be equal."; MS_LOG(EXCEPTION) << "AddN input shapes must be equal.";
} }
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) { if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output."; MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output.";
} }
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,51 +1,51 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class AddNCPUKernel : public MKLCPUKernel { class AddNCPUKernel : public MKLCPUKernel {
public: public:
AddNCPUKernel() = default; AddNCPUKernel() = default;
~AddNCPUKernel() override = default; ~AddNCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
void CheckParam(const CNodePtr &kernel_node); void CheckParam(const CNodePtr &kernel_node);
size_t input_num_{0}; size_t input_num_{0};
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;
TypeId dtype_{kNumberTypeFloat32}; TypeId dtype_{kNumberTypeFloat32};
}; };
MS_REG_CPU_KERNEL(AddN, MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AddNCPUKernel); AddNCPUKernel);
MS_REG_CPU_KERNEL(AddN, MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AddNCPUKernel); AddNCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_

View File

@ -1,178 +1,178 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/lstm_cpu_kernel.h"
#include <string> #include <string>
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
const int kMaxLSTMLayer = 100; const int kMaxLSTMLayer = 100;
const int kOutputWorkSpaceIndex = 3; const int kOutputWorkSpaceIndex = 3;
void LstmCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { void LstmCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node); CPUKernel::InitInputOutputSize(kernel_node);
output_size_list_[kOutputWorkSpaceIndex] = reserve_size_; output_size_list_[kOutputWorkSpaceIndex] = reserve_size_;
auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0); auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
auto output_types = std::vector<TypeId>(output_num, output_type); auto output_types = std::vector<TypeId>(output_num, output_type);
std::vector<std::vector<size_t>> output_shapes; std::vector<std::vector<size_t>> output_shapes;
for (size_t output_index = 0; output_index < output_num; ++output_index) { for (size_t output_index = 0; output_index < output_num; ++output_index) {
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(kernel_node, output_index); std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(kernel_node, output_index);
output_shapes.emplace_back(shape); output_shapes.emplace_back(shape);
} }
size_t len = reserve_size_ / 4; size_t len = reserve_size_ / 4;
output_shapes[kOutputWorkSpaceIndex] = {len, 1}; output_shapes[kOutputWorkSpaceIndex] = {len, 1};
AnfAlgo::SetOutputInferTypeAndShape(output_types, output_shapes, kernel_node.get()); AnfAlgo::SetOutputInferTypeAndShape(output_types, output_shapes, kernel_node.get());
} }
void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) { void LstmCPUKernel::InitKernel(const CNodePtr &kernel_node) {
#ifdef PLATFORM_86 #ifdef PLATFORM_86
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
#endif #endif
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
using tag = dnnl::memory::format_tag; using tag = dnnl::memory::format_tag;
using dim = dnnl::memory::dims; using dim = dnnl::memory::dims;
CheckParam(kernel_node); CheckParam(kernel_node);
auto eng = MKLKernelEngine::Get().engine(); auto eng = MKLKernelEngine::Get().engine();
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) { if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat; direction = dnnl::rnn_direction::bidirectional_concat;
} }
dim src_dims = {seq_len_, batch_size_, input_size_}; dim src_dims = {seq_len_, batch_size_, input_size_};
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_};
dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_};
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo);
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
if (!kernel_node->HasAttr(kAttrIsTraining)) { if (!kernel_node->HasAttr(kAttrIsTraining)) {
is_training = true; is_training = true;
} else { } else {
is_training = GetValue<bool>(kernel_node->GetAttr(kAttrIsTraining)); is_training = GetValue<bool>(kernel_node->GetAttr(kAttrIsTraining));
} }
auto prop_kind = dnnl::prop_kind::forward_training; auto prop_kind = dnnl::prop_kind::forward_training;
if (!is_training) { if (!is_training) {
prop_kind = dnnl::prop_kind::forward_inference; prop_kind = dnnl::prop_kind::forward_inference;
} }
auto desc = std::make_shared<dnnl::lstm_forward::desc>( auto desc = std::make_shared<dnnl::lstm_forward::desc>(
prop_kind, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), prop_kind, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc); formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc);
prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng); prim_desc_ = dnnl::lstm_forward::primitive_desc(*desc, eng);
primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_); primitive_ = std::make_shared<dnnl::lstm_forward>(prim_desc_);
if (is_training) { if (is_training) {
reserve_size_ = static_cast<size_t>(prim_desc_.workspace_desc().get_size()); reserve_size_ = static_cast<size_t>(prim_desc_.workspace_desc().get_size());
AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc()); AddArgument(DNNL_ARG_WORKSPACE, prim_desc_.workspace_desc());
} else { } else {
reserve_size_ = 1; reserve_size_ = 1;
} }
AddArgument(DNNL_ARG_SRC_LAYER, src_desc); AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc()); AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_desc_.weights_layer_desc());
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc()); AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_desc_.weights_iter_desc());
AddArgument(DNNL_ARG_BIAS, bias_desc); AddArgument(DNNL_ARG_BIAS, bias_desc);
AddArgument(DNNL_ARG_DST_LAYER, dst_desc); AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
} }
void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) { void LstmCPUKernel::CheckParam(const CNodePtr &kernel_node) {
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
input_size_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "input_size")); input_size_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "input_size"));
hidden_size_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "hidden_size")); hidden_size_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "hidden_size"));
num_layers_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "num_layers")); num_layers_ = static_cast<int>(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "num_layers"));
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias"); has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
batch_size_ = SizeToInt(src_shape[1]); batch_size_ = SizeToInt(src_shape[1]);
seq_len_ = SizeToInt(src_shape[0]); seq_len_ = SizeToInt(src_shape[0]);
num_directions_ = 1; num_directions_ = 1;
if (bidirectional_) { if (bidirectional_) {
num_directions_ = 2; num_directions_ = 2;
} }
const int gate_size = 4 * hidden_size_; const int gate_size = 4 * hidden_size_;
if (num_layers_ <= 0) { if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "Layers must be greater than zero!"; MS_LOG(EXCEPTION) << "Layers must be greater than zero!";
} }
if (num_layers_ > kMaxLSTMLayer) { if (num_layers_ > kMaxLSTMLayer) {
MS_LOG(EXCEPTION) << "Layers must be lower than 100!"; MS_LOG(EXCEPTION) << "Layers must be lower than 100!";
} }
for (int i = 0; i < num_layers_; ++i) { for (int i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
weight_h_size_ += gate_size * hidden_size_; weight_h_size_ += gate_size * hidden_size_;
} }
weight_size_ = weight_size_ * num_directions_; weight_size_ = weight_size_ * num_directions_;
weight_h_size_ = weight_h_size_ * num_directions_; weight_h_size_ = weight_h_size_ * num_directions_;
if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) { if (num_directions_ * num_layers_ != SizeToInt(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "Error iteration shape!"; MS_LOG(EXCEPTION) << "Error iteration shape!";
} }
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
MS_LOG(EXCEPTION) << "Lstm only support 3-D input!"; MS_LOG(EXCEPTION) << "Lstm only support 3-D input!";
} }
} }
bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, bool LstmCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
using dt = dnnl::memory::data_type; using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag; using tag = dnnl::memory::format_tag;
auto eng = MKLKernelEngine::Get().engine(); auto eng = MKLKernelEngine::Get().engine();
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng); auto weights_memory = dnnl::memory(prim_desc_.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng); auto weights_h_memory = dnnl::memory(prim_desc_.weights_iter_desc(), eng);
user_weights_memory.set_data_handle(inputs[3]->addr); user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
Reorder(&user_weights_memory, &weights_memory); Reorder(&user_weights_memory, &weights_memory);
Reorder(&user_weights_h_memory, &weights_h_memory); Reorder(&user_weights_h_memory, &weights_h_memory);
auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng); auto bias_memory = dnnl::memory(prim_desc_.bias_desc(), eng);
if (has_bias_) { if (has_bias_) {
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
} else { } else {
if (memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0, if (memset_s(bias_memory.get_data_handle(), prim_desc_.bias_desc().get_size(), 0,
prim_desc_.bias_desc().get_size())) { prim_desc_.bias_desc().get_size())) {
MS_LOG(EXCEPTION) << "Bias memset error"; MS_LOG(EXCEPTION) << "Bias memset error";
} }
} }
// set handle // set handle
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr); SetArgumentHandle(DNNL_ARG_DST_LAYER, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr); SetArgumentHandle(DNNL_ARG_DST_ITER, outputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr); SetArgumentHandle(DNNL_ARG_DST_ITER_C, outputs[2]->addr);
if (is_training) { if (is_training) {
SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr); SetArgumentHandle(DNNL_ARG_WORKSPACE, outputs[3]->addr);
} }
ExecutePrimitive(); ExecutePrimitive();
return true; return true;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,76 +1,76 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_
#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64) #if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64)
#define PLATFORM_86 #define PLATFORM_86
#endif #endif
#ifdef PLATFORM_86 #ifdef PLATFORM_86
#include <pmmintrin.h> #include <pmmintrin.h>
#endif #endif
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class LstmCPUKernel : public MKLCPUKernel { class LstmCPUKernel : public MKLCPUKernel {
public: public:
LstmCPUKernel() = default; LstmCPUKernel() = default;
~LstmCPUKernel() override = default; ~LstmCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
protected: protected:
void InitInputOutputSize(const CNodePtr &kernel_node) override; void InitInputOutputSize(const CNodePtr &kernel_node) override;
private: private:
void CheckParam(const CNodePtr &kernel_node); void CheckParam(const CNodePtr &kernel_node);
int weight_size_ = 0; int weight_size_ = 0;
int weight_h_size_ = 0; int weight_h_size_ = 0;
int input_size_; int input_size_;
int hidden_size_; int hidden_size_;
int num_layers_; int num_layers_;
int batch_size_; int batch_size_;
int seq_len_; int seq_len_;
int num_directions_; int num_directions_;
bool bidirectional_; bool bidirectional_;
bool has_bias_; bool has_bias_;
size_t reserve_size_; size_t reserve_size_;
bool is_training; bool is_training;
dnnl::memory::dims weights_dims_; dnnl::memory::dims weights_dims_;
dnnl::memory::dims weights_h_dims_; dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_; dnnl::memory::dims bias_dims_;
dnnl::lstm_forward::primitive_desc prim_desc_; dnnl::lstm_forward::primitive_desc prim_desc_;
}; };
MS_REG_CPU_KERNEL(LSTM, MS_REG_CPU_KERNEL(LSTM,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
LstmCPUKernel); LstmCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H

View File

@ -1,218 +1,218 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h"
#include <cstring> #include <cstring>
#include <string> #include <string>
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
const int kMaxLSTMLayer = 100; const int kMaxLSTMLayer = 100;
const int kInputWorkSpaceIndex = 10; const int kInputWorkSpaceIndex = 10;
void LSTMGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { void LSTMGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node); CPUKernel::InitInputOutputSize(kernel_node);
input_size_list_[kInputWorkSpaceIndex] = reserve_size_; input_size_list_[kInputWorkSpaceIndex] = reserve_size_;
} }
void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { void LSTMGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
using tag = dnnl::memory::format_tag; using tag = dnnl::memory::format_tag;
using dim = dnnl::memory::dims; using dim = dnnl::memory::dims;
CheckParam(kernel_node); CheckParam(kernel_node);
auto eng = MKLKernelEngine::Get().engine(); auto eng = MKLKernelEngine::Get().engine();
dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional; dnnl::rnn_direction direction = dnnl::rnn_direction::unidirectional;
if (bidirectional_) { if (bidirectional_) {
direction = dnnl::rnn_direction::bidirectional_concat; direction = dnnl::rnn_direction::bidirectional_concat;
} }
dim src_dims = {seq_len_, batch_size_, input_size_}; dim src_dims = {seq_len_, batch_size_, input_size_};
dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim src_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_}; weights_dims_ = {num_layers_, num_directions_, input_size_, 4, hidden_size_};
weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_}; weights_h_dims_ = {num_layers_, num_directions_, hidden_size_, 4, hidden_size_};
bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_}; bias_dims_ = {num_layers_, num_directions_, 4, hidden_size_};
dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_}; dim dst_dims = {seq_len_, batch_size_, static_cast<int64_t>(hidden_size_) * num_directions_};
dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_h_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_}; dim dst_c_dims = {num_layers_, num_directions_, batch_size_, hidden_size_};
dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc); dnnl::memory::desc src_desc = formatted_md(src_dims, tag::tnc);
dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc); dnnl::memory::desc src_h_desc = formatted_md(src_h_dims, tag::ldnc);
dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc); dnnl::memory::desc src_c_desc = formatted_md(src_c_dims, tag::ldnc);
dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo); dnnl::memory::desc bias_desc = formatted_md(bias_dims_, tag::ldgo);
dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc); dnnl::memory::desc dst_desc = formatted_md(dst_dims, tag::tnc);
dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc); dnnl::memory::desc dst_h_desc = formatted_md(dst_h_dims, tag::ldnc);
dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc); dnnl::memory::desc dst_c_desc = formatted_md(dst_c_dims, tag::ldnc);
auto forward_desc = std::make_shared<dnnl::lstm_forward::desc>( auto forward_desc = std::make_shared<dnnl::lstm_forward::desc>(
dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc, dnnl::prop_kind::forward_training, direction, src_desc, src_h_desc, src_c_desc,
formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc,
dst_c_desc); dst_c_desc);
auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng); auto prim_forward_desc = dnnl::lstm_forward::primitive_desc(*forward_desc, eng);
auto backward_desc = std::make_shared<dnnl::lstm_backward::desc>( auto backward_desc = std::make_shared<dnnl::lstm_backward::desc>(
dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any), dnnl::prop_kind::backward, direction, src_desc, src_h_desc, src_c_desc, formatted_md(weights_dims_, tag::any),
formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc, formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, dst_h_desc, dst_c_desc, src_desc, src_h_desc,
src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc, src_c_desc, formatted_md(weights_dims_, tag::any), formatted_md(weights_h_dims_, tag::any), bias_desc, dst_desc,
dst_h_desc, dst_c_desc); dst_h_desc, dst_c_desc);
prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc); prim_backward_desc_ = dnnl::lstm_backward::primitive_desc(*backward_desc, eng, prim_forward_desc);
primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_); primitive_ = std::make_shared<dnnl::lstm_backward>(prim_backward_desc_);
reserve_size_ = static_cast<size_t>(prim_forward_desc.workspace_desc().get_size()); reserve_size_ = static_cast<size_t>(prim_forward_desc.workspace_desc().get_size());
AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc()); AddArgument(DNNL_ARG_WORKSPACE, prim_forward_desc.workspace_desc());
AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc); AddArgumentOp(src_desc, src_h_desc, src_c_desc, bias_desc, dst_desc, dst_h_desc, dst_c_desc);
} }
void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc, void LSTMGradCPUKernel::AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc, const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc,
const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc, const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc,
const dnnl::memory::desc &dst_c_desc) { const dnnl::memory::desc &dst_c_desc) {
AddArgument(DNNL_ARG_SRC_LAYER, src_desc); AddArgument(DNNL_ARG_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_SRC_ITER, src_h_desc); AddArgument(DNNL_ARG_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc); AddArgument(DNNL_ARG_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc()); AddArgument(DNNL_ARG_WEIGHTS_LAYER, prim_backward_desc_.weights_layer_desc());
AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc()); AddArgument(DNNL_ARG_WEIGHTS_ITER, prim_backward_desc_.weights_iter_desc());
AddArgument(DNNL_ARG_BIAS, bias_desc); AddArgument(DNNL_ARG_BIAS, bias_desc);
AddArgument(DNNL_ARG_DST_LAYER, dst_desc); AddArgument(DNNL_ARG_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DST_ITER, dst_h_desc); AddArgument(DNNL_ARG_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc); AddArgument(DNNL_ARG_DST_ITER_C, dst_c_desc);
AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc); AddArgument(DNNL_ARG_DIFF_SRC_LAYER, src_desc);
AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc); AddArgument(DNNL_ARG_DIFF_SRC_ITER, src_h_desc);
AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc); AddArgument(DNNL_ARG_DIFF_SRC_ITER_C, src_c_desc);
AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc()); AddArgument(DNNL_ARG_DIFF_WEIGHTS_LAYER, prim_backward_desc_.diff_weights_layer_desc());
AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc()); AddArgument(DNNL_ARG_DIFF_WEIGHTS_ITER, prim_backward_desc_.diff_weights_iter_desc());
AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc); AddArgument(DNNL_ARG_DIFF_BIAS, bias_desc);
AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc); AddArgument(DNNL_ARG_DIFF_DST_LAYER, dst_desc);
AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc); AddArgument(DNNL_ARG_DIFF_DST_ITER, dst_h_desc);
AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc); AddArgument(DNNL_ARG_DIFF_DST_ITER_C, dst_c_desc);
} }
void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { void LSTMGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); std::vector<size_t> src_h_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2); std::vector<size_t> src_c_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional"); bidirectional_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "bidirectional");
input_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "input_size"); input_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "input_size");
hidden_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "hidden_size"); hidden_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "hidden_size");
num_layers_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "num_layers"); num_layers_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "num_layers");
has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias"); has_bias_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "has_bias");
batch_size_ = SizeToInt(src_shape[1]); batch_size_ = SizeToInt(src_shape[1]);
seq_len_ = SizeToInt(src_shape[0]); seq_len_ = SizeToInt(src_shape[0]);
num_directions_ = 1; num_directions_ = 1;
if (bidirectional_) { if (bidirectional_) {
num_directions_ = 2; num_directions_ = 2;
} }
const int64_t gate_size = 4 * hidden_size_; const int64_t gate_size = 4 * hidden_size_;
if (num_layers_ <= 0) { if (num_layers_ <= 0) {
MS_LOG(EXCEPTION) << "Layers must be greater than zero!"; MS_LOG(EXCEPTION) << "Layers must be greater than zero!";
} }
if (num_layers_ > kMaxLSTMLayer) { if (num_layers_ > kMaxLSTMLayer) {
MS_LOG(EXCEPTION) << "Layers must be lower than 100!"; MS_LOG(EXCEPTION) << "Layers must be lower than 100!";
} }
for (int64_t i = 0; i < num_layers_; ++i) { for (int64_t i = 0; i < num_layers_; ++i) {
weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_); weight_size_ += gate_size * (i == 0 ? input_size_ : hidden_size_ * num_directions_);
weight_h_size_ += gate_size * hidden_size_; weight_h_size_ += gate_size * hidden_size_;
} }
weight_size_ = weight_size_ * num_directions_; weight_size_ = weight_size_ * num_directions_;
weight_h_size_ = weight_h_size_ * num_directions_; weight_h_size_ = weight_h_size_ * num_directions_;
if (num_directions_ * num_layers_ != SizeToLong(src_h_shape[0])) { if (num_directions_ * num_layers_ != SizeToLong(src_h_shape[0])) {
MS_LOG(EXCEPTION) << "Error iteration shape!"; MS_LOG(EXCEPTION) << "Error iteration shape!";
} }
if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) { if (src_shape.size() != 3 || src_h_shape.size() != 3 || src_c_shape.size() != 3) {
MS_LOG(EXCEPTION) << "Lstm only support 3-D input!"; MS_LOG(EXCEPTION) << "Lstm only support 3-D input!";
} }
} }
void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs, void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs, const std::vector<kernel::AddressPtr> &outputs,
const dnnl::memory &weights_memory, const dnnl::memory &weights_h_memory, const dnnl::memory &weights_memory, const dnnl::memory &weights_h_memory,
const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory, const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory,
const dnnl::memory &diff_weights_h_memory, const dnnl::memory &diff_weights_h_memory,
const dnnl::memory &diff_bias_memory) { const dnnl::memory &diff_bias_memory) {
SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr); SetArgumentHandle(DNNL_ARG_SRC_LAYER, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr); SetArgumentHandle(DNNL_ARG_SRC_ITER, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr); SetArgumentHandle(DNNL_ARG_SRC_ITER_C, inputs[2]->addr);
SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_WEIGHTS_LAYER, weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_WEIGHTS_ITER, weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_BIAS, bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr); SetArgumentHandle(DNNL_ARG_DST_LAYER, inputs[4]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr); SetArgumentHandle(DNNL_ARG_DST_ITER, inputs[5]->addr);
SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr); SetArgumentHandle(DNNL_ARG_DST_ITER_C, inputs[6]->addr);
SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr); SetArgumentHandle(DNNL_ARG_WORKSPACE, inputs[10]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr); SetArgumentHandle(DNNL_ARG_DIFF_SRC_LAYER, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr); SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER, outputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr); SetArgumentHandle(DNNL_ARG_DIFF_SRC_ITER_C, outputs[2]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_h_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle()); SetArgumentHandle(DNNL_ARG_DIFF_BIAS, diff_bias_memory.get_data_handle());
SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr); SetArgumentHandle(DNNL_ARG_DIFF_DST_LAYER, inputs[7]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr); SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER, inputs[8]->addr);
SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr); SetArgumentHandle(DNNL_ARG_DIFF_DST_ITER_C, inputs[9]->addr);
} }
void LSTMGradCPUKernel::ResetMemory(const dnnl::memory &mem, const string name) const { void LSTMGradCPUKernel::ResetMemory(const dnnl::memory &mem, const string name) const {
if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) { if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) {
MS_LOG(EXCEPTION) << name << " memset error"; MS_LOG(EXCEPTION) << name << " memset error";
} }
} }
bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, bool LSTMGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
using dt = dnnl::memory::data_type; using dt = dnnl::memory::data_type;
using tag = dnnl::memory::format_tag; using tag = dnnl::memory::format_tag;
auto eng = MKLKernelEngine::Get().engine(); auto eng = MKLKernelEngine::Get().engine();
// construct fw memory // construct fw memory
auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); auto user_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); auto user_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng); auto weights_memory = dnnl::memory(prim_backward_desc_.weights_layer_desc(), eng);
auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng); auto weights_h_memory = dnnl::memory(prim_backward_desc_.weights_iter_desc(), eng);
auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng); auto bias_memory = dnnl::memory(prim_backward_desc_.bias_desc(), eng);
user_weights_memory.set_data_handle(inputs[3]->addr); user_weights_memory.set_data_handle(inputs[3]->addr);
user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_); user_weights_h_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_);
Reorder(&user_weights_memory, &weights_memory); Reorder(&user_weights_memory, &weights_memory);
Reorder(&user_weights_h_memory, &weights_h_memory); Reorder(&user_weights_h_memory, &weights_h_memory);
if (has_bias_) { if (has_bias_) {
bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_); bias_memory.set_data_handle(reinterpret_cast<float *>(inputs[3]->addr) + weight_size_ + weight_h_size_);
} else { } else {
if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0, if (memset_s(bias_memory.get_data_handle(), prim_backward_desc_.bias_desc().get_size(), 0,
prim_backward_desc_.bias_desc().get_size())) { prim_backward_desc_.bias_desc().get_size())) {
MS_LOG(EXCEPTION) << "Bias memset error"; MS_LOG(EXCEPTION) << "Bias memset error";
} }
} }
// construct bw memory // construct bw memory
auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng); auto diff_weights_memory = dnnl::memory(prim_backward_desc_.diff_weights_layer_desc(), eng);
auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng); auto diff_weights_h_memory = dnnl::memory(prim_backward_desc_.diff_weights_iter_desc(), eng);
auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng); auto diff_bias_memory = dnnl::memory(prim_backward_desc_.diff_bias_desc(), eng);
auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng); auto user_diff_weights_memory = dnnl::memory(dnnl::memory::desc{{weights_dims_}, dt::f32, tag::ldgoi}, eng);
auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng);
user_diff_weights_memory.set_data_handle(outputs[3]->addr); user_diff_weights_memory.set_data_handle(outputs[3]->addr);
user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_); user_diff_weights_h_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_);
ResetMemory(user_diff_weights_memory, "user weights grad"); ResetMemory(user_diff_weights_memory, "user weights grad");
ResetMemory(user_diff_weights_h_memory, "user weights iter grad"); ResetMemory(user_diff_weights_h_memory, "user weights iter grad");
ResetMemory(diff_weights_memory, "weights grad"); ResetMemory(diff_weights_memory, "weights grad");
ResetMemory(diff_weights_h_memory, "weights iter grad"); ResetMemory(diff_weights_h_memory, "weights iter grad");
if (has_bias_) { if (has_bias_) {
diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_); diff_bias_memory.set_data_handle(reinterpret_cast<float *>(outputs[3]->addr) + weight_size_ + weight_h_size_);
} }
if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0, if (memset_s(diff_bias_memory.get_data_handle(), prim_backward_desc_.diff_bias_desc().get_size(), 0,
prim_backward_desc_.diff_bias_desc().get_size())) { prim_backward_desc_.diff_bias_desc().get_size())) {
MS_LOG(EXCEPTION) << "Bias grad memset error"; MS_LOG(EXCEPTION) << "Bias grad memset error";
} }
SetArgumentHandleOp(inputs, outputs, weights_memory, weights_h_memory, bias_memory, diff_weights_memory, SetArgumentHandleOp(inputs, outputs, weights_memory, weights_h_memory, bias_memory, diff_weights_memory,
diff_weights_h_memory, diff_bias_memory); diff_weights_h_memory, diff_bias_memory);
ExecutePrimitive(); ExecutePrimitive();
Reorder(&diff_weights_memory, &user_diff_weights_memory); Reorder(&diff_weights_memory, &user_diff_weights_memory);
Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory); Reorder(&diff_weights_h_memory, &user_diff_weights_h_memory);
return true; return true;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,87 +1,87 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class LSTMGradCPUKernel : public MKLCPUKernel { class LSTMGradCPUKernel : public MKLCPUKernel {
public: public:
LSTMGradCPUKernel() = default; LSTMGradCPUKernel() = default;
~LSTMGradCPUKernel() override = default; ~LSTMGradCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
protected: protected:
void InitInputOutputSize(const CNodePtr &kernel_node) override; void InitInputOutputSize(const CNodePtr &kernel_node) override;
private: private:
void AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc, void AddArgumentOp(const dnnl::memory::desc &src_desc, const dnnl::memory::desc &src_h_desc,
const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc, const dnnl::memory::desc &src_c_desc, const dnnl::memory::desc &bias_desc,
const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc, const dnnl::memory::desc &dst_desc, const dnnl::memory::desc &dst_h_desc,
const dnnl::memory::desc &dst_c_desc); const dnnl::memory::desc &dst_c_desc);
void SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs, void SetArgumentHandleOp(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs, const dnnl::memory &weights_memory, const std::vector<kernel::AddressPtr> &outputs, const dnnl::memory &weights_memory,
const dnnl::memory &weights_h_memory, const dnnl::memory &bias_memory, const dnnl::memory &weights_h_memory, const dnnl::memory &bias_memory,
const dnnl::memory &diff_weights_memory, const dnnl::memory &diff_weights_h_memory, const dnnl::memory &diff_weights_memory, const dnnl::memory &diff_weights_h_memory,
const dnnl::memory &diff_bias_memory); const dnnl::memory &diff_bias_memory);
void ResetMemory(const dnnl::memory &mem, const string name) const; void ResetMemory(const dnnl::memory &mem, const string name) const;
void CheckParam(const CNodePtr &kernel_node); void CheckParam(const CNodePtr &kernel_node);
int64_t weight_size_ = 0; int64_t weight_size_ = 0;
int64_t weight_h_size_ = 0; int64_t weight_h_size_ = 0;
int64_t input_size_; int64_t input_size_;
int64_t hidden_size_; int64_t hidden_size_;
int64_t num_layers_; int64_t num_layers_;
int64_t batch_size_; int64_t batch_size_;
int64_t seq_len_; int64_t seq_len_;
int num_directions_; int num_directions_;
bool bidirectional_; bool bidirectional_;
bool has_bias_; bool has_bias_;
size_t reserve_size_; size_t reserve_size_;
dnnl::memory::dims weights_dims_; dnnl::memory::dims weights_dims_;
dnnl::memory::dims weights_h_dims_; dnnl::memory::dims weights_h_dims_;
dnnl::memory::dims bias_dims_; dnnl::memory::dims bias_dims_;
dnnl::lstm_backward::primitive_desc prim_backward_desc_; dnnl::lstm_backward::primitive_desc prim_backward_desc_;
}; };
MS_REG_CPU_KERNEL(LSTMGrad, MS_REG_CPU_KERNEL(LSTMGrad,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
LSTMGradCPUKernel); LSTMGradCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_GRAD_CPU_KERNEL_H_

View File

@ -1,99 +1,99 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/softmax_cross_entropy_with_logits_cpu_kernel.h"
#include <numeric> #include <numeric>
#include <functional> #include <functional>
#include <cmath> #include <cmath>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h" #include "utils/ms_utils.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) { void SoftmaxCrossEntropyWithLogitsCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node); CPUKernel::InitInputOutputSize(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
size_t type_size = sizeof(float); size_t type_size = sizeof(float);
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>()); size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
workspace_size_list_.emplace_back(tensor_size); workspace_size_list_.emplace_back(tensor_size);
} }
void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) { void SoftmaxCrossEntropyWithLogitsCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
dnnl::memory::dims mem_dims; dnnl::memory::dims mem_dims;
mem_dims.insert(mem_dims.end(), shape.begin(), shape.end()); mem_dims.insert(mem_dims.end(), shape.begin(), shape.end());
if (mem_dims.size() != 2) { if (mem_dims.size() != 2) {
MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size(); MS_LOG(EXCEPTION) << "SoftmaxCrossEntropyWithLogits kernel dims invalid " << mem_dims.size();
} }
batch_size_ = shape[0]; batch_size_ = shape[0];
class_num_ = shape[1]; class_num_ = shape[1];
if (batch_size_ == 0 || class_num_ == 0) { if (batch_size_ == 0 || class_num_ == 0) {
MS_LOG(EXCEPTION) << "Invalid batch size or class num input!"; MS_LOG(EXCEPTION) << "Invalid batch size or class num input!";
} }
dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc); dnnl::memory::desc mem_desc(mem_dims, dnnl::memory::data_type::f32, dnnl::memory::format_tag::nc);
dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1); dnnl::softmax_forward::desc desc = dnnl::softmax_forward::desc(dnnl::prop_kind::forward_training, mem_desc, 1);
auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine()); auto prim_desc = dnnl::softmax_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc); primitive_ = std::make_shared<dnnl::softmax_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, mem_desc); AddArgument(DNNL_ARG_SRC, mem_desc);
AddArgument(DNNL_ARG_DST, mem_desc); AddArgument(DNNL_ARG_DST, mem_desc);
} }
void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels, void SoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const float *logits, const float *labels,
float *output1, float *output2) const { float *output1, float *output2) const {
float epsilon = 1e-6; float epsilon = 1e-6;
for (size_t i = 0; i < batch_size_; ++i) { for (size_t i = 0; i < batch_size_; ++i) {
output1[i] = 0; output1[i] = 0;
float loss = 0.0; float loss = 0.0;
for (size_t j = 0; j < class_num_; ++j) { for (size_t j = 0; j < class_num_; ++j) {
float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]); float logit = logf(logits[i * class_num_ + j] <= 0.0 ? epsilon : logits[i * class_num_ + j]);
output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j]; output2[i * class_num_ + j] = logits[i * class_num_ + j] - labels[i * class_num_ + j];
loss += labels[i * class_num_ + j] * logit; loss += labels[i * class_num_ + j] * logit;
} }
output1[i] = -loss; output1[i] = -loss;
} }
} }
bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool SoftmaxCrossEntropyWithLogitsCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace, const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.empty() || workspace.empty() || outputs.empty()) { if (inputs.empty() || workspace.empty() || outputs.empty()) {
MS_LOG(EXCEPTION) << "Error input output size!"; MS_LOG(EXCEPTION) << "Error input output size!";
} }
size_t batch_float_size = batch_size_ * sizeof(float); size_t batch_float_size = batch_size_ * sizeof(float);
size_t batch_class_float_size = class_num_ * batch_float_size; size_t batch_class_float_size = class_num_ * batch_float_size;
if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size || if (inputs[0]->size != workspace[0]->size || inputs[0]->size != batch_class_float_size ||
inputs[1]->size != batch_class_float_size) { inputs[1]->size != batch_class_float_size) {
MS_LOG(EXCEPTION) << "Error input data size!"; MS_LOG(EXCEPTION) << "Error input data size!";
} }
if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) { if (outputs[1]->size != batch_class_float_size || outputs[0]->size != batch_float_size) {
MS_LOG(EXCEPTION) << "Error output data size!"; MS_LOG(EXCEPTION) << "Error output data size!";
} }
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr); SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr); SetArgumentHandle(DNNL_ARG_DST, workspace[0]->addr);
ExecutePrimitive(); ExecutePrimitive();
auto labels = reinterpret_cast<float *>(inputs[1]->addr); auto labels = reinterpret_cast<float *>(inputs[1]->addr);
auto logits = reinterpret_cast<float *>(workspace[0]->addr); auto logits = reinterpret_cast<float *>(workspace[0]->addr);
auto output1 = reinterpret_cast<float *>(outputs[0]->addr); auto output1 = reinterpret_cast<float *>(outputs[0]->addr);
auto output2 = reinterpret_cast<float *>(outputs[1]->addr); auto output2 = reinterpret_cast<float *>(outputs[1]->addr);
ForwardPostExecute(logits, labels, output1, output2); ForwardPostExecute(logits, labels, output1, output2);
return true; return true;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,53 +1,53 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h" #include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel { class SoftmaxCrossEntropyWithLogitsCPUKernel : public MKLCPUKernel {
public: public:
SoftmaxCrossEntropyWithLogitsCPUKernel() = default; SoftmaxCrossEntropyWithLogitsCPUKernel() = default;
~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default; ~SoftmaxCrossEntropyWithLogitsCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
protected: protected:
void InitInputOutputSize(const CNodePtr &kernel_node) override; void InitInputOutputSize(const CNodePtr &kernel_node) override;
private: private:
void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const; void ForwardPostExecute(const float *logits, const float *labels, float *output1, float *output2) const;
size_t class_num_{0}; size_t class_num_{0};
size_t batch_size_{0}; size_t batch_size_{0};
}; };
MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits, MS_REG_CPU_KERNEL(SoftmaxCrossEntropyWithLogits,
KernelAttr() KernelAttr()
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32), .AddOutputAttr(kNumberTypeFloat32),
SoftmaxCrossEntropyWithLogitsCPUKernel); SoftmaxCrossEntropyWithLogitsCPUKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS_CPU_KERNEL_H_

View File

@ -1,59 +1,59 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
#include "ps/util.h" #include "ps/util.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
namespace ps { namespace ps {
using mindspore::ps::Util; using mindspore::ps::Util;
class PServerKernel { class PServerKernel {
public: public:
PServerKernel(size_t rank_id, size_t pserver_num, size_t worker_num) PServerKernel(size_t rank_id, size_t pserver_num, size_t worker_num)
: rank_id_(rank_id), pserver_num_(pserver_num), worker_num_(worker_num) {} : rank_id_(rank_id), pserver_num_(pserver_num), worker_num_(worker_num) {}
~PServerKernel() = default; ~PServerKernel() = default;
PServerKernel(const PServerKernel &) = delete; PServerKernel(const PServerKernel &) = delete;
PServerKernel &operator=(const PServerKernel &) = delete; PServerKernel &operator=(const PServerKernel &) = delete;
virtual void InitKernel(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) {} virtual void InitKernel(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) {}
virtual void InitKernel(const CNodePtr &cnode, virtual void InitKernel(const CNodePtr &cnode,
const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) {} const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) {}
virtual void ReInit(const std::vector<std::vector<size_t>> &) {} virtual void ReInit(const std::vector<std::vector<size_t>> &) {}
virtual bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, virtual bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) = 0; const std::vector<AddressPtr> &outputs) = 0;
virtual void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, virtual void UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals,
size_t ids_size) {} size_t ids_size) {}
virtual const std::vector<size_t> &input_sizes() const = 0; virtual const std::vector<size_t> &input_sizes() const = 0;
virtual const std::vector<size_t> &output_sizes() const = 0; virtual const std::vector<size_t> &output_sizes() const = 0;
virtual const std::vector<size_t> &workspace_sizes() const = 0; virtual const std::vector<size_t> &workspace_sizes() const = 0;
protected: protected:
virtual void ReInit(const std::vector<AddressPtr> &) {} virtual void ReInit(const std::vector<AddressPtr> &) {}
void Shard(std::vector<size_t> *shape, int axis); void Shard(std::vector<size_t> *shape, int axis);
size_t rank_id_; size_t rank_id_;
size_t pserver_num_; size_t pserver_num_;
size_t worker_num_; size_t worker_num_;
}; };
} // namespace ps } // namespace ps
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_PS_PSERVER_KERNEL_H_

View File

@ -1,138 +1,138 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/reduce_cpu_kernel.h" #include "backend/kernel_compiler/cpu/reduce_cpu_kernel.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
void ReduceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { void ReduceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS); auto axis_addr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(AXIS);
if (axis_addr->isa<ValueTuple>() || axis_addr->isa<ValueList>()) { if (axis_addr->isa<ValueTuple>() || axis_addr->isa<ValueList>()) {
axis_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, AXIS); axis_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, AXIS);
} else if (axis_addr->isa<Int64Imm>()) { } else if (axis_addr->isa<Int64Imm>()) {
axis_.emplace_back(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS)); axis_.emplace_back(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS));
} else { } else {
MS_LOG(EXCEPTION) << "Attribute is invalid"; MS_LOG(EXCEPTION) << "Attribute is invalid";
} }
int dimension = input_shape_.size(); int dimension = input_shape_.size();
std::transform(axis_.begin(), axis_.end(), axis_.begin(), std::transform(axis_.begin(), axis_.end(), axis_.begin(),
[dimension](const auto &a) { return a < 0 ? dimension + a : a; }); [dimension](const auto &a) { return a < 0 ? dimension + a : a; });
sort(axis_.begin(), axis_.end()); sort(axis_.begin(), axis_.end());
// Delete the duplicate axis. // Delete the duplicate axis.
auto last = std::unique(axis_.begin(), axis_.end()); auto last = std::unique(axis_.begin(), axis_.end());
axis_.erase(last, axis_.end()); axis_.erase(last, axis_.end());
auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if constexpr (std::is_same<T, bool>::value) { if constexpr (std::is_same<T, bool>::value) {
if (kernel_name == "ReduceAll") { if (kernel_name == "ReduceAll") {
reduce_type_ = kReduceAll; reduce_type_ = kReduceAll;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out &= input[pos]; }; reduce_func_ = [](const T *input, size_t pos, T *out) { *out &= input[pos]; };
} else if (kernel_name == "ReduceAny") { } else if (kernel_name == "ReduceAny") {
reduce_type_ = kReduceAny; reduce_type_ = kReduceAny;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out |= input[pos]; }; reduce_func_ = [](const T *input, size_t pos, T *out) { *out |= input[pos]; };
} else { } else {
MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << fullname_ << " for bool."; MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << fullname_ << " for bool.";
} }
} else { } else {
if (kernel_name == "ReduceMax") { if (kernel_name == "ReduceMax") {
reduce_type_ = kReduceMax; reduce_type_ = kReduceMax;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::max(input[pos], *out); }; reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::max(input[pos], *out); };
} else if (kernel_name == "ReduceMin") { } else if (kernel_name == "ReduceMin") {
reduce_type_ = kReduceMin; reduce_type_ = kReduceMin;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::min(input[pos], *out); }; reduce_func_ = [](const T *input, size_t pos, T *out) { *out = std::min(input[pos], *out); };
} else if (kernel_name == "ReduceSum") { } else if (kernel_name == "ReduceSum") {
reduce_type_ = kReduceSum; reduce_type_ = kReduceSum;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; }; reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; };
} else if (kernel_name == "ReduceMean") { } else if (kernel_name == "ReduceMean") {
reduce_type_ = kReduceMean; reduce_type_ = kReduceMean;
reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; }; reduce_func_ = [](const T *input, size_t pos, T *out) { *out += input[pos]; };
} else { } else {
MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name; MS_LOG(EXCEPTION) << "Unsupported reduce operation: " << kernel_name;
} }
} }
} }
template <typename T> template <typename T>
bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, bool ReduceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
size_t input_size = inputs[0]->size / sizeof(T); size_t input_size = inputs[0]->size / sizeof(T);
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr); auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
if (axis_.empty() || input_shape_.empty() || input_shape_.size() == 1) { if (axis_.empty() || input_shape_.empty() || input_shape_.size() == 1) {
// Get one ret // Get one ret
*output_addr = input_addr[0]; *output_addr = input_addr[0];
for (size_t i = 1; i < input_size; ++i) { for (size_t i = 1; i < input_size; ++i) {
reduce_func_(input_addr, i, output_addr); reduce_func_(input_addr, i, output_addr);
} }
if (reduce_type_ == kReduceMean) { if (reduce_type_ == kReduceMean) {
*output_addr /= input_size; *output_addr /= input_size;
} }
} else { } else {
// Calculate transpose axes and stride // Calculate transpose axes and stride
int dimension = input_shape_.size(); int dimension = input_shape_.size();
size_t stride = 1; size_t stride = 1;
std::vector<size_t> axes(input_shape_.size()); std::vector<size_t> axes(input_shape_.size());
size_t j = 0; size_t j = 0;
size_t k = 0; size_t k = 0;
for (int i = 0; i < dimension; ++i) { for (int i = 0; i < dimension; ++i) {
if (j == axis_.size() || i != axis_[j]) { if (j == axis_.size() || i != axis_[j]) {
axes[k] = i; axes[k] = i;
++k; ++k;
} else { } else {
stride *= input_shape_[i]; stride *= input_shape_[i];
++j; ++j;
} }
} }
for (auto &it : axis_) { for (auto &it : axis_) {
axes[k] = it; axes[k] = it;
++k; ++k;
} }
// Calculate transpose shape // Calculate transpose shape
std::vector<size_t> transpose_shape(input_shape_.size()); std::vector<size_t> transpose_shape(input_shape_.size());
for (int i = 0; i < dimension; ++i) { for (int i = 0; i < dimension; ++i) {
transpose_shape[i] = input_shape_[axes[i]]; transpose_shape[i] = input_shape_[axes[i]];
} }
size_t output_size = outputs[0]->size / sizeof(T); size_t output_size = outputs[0]->size / sizeof(T);
TransposeIterator base_iter(std::move(transpose_shape), std::move(axes), input_shape_); TransposeIterator base_iter(std::move(transpose_shape), std::move(axes), input_shape_);
auto task = [this, &base_iter, input_addr, output_addr, stride](size_t start, size_t end) { auto task = [this, &base_iter, input_addr, output_addr, stride](size_t start, size_t end) {
auto iter = base_iter; auto iter = base_iter;
iter.SetPos(start * stride); iter.SetPos(start * stride);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
output_addr[i] = input_addr[iter.GetPos()]; output_addr[i] = input_addr[iter.GetPos()];
iter.GenNextPos(); iter.GenNextPos();
for (size_t j = 1; j < stride; ++j) { for (size_t j = 1; j < stride; ++j) {
reduce_func_(input_addr, iter.GetPos(), &output_addr[i]); reduce_func_(input_addr, iter.GetPos(), &output_addr[i]);
iter.GenNextPos(); iter.GenNextPos();
} }
if (reduce_type_ == kReduceMean) { if (reduce_type_ == kReduceMean) {
output_addr[i] /= stride; output_addr[i] /= stride;
} }
} }
}; };
CPUKernelUtils::ParallelFor(task, output_size); CPUKernelUtils::ParallelFor(task, output_size);
} }
return true; return true;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,69 +1,69 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include <functional> #include <functional>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
class ReduceCPUKernel : public CPUKernel { class ReduceCPUKernel : public CPUKernel {
public: public:
ReduceCPUKernel() = default; ReduceCPUKernel() = default;
~ReduceCPUKernel() override = default; ~ReduceCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
enum ReduceType { kReduceAll, kReduceAny, kReduceMax, kReduceMin, kReduceSum, kReduceMean }; enum ReduceType { kReduceAll, kReduceAny, kReduceMax, kReduceMin, kReduceSum, kReduceMean };
std::vector<size_t> input_shape_; std::vector<size_t> input_shape_;
std::vector<int64_t> axis_; std::vector<int64_t> axis_;
ReduceType reduce_type_{kReduceAll}; ReduceType reduce_type_{kReduceAll};
std::function<void(const T *, size_t, T *)> reduce_func_; std::function<void(const T *, size_t, T *)> reduce_func_;
}; };
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, float); MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, double); MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int32_t); MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int64_t); MS_REG_CPU_KERNEL_T(ReduceMean, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, float); MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, double); MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int32_t); MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int64_t); MS_REG_CPU_KERNEL_T(ReduceMax, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, float); MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, double); MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int32_t); MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int64_t); MS_REG_CPU_KERNEL_T(ReduceSum, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, float); MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, float);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, double); MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, double);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int32_t); MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int32_t);
MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int64_t); MS_REG_CPU_KERNEL_T(ReduceMin, KernelAttr(), ReduceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr(), ReduceCPUKernel, bool); MS_REG_CPU_KERNEL_T(ReduceAll, KernelAttr(), ReduceCPUKernel, bool);
MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr(), ReduceCPUKernel, bool); MS_REG_CPU_KERNEL_T(ReduceAny, KernelAttr(), ReduceCPUKernel, bool);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_REDUCE_CPU_KERNEL_H_

View File

@ -1,91 +1,91 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h" #include "backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h"
#include <vector> #include <vector>
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
void SpaceToDepthCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { void SpaceToDepthCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node); CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
block_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size"); block_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size");
} }
template <typename T> template <typename T>
bool SpaceToDepthCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, bool SpaceToDepthCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/, const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr); auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr); auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t size = IntToSize(inputs[0]->size / sizeof(T)); size_t size = IntToSize(inputs[0]->size / sizeof(T));
std::vector<size_t> input_shape = input_shape_; std::vector<size_t> input_shape = input_shape_;
std::vector<size_t> output_shape = output_shape_; std::vector<size_t> output_shape = output_shape_;
size_t block_size = block_size_; size_t block_size = block_size_;
size_t input_dimension = input_shape.size(); size_t input_dimension = input_shape.size();
size_t input_strides[3] = {1, 1, 1}; size_t input_strides[3] = {1, 1, 1};
for (size_t i = input_dimension - 1; i >= 1; --i) { for (size_t i = input_dimension - 1; i >= 1; --i) {
for (size_t j = 0; j < i; ++j) { for (size_t j = 0; j < i; ++j) {
input_strides[j] *= input_shape[i]; input_strides[j] *= input_shape[i];
} }
} }
auto task = [&, input_addr, output_addr](size_t start, size_t end) { auto task = [&, input_addr, output_addr](size_t start, size_t end) {
std::vector<size_t> input_pos_array(input_dimension, 0); std::vector<size_t> input_pos_array(input_dimension, 0);
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
size_t tmp_pos = i; size_t tmp_pos = i;
for (size_t j = 0; j < input_dimension - 1; ++j) { for (size_t j = 0; j < input_dimension - 1; ++j) {
input_pos_array[j] = tmp_pos / input_strides[j]; input_pos_array[j] = tmp_pos / input_strides[j];
tmp_pos %= input_strides[j]; tmp_pos %= input_strides[j];
} }
input_pos_array.back() = tmp_pos; input_pos_array.back() = tmp_pos;
size_t output_pos = input_pos_array[0]; size_t output_pos = input_pos_array[0];
output_pos = output_pos =
(output_pos * output_shape[1]) + (output_pos * output_shape[1]) +
(input_pos_array[1] + (input_pos_array[1] +
(block_size * (input_pos_array[2] % block_size) + input_pos_array[3] % block_size) * input_shape[1]); (block_size * (input_pos_array[2] % block_size) + input_pos_array[3] % block_size) * input_shape[1]);
output_pos = (output_pos * output_shape[2]) + (input_pos_array[2] / block_size); output_pos = (output_pos * output_shape[2]) + (input_pos_array[2] / block_size);
output_pos = (output_pos * output_shape[3]) + (input_pos_array[3] / block_size); output_pos = (output_pos * output_shape[3]) + (input_pos_array[3] / block_size);
output_addr[output_pos] = input_addr[i]; output_addr[output_pos] = input_addr[i];
} }
}; };
CPUKernelUtils::ParallelFor(task, size); CPUKernelUtils::ParallelFor(task, size);
return true; return true;
} }
template <typename T> template <typename T>
void SpaceToDepthCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) { void SpaceToDepthCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) { if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input."; MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input.";
} }
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) { if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output."; MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output.";
} }
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,84 +1,84 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_
#include <string> #include <string>
#include <vector> #include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
class SpaceToDepthCPUKernel : public CPUKernel { class SpaceToDepthCPUKernel : public CPUKernel {
public: public:
SpaceToDepthCPUKernel() = default; SpaceToDepthCPUKernel() = default;
~SpaceToDepthCPUKernel() override = default; ~SpaceToDepthCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
void CheckParam(const CNodePtr &kernel_node); void CheckParam(const CNodePtr &kernel_node);
std::vector<size_t> input_shape_; std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;
size_t block_size_; size_t block_size_;
}; };
MS_REG_CPU_KERNEL_T( MS_REG_CPU_KERNEL_T(
SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpaceToDepthCPUKernel, float); SpaceToDepthCPUKernel, float);
MS_REG_CPU_KERNEL_T( MS_REG_CPU_KERNEL_T(
SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpaceToDepthCPUKernel, float16); SpaceToDepthCPUKernel, float16);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpaceToDepthCPUKernel, int8_t); SpaceToDepthCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
SpaceToDepthCPUKernel, int16_t); SpaceToDepthCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpaceToDepthCPUKernel, int); SpaceToDepthCPUKernel, int);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpaceToDepthCPUKernel, int64_t); SpaceToDepthCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SpaceToDepthCPUKernel, uint8_t); SpaceToDepthCPUKernel, uint8_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
SpaceToDepthCPUKernel, uint16_t); SpaceToDepthCPUKernel, uint16_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
SpaceToDepthCPUKernel, uint32_t); SpaceToDepthCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth, MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
SpaceToDepthCPUKernel, uint64_t); SpaceToDepthCPUKernel, uint64_t);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_

View File

@ -1,87 +1,87 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <string> #include <string>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <map> #include <map>
#include "backend/kernel_compiler/cpu/topk_cpu_kernel.h" #include "backend/kernel_compiler/cpu/topk_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
template <typename T> template <typename T>
void TopKCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { void TopKCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) {
if (inputs.size() != 2 || outputs.size() != 2) { if (inputs.size() != 2 || outputs.size() != 2) {
MS_LOG(EXCEPTION) << "TopK needs 2 inputs and 2 outputs, but get inputs: " << inputs.size() MS_LOG(EXCEPTION) << "TopK needs 2 inputs and 2 outputs, but get inputs: " << inputs.size()
<< "outputs: " << outputs.size(); << "outputs: " << outputs.size();
} }
if (inputs[0]->size != outer_size_ * inner_size_ * sizeof(T)) { if (inputs[0]->size != outer_size_ * inner_size_ * sizeof(T)) {
MS_LOG(EXCEPTION) << "Error input data size!"; MS_LOG(EXCEPTION) << "Error input data size!";
} }
if (inputs[1]->size != sizeof(int)) { if (inputs[1]->size != sizeof(int)) {
MS_LOG(EXCEPTION) << "Input K must be int!"; MS_LOG(EXCEPTION) << "Input K must be int!";
} }
auto input = reinterpret_cast<T *>(inputs[0]->addr); auto input = reinterpret_cast<T *>(inputs[0]->addr);
int k = reinterpret_cast<int *>(inputs[1]->addr)[0]; int k = reinterpret_cast<int *>(inputs[1]->addr)[0];
auto output = reinterpret_cast<T *>(outputs[0]->addr); auto output = reinterpret_cast<T *>(outputs[0]->addr);
auto indices = reinterpret_cast<int *>(outputs[1]->addr); auto indices = reinterpret_cast<int *>(outputs[1]->addr);
if (k < 1) { if (k < 1) {
MS_LOG(EXCEPTION) << "Input k must > 0!"; MS_LOG(EXCEPTION) << "Input k must > 0!";
} }
size_t k_num = IntToSize(std::min<int>(inner_size_, k)); size_t k_num = IntToSize(std::min<int>(inner_size_, k));
if (outputs[0]->size != outer_size_ * k_num * sizeof(T)) { if (outputs[0]->size != outer_size_ * k_num * sizeof(T)) {
MS_LOG(EXCEPTION) << "Error output data size!"; MS_LOG(EXCEPTION) << "Error output data size!";
} }
for (size_t i = 0; i < outer_size_; ++i) { for (size_t i = 0; i < outer_size_; ++i) {
std::vector<size_t> idx(inner_size_); std::vector<size_t> idx(inner_size_);
auto base_input = i * inner_size_; auto base_input = i * inner_size_;
std::iota(idx.begin(), idx.end(), base_input); std::iota(idx.begin(), idx.end(), base_input);
std::stable_sort(idx.begin(), idx.end(), std::stable_sort(idx.begin(), idx.end(),
[&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; }); [&input](size_t index_1, size_t index_2) { return input[index_1] > input[index_2]; });
auto base_output = i * k_num; auto base_output = i * k_num;
if (!sorted_) { if (!sorted_) {
std::stable_sort(idx.begin(), idx.begin() + SizeToLong(k_num)); std::stable_sort(idx.begin(), idx.begin() + SizeToLong(k_num));
} }
for (size_t j = 0; j < k_num; ++j) { for (size_t j = 0; j < k_num; ++j) {
indices[base_output + j] = SizeToInt(idx[j]) - SizeToInt(base_input); indices[base_output + j] = SizeToInt(idx[j]) - SizeToInt(base_input);
output[base_output + j] = input[idx[j]]; output[base_output + j] = input[idx[j]];
} }
} }
} }
void TopKCPUKernel::InitKernel(const CNodePtr &kernel_node) { void TopKCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
auto x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); auto x_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < x_shape_.size() - 1; ++i) { for (size_t i = 0; i < x_shape_.size() - 1; ++i) {
outer_size_ *= x_shape_[i]; outer_size_ *= x_shape_[i];
} }
inner_size_ = x_shape_[x_shape_.size() - 1]; inner_size_ = x_shape_[x_shape_.size() - 1];
sorted_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "sorted"); sorted_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "sorted");
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
} }
bool TopKCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, bool TopKCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat16) { if (dtype_ == kNumberTypeFloat16) {
LaunchKernel<float16>(inputs, outputs); LaunchKernel<float16>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat32) { } else if (dtype_ == kNumberTypeFloat32) {
LaunchKernel<float>(inputs, outputs); LaunchKernel<float>(inputs, outputs);
} }
return true; return true;
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,46 +1,46 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <string> #include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class TopKCPUKernel : public CPUKernel { class TopKCPUKernel : public CPUKernel {
public: public:
TopKCPUKernel() = default; TopKCPUKernel() = default;
~TopKCPUKernel() override = default; ~TopKCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
template <typename T> template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
size_t outer_size_{1}; size_t outer_size_{1};
size_t inner_size_{1}; size_t inner_size_{1};
bool sorted_{false}; bool sorted_{false};
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
}; };
MS_REG_CPU_KERNEL(TopK, KernelAttr(), TopKCPUKernel) MS_REG_CPU_KERNEL(TopK, KernelAttr(), TopKCPUKernel)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TOPK_CPU_KERNEL_H_

View File

@ -1,159 +1,159 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "backend/kernel_compiler/cpu/transpose_cpu_kernel.h" #include "backend/kernel_compiler/cpu/transpose_cpu_kernel.h"
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "runtime/device/cpu/cpu_device_address.h" #include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h" #include "common/thread_pool.h"
#include "nnacl/fp32/transpose_fp32.h" #include "nnacl/fp32/transpose_fp32.h"
#include "nnacl/int8/transpose_int8.h" #include "nnacl/int8/transpose_int8.h"
#include "nnacl/errorcode.h" #include "nnacl/errorcode.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) { void TransposeCPUFwdKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
auto tmp = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm"); auto tmp = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm");
axes_ = {tmp.begin(), tmp.end()}; axes_ = {tmp.begin(), tmp.end()};
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (axes_.size() > MAX_TRANSPOSE_DIM_SIZE) { if (axes_.size() > MAX_TRANSPOSE_DIM_SIZE) {
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got " MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got "
<< axes_.size() << "D."; << axes_.size() << "D.";
} }
for (size_t i = 0; i < axes_.size(); ++i) { for (size_t i = 0; i < axes_.size(); ++i) {
transpose_param_.perm_[i] = SizeToInt(axes_[i]); transpose_param_.perm_[i] = SizeToInt(axes_[i]);
} }
int num_axes = SizeToInt(input_shape_.size()); int num_axes = SizeToInt(input_shape_.size());
transpose_param_.perm_size_ = axes_.size(); transpose_param_.perm_size_ = axes_.size();
transpose_param_.num_axes_ = num_axes; transpose_param_.num_axes_ = num_axes;
transpose_param_.strides_[num_axes - 1] = 1; transpose_param_.strides_[num_axes - 1] = 1;
transpose_param_.out_strides_[num_axes - 1] = 1; transpose_param_.out_strides_[num_axes - 1] = 1;
for (int i = num_axes - 2; i >= 0; i--) { for (int i = num_axes - 2; i >= 0; i--) {
transpose_param_.strides_[i] = input_shape_[i + 1] * transpose_param_.strides_[i + 1]; transpose_param_.strides_[i] = input_shape_[i + 1] * transpose_param_.strides_[i + 1];
transpose_param_.out_strides_[i] = output_shape_[i + 1] * transpose_param_.out_strides_[i + 1]; transpose_param_.out_strides_[i] = output_shape_[i + 1] * transpose_param_.out_strides_[i + 1];
} }
launch_map_[kNumberTypeInt8] = &TransposeCPUFwdKernel::LaunchKernel<int8_t>; launch_map_[kNumberTypeInt8] = &TransposeCPUFwdKernel::LaunchKernel<int8_t>;
launch_map_[kNumberTypeInt16] = &TransposeCPUFwdKernel::LaunchKernel<int16_t>; launch_map_[kNumberTypeInt16] = &TransposeCPUFwdKernel::LaunchKernel<int16_t>;
launch_map_[kNumberTypeInt32] = &TransposeCPUFwdKernel::LaunchKernel<int>; launch_map_[kNumberTypeInt32] = &TransposeCPUFwdKernel::LaunchKernel<int>;
launch_map_[kNumberTypeInt64] = &TransposeCPUFwdKernel::LaunchKernel<int64_t>; launch_map_[kNumberTypeInt64] = &TransposeCPUFwdKernel::LaunchKernel<int64_t>;
launch_map_[kNumberTypeUInt8] = &TransposeCPUFwdKernel::LaunchKernel<uint8_t>; launch_map_[kNumberTypeUInt8] = &TransposeCPUFwdKernel::LaunchKernel<uint8_t>;
launch_map_[kNumberTypeUInt16] = &TransposeCPUFwdKernel::LaunchKernel<uint16_t>; launch_map_[kNumberTypeUInt16] = &TransposeCPUFwdKernel::LaunchKernel<uint16_t>;
launch_map_[kNumberTypeUInt32] = &TransposeCPUFwdKernel::LaunchKernel<uint32_t>; launch_map_[kNumberTypeUInt32] = &TransposeCPUFwdKernel::LaunchKernel<uint32_t>;
launch_map_[kNumberTypeUInt64] = &TransposeCPUFwdKernel::LaunchKernel<uint64_t>; launch_map_[kNumberTypeUInt64] = &TransposeCPUFwdKernel::LaunchKernel<uint64_t>;
launch_map_[kNumberTypeFloat32] = &TransposeCPUFwdKernel::LaunchKernel<float>; launch_map_[kNumberTypeFloat32] = &TransposeCPUFwdKernel::LaunchKernel<float>;
launch_map_[kNumberTypeBool] = &TransposeCPUFwdKernel::LaunchKernel<bool>; launch_map_[kNumberTypeBool] = &TransposeCPUFwdKernel::LaunchKernel<bool>;
auto iter = launch_map_.find(dtype_); auto iter = launch_map_.find(dtype_);
if (iter != launch_map_.end()) { if (iter != launch_map_.end()) {
launch_func_ = iter->second; launch_func_ = iter->second;
} else { } else {
MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for Transpose kernel on CPU."; MS_LOG(EXCEPTION) << "Input data type: " << dtype_ << "is not supported for Transpose kernel on CPU.";
} }
} }
bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool TransposeCPUFwdKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
launch_func_(this, inputs, outputs); launch_func_(this, inputs, outputs);
return true; return true;
} }
template <typename T> template <typename T>
void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, void TransposeCPUFwdKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr); const auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr); auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
transpose_param_.data_num_ = inputs[0]->size / sizeof(T); transpose_param_.data_num_ = inputs[0]->size / sizeof(T);
int output_shape[SizeToInt(output_shape_.size())]; int output_shape[SizeToInt(output_shape_.size())];
for (size_t i = 0; i < output_shape_.size(); ++i) { for (size_t i = 0; i < output_shape_.size(); ++i) {
output_shape[i] = SizeToInt(output_shape_[i]); output_shape[i] = SizeToInt(output_shape_[i]);
} }
size_t data_count = (inputs[0]->size) / sizeof(T); size_t data_count = (inputs[0]->size) / sizeof(T);
if (axes_.size() <= DIMENSION_6D && data_count < MAX_TRANSPOSE_SERIAL_SIZE) { if (axes_.size() <= DIMENSION_6D && data_count < MAX_TRANSPOSE_SERIAL_SIZE) {
int res = NNACL_ERR; int res = NNACL_ERR;
if constexpr (std::is_same_v<T, int8_t>) { if constexpr (std::is_same_v<T, int8_t>) {
res = DoTransposeInt8(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeInt8(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, int16_t>) { } else if constexpr (std::is_same_v<T, int16_t>) {
res = DoTransposeInt16(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeInt16(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, int32_t>) { } else if constexpr (std::is_same_v<T, int32_t>) {
res = DoTransposeInt32(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeInt32(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, int64_t>) { } else if constexpr (std::is_same_v<T, int64_t>) {
res = DoTransposeInt64(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeInt64(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, uint8_t>) { } else if constexpr (std::is_same_v<T, uint8_t>) {
res = DoTransposeUInt8(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeUInt8(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, uint16_t>) { } else if constexpr (std::is_same_v<T, uint16_t>) {
res = DoTransposeUInt16(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeUInt16(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, uint32_t>) { } else if constexpr (std::is_same_v<T, uint32_t>) {
res = DoTransposeUInt32(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeUInt32(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, uint64_t>) { } else if constexpr (std::is_same_v<T, uint64_t>) {
res = DoTransposeUInt64(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeUInt64(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
res = DoTransposeFp32(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeFp32(input_addr, output_addr, output_shape, &transpose_param_);
} else if constexpr (std::is_same_v<T, bool>) { } else if constexpr (std::is_same_v<T, bool>) {
res = DoTransposeBool(input_addr, output_addr, output_shape, &transpose_param_); res = DoTransposeBool(input_addr, output_addr, output_shape, &transpose_param_);
} }
if (res != NNACL_OK) { if (res != NNACL_OK) {
MS_LOG(ERROR) << "Transpose run failed"; MS_LOG(ERROR) << "Transpose run failed";
} }
} else { } else {
ParallelRun(input_addr, output_addr, output_shape, data_count); ParallelRun(input_addr, output_addr, output_shape, data_count);
} }
} }
template <typename T> template <typename T>
void TransposeCPUFwdKernel::ParallelRun(const T *input_addr, T *output_addr, const int *output_shape, size_t count) { void TransposeCPUFwdKernel::ParallelRun(const T *input_addr, T *output_addr, const int *output_shape, size_t count) {
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
const float block_size = 128.0; const float block_size = 128.0;
size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num; size_t thread_num = count < block_size * max_thread_num ? std::ceil(count / block_size) : max_thread_num;
std::vector<common::Task> tasks; std::vector<common::Task> tasks;
std::function<void(const T *, T *, const int *, TransposeParameter *, int, int)> TransposeDims; std::function<void(const T *, T *, const int *, TransposeParameter *, int, int)> TransposeDims;
if constexpr (std::is_same_v<T, int8_t>) { if constexpr (std::is_same_v<T, int8_t>) {
TransposeDims = &TransposeDimsInt8; TransposeDims = &TransposeDimsInt8;
} else if constexpr (std::is_same_v<T, int16_t>) { } else if constexpr (std::is_same_v<T, int16_t>) {
TransposeDims = &TransposeDimsInt16; TransposeDims = &TransposeDimsInt16;
} else if constexpr (std::is_same_v<T, int32_t>) { } else if constexpr (std::is_same_v<T, int32_t>) {
TransposeDims = &TransposeDimsInt32; TransposeDims = &TransposeDimsInt32;
} else if constexpr (std::is_same_v<T, int64_t>) { } else if constexpr (std::is_same_v<T, int64_t>) {
TransposeDims = &TransposeDimsInt64; TransposeDims = &TransposeDimsInt64;
} else if constexpr (std::is_same_v<T, uint8_t>) { } else if constexpr (std::is_same_v<T, uint8_t>) {
TransposeDims = &TransposeDimsUInt8; TransposeDims = &TransposeDimsUInt8;
} else if constexpr (std::is_same_v<T, uint16_t>) { } else if constexpr (std::is_same_v<T, uint16_t>) {
TransposeDims = &TransposeDimsUInt16; TransposeDims = &TransposeDimsUInt16;
} else if constexpr (std::is_same_v<T, uint32_t>) { } else if constexpr (std::is_same_v<T, uint32_t>) {
TransposeDims = &TransposeDimsUInt32; TransposeDims = &TransposeDimsUInt32;
} else if constexpr (std::is_same_v<T, uint64_t>) { } else if constexpr (std::is_same_v<T, uint64_t>) {
TransposeDims = &TransposeDimsUInt64; TransposeDims = &TransposeDimsUInt64;
} else if constexpr (std::is_same_v<T, float>) { } else if constexpr (std::is_same_v<T, float>) {
TransposeDims = &TransposeDimsFp32; TransposeDims = &TransposeDimsFp32;
} else if constexpr (std::is_same_v<T, bool>) { } else if constexpr (std::is_same_v<T, bool>) {
TransposeDims = &TransposeDimsBool; TransposeDims = &TransposeDimsBool;
} }
for (int task_id = 0; task_id < SizeToInt(thread_num); ++task_id) { for (int task_id = 0; task_id < SizeToInt(thread_num); ++task_id) {
auto task = [&, task_id, thread_num]() { auto task = [&, task_id, thread_num]() {
TransposeDims(input_addr, output_addr, output_shape, &transpose_param_, task_id, SizeToInt(thread_num)); TransposeDims(input_addr, output_addr, output_shape, &transpose_param_, task_id, SizeToInt(thread_num));
return common::SUCCESS; return common::SUCCESS;
}; };
tasks.emplace_back(task); tasks.emplace_back(task);
} }
common::ThreadPool::GetInstance().SyncRun(tasks); common::ThreadPool::GetInstance().SyncRun(tasks);
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -1,58 +1,58 @@
/** /**
* Copyright 2021 Huawei Technologies Co., Ltd * Copyright 2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
* You may obtain a copy of the License at * You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <memory> #include <memory>
#include <string> #include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "nnacl/base/transpose_base.h" #include "nnacl/base/transpose_base.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
class TransposeCPUFwdKernel : public CPUKernel { class TransposeCPUFwdKernel : public CPUKernel {
public: public:
TransposeCPUFwdKernel() = default; TransposeCPUFwdKernel() = default;
~TransposeCPUFwdKernel() override = default; ~TransposeCPUFwdKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override; void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override; const std::vector<AddressPtr> &outputs) override;
private: private:
template <typename T> template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T> template <typename T>
void ParallelRun(const T *input_addr, T *output_addr, const int *output_shape, size_t count); void ParallelRun(const T *input_addr, T *output_addr, const int *output_shape, size_t count);
TransposeParameter transpose_param_; TransposeParameter transpose_param_;
std::vector<size_t> input_shape_; std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_; std::vector<size_t> output_shape_;
std::vector<size_t> axes_; std::vector<size_t> axes_;
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
using TypeKernel = using TypeKernel =
std::function<void(TransposeCPUFwdKernel *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>; std::function<void(TransposeCPUFwdKernel *, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &)>;
std::unordered_map<TypeId, TypeKernel> launch_map_; std::unordered_map<TypeId, TypeKernel> launch_map_;
TypeKernel launch_func_; TypeKernel launch_func_;
}; };
MS_REG_CPU_KERNEL(Transpose, KernelAttr(), TransposeCPUFwdKernel); MS_REG_CPU_KERNEL(Transpose, KernelAttr(), TransposeCPUFwdKernel);
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_ #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TRANSPOSE_CPU_KERNEL_H_