Support inplace update v2 gpu kernel.

This commit is contained in:
ZPaC 2023-03-06 16:08:40 +08:00
parent b834aa3a45
commit 01e3ea47aa
8 changed files with 341 additions and 24 deletions

View File

@ -23,8 +23,6 @@ static std::unordered_map<std::string, int> op_type_map = {
{"InplaceUpdate", INPLACE_OP_TYPE_UPDATE}, {"InplaceAdd", INPLACE_OP_TYPE_ADD}, {"InplaceSub", INPLACE_OP_TYPE_SUB}};
bool InplaceOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
// auto kernel_ptr_ = std::dynamic_pointer_cast<ops::InplaceUpdate>(base_operator);
// kernel_name_ = kernel_ptr_->name();
kernel_name_ = base_operator->name();
auto iter = op_type_map.find(kernel_name_);
if (iter == op_type_map.end()) {

View File

@ -0,0 +1,154 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/inplace_op_v2_gpu_kernel.h"
#include <unordered_map>
#include <string>
namespace mindspore {
namespace kernel {
static std::unordered_map<std::string, int> op_type_map = {{"InplaceUpdateV2", INPLACE_OP_TYPE_UPDATE}};
bool InplaceOpV2GpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
auto iter = op_type_map.find(kernel_name_);
if (iter == op_type_map.end()) {
MS_LOG(ERROR) << "For InplaceOpV2 kernel, Can only support InplaceUpdateV2, but got " << kernel_name_;
return false;
}
kernel_type_ = iter->second;
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the kernel type should be in [float16, float32, float64, int32]"
", but got: "
<< kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(inputs[0]->GetDtype());
return true;
}
int InplaceOpV2GpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> input_shape_x = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
std::vector<int64_t> input_shape_indices = std::vector<int64_t>(
inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(), inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
std::vector<int64_t> input_shape_v = std::vector<int64_t>(inputs.at(kIndex2)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex2)->GetDeviceShapeAdaptively().end());
band_size_ = 1;
for (size_t i = 1; i < input_shape_x.size(); ++i) {
band_size_ *= input_shape_x[i];
}
input_elements_x = std::accumulate(input_shape_x.begin(), input_shape_x.end(), 1, std::multiplies<int64_t>());
input_elements_v = std::accumulate(input_shape_v.begin(), input_shape_v.end(), 1, std::multiplies<int64_t>());
size_t input_size_x = input_elements_x * unit_size_;
size_t indices_size = input_shape_indices.size() * sizeof(int32_t);
size_t input_size_v = input_elements_v * unit_size_;
input_size_list_.push_back(input_size_x);
input_size_list_.push_back(IntToSize(indices_size));
input_size_list_.push_back(input_size_v);
output_size_list_.push_back(input_size_x);
if (kernel_name_ == ops::kNameInplaceUpdateV2) {
workspace_size_list_.push_back(indices_size);
}
return KRET_OK;
}
void InplaceOpV2GpuKernelMod::ResetResource() noexcept {
band_size_ = 1;
input_elements_x = 0;
input_elements_v = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
template <typename T>
bool InplaceOpV2GpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input_x = GetDeviceAddress<T>(inputs, kIndex0);
int32_t *input_indices = GetDeviceAddress<int32_t>(inputs, kIndex1);
T *input_v = GetDeviceAddress<T>(inputs, kIndex2);
T *output = GetDeviceAddress<T>(outputs, kIndex0);
int32_t *indices_key_ptr = nullptr;
if (kernel_name_ == ops::kNameInplaceUpdateV2) {
indices_key_ptr = GetDeviceAddress<int32_t>(workspace, kIndex0);
}
auto cuda_stream = reinterpret_cast<cudaStream_t>(cuda_stream_);
// Copy from 'x' into 'y'.
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(output, input_x, input_elements_x * unit_size_, cudaMemcpyDeviceToDevice, cuda_stream),
"cudaMemcpyAsync output 'output' from 'input_x' failed.");
CalInplaceOp(input_elements_v, input_v, output, input_indices, indices_key_ptr, band_size_, device_id_, kernel_type_,
cuda_stream);
return true;
}
std::vector<std::pair<KernelAttr, InplaceOpV2GpuKernelMod::InplaceOpFunc>> InplaceOpV2GpuKernelMod::func_list_ = {
{KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
&InplaceOpV2GpuKernelMod::LaunchKernel<half>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
&InplaceOpV2GpuKernelMod::LaunchKernel<float>},
{KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
&InplaceOpV2GpuKernelMod::LaunchKernel<double>},
{KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
&InplaceOpV2GpuKernelMod::LaunchKernel<int>}};
std::vector<KernelAttr> InplaceOpV2GpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, InplaceOpFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, InplaceUpdateV2, InplaceOpV2GpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,83 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_INPLACE_UPDATE_V2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_INPLACE_UPDATE_V2_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <algorithm>
#include <functional>
#include <map>
#include "mindspore/core/ops/inplace_update_v2.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/inplace_op_impl.cuh"
namespace mindspore {
namespace kernel {
class InplaceOpV2GpuKernelMod : public NativeGpuKernelMod {
public:
InplaceOpV2GpuKernelMod() { ResetResource(); }
~InplaceOpV2GpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
std::vector<KernelAttr> GetOpSupport() override;
void ResetResource() noexcept;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using InplaceOpFunc =
std::function<bool(InplaceOpV2GpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
std::vector<int64_t> indices_;
int kernel_type_{-1};
size_t unit_size_{1};
size_t input_elements_x;
size_t input_elements_v;
int64_t band_size_;
InplaceOpFunc kernel_func_{};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
static std::vector<std::pair<KernelAttr, InplaceOpFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_INPLACE_UPDATE_V2_GPU_KERNEL_H_

View File

@ -32,14 +32,14 @@ struct AddFunc {
__device__ __forceinline__ void operator()(T *lhs, const T &rhs) { MsAtomicAdd(lhs, rhs); }
};
template <typename T>
__global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, const int64_t *indices,
int64_t *indices_key, size_t indices_len, const int64_t band_size) {
template <typename T, typename S>
__global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, const S *indices, S *indices_key,
size_t indices_len, const int64_t band_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
size_t row = pos / band_size;
if (row == indices_len || indices[row] != indices[row + 1]) {
int x_row = indices[row];
int v_row = indices_key[row];
S x_row = indices[row];
S v_row = indices_key[row];
int offset = pos % band_size;
int x_offset = x_row * band_size;
output[x_offset + offset] = input_v[v_row * band_size + offset];
@ -47,12 +47,12 @@ __global__ void InplaceUpdate(const size_t size, const T *input_v, T *output, co
}
return;
}
template <typename T, typename Func>
__global__ void InplaceAddOrSub(const size_t size, const T *input_v, T *output, const int64_t *indices,
template <typename T, typename S, typename Func>
__global__ void InplaceAddOrSub(const size_t size, const T *input_v, T *output, const S *indices,
const int64_t band_size) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int v_row = pos / band_size;
int x_row = indices[v_row];
S x_row = indices[v_row];
int offset = pos % band_size;
int x_offset = x_row * band_size;
Func()(&output[x_offset + offset], input_v[pos]);
@ -60,9 +60,9 @@ __global__ void InplaceAddOrSub(const size_t size, const T *input_v, T *output,
return;
}
template <typename T>
void CalInplaceOp(const size_t size_v, const T *input_v, T *output, int64_t *indices, int64_t *indices_key,
const int64_t band_size, const uint32_t &device_id, int op_type, cudaStream_t cuda_stream) {
template <typename T, typename S>
void CalInplaceOp(const size_t size_v, const T *input_v, T *output, S *indices, S *indices_key, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream) {
int thread_num = 256 > size_v ? size_v : 256;
if (op_type == INPLACE_OP_TYPE_UPDATE) {
auto policy = thrust::cuda::par.on(cuda_stream);
@ -75,10 +75,10 @@ void CalInplaceOp(const size_t size_v, const T *input_v, T *output, int64_t *ind
InplaceUpdate<<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
size_v, input_v, output, indices, indices_key, indices_element, band_size);
} else if (op_type == INPLACE_OP_TYPE_ADD) {
InplaceAddOrSub<T, AddFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
InplaceAddOrSub<T, S, AddFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
size_v, input_v, output, indices, band_size);
} else if (op_type == INPLACE_OP_TYPE_SUB) {
InplaceAddOrSub<T, SubFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
InplaceAddOrSub<T, S, SubFunc<T>><<<CUDA_BLOCKS_CAL(device_id, size_v, thread_num), thread_num, 0, cuda_stream>>>(
size_v, input_v, output, indices, band_size);
}
return;
@ -99,3 +99,19 @@ template CUDA_LIB_EXPORT void CalInplaceOp<double>(const size_t size_v, const do
template CUDA_LIB_EXPORT void CalInplaceOp<int>(const size_t size_v, const int *input_v, int *output, int64_t *indices,
int64_t *indices_key, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceOp<half>(const size_t size_v, const half *input_v, half *output,
int32_t *indices, int32_t *indices_key, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceOp<float>(const size_t size_v, const float *input_v, float *output,
int32_t *indices, int32_t *indices_key, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceOp<double>(const size_t size_v, const double *input_v, double *output,
int32_t *indices, int32_t *indices_key, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalInplaceOp<int>(const size_t size_v, const int *input_v, int *output, int32_t *indices,
int32_t *indices_key, const int64_t band_size,
const uint32_t &device_id, int op_type, cudaStream_t cuda_stream);

View File

@ -25,9 +25,9 @@ enum BroadcastOpType {
INPLACE_OP_TYPE_SUB = 2,
};
template <typename T>
CUDA_LIB_EXPORT void CalInplaceOp(const size_t size_v, const T *input_v, T *output, int64_t *indices,
int64_t *indices_key_ptr, const int64_t band_size, const uint32_t &device_id,
int op_type, cudaStream_t cuda_stream);
template <typename T, typename S>
CUDA_LIB_EXPORT void CalInplaceOp(const size_t size_v, const T *input_v, T *output, S *indices, S *indices_key_ptr,
const int64_t band_size, const uint32_t &device_id, int op_type,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_INPLACE_UPDATE_IMPL_CUH_

View File

@ -67,8 +67,8 @@ from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerA
FusedAdaFactorWithGlobalNorm)
from .linalg_ops import (Svd, Geqrf)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
BitwiseAnd, BitwiseOr, Ger,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, InplaceUpdate,
BitwiseAnd, BitwiseOr, Ger, BitwiseXor, Inv, Invert, ApproximateEqual,
InplaceAdd, InplaceSub, InplaceUpdate, InplaceUpdateV2,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
@ -470,6 +470,7 @@ __all__ = [
"DataFormatDimMap",
"ApproximateEqual",
"InplaceUpdate",
"InplaceUpdateV2",
"InTopK",
"UniformCandidateSampler",
"LogUniformCandidateSampler",

View File

@ -1838,14 +1838,14 @@ class InplaceUpdateV2(Primitive):
TypeError: If `indices` is a tuple and its element is not an int.
Supported Platforms:
``Ascend``
``Ascend` ``GPU```
Examples:
>>> indices = (0, 1)
>>> x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32)
>>> v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32)
>>> inplace_update = ops.InplaceUpdate(indices)
>>> output = inplace_update(x, v)
>>> inplace_update_v2 = ops.InplaceUpdateV2()
>>> output = inplace_update_v2(x, indices, v)
>>> print(output)
[[0.5 1. ]
[1. 1.5]

View File

@ -0,0 +1,65 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# This example should be run with multiple processes.
# Please refer to the Programming Guide > Distributed Training -> Distributed Parallel Usage Example
# on mindspore.cn and focus on the contents of these three parts: Configuring Distributed Environment
# Variables, Calling the Collective Communication Library, Running the Script.
import pytest
import mindspore
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore import nn
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetInplaceUpdateV2(nn.Cell):
def __init__(self, x, v):
super(NetInplaceUpdateV2, self).__init__()
self.x = x
self.v = v
self.inplace_update_v2 = P.InplaceUpdateV2()
def construct(self, indices):
output = self.inplace_update_v2(self.x, indices, self.v)
return output
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_inplace_update_fp16():
"""
Feature: ALL To ALL
Description: test cases for InplaceUpdateV2
Expectation: the result match to expect result
"""
x = Tensor([[1, 2], [3, 4], [5, 6]], mindspore.float16)
v = Tensor([[0.5, 1.0], [1.0, 1.5]], mindspore.float16)
inplace_update_v2 = NetInplaceUpdateV2(x, v)
indices = Tensor(shape=[None], dtype=mindspore.int32)
inplace_update_v2.set_inputs(indices)
real_indices = Tensor([0, 1], dtype=mindspore.int32)
output = inplace_update_v2(real_indices)
expect = Tensor([[0.5, 1.0], [1.0, 1.5], [5, 6]], mindspore.float16)
assert (output.asnumpy() == expect).all()