!45062 [MS][OP]unstack gpu kernelMod

Merge pull request !45062 from mengyuanli/ds_unstack_gpu
This commit is contained in:
i-robot 2022-11-03 14:07:15 +00:00 committed by Gitee
commit 9ee0476a3b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 48 additions and 46 deletions

View File

@ -14,12 +14,14 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNPACK_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNPACK_GPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_UNPACK_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_UNPACK_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unpack.cuh"
@ -27,7 +29,7 @@
namespace mindspore {
namespace kernel {
template <typename T>
class UnpackFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class UnpackFwdGpuKernelMod : public NativeGpuKernelMod {
public:
UnpackFwdGpuKernelMod()
: axis_(0), is_null_input_(false), output_num_(0), input_size_(1), dims_after_axis_(1), outputs_host_(nullptr) {}
@ -43,42 +45,53 @@ class UnpackFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
for (size_t i = 0; i < outputs.size(); i++) {
outputs_host_[i] = GetDeviceAddress<T>(outputs, i);
}
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(outputs_array, // NOLINT
outputs_host_.get(), sizeof(T *) * output_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Unpack opt cudaMemcpyAsync outputs failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(outputs_array, // NOLINT
outputs_host_.get(), sizeof(T *) * output_num_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Unpack opt cudaMemcpyAsync outputs failed");
UnpackKernel(input_size_, output_num_, dims_after_axis_, outputs_array, input,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
(void)CheckParam(kernel_node);
axis_ = static_cast<int32_t>(GetAttr<int64_t>(kernel_node, "axis"));
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
int32_t shape_size = SizeToInt(input_shape.size());
if (axis_ < -shape_size || axis_ >= shape_size) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the `axis` should be in [" << -shape_size << ", "
<< shape_size << "), but got " << axis_;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(base_operator);
kernel_name_ = base_operator->name();
constexpr size_t input_num = 1;
CHECK_KERNEL_INPUTS_NUM(inputs.size(), input_num, kernel_name_);
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
axis_ = static_cast<int32_t>(GetValue<int64_t>(prim->GetAttr("axis")));
origin_data_format_ = GetValue<std::string>(prim->GetAttr("operator_origin_format"));
auto input_shape = inputs[kIndex0]->GetDeviceShapeAdaptively();
if (axis_ < 0) {
axis_ += SizeToInt(input_shape.size());
}
auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node);
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);
axis_ = AxisTransform(origin_data_format, input_format, axis_);
output_num_ = LongToSize(GetAttr<int64_t>(kernel_node, "num"));
auto input_format = FormatEnumToString(inputs[0]->GetFormat());
axis_ = AxisTransform(origin_data_format_, input_format, axis_);
output_num_ = LongToSize(GetValue<int64_t>(prim->GetAttr("num")));
outputs_host_ = std::make_unique<T *[]>(output_num_);
ResetResource();
for (size_t i = 0; i < output_num_; i++) {
size_t _size = 1;
auto _shape = AnfAlgo::GetOutputDeviceShape(kernel_node, i);
auto _shape = outputs[i]->GetDeviceShapeAdaptively();
is_null_input_ = CHECK_SHAPE_NULL(_shape, kernel_name_, "output");
if (is_null_input_) {
InitSizeLists();
return true;
return KRET_OK;
}
for (size_t j = 0; j < _shape.size(); j++) {
_size *= static_cast<size_t>(_shape[j]);
@ -89,8 +102,7 @@ class UnpackFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
if (is_null_input_) {
InitSizeLists();
return true;
return KRET_OK;
}
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= static_cast<size_t>(input_shape[i]);
@ -99,39 +111,29 @@ class UnpackFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
}
input_size_list_.push_back(input_size_ * sizeof(T));
InitSizeLists();
return true;
return KRET_OK;
}
protected:
void InitSizeLists() override {}
private:
void ResetResource() noexcept override {
axis_ = 0;
void ResetResource() noexcept {
is_null_input_ = false;
output_num_ = 0;
input_size_ = 1;
dims_after_axis_ = 1;
outputs_host_ = nullptr;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
void CheckParam(const CNodePtr &kernel_node) {
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1, but got " << input_num;
}
}
int axis_;
bool is_null_input_;
size_t output_num_;
int axis_{0};
bool is_null_input_{false};
size_t output_num_{0};
size_t input_size_;
size_t dims_after_axis_;
std::unique_ptr<T *[]> outputs_host_;
std::string origin_data_format_{};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_UNPACK_GPU_KERNEL_H
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_UNPACK_GPU_KERNEL_H_