From 790118a4e875d6d388e0e84db0fe1a55694f9252 Mon Sep 17 00:00:00 2001 From: hanhuifeng2020 Date: Sat, 19 Mar 2022 18:15:20 +0800 Subject: [PATCH] [Dynamic][Gpu]fix some dynamic operators and add test cases --- .../kernel/arrays/broadcast_to_gpu_kernel.h | 17 ++- .../other/dynamic_broadcastto_gpu_kernel.cc | 50 +++---- .../other/dynamic_broadcastto_gpu_kernel.h | 133 ------------------ .../kernel/other/dynamic_stitch_gpu_kernel.cc | 13 +- .../kernel/other/dynamic_stitch_gpu_kernel.h | 3 +- mindspore/core/ops/concat.cc | 4 +- tests/st/ops/gpu/test_dynamic_ops.py | 79 ++++++++++- 7 files changed, 123 insertions(+), 176 deletions(-) delete mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.h diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.h index 0e19c45ff5e..0efb435c413 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.h @@ -47,8 +47,8 @@ class BroadcastToGpuKernelMod : public NativeGpuKernelMod { } bool Init(const CNodePtr &kernel_node) override { kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - auto input_shapes = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto output_shapes = common::AnfAlgo::GetOutputInferShape(kernel_node, 0); + auto input_shapes = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0); + auto output_shapes = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0); kernel_node_ = kernel_node; is_null_input_ = CHECK_SHAPE_NULL(input_shapes, kernel_name_, "input") || CHECK_SHAPE_NULL(output_shapes, kernel_name_, "output"); @@ -82,6 +82,15 @@ class BroadcastToGpuKernelMod : public NativeGpuKernelMod { return true; } + void ResetResource() noexcept override { + ResetSizeLists(); + for (size_t i = 0; i < SHAPE_SIZE; ++i) { + input_shape_[i] = 1; + output_shape_[i] = 1; + } + is_null_input_ = false; + } + protected: void InitSizeLists() override { input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T)); @@ -89,8 +98,8 @@ class BroadcastToGpuKernelMod : public NativeGpuKernelMod { } private: - size_t input_shape_[4] = {1, 1, 1, 1}; - size_t output_shape_[4] = {1, 1, 1, 1}; + size_t input_shape_[SHAPE_SIZE] = {1, 1, 1, 1}; + size_t output_shape_[SHAPE_SIZE] = {1, 1, 1, 1}; bool is_null_input_ = false; std::string kernel_name_; }; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.cc index 841d63b5298..f26fabd5462 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.cc @@ -13,57 +13,57 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.h" +#include "plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.h" namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_TWO( +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), - DynamicBroadcastToGpuKernelMod, double, int64_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, double) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), - DynamicBroadcastToGpuKernelMod, float, int64_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, float) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), - DynamicBroadcastToGpuKernelMod, half, int64_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, half) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), - DynamicBroadcastToGpuKernelMod, int16_t, int64_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, int16_t) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), - DynamicBroadcastToGpuKernelMod, int32_t, int64_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, int32_t) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - DynamicBroadcastToGpuKernelMod, int64_t, int64_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, int64_t) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), - DynamicBroadcastToGpuKernelMod, double, int32_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, double) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), - DynamicBroadcastToGpuKernelMod, float, int32_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, float) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), - DynamicBroadcastToGpuKernelMod, half, int32_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, half) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), - DynamicBroadcastToGpuKernelMod, int16_t, int32_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, int16_t) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - DynamicBroadcastToGpuKernelMod, int32_t, int32_t) -MS_REG_GPU_KERNEL_TWO( + BroadcastToGpuKernelMod, int32_t) +MS_REG_GPU_KERNEL_ONE( DynamicBroadcastTo, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), - DynamicBroadcastToGpuKernelMod, int64_t, int32_t) + BroadcastToGpuKernelMod, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.h deleted file mode 100644 index 2f77c99f8a5..00000000000 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.h +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2021-2022 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_ - -#include -#include -#include -#include -#include -#include "plugin/device/gpu/kernel/gpu_kernel.h" -#include "plugin/device/gpu/kernel/gpu_kernel_factory.h" -#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh" - -namespace mindspore { -namespace kernel { -constexpr size_t SHAPE_SIZE = 4; -constexpr size_t kIndex2 = 2; -constexpr size_t kIndex3 = 3; -template -class DynamicBroadcastToGpuKernelMod : public NativeGpuKernelMod { - public: - DynamicBroadcastToGpuKernelMod() : shape_size_(0), is_null_input_(false) { ResetResource(); } - ~DynamicBroadcastToGpuKernelMod() = default; - - bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - auto cuda_stream = reinterpret_cast(stream_ptr); - auto data_addr = GetDeviceAddress(inputs, 0); - auto shape_addr = GetDeviceAddress(inputs, 1); - auto output_addr = GetDeviceAddress(outputs, 0); - - BroadcastTo(input_shape_[0], input_shape_[1], input_shape_[kIndex2], input_shape_[kIndex3], output_shape_[0], - output_shape_[1], output_shape_[kIndex2], output_shape_[kIndex3], data_addr, output_addr, cuda_stream); - real_output_shape_ = std::vector(input_size_list_[1] / sizeof(S), 0); - CHECK_CUDA_RET_WITH_EXCEPT( - kernel_node_, - cudaMemcpyAsync(&real_output_shape_[0], shape_addr, input_size_list_[1], cudaMemcpyDeviceToHost, cuda_stream), - "DynamicBroadcastTo copy real output shape value failed"); - return true; - } - bool Init(const CNodePtr &kernel_node) override { - kernel_node_ = kernel_node; - auto input_shapes = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0); - auto shape_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1); - auto output_shapes = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0); - is_null_input_ = CHECK_NULL_INPUT(input_shapes) || CHECK_NULL_INPUT(output_shapes) || CHECK_NULL_INPUT(shape_shape); - if (is_null_input_) { - MS_LOG(WARNING) << "For 'BroadcastToGpuKernelMod', input or output is null"; - InitSizeLists(); - return true; - } - - if (input_shapes.size() > SHAPE_SIZE || output_shapes.size() > SHAPE_SIZE) { - MS_LOG(EXCEPTION) << "BroadcastTo operation does not support dim greater than " << SHAPE_SIZE; - } - - if (output_shapes.size() < input_shapes.size()) { - MS_LOG(EXCEPTION) << "The rank of BroadcastTo's output [" << output_shapes.size() - << "] cannot be smaller than the rank of the input [" << input_shapes.size() << "]."; - } - - shape_size_ = std::accumulate(shape_shape.begin(), shape_shape.end(), sizeof(S), std::multiplies()); - - size_t offset = output_shapes.size() - input_shapes.size(); - for (size_t i = 0; i < input_shapes.size(); i++) { - input_shape_[i + offset] = input_shapes[i]; - } - - for (size_t j = 0; j < output_shapes.size(); j++) { - output_shape_[j] = (output_shapes[j] > 0 ? output_shapes[j] : input_shapes[j]); - } - - InitSizeLists(); - is_need_updateop_ = true; - return true; - } - void ResetResource() noexcept override { - real_output_shape_.clear(); - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - for (size_t i = 0; i < SHAPE_SIZE; i++) { - input_shape_[i] = 1; - output_shape_[i] = 1; - } - } - void UpdateOp() override { - auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), 0); - std::vector output_shape; - std::transform(real_output_shape_.begin(), real_output_shape_.end(), std::back_inserter(output_shape), - [](const S &i) { return static_cast(i); }); - common::AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get()); - MS_LOG(DEBUG) << "Run PostExecute for DynamicBroadcastTo, real output shape is " << output_shape; - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[kIndex2] * input_shape_[kIndex3] * - sizeof(T)); - input_size_list_.push_back(shape_size_); - output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[kIndex2] * output_shape_[kIndex3] * - sizeof(T)); - } - - private: - size_t shape_size_; - size_t input_shape_[SHAPE_SIZE] = {1, 1, 1, 1}; - size_t output_shape_[SHAPE_SIZE] = {1, 1, 1, 1}; - bool is_null_input_ = false; - std::vector real_output_shape_; -}; -} // namespace kernel -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.cc index 3fef0ad2d2d..a59d70e758f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.cc @@ -61,8 +61,10 @@ bool DynamicStitchKernelMod::Init(const CNodePtr &kernel_node) { } void DynamicStitchKernelMod::UpdateOp() { + CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast(stream_ptr_)), + "DynamicStitch cudaStreamSynchronized failed"); auto output_shape = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node_.lock(), 0); - output_shape[0] = max_index_ + 1; + output_shape[0] = IntToSize(max_index_) + 1; auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), n_); common::AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get()); MS_LOG(DEBUG) << "Run PostExecute for dynamicstitch, real output shape is " << output_shape; @@ -79,6 +81,7 @@ void DynamicStitchKernelMod::InitSizeLists() { return; } bool DynamicStitchKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream) { auto cuda_stream = reinterpret_cast(stream); + stream_ptr_ = stream; auto max_index_dev = GetDeviceAddress(workspace, 0); auto output_addr = GetDeviceAddress(outputs, 0); // Init output and max_index with 0 @@ -94,11 +97,9 @@ bool DynamicStitchKernelMod::Launch(const std::vector &inputs, const CallStitch(index_addr, data_addr, output_addr, index_num, one_data_ele_num_ * data_type_size_, max_index_dev, cuda_stream); } - int temp = 0; - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(&temp, max_index_dev, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream), - "Copy max_index failed") - max_index_ = IntToSize(temp); + CHECK_CUDA_RET_WITH_EXCEPT( + kernel_node_, cudaMemcpyAsync(&max_index_, max_index_dev, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream), + "Copy max_index failed") return true; } } // namespace kernel diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.h index d0b445e2e29..36626828d1a 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/other/dynamic_stitch_gpu_kernel.h @@ -43,9 +43,10 @@ class DynamicStitchKernelMod : public NativeGpuKernelMod { private: size_t n_; size_t real_ele_num_; - size_t max_index_; + int max_index_; size_t one_data_ele_num_; size_t data_type_size_; + void *stream_ptr_; }; MS_REG_GPU_KERNEL(DynamicStitch, DynamicStitchKernelMod) diff --git a/mindspore/core/ops/concat.cc b/mindspore/core/ops/concat.cc index 901d0500c08..be921b06632 100644 --- a/mindspore/core/ops/concat.cc +++ b/mindspore/core/ops/concat.cc @@ -70,8 +70,8 @@ abstract::ShapePtr ConcatInferShape(const PrimitivePtr &primitive, const std::ve auto ret_max_shape = element0_max_shape; auto ret_min_shape = element0_min_shape; for (size_t i = 1; i < elements.size(); ++i) { - auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape]; - auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape]; + auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMaxShape]; + auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMinShape]; ret_max_shape[axis] += elementi_max_shape[axis]; ret_min_shape[axis] += elementi_min_shape[axis]; } diff --git a/tests/st/ops/gpu/test_dynamic_ops.py b/tests/st/ops/gpu/test_dynamic_ops.py index f3daa6f34e3..195587863b3 100644 --- a/tests/st/ops/gpu/test_dynamic_ops.py +++ b/tests/st/ops/gpu/test_dynamic_ops.py @@ -222,9 +222,9 @@ def test_dynamic_reshape(): assert compare(output, output_cmp) -class ReduceSumNet(nn.Cell): +class ReduceSumInputAxisNet(nn.Cell): def __init__(self): - super(ReduceSumNet, self).__init__() + super(ReduceSumInputAxisNet, self).__init__() self.reduce = ops.ReduceSum() def construct(self, x, y): @@ -234,9 +234,9 @@ class ReduceSumNet(nn.Cell): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_dynamic_reduce_sum(): +def test_dynamic_reduce_sum_input_axis(): """ - Feature: Test ReduceSum. + Feature: Test ReduceSum with axis is input. Description: The shape of inputs is dynamic. Expectation: Assert that results are consistent with result of the numpy compute """ @@ -251,7 +251,7 @@ def test_dynamic_reduce_sum(): dataset = ds.GeneratorDataset(data_list, column_names, shuffle=False) dynamic_columns = {column_names[0]: [None, 256]} dataset.set_dynamic_columns(columns=dynamic_columns) - net = ReduceSumNet() + net = ReduceSumInputAxisNet() output = dynamic_shape_sink_process(net, dataset) # Currently, the parameter axis of ReduceSum operator is dynamic(tensor) is # not supported under the fixed shape, so numpy is used for comparison @@ -290,3 +290,72 @@ def test_dynamic_nop(): output = dynamic_shape_sink_process(net, dataset) output_cmp = fixed_shape_process(net, dataset) assert compare(output, output_cmp) + + +class ReduceSumNet(nn.Cell): + def __init__(self, axis=()): + super(ReduceSumNet, self).__init__() + self.reduce = ops.ReduceSum() + self.axis = axis + + def construct(self, x): + return self.reduce(x, self.axis) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_reduce_sum(): + """ + Feature: Test ReduceSum and its backward. + Description: The shape of inputs is dynamic. + Expectation: Assert that results are consistent with result of with fixed shape. + """ + dtype = np.float32 + data_list = [] + for i in [2, 96]: + data = [] + data.append(np.random.rand(i, 256).astype(dtype)) + data.append(np.array(1).astype(np.float32)) + data_list.append(tuple(data)) + column_names = get_columns(len(data_list[0])) + dataset = ds.GeneratorDataset(data_list, column_names, shuffle=False) + dynamic_columns = {column_names[0]: [None, 256]} + dataset.set_dynamic_columns(columns=dynamic_columns) + net = GradNetWrtX(ReduceSumNet()) + output = dynamic_shape_sink_process(net, dataset) + output_cmp = fixed_shape_process(net, dataset) + assert compare(output, output_cmp) + + +class AddNet(nn.Cell): + def construct(self, x, y): + return ops.add(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dynamic_add(): + """ + Feature: Test add and its backward. + Description: The shape of inputs is dynamic. + Expectation: Assert that results are consistent with result of with fixed shape. + """ + dtype = np.float32 + data_list = [] + for i in [2, 96]: + data = [] + data.append(np.random.rand(i, 256).astype(dtype)) + data.append(np.random.rand(i, 256).astype(dtype)) + data.append(np.random.rand(i, 256).astype(dtype)) + data_list.append(tuple(data)) + column_names = get_columns(len(data_list[0])) + dataset = ds.GeneratorDataset(data_list, column_names, shuffle=False) + dynamic_columns = {column_names[0]: [None, 256], column_names[1]: [ + None, 256], column_names[2]: [None, 256]} + dataset.set_dynamic_columns(columns=dynamic_columns) + net = GradNetWrtX(AddNet()) + output = dynamic_shape_sink_process(net, dataset) + output_cmp = fixed_shape_process(net, dataset) + assert compare(output, output_cmp)