[DynamicShape][GPU]add dynamic shape support of Concat and its backward for DCN

This commit is contained in:
hanhuifeng2020 2022-02-22 10:38:32 +08:00
parent 81260a2319
commit 662c51c019
7 changed files with 492 additions and 49 deletions

View File

@ -75,7 +75,7 @@ class ConcatV2FwdGpuKernelMod : public NativeGpuKernelMod {
if (!CheckParam(kernel_node)) {
return false;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
int dims = SizeToInt(input_shape.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < -dims || axis_ >= dims) {
@ -95,7 +95,7 @@ class ConcatV2FwdGpuKernelMod : public NativeGpuKernelMod {
int current_dim = 0;
for (int i = 0; i < input_num_; i++) {
size_t input_size = 1;
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) {
input_size *= input_shape[j];
}
@ -128,6 +128,18 @@ class ConcatV2FwdGpuKernelMod : public NativeGpuKernelMod {
return true;
}
void ResetResource() noexcept override {
ResetSizeLists();
axis_ = 0;
input_num_ = 1;
output_size_ = 0;
all_size_before_axis_ = 1;
all_size_axis_ = 1;
kernel_name_ = "ConcatV2";
inputs_host_ = nullptr;
len_axis_ = nullptr;
}
protected:
void InitSizeLists() override {}

View File

@ -18,21 +18,34 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SliceFwdGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SliceFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SliceFwdGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SliceFwdGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SliceFwdGpuKernelMod, int32_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
SliceFwdGpuKernelMod, int16_t)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SliceFwdGpuKernelMod, uchar)
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
SliceFwdGpuKernelMod, bool)
#define REG_SLICE_GPU(MS_DTYPE, DTYPE) \
MS_REG_GPU_KERNEL_ONE(Slice, KernelAttr().AddInputAttr(MS_DTYPE).AddOutputAttr(MS_DTYPE), SliceFwdGpuKernelMod, DTYPE)
#define REG_SLICE_GPU_DTYPES(F) \
F(kNumberTypeFloat64, double) \
F(kNumberTypeFloat32, float) \
F(kNumberTypeFloat16, half) \
F(kNumberTypeInt64, int64_t) \
F(kNumberTypeInt32, int32_t) \
F(kNumberTypeInt16, int16_t) \
F(kNumberTypeUInt8, uchar) \
F(kNumberTypeBool, bool)
REG_SLICE_GPU_DTYPES(REG_SLICE_GPU)
#define REG_DYNAMIC_SLICE_GPU_ATTR(T0_MS_DTYPE, T0_DTYPE, T1_MS_DTYPE) \
MS_REG_GPU_KERNEL_ONE(Slice, \
KernelAttr() \
.AddInputAttr(T0_MS_DTYPE) \
.AddInputAttr(T1_MS_DTYPE) \
.AddInputAttr(T1_MS_DTYPE) \
.AddOutputAttr(T0_MS_DTYPE), \
SliceFwdGpuKernelMod, T0_DTYPE)
#define REG_DYNAMIC_SLICE_GPU(MS_DTYPE, DTYPE) \
REG_DYNAMIC_SLICE_GPU_ATTR(MS_DTYPE, DTYPE, kNumberTypeInt32) \
REG_DYNAMIC_SLICE_GPU_ATTR(MS_DTYPE, DTYPE, kNumberTypeInt64)
REG_SLICE_GPU_DTYPES(REG_DYNAMIC_SLICE_GPU)
} // namespace kernel
} // namespace mindspore

View File

@ -14,9 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SLICE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SLICE_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <utility>
@ -44,8 +43,7 @@ constexpr auto kIdx6 = 6;
template <typename T>
class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
public:
SliceFwdGpuKernelMod()
: is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0), kernel_name_("Slice") {}
SliceFwdGpuKernelMod() { kernel_name_ = "Slice"; }
~SliceFwdGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
@ -56,6 +54,9 @@ class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
if (is_dynamic_attr_ && !get_dynamic_attr_value_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of the dynamic attr!";
}
size_t input_rank = input_shape_.size();
switch (input_rank) {
@ -105,8 +106,8 @@ class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
(void)CheckParam(kernel_node);
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto out_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
auto out_shape = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
is_null_input_ =
CHECK_SHAPE_NULL(input_shape, kernel_name_, "input") || CHECK_SHAPE_NULL(out_shape, kernel_name_, "output");
if (is_null_input_) {
@ -116,24 +117,36 @@ class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(input_shape_),
[](const int64_t &e) { return static_cast<int32_t>(e); });
input_size_ = sizeof(T);
size_t input_size = sizeof(T);
for (size_t x : input_shape) {
input_size_ *= x;
input_size *= x;
}
input_size_list_.push_back(input_size);
if (is_dynamic_attr_) {
std::vector<size_t> dynamic_attr_indexs = {kBeginIndex_, kSizeIndex_};
for (size_t index : dynamic_attr_indexs) {
input_size = sizeof(T);
for (size_t x : AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, index)) {
input_size *= x;
}
input_size_list_.push_back(input_size);
}
}
output_size_ = sizeof(T);
size_t output_size = sizeof(T);
for (size_t x : out_shape) {
output_size_ *= x;
output_size *= x;
}
output_size_list_.push_back(output_size);
// transpose begin and size for NHWC data
auto data_format = AnfAlgo::GetInputFormat(kernel_node, 0);
if (data_format == "NHWC") {
if (data_format == kOpFormat_NHWC) {
std::swap(begin_[1], begin_[kIdx3]);
std::swap(begin_[1], begin_[kIdx2]);
std::swap(size_[1], size_[kIdx3]);
std::swap(size_[1], size_[kIdx2]);
} else if (data_format == "NDHWC") {
} else if (data_format == kOpFormat_NDHWC) {
std::swap(begin_[1], begin_[kIdx4]);
std::swap(begin_[1], begin_[kIdx3]);
std::swap(begin_[1], begin_[kIdx2]);
@ -147,17 +160,30 @@ class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
void ResetResource() noexcept override {
ResetSizeLists();
begin_.clear();
size_.clear();
input_shape_.clear();
is_null_input_ = false;
kernel_name_ = "Slice";
is_dynamic_attr_ = false;
get_dynamic_attr_value_ = false;
}
protected:
void InitSizeLists() override {}
private:
void CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 1, but got " << input_num;
constexpr size_t kDynamicSliceInputNum = 3;
if (input_num != 1 && input_num != kDynamicSliceInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs should be 1 or " << kDynamicSliceInputNum
<< ", but got " << input_num;
}
if (input_num == kDynamicSliceInputNum) {
is_dynamic_attr_ = true;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
@ -173,8 +199,19 @@ class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be equal to 0, but got "
<< input_shape.size();
}
auto size = GetAttr<std::vector<int64_t>>(kernel_node, "size");
auto begin = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
std::vector<int64_t> size, begin;
if (!is_dynamic_attr_) {
size = GetAttr<std::vector<int64_t>>(kernel_node, "size");
begin = GetAttr<std::vector<int64_t>>(kernel_node, "begin");
} else {
// The value of dynamic attr can only be obtained after the InferShape() of dynamic kernel is executed
if (DynamicKernel() == nullptr) {
return;
}
begin = GetDynamicAttrIntValue(kernel_node, kBeginIndex_);
size = GetDynamicAttrIntValue(kernel_node, kSizeIndex_);
get_dynamic_attr_value_ = true;
}
if (size.size() != input_shape.size() || begin.size() != input_shape.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
@ -182,9 +219,9 @@ class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
<< "of size: " << size.size() << ", the dimension of begin: " << begin.size()
<< ", the dimension of input_x: " << input_shape.size();
}
const int64_t kDynamicShape = -1;
const int64_t NEG_ONE = -1;
for (size_t i = 0; i < input_shape.size(); i++) {
if (size[i] == kDynamicShape) {
if (size[i] == NEG_ONE) {
size[i] = input_shape[i] - begin[i];
}
if (input_shape[i] <= 0 || size[i] <= 0) {
@ -207,14 +244,13 @@ class SliceFwdGpuKernelMod : public NativeGpuKernelMod {
std::vector<int32_t> size_;
std::vector<int32_t> input_shape_;
bool is_null_input_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
std::string kernel_name_;
bool is_null_input_{false};
bool is_dynamic_attr_{false};
bool get_dynamic_attr_value_{false};
static constexpr size_t kBeginIndex_{1};
static constexpr size_t kSizeIndex_{2};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_SLICE_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_SLICE_GPU_KERNEL_H_

View File

@ -80,6 +80,8 @@ class GpuDynamicKernel : public device::DynamicKernel {
void UpdateArgs() override;
void PostExecute() final { MS_LOG(EXCEPTION) << "`PostExecute()` should not invoked with gpu backend"; };
void Execute() final { MS_LOG(EXCEPTION) << "`Execute()` should not invoked with gpu backend"; }
std::map<uint32_t, tensor::TensorPtr> GetDependTensorMap() { return depend_tensor_map_; }
};
class NativeGpuKernelMod : public GpuKernelMod {
@ -100,6 +102,12 @@ class NativeGpuKernelMod : public GpuKernelMod {
virtual void InitSizeLists() = 0;
std::weak_ptr<CNode> kernel_node_;
inline void ResetSizeLists() {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
template <typename T>
inline T *GetDeviceAddress(const std::vector<AddressPtr> &addr_list, size_t index) {
if (index >= addr_list.size()) {
@ -321,7 +329,66 @@ class NativeGpuKernelMod : public GpuKernelMod {
return type->second;
}
device::DynamicKernelPtr dynamic_kernel_;
inline std::map<uint32_t, tensor::TensorPtr> GetDependTensorMap() {
auto gpu_dynamic_kernel = dynamic_cast<GpuDynamicKernel *>(DynamicKernel().get());
if (gpu_dynamic_kernel != nullptr) {
return gpu_dynamic_kernel->GetDependTensorMap();
}
return {};
}
inline std::vector<int64_t> GetTensorIntValue(const tensor::TensorPtr input_tensor, const size_t input_index) {
std::vector<int64_t> tensor_value;
MS_EXCEPTION_IF_NULL(input_tensor);
size_t data_size = input_tensor->DataSize();
auto tensor_type = input_tensor->Dtype();
if (tensor_type->type_id() == kNumberTypeInt32) {
auto tensor_data = reinterpret_cast<int32_t *>(input_tensor->data_c());
MS_EXCEPTION_IF_NULL(tensor_data);
tensor_value.assign(tensor_data, tensor_data + data_size);
} else if (tensor_type->type_id() == kNumberTypeInt64) {
auto tensor_data = reinterpret_cast<int64_t *>(input_tensor->data_c());
MS_EXCEPTION_IF_NULL(tensor_data);
tensor_value.assign(tensor_data, tensor_data + data_size);
} else {
MS_EXCEPTION(TypeError) << "For '" << kernel_name_ << "', the " << input_index
<< "th input must be a Tensor[Int64] or Tensor[Int32] type, but got "
<< input_tensor->ToString();
}
return tensor_value;
}
inline bool ShapeEqual(const std::vector<size_t> &s1, const std::vector<int64_t> &s2) {
std::vector<size_t> s2_trans;
std::transform(s2.begin(), s2.end(), std::back_inserter(s2_trans), [](const int64_t &e) { return LongToSize(e); });
return std::equal(s1.begin(), s1.end(), s2_trans.begin(), s2_trans.end());
}
inline std::vector<int64_t> GetDynamicAttrIntValue(const CNodePtr &kernel_node, const size_t input_index) {
const auto &depend_tensor_map = GetDependTensorMap();
if (depend_tensor_map.empty()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the depend_tensor_map is empty!";
}
auto depend_iter = depend_tensor_map.find(input_index);
if (depend_iter == depend_tensor_map.end()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to find the " << input_index
<< "th input in the depend_tensor_map";
}
auto input_tensor = depend_iter->second;
const auto &input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_index);
if (!ShapeEqual(input_shape, input_tensor->shape())) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the " << input_index
<< "th input is different between the InferShape and the TensorShape";
}
const auto &data_format = AnfAlgo::GetInputFormat(kernel_node, input_index);
if (data_format != kOpFormat_DEFAULT && data_format != kOpFormat_NCHW) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the format of the " << input_index
<< "th input currently should be the default format and does not support " << data_format;
}
return GetTensorIntValue(input_tensor, input_index);
}
device::DynamicKernelPtr dynamic_kernel_{nullptr};
};
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,47 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/other/concat_offset_gpu_kernel.h"
namespace mindspore {
namespace kernel {
#define REG_CONCAT_OFFSET_GPU_INPUT_OUTPUT(T0_MS_DTYPE, T0_DTYPE, T1_MS_DTYPE, T1_DTYPE) \
MS_REG_GPU_KERNEL_TWO(ConcatOffset, \
KernelAttr().AddAllSameAttr(true).AddInputAttr(T0_MS_DTYPE).AddOutputAttr(T1_MS_DTYPE), \
ConcatOffsetGpuKernelMod, T0_DTYPE, T1_DTYPE)
#define REG_CONCAT_OFFSET_GPU_INPUT(MS_DTYPE, DTYPE) \
REG_CONCAT_OFFSET_GPU_INPUT_OUTPUT(MS_DTYPE, DTYPE, kNumberTypeInt64, int64_t) \
REG_CONCAT_OFFSET_GPU_INPUT_OUTPUT(MS_DTYPE, DTYPE, kNumberTypeInt32, int32_t)
#define REG_CONCAT_OFFSET_GPU_INPUT_FLOAT(F) \
F(kNumberTypeFloat64, double) F(kNumberTypeFloat32, float) F(kNumberTypeFloat16, half)
#define REG_CONCAT_OFFSET_GPU_INPUT_INT(F) \
F(kNumberTypeInt64, int64_t) \
F(kNumberTypeInt32, int32_t) \
F(kNumberTypeInt16, int16_t) \
F(kNumberTypeInt8, char) \
F(kNumberTypeUInt64, uint64_t) \
F(kNumberTypeUInt32, uint32_t) \
F(kNumberTypeUInt16, uint16_t) \
F(kNumberTypeUInt8, uchar) \
F(kNumberTypeBool, bool)
REG_CONCAT_OFFSET_GPU_INPUT_FLOAT(REG_CONCAT_OFFSET_GPU_INPUT)
REG_CONCAT_OFFSET_GPU_INPUT_INT(REG_CONCAT_OFFSET_GPU_INPUT)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,136 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_OTHER_CONCAT_OFFSET_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_OTHER_CONCAT_OFFSET_GPU_KERNEL_H_
#include <vector>
#include <string>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T, typename S>
class ConcatOffsetGpuKernelMod : public NativeGpuKernelMod {
public:
ConcatOffsetGpuKernelMod() {}
~ConcatOffsetGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
S *output_device_address = GetDeviceAddress<S>(outputs, 0);
size_t out_size = out_offset_.size() * sizeof(S);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(output_device_address, out_offset_.data(), out_size,
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync error in ConcatOffsetGpuKernelMod::Launch");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
if (!CheckParam(kernel_node)) {
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto rank = input_shape.size();
auto rank_int = SizeToInt(rank);
auto axis = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis < -rank_int || axis >= rank_int) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'axis' should be in the range [-" << rank << "," << rank
<< "), but got " << axis;
}
if (axis < 0) {
axis += rank_int;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num == 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of input should be greater than 0";
}
for (size_t i = 0; i < input_num; i++) {
size_t input_size = 1;
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) {
input_size *= input_shape[j];
}
input_size_list_.push_back(input_size * sizeof(T));
}
// cal offset
size_t shape_offset = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0)[axis];
std::vector<size_t> offset(input_num, 0);
for (size_t i = 1; i < input_num; i++) {
input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
if (input_shape.size() != rank) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << " the dimension of input should be equal, but got:"
<< " the dimension of the " << i << "'th input: " << input_shape.size()
<< " and the dimension of the first input: " << rank;
}
offset[i] = shape_offset;
shape_offset += input_shape[axis];
}
constexpr size_t kConcatOffsetOutputShapeSize = 2;
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
if (output_shape.size() != kConcatOffsetOutputShapeSize) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output should be "
<< kConcatOffsetOutputShapeSize << ", but got:" << output_shape.size();
}
if (output_shape[0] != input_num) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the first dimension value of output should be equal to "
"the number of input, but got the first dimension value of output: "
<< output_shape[0] << ", and the number of input: " << input_num;
}
if (output_shape[1] != rank) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the second dimension value of output should be equal to "
"the dimension of input, but got the second dimension value of output: "
<< output_shape[1] << ", and the dimension of input: " << rank;
}
auto output_size = input_num * rank;
out_offset_.assign(output_size, 0);
for (size_t i = 0; i < input_num; ++i) {
out_offset_[i * rank + axis] = offset[i];
}
output_size_list_.push_back(out_offset_.size() * sizeof(S));
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
ResetSizeLists();
out_offset_.clear();
}
protected:
void InitSizeLists() override {}
private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs should be 1, but got " << output_num;
}
return true;
}
std::vector<S> out_offset_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_OTHER_CONCAT_OFFSET_GPU_KERNEL_H_

View File

@ -0,0 +1,132 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import ops, nn, ParameterTuple, context, set_seed
from mindspore.train import DatasetHelper, connect_network_with_dataset
import mindspore.dataset as ds
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
set_seed(2)
def _exec_preprocess(network, is_train, dataset, dataset_sink_mode, epoch_num, sink_size):
if dataset_sink_mode and not is_train:
dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(
dataset, dataset_sink_mode, sink_size, epoch_num)
if dataset_sink_mode:
network = connect_network_with_dataset(network, dataset_helper)
return dataset_helper, network
def dynamic_shape_sink_process(network, dataset, is_train=True):
dataset_sink_mode = True
sink_size = 1
epoch_num = 1
dataset_helper, network = _exec_preprocess(
network, is_train, dataset, dataset_sink_mode, epoch_num, sink_size)
network.set_train(is_train)
for inputs in dataset_helper:
outputs = network(*inputs)
return outputs
def fixed_shape_process(network, dataset, is_train=True):
network.set_train(is_train)
for inputs in dataset.create_tuple_iterator():
outputs = network(*inputs)
return outputs
def dataset_generator(data_list):
for data in data_list:
yield data
def get_columns(tensor_num):
columns = []
for i in range(tensor_num):
columns.append("data" + str(i))
return columns
def compare(output, expect):
if isinstance(output, (tuple, list)):
assert isinstance(expect, (tuple, list))
for output_, expect_ in zip(output, expect):
if not compare(output_, expect_):
return False
else:
if not np.allclose(output.asnumpy(), expect.asnumpy(), rtol=1.0e-4, atol=1.0e-4):
return False
return True
class GradNetWrtX(nn.Cell):
def __init__(self, net):
super(GradNetWrtX, self).__init__()
self.net = net
self.grad_op = ops.GradOperation(
get_all=True, get_by_list=True, sens_param=True)
self.params = ParameterTuple(net.trainable_params())
def construct(self, *inputs):
gradient_function = self.grad_op(self.net, self.params)
return gradient_function(*inputs)
class ConcatNet(nn.Cell):
def __init__(self, axis):
super(ConcatNet, self).__init__()
self.op = ops.Concat(axis)
def construct(self, x1, x2):
return self.op((x1, x2))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_dynamic_concat():
"""
Feature: Test Concat and its backward.
Description: The shape of inputs is dynamic.
Expectation: Assert that results are consistent with fixed shape.
"""
axis = 1
dtype = np.float32
data_list = []
for i in [2, 64]:
data = []
data.append(np.random.rand(i, 16).astype(dtype))
data.append(np.random.rand(i, 32).astype(dtype))
data.append(np.random.rand(i, 48).astype(dtype))
data_list.append(tuple(data))
column_names = get_columns(len(data_list[0]))
dataset = ds.GeneratorDataset(data_list, column_names, shuffle=False)
dataset.set_dynamic_columns(
columns={column_names[0]: [None, 16], column_names[1]: [None, 32], column_names[2]: [None, 48]})
net = GradNetWrtX(ConcatNet(axis))
gradients = dynamic_shape_sink_process(net, dataset)
gradients_cmp = fixed_shape_process(net, dataset)
assert compare(gradients, gradients_cmp)