!1782 GPU slice support null output
Merge pull request !1782 from VectorSL/slice-null
This commit is contained in:
commit
39ff4572ce
|
@ -27,7 +27,8 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class SliceGpuFwdKernel : public GpuKernel {
|
||||
public:
|
||||
SliceGpuFwdKernel() : is_strided_slice_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
|
||||
SliceGpuFwdKernel()
|
||||
: is_strided_slice_(false), is_null_input_(false), input_size_(0), output_size_(0), workspace_size_(0) {}
|
||||
~SliceGpuFwdKernel() override = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
|
@ -35,6 +36,9 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
if (is_strided_slice_) {
|
||||
|
@ -79,7 +83,11 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
if (size_[i] < 0) {
|
||||
size_[i] = (size_[i] + input_shape_[i]) > 0 ? (size_[i] + input_shape_[i]) : 0;
|
||||
}
|
||||
if (size_[i] == 0) {
|
||||
if (begin_[i] == size_[i] && is_strided_slice_) {
|
||||
MS_LOG(WARNING) << "Output is null.";
|
||||
is_null_input_ = true;
|
||||
}
|
||||
if (size_[i] == 0 && strides_[i] > 0) {
|
||||
size_[i] = begin_[i] + 1;
|
||||
}
|
||||
}
|
||||
|
@ -143,6 +151,7 @@ class SliceGpuFwdKernel : public GpuKernel {
|
|||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
bool is_strided_slice_;
|
||||
bool is_null_input_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
|
|
Loading…
Reference in New Issue