!1782 GPU slice support null output

Merge pull request !1782 from VectorSL/slice-null
This commit is contained in:
mindspore-ci-bot 2020-06-02 15:54:36 +08:00 committed by Gitee
commit 39ff4572ce
1 changed files with 11 additions and 2 deletions

View File

@ -27,7 +27,8 @@ namespace kernel {
template <typename T>
class SliceGpuFwdKernel : public GpuKernel {
public:
SliceGpuFwdKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
SliceGpuFwdKernel()
: is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
~SliceGpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -35,6 +36,9 @@ class SliceGpuFwdKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
if (is_strided_slice_) {
@ -79,7 +83,11 @@ class SliceGpuFwdKernel : public GpuKernel {
if (size_[i] < 0) {
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
}
if (size_[i] == 0) {
if (begin_[i] == size_[i] && is_strided_slice_) {
MS_LOG(WARNING) << "Output is null.";
is_null_input_ = true;
}
if (size_[i] == 0 && strides_[i] > 0) {
size_[i] = begin_[i] + 1;
}
}
@ -143,6 +151,7 @@ class SliceGpuFwdKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_;
bool is_strided_slice_;
bool is_null_input_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;