add AddN,AssignAdd,Conv2d,BN,Relu6,TensorAdd CPU operator

This commit is contained in:
zhaoting 2020-08-20 11:52:26 +08:00
parent c7b7af6c3a
commit cbca6be8bf
20 changed files with 792 additions and 129 deletions

View File

@ -1,65 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/addn_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node);
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
}
bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
size_t offset = 0;
for (size_t i = 0; i < output_shape_[0]; ++i) {
for (size_t j = 0; j < output_shape_[1]; ++j) {
for (size_t k = 0; k < output_shape_[2]; ++k) {
for (size_t m = 0; m < output_shape_[3]; ++m) {
float sum = 0;
for (size_t index = 0; index < input_num_; ++index) {
auto input_addr = reinterpret_cast<float *>(inputs[index]->addr);
sum += input_addr[offset];
}
output_addr[offset++] = sum;
}
}
}
}
return true;
}
void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() > 4) {
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but AddNCPUKernel olny support 4d or lower.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -50,6 +50,7 @@ const char BEGIN[] = "begin";
const char END[] = "end";
const char SIZE[] = "size";
const char USE_NESTEROV[] = "use_nesterov";
const char GROUP[] = "group";
class CPUKernel : public kernel::KernelMod {
public:

View File

@ -0,0 +1,76 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/addn_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
void AddNCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
CheckParam(kernel_node);
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
dnnl::memory::desc src0_mem_desc = GetDefaultMemDesc(src0_shape);
dnnl::memory::desc src1_mem_desc = GetDefaultMemDesc(src1_shape);
dnnl::memory::desc dst_mem_desc = GetDefaultMemDesc(dst_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_mem_desc, src1_mem_desc, dst_mem_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_mem_desc);
AddArgument(DNNL_ARG_SRC_1, src1_mem_desc);
AddArgument(DNNL_ARG_DST, dst_mem_desc);
}
bool AddNCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
for (size_t index = 2; index < input_num_; ++index) {
SetArgumentHandle(DNNL_ARG_SRC_0, outputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[index]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
}
return true;
}
void AddNCPUKernel::CheckParam(const CNodePtr &kernel_node) {
auto src0_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (src0_shape != dst_shape) {
MS_LOG(EXCEPTION) << "AddN output shape must be equal to input shape.";
}
for (size_t index = 1; index < input_num_; ++index) {
auto src_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, index);
if (src0_shape != src_shape) {
MS_LOG(EXCEPTION) << "AddN input shapes must be equal.";
}
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but AddNCPUKernel needs 1 output.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -1,48 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class AddNCPUKernel : public CPUKernel {
public:
AddNCPUKernel() : input_num_(0) {}
~AddNCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const CNodePtr &kernel_node);
size_t input_num_;
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AddNCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class AddNCPUKernel : public MKLCPUKernel {
public:
AddNCPUKernel() : input_num_(0) {}
~AddNCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const CNodePtr &kernel_node);
size_t input_num_;
std::vector<size_t> output_shape_;
};
MS_REG_CPU_KERNEL(AddN,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AddNCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ADDN_CPU_KERNEL_H_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/assignadd_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
void AssignAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) {
MS_LOG(EXCEPTION) << "AssignAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs "
<< src1_shape.size();
}
if (src1_shape.size() < src0_shape.size()) {
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
src1_shape.emplace_back(1);
}
}
dnnl::memory::desc src0_desc = GetDefaultMemDesc(src0_shape);
dnnl::memory::desc src1_desc = GetDefaultMemDesc(src1_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_desc, src1_desc, src0_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_desc);
AddArgument(DNNL_ARG_SRC_1, src1_desc);
AddArgument(DNNL_ARG_DST, src0_desc);
}
bool AssignAddCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 2) {
MS_LOG(EXCEPTION) << "AssignAdd error input output size!";
}
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
memcpy_s(inputs[0]->addr, inputs[0]->size, outputs[0]->addr, outputs[0]->size);
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ASSIGNADD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ASSIGNADD_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class AssignAddCPUKernel : public MKLCPUKernel {
public:
AssignAddCPUKernel() = default;
~AssignAddCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(
AssignAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AssignAddCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_ASSIGNADD_CPU_KERNEL_H_

View File

@ -29,6 +29,15 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (src_shape.size() != 4 || weight_shape.size() != 4) {
MS_LOG(EXCEPTION) << "conv2d only support nchw input!";
}
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
size_t group = IntToSize(AnfAlgo::GetNodeAttr<int>(kernel_node, GROUP));
if (group != 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "conv2d channels should be divided by group!";
}
weight_shape.insert(weight_shape.begin(), group);
weight_shape[1] = weight_shape[1] / group;
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape);
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
@ -48,14 +57,11 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) {
}
int stride = stride_ori[2];
int dilation = dilation_ori[2];
dnnl::memory::dims strides{stride, stride};
dnnl::memory::dims dilates{dilation - 1, dilation - 1};
std::vector<int> int_padding_l;
std::vector<int> int_padding_r;
const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r);
if (int_padding_l.size() != 2 || int_padding_r.size() != 2) {
MS_LOG(EXCEPTION) << "get padding failed";
@ -68,7 +74,6 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) {
auto prim_desc = dnnl::convolution_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::convolution_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, src_desc);
AddArgument(DNNL_ARG_WEIGHTS, weights_desc);
AddArgument(DNNL_ARG_DST, dst_desc);

View File

@ -29,6 +29,15 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (src_shape.size() != 4 || weight_shape.size() != 4) {
MS_LOG(EXCEPTION) << ("conv2d grad filter only support nchw input!");
}
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
size_t group = IntToSize(AnfAlgo::GetNodeAttr<int>(kernel_node, GROUP));
if (group != 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "conv2d channels should be divided by group!";
}
weight_shape.insert(weight_shape.begin(), group);
weight_shape[1] = weight_shape[1] / group;
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape);
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
@ -51,7 +60,6 @@ void Conv2dGradFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) {
const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
std::vector<int> int_padding_l;
std::vector<int> int_padding_r;
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r);
if (int_padding_l.size() != 2 || int_padding_r.size() != 2) {
MS_LOG(EXCEPTION) << "get padding failed";

View File

@ -29,6 +29,15 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (src_shape.size() != 4 || weight_shape.size() != 4) {
MS_LOG(EXCEPTION) << "conv2d grad filter only support nchw input!";
}
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
size_t group = IntToSize(AnfAlgo::GetNodeAttr<int>(kernel_node, GROUP));
if (group != 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "conv2d channels should be divided by group!";
}
weight_shape.insert(weight_shape.begin(), group);
weight_shape[1] = weight_shape[1] / group;
}
dnnl::memory::desc src_desc = GetDefaultMemDesc(src_shape);
dnnl::memory::desc weights_desc = GetDefaultMemDesc(weight_shape);
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
@ -51,7 +60,6 @@ void Conv2dGradInputCPUKernel::InitKernel(const CNodePtr &kernel_node) {
std::vector<int> int_padding_l;
std::vector<int> int_padding_r;
const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r);
if (int_padding_l.size() != 2 || int_padding_r.size() != 2) {
MS_LOG(EXCEPTION) << "conv2d grad get padding failed";

View File

@ -0,0 +1,97 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "backend/kernel_compiler/cpu/mkldnn/fused_batch_norm_cpu_kernel.h"
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void FusedBatchNormCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
CPUKernel::InitInputOutputSize(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_node);
size_t type_size = sizeof(float);
std::vector<size_t> shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
size_t tensor_size = shape[1] * 2 * type_size; // [2, c] to store scale and bias
workspace_size_list_.emplace_back(tensor_size);
}
void FusedBatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
if (node_name == "FusedBatchNorm") {
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
is_train = true;
}
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "fused batchnorm only support nchw input!";
}
batch_size = x_shape[0];
channel = x_shape[1];
hw_size = x_shape[2] * x_shape[3];
nhw_size = x_shape[0] * hw_size;
dnnl::memory::desc x_desc = GetDefaultMemDesc(x_shape);
dnnl::memory::desc scale_bias_desc = GetDefaultMemDesc({2, channel});
auto epsilon = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
auto prop_kind = dnnl::prop_kind::forward_inference;
if (is_train) {
prop_kind = dnnl::prop_kind::forward_training;
}
dnnl::batch_normalization_forward::desc desc =
dnnl::batch_normalization_forward::desc(prop_kind, x_desc, epsilon, dnnl::normalization_flags::use_scale_shift);
auto prim_desc = dnnl::batch_normalization_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::batch_normalization_forward>(prim_desc);
AddArgument(DNNL_ARG_SRC, x_desc);
AddArgument(DNNL_ARG_MEAN, prim_desc.mean_desc());
AddArgument(DNNL_ARG_VARIANCE, prim_desc.variance_desc());
AddArgument(DNNL_ARG_SCALE_SHIFT, scale_bias_desc);
AddArgument(DNNL_ARG_WORKSPACE, prim_desc.workspace_desc());
AddArgument(DNNL_ARG_DST, x_desc);
}
bool FusedBatchNormCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 5 || outputs.empty()) {
MS_LOG(EXCEPTION) << "error input output size!";
}
auto wksp = reinterpret_cast<float *>(workspace[0]->addr);
memcpy_s(wksp, workspace[0]->size, inputs[1]->addr, inputs[1]->size);
memcpy_s(wksp + (inputs[1]->size / sizeof(float)), inputs[2]->size, inputs[2]->addr, inputs[2]->size);
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_MEAN, outputs[3]->addr);
SetArgumentHandle(DNNL_ARG_VARIANCE, outputs[4]->addr);
SetArgumentHandle(DNNL_ARG_SCALE_SHIFT, workspace[0]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
if (is_train) {
auto moving_mean = reinterpret_cast<float *>(inputs[3]->addr);
auto moving_variance = reinterpret_cast<float *>(inputs[4]->addr);
auto mean = reinterpret_cast<float *>(outputs[3]->addr);
auto variance = reinterpret_cast<float *>(outputs[4]->addr);
for (size_t i = 0; i < inputs[3]->size / sizeof(float); ++i) {
moving_mean[i] = moving_mean[i] * (1 - momentum) + mean[i] * momentum;
moving_variance[i] = moving_variance[i] * (1 - momentum) + variance[i] * momentum;
}
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,77 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_FUSED_BATCH_NORM_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class FusedBatchNormCPUKernel : public MKLCPUKernel {
public:
FusedBatchNormCPUKernel() = default;
~FusedBatchNormCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
protected:
void InitInputOutputSize(const CNodePtr &kernel_node) override;
private:
bool is_train{false};
float momentum{0.9};
size_t batch_size{0};
size_t channel{0};
size_t hw_size{0};
size_t nhw_size{0};
};
MS_REG_CPU_KERNEL(FusedBatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormCPUKernel)
MS_REG_CPU_KERNEL(BatchNorm,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
FusedBatchNormCPUKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_CPU_KERNEL_H_

View File

@ -16,6 +16,7 @@
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
#include <vector>
#include <string>
#include <algorithm>
#include "utils/ms_utils.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
@ -32,7 +33,7 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
weight_height.emplace_back(src_shape[src_shape.size() - 2]);
weight_height.emplace_back(src_shape[src_shape.size() - 1]);
MS_LOG(INFO) << "pad mode " << pad_mode;
MS_LOG(INFO) << "pad mode: " << pad_mode;
if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) {
for (size_t i = 0; i < weight_height.size(); ++i) {
auto wh = weight_height[i];
@ -51,21 +52,20 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
padding_r->emplace_back(0);
padding_r->emplace_back(0);
} else {
std::vector<int> pad = AnfAlgo::GetNodeAttr<std::vector<int>>(kernel_node, PAD);
if (pad.size() != 4) {
MS_LOG(EXCEPTION) << "wrong pad size in max pooling " << pad.size();
}
padding_l->emplace_back(pad[0]);
padding_l->emplace_back(pad[1]);
padding_r->emplace_back(pad[2]);
padding_r->emplace_back(pad[3]);
int pad = AnfAlgo::GetNodeAttr<int>(kernel_node, PAD);
padding_l->emplace_back(pad);
padding_l->emplace_back(pad);
padding_r->emplace_back(pad);
padding_r->emplace_back(pad);
}
}
dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const {
dnnl::memory::format_tag mem_tag;
auto dim_size = dims.size();
if (dim_size == 4) {
if (dim_size == 5) {
mem_tag = dnnl::memory::format_tag::abcde;
} else if (dim_size == 4) {
mem_tag = dnnl::memory::format_tag::abcd;
} else if (dim_size == 3) {
mem_tag = dnnl::memory::format_tag::abc;

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "backend/kernel_compiler/cpu/mkldnn/relu_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
@ -30,6 +31,12 @@ void ReluCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::eltwise_forward::desc desc =
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_relu, src_desc, 0.0);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == "ReLU6") {
desc =
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::eltwise_clip, src_desc, 0.0, 6.0);
}
auto prim_desc = dnnl::eltwise_forward::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::eltwise_forward>(prim_desc);

View File

@ -34,6 +34,8 @@ class ReluCPUKernel : public MKLCPUKernel {
};
MS_REG_CPU_KERNEL(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ReluCPUKernel);
MS_REG_CPU_KERNEL(ReLU6, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ReluCPUKernel);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/tensoradd_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
namespace mindspore {
namespace kernel {
void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (src0_shape.size() != src1_shape.size() && src1_shape.size() > 1) {
MS_LOG(EXCEPTION) << "TensorAdd only support same dim input or tensor * scalar " << src0_shape.size() << " vs "
<< src1_shape.size();
}
if (src1_shape.size() < src0_shape.size()) {
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
src1_shape.emplace_back(1);
}
}
dnnl::memory::desc src0_desc = GetDefaultMemDesc(src0_shape);
dnnl::memory::desc src1_desc = GetDefaultMemDesc(src1_shape);
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
dnnl::binary::desc desc = dnnl::binary::desc(dnnl::algorithm::binary_add, src0_desc, src1_desc, dst_desc);
auto prim_desc = dnnl::binary::primitive_desc(desc, MKLKernelEngine::Get().engine());
primitive_ = std::make_shared<dnnl::binary>(prim_desc);
AddArgument(DNNL_ARG_SRC_0, src0_desc);
AddArgument(DNNL_ARG_SRC_1, src1_desc);
AddArgument(DNNL_ARG_DST, dst_desc);
}
bool TensorAddCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 2 || outputs.empty()) {
MS_LOG(EXCEPTION) << "TensorAdd error input output size!";
}
SetArgumentHandle(DNNL_ARG_SRC_0, inputs[0]->addr);
SetArgumentHandle(DNNL_ARG_SRC_1, inputs[1]->addr);
SetArgumentHandle(DNNL_ARG_DST, outputs[0]->addr);
ExecutePrimitive();
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSORADD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSORADD_CPU_KERNEL_H_
#include <vector>
#include <memory>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
namespace mindspore {
namespace kernel {
class TensorAddCPUKernel : public MKLCPUKernel {
public:
TensorAddCPUKernel() = default;
~TensorAddCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
};
MS_REG_CPU_KERNEL(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorAddCPUKernel);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TENSORADD_CPU_KERNEL_H_

View File

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
class AssignAdd(nn.Cell):
def __init__(self, value):
super(AssignAdd, self).__init__()
self.var = Parameter(value, name="var")
self.add = P.AssignAdd()
def construct(self, y):
res = self.add(self.var, y)
return res
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_assign_add():
expect1 = np.array([[[[0, 2, 4.],
[6, 8, 10.],
[12, 14, 16.]],
[[18, 20, 22.],
[24, 26, 28.],
[30, 32, 34.]],
[[36, 38, 40.],
[42, 44, 46.],
[48, 50, 52.]]]])
expect2 = np.array([[[[0, 3, 6],
[9, 12, 15],
[18, 21, 24]],
[[27, 30, 33],
[36, 39, 42],
[45, 48, 51]],
[[54, 57, 60],
[63, 66, 69],
[72, 75, 78]]]])
x2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
add = AssignAdd(x2)
output1 = add(y2)
assert (output1.asnumpy() == expect1).all()
add = AssignAdd(output1)
output2 = add(y2)
assert (output2.asnumpy() == expect2).all()

View File

@ -0,0 +1,82 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore.common.tensor import Tensor
from mindspore.nn import BatchNorm2d
from mindspore.nn import Cell
from mindspore.ops import composite as C
class Batchnorm_Net(Cell):
def __init__(self, c, weight, bias, moving_mean, moving_var_init):
super(Batchnorm_Net, self).__init__()
self.bn = BatchNorm2d(c, eps=0.00001, momentum=0.1, beta_init=bias, gamma_init=weight,
moving_mean_init=moving_mean, moving_var_init=moving_var_init)
def construct(self, input_data):
x = self.bn(input_data)
return x
class Grad(Cell):
def __init__(self, network):
super(Grad, self).__init__()
self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True)
self.network = network
def construct(self, input_data, sens):
gout = self.grad(self.network)(input_data, sens)
return gout
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_train_forward():
x = np.array([[
[[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]],
[[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32)
expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294],
[-0.1471, 0.7706, 1.6882, 2.6059],
[0.3118, 1.6882, 2.1471, 2.1471],
[0.7706, 0.3118, 2.6059, -0.1471]],
[[0.9119, 1.8518, 1.3819, -0.0281],
[-0.0281, 0.9119, 1.3819, 1.8518],
[2.7918, 0.4419, -0.4981, 0.9119],
[1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32)
weight = np.ones(2).astype(np.float32)
bias = np.ones(2).astype(np.float32)
moving_mean = np.ones(2).astype(np.float32)
moving_var_init = np.ones(2).astype(np.float32)
error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
bn_net = Batchnorm_Net(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var_init))
bn_net.set_train()
output = bn_net(Tensor(x))
diff = output.asnumpy() - expect_output
assert np.all(diff < error)
assert np.all(-diff < error)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
bn_net = Batchnorm_Net(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var_init))
bn_net.set_train(False)
output = bn_net(Tensor(x))

View File

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class NetReLU6(nn.Cell):
def __init__(self):
super(NetReLU6, self).__init__()
self.relu6 = P.ReLU6()
def construct(self, x):
return self.relu6(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_relu6():
x = Tensor(np.array([[[[-1, 1, 10],
[5.9, 6.1, 6],
[10, 1, -1]]]]).astype(np.float32))
expect = np.array([[[[0, 1, 6,],
[5.9, 6, 6,],
[6, 1, 0.]]]]).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
relu6 = NetReLU6()
output = relu6(x)
assert (output.asnumpy() == expect).all()

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
import mindspore.context as context
class TensorAdd(nn.Cell):
def __init__(self):
super(TensorAdd, self).__init__()
self.add = P.TensorAdd()
def construct(self, x, y):
res = self.add(x, y)
return res
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_tensor_add():
x = np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)
y = np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
add = TensorAdd()
output = add(Tensor(x), Tensor(y))
assert (output.asnumpy() == x + y).all()