forked from mindspore-Ecosystem/mindspore
!42360 [MSLITE] Support dynamic shape for slicegrad
Merge pull request !42360 from zhangyongxian/dev_zhangyongxian_slicegrad
This commit is contained in:
commit
6277a023ed
|
@ -30,60 +30,64 @@ constexpr size_t kOutputsNum = 1;
|
|||
constexpr size_t kSliceGradMaxInputShapeSize = 8;
|
||||
} // namespace
|
||||
|
||||
void SliceGradCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
bool SliceGradCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
auto input_num = inputs.size();
|
||||
dtype_ = inputs.at(0)->GetDtype();
|
||||
constexpr size_t kInputNum2 = 2;
|
||||
ClearVectors();
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_num == kSliceGradDynamicInputsNum || input_num == kStridedSliceGradDynamicInputsNum) {
|
||||
is_dynamic_attr_ = true;
|
||||
strides_dtype_ = inputs.at(kInputNum2)->GetDtype();
|
||||
return true;
|
||||
}
|
||||
auto prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto begin_value = prim->GetAttr(kAttrBegin);
|
||||
MS_EXCEPTION_IF_NULL(begin_value);
|
||||
begin_ = GetValue<std::vector<int64_t>>(begin_value);
|
||||
auto strides_value = prim->GetAttr(STRIDES);
|
||||
if (strides_value != nullptr) { // StrideSliceGrad
|
||||
strides_ = GetValue<std::vector<int64_t>>(strides_value);
|
||||
auto end_value = prim->GetAttr(kAttrEnd);
|
||||
MS_EXCEPTION_IF_NULL(end_value);
|
||||
end_ = GetValue<std::vector<int64_t>>(end_value);
|
||||
} else { // SliceGrad
|
||||
auto size_value = prim->GetAttr(kAttrSize);
|
||||
MS_EXCEPTION_IF_NULL(size_value);
|
||||
size_ = GetValue<std::vector<int64_t>>(size_value);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int SliceGradCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
auto begin = GetDynamicAttrIntValue(inputs, kBeginIndex_, inputsOnHost, kernel_name_, &begin_);
|
||||
if (kernel_name_ == prim::kPrimStridedSliceGrad->name()) {
|
||||
auto end = GetDynamicAttrIntValue(inputs, kEndIndex_, inputsOnHost, kernel_name_, &end_);
|
||||
auto stride = GetDynamicAttrIntValue(inputs, kStrideIndex_, inputsOnHost, kernel_name_, &strides_);
|
||||
get_dynamic_attr_value_ = begin && end && stride;
|
||||
} else {
|
||||
auto size = GetDynamicAttrIntValue(inputs, kSizeIndex_, inputsOnHost, kernel_name_, &size_);
|
||||
get_dynamic_attr_value_ = begin && size;
|
||||
}
|
||||
auto ret = KernelMod::Resize(base_operator, inputs, outputs);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
auto input_shape = inputs[0]->GetShapeVector();
|
||||
if (input_shape.size() > kSliceGradMaxInputShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input tensor must be 8D or lower, but got "
|
||||
<< input_shape.size() << "D.";
|
||||
}
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num == kSliceGradDynamicInputsNum || input_num == kStridedSliceGradDynamicInputsNum) {
|
||||
strides_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kInputNum2);
|
||||
return;
|
||||
}
|
||||
// in the case that begin, end, size, stride are const value.
|
||||
std::vector<int64_t> begin_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, BEGIN);
|
||||
(void)std::transform(begin_me.begin(), begin_me.end(), std::back_inserter(begin_),
|
||||
[](const int64_t &value) { return LongToInt(value); });
|
||||
auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto strides = prim->GetAttr(STRIDES);
|
||||
if (strides != nullptr) { // StridedSliceGrad
|
||||
std::vector<int64_t> strides_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDES);
|
||||
std::vector<int64_t> end_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, END);
|
||||
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides_),
|
||||
[](const int64_t &value) { return LongToInt(value); });
|
||||
(void)std::transform(end_me.begin(), end_me.end(), std::back_inserter(end_),
|
||||
[](const int64_t &value) { return LongToInt(value); });
|
||||
if (strides_.size() != end_.size() || strides_.size() != output_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'strides|end|output' must be equal, but got the dimension of "
|
||||
<< "'strides': " << strides_.size() << ", the dimension of 'end': " << end_.size()
|
||||
<< ", and the dimension of output: " << output_shape_.size();
|
||||
}
|
||||
FormatArgs(true);
|
||||
} else { // SliceGrad
|
||||
std::vector<int64_t> size_me = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, SIZE);
|
||||
(void)std::transform(size_me.begin(), size_me.end(), std::back_inserter(size_),
|
||||
[](const int64_t &value) { return LongToInt(value); });
|
||||
if (size_.size() != output_shape_.size() || begin_.size() != output_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', 'begin|size|input' size must be equal, but got 'begin' size: " << begin_.size()
|
||||
<< ", 'size' size: " << size_.size() << " and 'input' size: " << output_shape_.size();
|
||||
}
|
||||
FormatArgs(false);
|
||||
}
|
||||
output_shape_ = outputs[0]->GetShapeVector();
|
||||
FormatArgs(kernel_name_ == prim::kPrimStridedSliceGrad->name());
|
||||
ExpandAllMemberDims(kSliceGradMaxInputShapeSize);
|
||||
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void SliceGradCpuKernelMod::ClearVectors() {
|
||||
|
@ -127,67 +131,6 @@ void SliceGradCpuKernelMod::ExpandAllMemberDims(size_t expand_dims) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SliceGradCpuKernelMod::InitParams(const std::vector<kernel::AddressPtr> &inputs) {
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
ClearVectors();
|
||||
output_shape_ = common::AnfAlgo::GetOutputInferShape(cnode, 0);
|
||||
std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode);
|
||||
auto begin_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 2);
|
||||
auto begin_ptr = reinterpret_cast<T *>(inputs[2]->addr);
|
||||
std::vector<T> begin{begin_ptr, begin_ptr + begin_shape[0]};
|
||||
(void)std::transform(begin.begin(), begin.end(), std::back_inserter(begin_),
|
||||
[](const T &value) { return static_cast<int>(value); });
|
||||
if (kernel_name == prim::kPrimStridedSliceGrad->name()) { // StridedSliceGrad
|
||||
auto end_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3);
|
||||
auto stride_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 4);
|
||||
if (begin_shape.size() != 1 || end_shape.size() != 1 || stride_shape.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimensions of 'begin', 'end', 'strides' must be 1, "
|
||||
"but got the dimension of 'begin': "
|
||||
<< begin_shape.size() << ", the dimension of 'end': " << end_shape.size()
|
||||
<< ", and the dimension of 'strides': " << stride_shape.size();
|
||||
}
|
||||
|
||||
auto end_ptr = reinterpret_cast<T *>(inputs[3]->addr);
|
||||
auto strides_ptr = reinterpret_cast<T *>(inputs[4]->addr);
|
||||
|
||||
std::vector<T> end{end_ptr, end_ptr + end_shape[0]};
|
||||
std::vector<T> strides{strides_ptr, strides_ptr + stride_shape[0]};
|
||||
(void)std::transform(strides.begin(), strides.end(), std::back_inserter(strides_),
|
||||
[](const T &value) { return static_cast<int>(value); });
|
||||
(void)std::transform(end.begin(), end.end(), std::back_inserter(end_), [](const T &value) { return value; });
|
||||
if (strides_.size() != end_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimension of 'strides|end|output' must be equal, but got the dimension of "
|
||||
<< "'strides': " << strides_.size() << ", the dimension of 'end': " << end_.size()
|
||||
<< ", and the dimension of output: " << output_shape_.size();
|
||||
}
|
||||
FormatArgs(true);
|
||||
} else { // SliceGrad
|
||||
auto size_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 3);
|
||||
if (begin_shape.size() != 1 || size_shape.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the dimensions of 'begin', 'end' must be 1, but got the dimension of 'begin': "
|
||||
<< begin_shape.size() << ", and the dimension of 'end': " << size_shape.size();
|
||||
}
|
||||
auto size_ptr = reinterpret_cast<T *>(inputs[3]->addr);
|
||||
std::vector<T> size{size_ptr, size_ptr + size_shape[0]};
|
||||
(void)std::transform(size.begin(), size.end(), std::back_inserter(size_),
|
||||
[](const T &value) { return static_cast<int>(value); });
|
||||
if (size_.size() != output_shape_.size() || begin_.size() != output_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', 'begin|size|input' size must be equal, but got 'begin' size: " << begin_.size()
|
||||
<< ", 'size' size: " << size_.size() << " and 'input' size: " << output_shape_.size();
|
||||
}
|
||||
FormatArgs(false);
|
||||
}
|
||||
ExpandAllMemberDims(kSliceGradMaxInputShapeSize);
|
||||
|
||||
CPUKernelUtils::GetElementNumEveryDim(input_shape_, &input_element_num_);
|
||||
CPUKernelUtils::GetElementNumEveryDim(output_shape_, &output_element_num_);
|
||||
}
|
||||
|
||||
bool SliceGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
|
@ -216,17 +159,12 @@ bool SliceGradCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs
|
|||
template <typename T>
|
||||
bool SliceGradCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (is_dynamic_attr_ && !get_dynamic_attr_value_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of dynamic attr!";
|
||||
}
|
||||
auto *input_addr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
|
||||
// init params for not const inputs
|
||||
if (inputs.size() == kSliceGradDynamicInputsNum || inputs.size() == kStridedSliceGradDynamicInputsNum) {
|
||||
if (strides_dtype_ == kNumberTypeInt32) {
|
||||
InitParams<int32_t>(inputs);
|
||||
} else {
|
||||
InitParams<int64_t>(inputs);
|
||||
}
|
||||
}
|
||||
auto ret = memset_s(output_addr, outputs[0]->size, 0, outputs[0]->size);
|
||||
if (ret != EOK) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', output buff memset failed. Error no: " << ret;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
@ -30,7 +31,7 @@ constexpr auto kSliceGrad = "SliceGrad";
|
|||
constexpr auto kStridedSliceGrad = "StridedSliceGrad";
|
||||
constexpr auto kUnknown = "Unknown";
|
||||
|
||||
class SliceGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
||||
class SliceGradCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
SliceGradCpuKernelMod() = default;
|
||||
|
||||
|
@ -38,7 +39,11 @@ class SliceGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
|
||||
~SliceGradCpuKernelMod() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
@ -71,16 +76,23 @@ class SliceGradCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
bool SliceGrad8D(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs,
|
||||
T *input_addr, T *output_addr);
|
||||
|
||||
std::vector<int> begin_;
|
||||
std::vector<int> end_;
|
||||
std::vector<int> strides_;
|
||||
std::vector<int> size_;
|
||||
std::vector<int64_t> begin_;
|
||||
std::vector<int64_t> end_;
|
||||
std::vector<int64_t> strides_;
|
||||
std::vector<int64_t> size_;
|
||||
static constexpr size_t kShapexIndex_{1};
|
||||
static constexpr size_t kBeginIndex_{2};
|
||||
static constexpr size_t kEndIndex_{3};
|
||||
static constexpr size_t kStrideIndex_{4};
|
||||
static constexpr size_t kSizeIndex_{3};
|
||||
ShapeVector input_shape_;
|
||||
std::vector<size_t> input_element_num_;
|
||||
ShapeVector output_shape_;
|
||||
std::vector<size_t> output_element_num_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId strides_dtype_{kNumberTypeInt32};
|
||||
bool get_dynamic_attr_value_{false};
|
||||
bool is_dynamic_attr_{false};
|
||||
std::string kernel_type_{kUnknown};
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -15,32 +15,186 @@
|
|||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/slice_grad_gpu_kernel.h"
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SliceGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SliceGradGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SliceGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
SliceGradGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SliceGrad,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
SliceGradGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
SliceGradGpuKernelMod, int)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
SliceGradGpuKernelMod, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
SliceGradGpuKernelMod, uchar)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
SliceGrad, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
SliceGradGpuKernelMod, bool)
|
||||
namespace {
|
||||
void ShapeNdToMd(const ShapeVector &src, ShapeVector *dst, size_t nd_maximum_size) {
|
||||
if (src.size() > nd_maximum_size) {
|
||||
MS_LOG(ERROR) << src.size() << "-D data is not supported!";
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = nd_maximum_size; i > 0; --i) {
|
||||
dst->push_back(src.size() < i ? 1 : src[src.size() - i]);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
using SliceGradPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
template <typename T, typename S = int64_t>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateSliceKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::SliceGradHelperGpuKernel<T, S>>(kernel_name, device_id);
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, SliceGradPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSliceKernelPtr<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSliceKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSliceKernelPtr<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
CreateSliceKernelPtr<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
CreateSliceKernelPtr<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
CreateSliceKernelPtr<uchar>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
CreateSliceKernelPtr<bool>},
|
||||
};
|
||||
|
||||
std::vector<KernelAttr> SliceGradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SliceGradPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
SliceGradGpuKernelMod::SliceGradGpuKernelMod() : kernel_name_("SliceGrad") {
|
||||
attr_ptr_ = std::make_shared<cukernel::SliceGradAttr>();
|
||||
}
|
||||
|
||||
bool SliceGradGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
if (is_dynamic_attr_ && !get_dynamic_attr_value_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of the dynamic attr!";
|
||||
}
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
return helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) == 0;
|
||||
}
|
||||
|
||||
bool SliceGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::SliceGrad>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
|
||||
(void)CheckParam(inputs, outputs);
|
||||
|
||||
if (!is_dynamic_attr_) {
|
||||
auto begin_value = kernel_ptr->GetAttr(ops::kBegin);
|
||||
begin_ = GetValue<std::vector<int64_t>>(begin_value);
|
||||
auto size_value = kernel_ptr->GetAttr(ops::kSize);
|
||||
size_ = GetValue<std::vector<int64_t>>(size_value);
|
||||
ProccessAttr(inputs);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int SliceGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
if (is_dynamic_attr_) {
|
||||
if (GetDynamicAttrIntValue(inputs, kBeginIndex_, inputsOnHost, kernel_name_, &begin_) &&
|
||||
GetDynamicAttrIntValue(inputs, kSizeIndex_, inputsOnHost, kernel_name_, &size_)) {
|
||||
get_dynamic_attr_value_ = true;
|
||||
ProccessAttr(inputs);
|
||||
}
|
||||
}
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::transform(inputs.begin(), inputs.end(), std::back_inserter(input_shapes),
|
||||
[](const KernelTensorPtr &input) { return input->GetDeviceShapeAdaptively(); });
|
||||
std::vector<int64_t> out_shape = outputs[0]->GetDeviceShapeAdaptively();
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
void SliceGradGpuKernelMod::ProccessAttr(const std::vector<KernelTensorPtr> &inputs) {
|
||||
auto input_shape = inputs[1]->GetShapeVector();
|
||||
auto data_format = inputs[1]->GetFormat();
|
||||
auto dy_shape = inputs[0]->GetShapeVector();
|
||||
if (dy_shape.size() <= kSliceGradDefaultInputShapeSize) {
|
||||
ShapeNdToMd(dy_shape, &dy_shape_, kDim4);
|
||||
CalcBeginAndSize(data_format, kSliceGradDefaultInputShapeSize);
|
||||
} else {
|
||||
ShapeNdToMd(dy_shape, &dy_shape_, kDim7);
|
||||
CalcBeginAndSize(data_format, kSliceGradMaxInputShapeSize);
|
||||
}
|
||||
if (input_shape.size() <= kSliceGradDefaultInputShapeSize) {
|
||||
ShapeNdToMd(input_shape, &input_shape_, kDim4);
|
||||
} else {
|
||||
ShapeNdToMd(input_shape, &input_shape_, kDim7);
|
||||
}
|
||||
attr_ptr_->size = size_;
|
||||
attr_ptr_->begin = begin_;
|
||||
attr_ptr_->input_shape = input_shape_;
|
||||
int64_t output_num = std::accumulate(dy_shape_.begin(), dy_shape_.end(), 1, std::multiplies<int64_t>());
|
||||
attr_ptr_->output_num = output_num;
|
||||
}
|
||||
|
||||
void SliceGradGpuKernelMod::CalcBeginAndSize(const mindspore::Format &data_format, size_t dim) {
|
||||
for (auto i = begin_.size(); i < dim; i++) {
|
||||
(void)begin_.insert(begin_.begin(), 0);
|
||||
}
|
||||
for (auto i = size_.size(); i < dim; i++) {
|
||||
(void)size_.insert(size_.begin(), 1);
|
||||
}
|
||||
if (dim == kSliceGradDefaultInputShapeSize && data_format == mindspore::Format::NHWC) {
|
||||
std::swap(begin_[1], begin_[3]);
|
||||
std::swap(begin_[1], begin_[2]);
|
||||
std::swap(size_[1], size_[3]);
|
||||
std::swap(size_[1], size_[2]);
|
||||
}
|
||||
for (size_t i = 0; i != begin_.size(); ++i) {
|
||||
if (i < input_shape_.size() && begin_[i] < 0) {
|
||||
begin_[i] = begin_[i] + input_shape_[i];
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i != size_.size(); ++i) {
|
||||
if (i < input_shape_.size() && size_[i] < 0) {
|
||||
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SliceGradGpuKernelMod::CheckParam(const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
size_t output_num = outputs.size();
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
auto input_shape = inputs[0]->GetShapeVector();
|
||||
if (input_shape.size() > kSliceGradMaxInputShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than 7, but got "
|
||||
<< input_shape.size();
|
||||
}
|
||||
if (inputs.size() == DynamicInputNum) {
|
||||
is_dynamic_attr_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SliceGrad, SliceGradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,159 +14,52 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ops/op_name.h"
|
||||
#include "mindspore/core/ops/grad/slice_grad.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/slice_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/slice_grad_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kSliceGradDefaultInputShapeSize = 4;
|
||||
constexpr size_t kSliceGradMaxInputShapeSize = 7;
|
||||
constexpr size_t DynamicInputNum = 4;
|
||||
constexpr size_t kBeginIndex_ = 2;
|
||||
constexpr size_t kSizeIndex_ = 3;
|
||||
constexpr size_t kDim4 = 4;
|
||||
constexpr size_t kDim7 = 7;
|
||||
|
||||
template <typename T>
|
||||
class SliceGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class SliceGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
SliceGradGpuKernelMod()
|
||||
: is_strided_slice_(false),
|
||||
is_null_input_(false),
|
||||
input_size_(0),
|
||||
output_size_(0),
|
||||
workspace_size_(0),
|
||||
kernel_name_("SliceGrad") {}
|
||||
SliceGradGpuKernelMod();
|
||||
~SliceGradGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *dy = GetDeviceAddress<T>(inputs, 0);
|
||||
T *dx = GetDeviceAddress<T>(outputs, 0);
|
||||
FillDeviceArray(outputs[0]->size / sizeof(T), dx, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
if (input_shape_.size() <= kSliceGradDefaultInputShapeSize) {
|
||||
CalSlice4DGrad(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3],
|
||||
input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], dy, dx,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CalSlice7DGrad(begin_[0], begin_[1], begin_[2], begin_[3], begin_[4], begin_[5], begin_[6], size_[0], size_[1],
|
||||
size_[2], size_[3], size_[4], size_[5], size_[6], input_shape_[0], input_shape_[1],
|
||||
input_shape_[2], input_shape_[3], input_shape_[4], input_shape_[5], input_shape_[6], dy, dx,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
(void)CheckParam(kernel_node);
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
|
||||
if (kernel_name == "StridedSliceGrad") {
|
||||
is_strided_slice_ = true;
|
||||
std::vector<int64_t> shapex = GetAttr<std::vector<int64_t>>(kernel_node, "shapex");
|
||||
for (auto x : shapex) {
|
||||
input_shape_.push_back(x);
|
||||
}
|
||||
for (auto i = input_shape_.size(); i < kSliceGradDefaultInputShapeSize; i++) {
|
||||
(void)input_shape_.insert(input_shape_.begin(), 1);
|
||||
}
|
||||
strides_ = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
|
||||
for (auto i = strides_.size(); i < kSliceGradDefaultInputShapeSize; i++) {
|
||||
(void)strides_.insert(strides_.begin(), 1);
|
||||
}
|
||||
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "end");
|
||||
} else {
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape, kernel_name_, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (input_shape.size() <= kSliceGradDefaultInputShapeSize) {
|
||||
ShapeNdTo4d(input_shape, &input_shape_);
|
||||
} else {
|
||||
ShapeNdTo7d(input_shape, &input_shape_);
|
||||
}
|
||||
size_ = GetAttr<std::vector<int64_t>>(kernel_node, "size");
|
||||
}
|
||||
auto dy_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(dy_shape, kernel_name_, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
if (dy_shape.size() <= kSliceGradDefaultInputShapeSize) {
|
||||
ShapeNdTo4d(dy_shape, &dy_shape_);
|
||||
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
|
||||
CalcBeginAndSize(data_format, kSliceGradDefaultInputShapeSize);
|
||||
} else {
|
||||
ShapeNdTo7d(dy_shape, &dy_shape_);
|
||||
begin_ = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
|
||||
CalcBeginAndSize(data_format, kSliceGradMaxInputShapeSize);
|
||||
}
|
||||
input_size_ = sizeof(T);
|
||||
for (auto shape : input_shape_) {
|
||||
input_size_ = input_size_ * static_cast<size_t>(shape);
|
||||
}
|
||||
output_size_ = sizeof(T);
|
||||
for (auto x : dy_shape_) {
|
||||
output_size_ = output_size_ * static_cast<size_t>(x);
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(output_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(input_size_);
|
||||
}
|
||||
void CalcBeginAndSize(const std::string &data_format, size_t dim = 4) {
|
||||
for (auto i = begin_.size(); i < dim; i++) {
|
||||
(void)begin_.insert(begin_.begin(), 0);
|
||||
}
|
||||
for (auto i = size_.size(); i < dim; i++) {
|
||||
(void)size_.insert(size_.begin(), 1);
|
||||
}
|
||||
if (dim == kSliceGradDefaultInputShapeSize && data_format == "NHWC") {
|
||||
std::swap(begin_[1], begin_[3]);
|
||||
std::swap(begin_[1], begin_[2]);
|
||||
std::swap(size_[1], size_[3]);
|
||||
std::swap(size_[1], size_[2]);
|
||||
}
|
||||
for (size_t i = 0; i < begin_.size(); i++) {
|
||||
if (begin_[i] < 0 && i < input_shape_.size()) {
|
||||
begin_[i] = begin_[i] + input_shape_[i];
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < size_.size(); i++) {
|
||||
if (size_[i] < 0 && i < input_shape_.size()) {
|
||||
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > kSliceGradMaxInputShapeSize) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than 7, but got "
|
||||
<< input_shape.size();
|
||||
}
|
||||
}
|
||||
void ProccessAttr(const std::vector<KernelTensorPtr> &inputs);
|
||||
void CalcBeginAndSize(const mindspore::Format &data_format, size_t dim = kDim4);
|
||||
void CheckParam(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs);
|
||||
|
||||
std::vector<int64_t> begin_;
|
||||
std::vector<int64_t> size_;
|
||||
|
@ -174,14 +67,13 @@ class SliceGradGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
ShapeVector input_shape_;
|
||||
ShapeVector dy_shape_;
|
||||
|
||||
bool is_strided_slice_;
|
||||
bool is_null_input_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
bool is_dynamic_attr_{false};
|
||||
bool get_dynamic_attr_value_{false};
|
||||
std::string kernel_name_;
|
||||
}; // namespace kernel
|
||||
std::shared_ptr<cukernel::SliceGradAttr> attr_ptr_{nullptr};
|
||||
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SLICE_GRAD_GPU_KERNEL_H_
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SLICE_GRAD_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SLICE_GRAD_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/slice_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
constexpr size_t kSliceGradDefaultInputShapeSize = 4;
|
||||
constexpr size_t kSliceGradMaxInputShapeSize = 7;
|
||||
constexpr size_t kDim4 = 4;
|
||||
constexpr size_t kDim7 = 7;
|
||||
class SliceGradAttr : public GpuKernelAttrBase {
|
||||
public:
|
||||
SliceGradAttr() = default;
|
||||
~SliceGradAttr() override = default;
|
||||
std::vector<int64_t> begin;
|
||||
std::vector<int64_t> size;
|
||||
std::vector<int64_t> input_shape;
|
||||
int64_t output_num;
|
||||
};
|
||||
|
||||
template <typename T, typename S>
|
||||
class SliceGradHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit SliceGradHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {}
|
||||
|
||||
virtual ~SliceGradHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) {
|
||||
ResetResource();
|
||||
input_size_ = sizeof(T);
|
||||
for (auto shape : attr_ptr_->input_shape) {
|
||||
input_size_ = input_size_ * static_cast<size_t>(shape);
|
||||
}
|
||||
size_t output_size = sizeof(T) * attr_ptr_->output_num;
|
||||
input_size_list_.push_back(output_size);
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(input_size_);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *stream_ptr) override {
|
||||
T *dy = nullptr;
|
||||
T *dx = nullptr;
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &dy);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &dx);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
FillDeviceArray(input_size_ / sizeof(T), dx, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
auto &input_shape = attr_ptr_->input_shape;
|
||||
auto &begin = attr_ptr_->begin;
|
||||
auto &size = attr_ptr_->size;
|
||||
if (input_shape.size() <= kSliceGradDefaultInputShapeSize) {
|
||||
CalSlice4DGrad(begin[0], begin[1], begin[2], begin[3], size[0], size[1], size[2], size[3], input_shape[0],
|
||||
input_shape[1], input_shape[2], input_shape[3], dy, dx,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CalSlice7DGrad(begin[0], begin[1], begin[2], begin[3], begin[4], begin[5], begin[6], size[0], size[1], size[2],
|
||||
size[3], size[4], size[5], size[6], input_shape[0], input_shape[1], input_shape[2], input_shape[3],
|
||||
input_shape[4], input_shape[5], input_shape[6], dy, dx,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||
attr_ptr_ = std::dynamic_pointer_cast<SliceGradAttr>(kernel_attr);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
std::shared_ptr<SliceGradAttr> attr_ptr_{nullptr};
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SLICE_GRAD_HELPER_H_
|
Loading…
Reference in New Issue