Dynamic shape for bias_add

This commit is contained in:
liangzhibo 2022-09-08 19:02:02 +08:00
parent 3bea901fff
commit 2ce7a2a845
9 changed files with 244 additions and 205 deletions

View File

@ -1,19 +1,18 @@
mindspore.ops.bias_add
======================
mindspore.ops.func_bias_add
===========================
.. py:function:: mindspore.ops.bias_add(input_x, bias)
.. py:class:: mindspore.ops.bias_add(input_x, bias)
返回输入Tensor与偏置Tensor之和。相加前会把偏置Tensor广播成与输入Tensor的shape一致。
参数
- **input_x** (Tensor) - 输入Tensor。shape可以有2~5个维度。数据类型应为float16或float32。
- **bias** (Tensor) - 偏置Tensorshape为 :math:`(C)`。C必须与 `input_x` 的通道维度C相同数据类型应为float16或float32
输入
- **input_x** (Tensor) -输入Tensor。shape可以有2~5个维度。
- **bias** (Tensor) - 偏置Tensorshape为 :math:`(C)`。C必须与 `input_x` 的通道维度C相同。
返回
输出
Tensorshape和数据类型与 `input_x` 相同。
异常:
- **TypeError** - `input_x``bias` 不是Tensor。
- **TypeError** - `input_x``bias` 的数据类型既不是float16也不是float32。
- **TypeError** - `input_x``bias` 的数据类型不一致。
- **TypeError** - `input_x``bias` 不是Tensor。
- **TypeError** - `input_x``bias` 的数据类型不一致。
- **TypeError** - `input_x` 的维度不在[2, 5]范围内。

View File

@ -15,8 +15,9 @@
*/
#include "plugin/device/cpu/kernel/bias_add_cpu_kernel.h"
#include "plugin/device/cpu/kernel/nnacl/fp32/add_fp32.h"
#include "plugin/device/cpu/kernel/nnacl/errorcode.h"
#include "ops/bias_add.h"
#include <map>
#include <complex>
namespace mindspore {
namespace kernel {
@ -27,14 +28,25 @@ constexpr size_t kBiasAddInputsNum = 2;
constexpr size_t kBiasAddOutputsNum = 1;
} // namespace
void BiasAddCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
if (AnfAlgo::IsShapesDynamic({input_shape_, bias_shape_})) {
return;
bool BiasAddCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
MS_EXCEPTION_IF_NULL(base_operator);
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
kernel_name_ = base_operator->name();
return true;
}
int BiasAddCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
int ret = KernelMod::Resize(base_operator, inputs, outputs);
if (ret != KRET_OK) {
return ret;
}
input_shape_ = Convert2SizeTClipNeg(inputs[kIndex0]->GetShapeVector());
bias_shape_ = Convert2SizeTClipNeg(inputs[kIndex1]->GetShapeVector());
data_shape_ = input_shape_.size();
if (input_shape_.size() < kBiasAddMinDim || input_shape_.size() > kBiasAddMaxDim) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
@ -51,32 +63,40 @@ void BiasAddCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
"the second dimension length of 'input_x', the first dimension length of 'bias': "
<< bias_shape_[0] << ", and the second dimension length of 'input_x': " << input_shape_[1];
}
return ret;
}
bool BiasAddCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
bool BiasAddCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kBiasAddInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kBiasAddOutputsNum, kernel_name_);
const auto *src_addr = reinterpret_cast<float *>(inputs[0]->addr);
const auto *bias_addr = reinterpret_cast<float *>(inputs[1]->addr);
auto *output_addr = reinterpret_cast<float *>(outputs[0]->addr);
kernel_func_(this, inputs, workspace, outputs);
return true;
}
if (input_shape_.size() > 2) {
int64_t hw_size = 1;
template <typename T>
bool BiasAddCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs) {
const auto *src_addr = reinterpret_cast<T *>(inputs[kIndex0]->addr);
const auto *bias_addr = reinterpret_cast<T *>(inputs[kIndex1]->addr);
auto *output_addr = reinterpret_cast<T *>(outputs[kIndex0]->addr);
if (input_shape_.size() > kBiasAddMinDim) {
size_t hw_size = 1;
for (size_t i = 2; i < input_shape_.size(); ++i) {
hw_size *= input_shape_[i];
}
int64_t c_size = input_shape_[1];
for (int64_t n = 0; n < input_shape_[0]; ++n) {
for (int64_t c = 0; c < c_size; ++c) {
size_t c_size = input_shape_[kIndex1];
for (size_t n = 0; n < input_shape_[kIndex0]; ++n) {
for (size_t c = 0; c < c_size; ++c) {
size_t offset = LongToSize(n * c_size * hw_size + c * hw_size);
size_t hw = 0;
#ifdef ENABLE_AVX
constexpr size_t C8NUM = 8;
size_t hw8 = hw_size / C8NUM * C8NUM;
const float *in_ptr = src_addr + offset;
float *out_ptr = output_addr + offset;
const T *in_ptr = src_addr + offset;
T *out_ptr = output_addr + offset;
for (; hw < hw8; hw += C8NUM) {
__m256 src_r1 = _mm256_loadu_ps(in_ptr);
__m256 bias_r2 = _mm256_set1_ps(bias_addr[c]);
@ -95,18 +115,44 @@ bool BiasAddCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const st
} else {
auto task = [&](size_t start, size_t end) {
for (size_t n = start; n < end; ++n) {
size_t n_offset = LongToSize(input_shape_[1] * n);
if (ElementAdd(src_addr + n_offset, bias_addr, output_addr + n_offset, input_shape_[1]) != NNACL_OK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', ElementAdd failed.";
size_t n_offset = LongToSize(input_shape_[kIndex1] * n);
const T *inner_src = src_addr + n_offset;
T *inner_dst = output_addr + n_offset;
for (size_t index = 0; index < input_shape_[kIndex1]; ++index) {
inner_dst[index] = inner_src[index] + bias_addr[index];
}
}
};
ParallelLaunchAutoSearch(task, LongToSize(input_shape_[0]), this, &parallel_search_info_);
ParallelLaunchAutoSearch(task, LongToSize(input_shape_[kIndex0]), this, &parallel_search_info_);
}
return true;
}
template <typename T>
std::pair<KernelAttr, BiasAddCpuKernelMod::KernelRunFunc> BiasAddCpuKernelMod::MakeKernelFunc(TypeId type_id) const {
return std::make_pair(KernelAttr().AddInputAttr(type_id).AddInputAttr(type_id).AddOutputAttr(type_id),
&BiasAddCpuKernelMod::LaunchKernel<T>);
}
const std::vector<std::pair<KernelAttr, BiasAddCpuKernelMod::KernelRunFunc>> &BiasAddCpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, BiasAddCpuKernelMod::KernelRunFunc>> func_list = {
MakeKernelFunc<float16>(kNumberTypeFloat16),
MakeKernelFunc<float>(kNumberTypeFloat32),
MakeKernelFunc<double>(kNumberTypeFloat64),
MakeKernelFunc<int8_t>(kNumberTypeInt8),
MakeKernelFunc<int16_t>(kNumberTypeInt16),
MakeKernelFunc<int64_t>(kNumberTypeInt32),
MakeKernelFunc<uint8_t>(kNumberTypeUInt8),
MakeKernelFunc<uint16_t>(kNumberTypeUInt16),
MakeKernelFunc<uint32_t>(kNumberTypeUInt32),
MakeKernelFunc<uint64_t>(kNumberTypeUInt64),
MakeKernelFunc<std::complex<float>>(kNumberTypeComplex64),
MakeKernelFunc<std::complex<double>>(kNumberTypeComplex128),
};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, BiasAdd, BiasAddCpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -14,31 +14,45 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_BIAS_ADD_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_BIAS_ADD_CPU_KERNEL_H_
#include <map>
#include <vector>
#include <memory>
#include <utility>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class BiasAddCpuKernelMod : public DeprecatedNativeCpuKernelMod {
class BiasAddCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<BiasAddCpuKernelMod> {
public:
BiasAddCpuKernelMod() = default;
~BiasAddCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
template <typename T>
std::pair<KernelAttr, BiasAddCpuKernelMod::KernelRunFunc> MakeKernelFunc(TypeId type_id) const;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
size_t data_shape_{0};
std::vector<int64_t> input_shape_;
std::vector<int64_t> bias_shape_;
std::vector<size_t> input_shape_;
std::vector<size_t> bias_shape_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_BIAS_ADD_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_BIAS_ADD_CPU_KERNEL_H_

View File

@ -15,16 +15,116 @@
*/
#include "plugin/device/gpu/kernel/nn/bias_add_gpu_kernel.h"
#include <mindspore/core/abstract/utils.h>
#include <map>
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
BiasAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BiasAddGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(
BiasAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BiasAddGpuKernelMod, float16)
bool BiasAddGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->GetPrim()->name();
constexpr size_t input_num = 2;
constexpr size_t output_num = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), output_num, kernel_name_);
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel data type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(inputs[kIndex1]->GetDtype()));
return true;
}
int BiasAddGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (auto ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
auto x_shape = LongVecToSizeVec(inputs[kIndex0]->GetShapeVector());
auto num_dims = x_shape.size();
is_null_input_ = CHECK_SHAPE_NULL(x_shape, kernel_name_, "input_x");
constexpr size_t min_num_dims = 2;
if (num_dims < min_num_dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input_x cannot be less than 2, but got "
<< num_dims;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::BiasAdd>(base_operator);
auto format = kernel_ptr->get_format();
auto format_str = format_str_list[format + 1];
string::size_type pos = format_str.find("C");
if (pos == std::string::npos || pos >= num_dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'C' character must be in 'format', but got " << format_str;
}
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
constexpr size_t four_4D = 4;
size_t cudnn_dims = std::max(num_dims, four_4D);
std::unique_ptr<int[]> x_dims = std::make_unique<int[]>(cudnn_dims);
std::unique_ptr<int[]> b_dims = std::make_unique<int[]>(cudnn_dims);
for (size_t i = 0; i < cudnn_dims; i++) {
x_dims[i] = (i < num_dims) ? LongToInt(x_shape[i]) : 1;
b_dims[i] = (i == pos) ? LongToInt(x_shape[i]) : 1;
}
auto input_device_format = inputs[kIndex0]->GetFormat();
auto cudnn_cal_format = (input_device_format == Format::NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensorNdDescriptorEx(x_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetTensorNdDescriptorEx(b_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN),
"cudnnSetOpTensorDescriptor failed");
return KRET_OK;
}
template <typename T>
bool BiasAddGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
T *x_addr = GetDeviceAddress<T>(inputs, 0);
T *b_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
try {
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_,
b_addr, &beta, x_desc_, output_addr),
"cudnnOpTensor failed");
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor";
}
return true;
}
std::vector<std::pair<KernelAttr, BiasAddGpuKernelMod::BiasAddLaunchFunc>> BiasAddGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&BiasAddGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&BiasAddGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&BiasAddGpuKernelMod::LaunchKernel<int8_t>}};
std::vector<KernelAttr> BiasAddGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(
func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BiasAddGpuKernelMod::BiasAddLaunchFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BiasAdd, BiasAddGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -14,150 +14,64 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_NN_BIAS_ADD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_NN_BIAS_ADD_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <string>
#include <map>
#include <algorithm>
#include <memory>
#include <vector>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "mindspore/core/ops/bias_add.h"
namespace mindspore {
namespace kernel {
template <typename T>
class BiasAddGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class BiasAddGpuKernelMod : public NativeGpuKernelMod {
public:
BiasAddGpuKernelMod() { ResetResource(); }
~BiasAddGpuKernelMod() override { DestroyResource(); }
BiasAddGpuKernelMod() {}
~BiasAddGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
T *x_addr = GetDeviceAddress<T>(inputs, 0);
T *b_addr = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
try {
const float alpha = 1;
const float beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnOpTensor(cudnn_handle_, op_desc_, &alpha, x_desc_, x_addr, &alpha, b_desc_,
b_addr, &beta, x_desc_, output_addr),
"cudnnOpTensor failed");
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cudnnOpTensor";
}
return true;
return kernel_func_(this, inputs, workspace, outputs, stream_ptr);
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
auto x_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (IsDynamic(x_shape)) {
return true;
}
auto num_dims = x_shape.size();
is_null_input_ = CHECK_SHAPE_NULL(x_shape, kernel_name_, "input_x");
if (is_null_input_) {
InitSizeLists();
return true;
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
if (num_dims < 2) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input_x cannot be less than 2, but got "
<< num_dims;
}
std::string format = GetAttr<std::string>(kernel_node, "format");
string::size_type pos = format.find("C");
if (pos == std::string::npos || pos >= num_dims) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'C' character must be in 'format', but got " << format;
}
// Expand to 4 dims for cudnnSetTensorNdDescriptorEx.
constexpr size_t four_4D = 4;
size_t cudnn_dims = std::max(num_dims, four_4D);
std::unique_ptr<int[]> x_dims = std::make_unique<int[]>(cudnn_dims);
std::unique_ptr<int[]> b_dims = std::make_unique<int[]>(cudnn_dims);
for (size_t i = 0; i < cudnn_dims; i++) {
x_dims[i] = (i < num_dims) ? LongToInt(x_shape[i]) : 1;
b_dims[i] = (i == pos) ? LongToInt(x_shape[i]) : 1;
}
auto input_device_format = AnfAlgo::GetInputFormat(kernel_node, 0);
auto cudnn_cal_format = (input_device_format == kOpFormat_NHWC) ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW;
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensorNdDescriptorEx(x_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), x_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetTensorNdDescriptorEx(b_desc_, cudnn_cal_format, cudnn_data_type_, SizeToInt(cudnn_dims), b_dims.get()),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_,
cudnnSetOpTensorDescriptor(op_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN),
"cudnnSetOpTensorDescriptor failed");
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
cudnn_handle_ = nullptr;
cudnn_data_type_ = CUDNN_DATA_FLOAT;
x_desc_ = nullptr;
b_desc_ = nullptr;
op_desc_ = nullptr;
is_null_input_ = false;
kernel_name_ = "BiasAdd";
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void DestroyResource() noexcept override {
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyOpTensorDescriptor(op_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(b_desc_),
"cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_),
"cudnnDestroyOpTensorDescriptor failed");
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&x_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateTensorDescriptor(&b_desc_),
"cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateOpTensorDescriptor(&op_desc_),
"cudnnCreateOpTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&x_desc_), "cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateTensorDescriptor(&b_desc_), "cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(cudnnCreateOpTensorDescriptor(&op_desc_),
"cudnnCreateOpTensorDescriptor failed");
}
void InitSizeLists() override {
size_t x_size, b_size;
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size),
"cudnnGetTensorSizeInBytes failed.");
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(b_desc_, &b_size),
"cudnnGetTensorSizeInBytes failed.");
input_size_list_.push_back(x_size);
input_size_list_.push_back(b_size);
output_size_list_.push_back(x_size);
}
std::vector<KernelAttr> GetOpSupport() override;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr);
using BiasAddLaunchFunc =
std::function<bool(BiasAddGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
std::vector<std::string> format_str_list = {"DEFAULT", "NCHW", "NHWC", "NHWC4", "HWKC", "HWCK", "KCHW",
"CKHW", "KHWC", "CHWK", "HW", "HW4", "NC", "NC4",
"NC4HW4", "NCDHW", "NWC", "NCW", "NDHWC", "NC8HW8"};
private:
BiasAddLaunchFunc kernel_func_;
static std::vector<std::pair<KernelAttr, BiasAddLaunchFunc>> func_list_;
cudnnHandle_t cudnn_handle_;
cudnnDataType_t cudnn_data_type_;
cudnnTensorDescriptor_t x_desc_;
@ -169,4 +83,4 @@ class BiasAddGpuKernelMod : public DeprecatedNativeGpuKernelMod {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_GPU_NN_BIAS_ADD_GPU_KERNEL_H_

View File

@ -92,7 +92,7 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
std::map<std::string, TypePtr> types;
(void)types.emplace("input_x", input_args[0]->BuildType());
(void)types.emplace("bias", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types_with_complex, prim_name);
}
} // namespace

View File

@ -29,7 +29,6 @@ from .adam_weight_decay import _adam_weight_decay_cpu
from .arg_max import _arg_max_cpu
from .arg_min_with_value import _arg_min_with_value_cpu
from .arg_max_with_value import _arg_max_with_value_cpu
from .bias_add import _bias_add_cpu
from .bias_add_grad import _bias_add_grad_cpu
from .dropout import _dropout_cpu
from .dropout_grad import _dropout_grad_cpu

View File

@ -1,30 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BiasAdd op"""
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
bias_add_op_info = CpuRegOp("BiasAdd") \
.input(0, "x", "required") \
.input(1, "bias", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(bias_add_op_info)
def _bias_add_cpu():
"""BiasAdd cpu register"""
return

View File

@ -2949,16 +2949,13 @@ def bias_add(input_x, bias):
Args:
input_x (Tensor): The input tensor. The shape can be 2-5 dimensions.
The data type should be float16 or float32.
bias (Tensor): The bias tensor, with shape :math:`(C)`. C must be the same as channel dimension C of
`input_x`. The data type should be float16 or float32.
bias (Tensor): The bias tensor, with shape :math:`(C)`. C must be the same as channel dimension C of `input_x`.
Returns:
Tensor, with the same shape and data type as `input_x`.
Raises:
TypeError: If `input_x` or `bias` is not a Tensor.
TypeError: If dtype of `input_x` or `bias` is neither float16 nor float32.
TypeError: If dtype of `input_x` or `bias` is inconsistent.
TypeError: If dimension of `input_x` is not in the range [2, 5].