addn int64 support

This commit is contained in:
TFBunny 2021-04-27 16:55:14 -04:00
parent 37af5b2f7a
commit bc5434090c
3 changed files with 76 additions and 21 deletions

View File

@ -30,5 +30,8 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(AddN,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AddNGpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(AddN,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
AddNGpuFwdKernel, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -34,6 +34,7 @@ class AddNGpuFwdKernel : public GpuKernel {
: cudnn_handle_(nullptr),
input_descriptor_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
is_int64_(false),
input_size_(0),
output_size_(0),
workspace_size_(0),
@ -57,7 +58,7 @@ class AddNGpuFwdKernel : public GpuKernel {
break;
}
}
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
if (cudnn_data_type_ == CUDNN_DATA_INT32 || is_int64_) {
FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
FillDeviceArray(outputs[0]->size / sizeof(T), work_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr));
}
@ -67,7 +68,7 @@ class AddNGpuFwdKernel : public GpuKernel {
const double dbeta = static_cast<double>(0.0f);
for (size_t i = 0; i < num_input_; i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i);
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
if (cudnn_data_type_ == CUDNN_DATA_INT32 || is_int64_) {
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else if (cudnn_data_type_ == CUDNN_DATA_DOUBLE) {
@ -92,8 +93,12 @@ class AddNGpuFwdKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
// because int64 is not supported in cudnn, so we go separate path
is_int64_ = (AnfAlgo::GetInputDeviceDataType(kernel_node, 0) == kNumberTypeInt64) ? true : false;
InitResource();
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
if (!is_int64_) {
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
num_input_ = GetAttr<int64_t>(kernel_node, "n");
if (num_input_ != input_num) {
@ -112,24 +117,31 @@ class AddNGpuFwdKernel : public GpuKernel {
InitSizeLists();
return true;
}
for (size_t i = input_shape.size(); i < 4; i++) {
(void)input_shape.insert(input_shape.begin(), 1);
}
std::vector<int> dimA;
for (size_t i = 0; i < input_shape.size(); i++) {
dimA.push_back(SizeToInt(input_shape[i]));
}
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
if (input_format == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
SizeToInt(input_shape.size()), dimA.data()),
"cudnnSetTensorNdDescriptor failed");
if (is_int64_) {
input_size_ = sizeof(T);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
SizeToInt(input_shape.size()), dimA.data()),
"cudnnSetTensorNdDescriptor failed");
for (size_t i = input_shape.size(); i < 4; i++) {
(void)input_shape.insert(input_shape.begin(), 1);
}
std::vector<int> dimA;
for (size_t i = 0; i < input_shape.size(); i++) {
dimA.push_back(SizeToInt(input_shape[i]));
}
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
if (input_format == kOpFormat_NHWC) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NHWC, cudnn_data_type_,
SizeToInt(input_shape.size()), dimA.data()),
"cudnnSetTensorNdDescriptor failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptorEx(input_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
SizeToInt(input_shape.size()), dimA.data()),
"cudnnSetTensorNdDescriptor failed");
}
}
InitSizeLists();
return true;
@ -147,7 +159,7 @@ class AddNGpuFwdKernel : public GpuKernel {
"cudnnCreateTensorDescriptor failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
if (!is_null_input_ && !is_int64_) {
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(input_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed");
}
@ -167,6 +179,7 @@ class AddNGpuFwdKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
bool is_int64_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;

View File

@ -94,3 +94,42 @@ def test_net_float64():
[84., 87., 90., 93.],
[96., 99., 102., 105.]]]]).astype(np.float64)
assert (output.asnumpy() == expect_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_int64():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.int64)
y = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.int64)
z = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.int64)
add = Net()
output = add(Tensor(x), Tensor(y), Tensor(z))
expect_result = np.array([[[[0., 3., 6., 9.],
[12., 15., 18., 21.],
[24., 27., 30., 33.]],
[[36., 39., 42., 45.],
[48., 51., 54., 57.],
[60., 63., 66., 69.]],
[[72., 75., 78., 81.],
[84., 87., 90., 93.],
[96., 99., 102., 105.]]]]).astype(np.int64)
assert (output.asnumpy() == expect_result).all()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.int64)
y = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.int64)
z = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.int64)
add = Net()
output = add(Tensor(x), Tensor(y), Tensor(z))
expect_result = np.array([[[[0., 3., 6., 9.],
[12., 15., 18., 21.],
[24., 27., 30., 33.]],
[[36., 39., 42., 45.],
[48., 51., 54., 57.],
[60., 63., 66., 69.]],
[[72., 75., 78., 81.],
[84., 87., 90., 93.],
[96., 99., 102., 105.]]]]).astype(np.int64)
assert (output.asnumpy() == expect_result).all()