[Dynamic][Gpu]fix some dynamic operators and add test cases
This commit is contained in:
parent
86b4de336b
commit
790118a4e8
|
@ -47,8 +47,8 @@ class BroadcastToGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
auto input_shapes = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto output_shapes = common::AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
auto input_shapes = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
|
||||
auto output_shapes = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
|
||||
kernel_node_ = kernel_node;
|
||||
is_null_input_ =
|
||||
CHECK_SHAPE_NULL(input_shapes, kernel_name_, "input") || CHECK_SHAPE_NULL(output_shapes, kernel_name_, "output");
|
||||
|
@ -82,6 +82,15 @@ class BroadcastToGpuKernelMod : public NativeGpuKernelMod {
|
|||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
ResetSizeLists();
|
||||
for (size_t i = 0; i < SHAPE_SIZE; ++i) {
|
||||
input_shape_[i] = 1;
|
||||
output_shape_[i] = 1;
|
||||
}
|
||||
is_null_input_ = false;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[2] * input_shape_[3] * sizeof(T));
|
||||
|
@ -89,8 +98,8 @@ class BroadcastToGpuKernelMod : public NativeGpuKernelMod {
|
|||
}
|
||||
|
||||
private:
|
||||
size_t input_shape_[4] = {1, 1, 1, 1};
|
||||
size_t output_shape_[4] = {1, 1, 1, 1};
|
||||
size_t input_shape_[SHAPE_SIZE] = {1, 1, 1, 1};
|
||||
size_t output_shape_[SHAPE_SIZE] = {1, 1, 1, 1};
|
||||
bool is_null_input_ = false;
|
||||
std::string kernel_name_;
|
||||
};
|
||||
|
|
|
@ -13,57 +13,57 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/gpu/kernel/other/dynamic_broadcastto_gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicBroadcastToGpuKernelMod, double, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicBroadcastToGpuKernelMod, float, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
DynamicBroadcastToGpuKernelMod, half, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
DynamicBroadcastToGpuKernelMod, int16_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicBroadcastToGpuKernelMod, int32_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastToGpuKernelMod, int64_t, int64_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicBroadcastToGpuKernelMod, double, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicBroadcastToGpuKernelMod, float, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
DynamicBroadcastToGpuKernelMod, half, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
DynamicBroadcastToGpuKernelMod, int16_t, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicBroadcastToGpuKernelMod, int32_t, int32_t)
|
||||
MS_REG_GPU_KERNEL_TWO(
|
||||
BroadcastToGpuKernelMod, int32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
DynamicBroadcastTo,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicBroadcastToGpuKernelMod, int64_t, int32_t)
|
||||
BroadcastToGpuKernelMod, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,133 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <algorithm>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t SHAPE_SIZE = 4;
|
||||
constexpr size_t kIndex2 = 2;
|
||||
constexpr size_t kIndex3 = 3;
|
||||
template <typename T, typename S>
|
||||
class DynamicBroadcastToGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
DynamicBroadcastToGpuKernelMod() : shape_size_(0), is_null_input_(false) { ResetResource(); }
|
||||
~DynamicBroadcastToGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
auto data_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
auto shape_addr = GetDeviceAddress<S>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
BroadcastTo(input_shape_[0], input_shape_[1], input_shape_[kIndex2], input_shape_[kIndex3], output_shape_[0],
|
||||
output_shape_[1], output_shape_[kIndex2], output_shape_[kIndex3], data_addr, output_addr, cuda_stream);
|
||||
real_output_shape_ = std::vector<S>(input_size_list_[1] / sizeof(S), 0);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cudaMemcpyAsync(&real_output_shape_[0], shape_addr, input_size_list_[1], cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"DynamicBroadcastTo copy real output shape value failed");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
auto input_shapes = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
|
||||
auto shape_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 1);
|
||||
auto output_shapes = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shapes) || CHECK_NULL_INPUT(output_shapes) || CHECK_NULL_INPUT(shape_shape);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "For 'BroadcastToGpuKernelMod', input or output is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (input_shapes.size() > SHAPE_SIZE || output_shapes.size() > SHAPE_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "BroadcastTo operation does not support dim greater than " << SHAPE_SIZE;
|
||||
}
|
||||
|
||||
if (output_shapes.size() < input_shapes.size()) {
|
||||
MS_LOG(EXCEPTION) << "The rank of BroadcastTo's output [" << output_shapes.size()
|
||||
<< "] cannot be smaller than the rank of the input [" << input_shapes.size() << "].";
|
||||
}
|
||||
|
||||
shape_size_ = std::accumulate(shape_shape.begin(), shape_shape.end(), sizeof(S), std::multiplies<size_t>());
|
||||
|
||||
size_t offset = output_shapes.size() - input_shapes.size();
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
input_shape_[i + offset] = input_shapes[i];
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < output_shapes.size(); j++) {
|
||||
output_shape_[j] = (output_shapes[j] > 0 ? output_shapes[j] : input_shapes[j]);
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
is_need_updateop_ = true;
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
real_output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
for (size_t i = 0; i < SHAPE_SIZE; i++) {
|
||||
input_shape_[i] = 1;
|
||||
output_shape_[i] = 1;
|
||||
}
|
||||
}
|
||||
void UpdateOp() override {
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), 0);
|
||||
std::vector<size_t> output_shape;
|
||||
std::transform(real_output_shape_.begin(), real_output_shape_.end(), std::back_inserter(output_shape),
|
||||
[](const S &i) { return static_cast<size_t>(i); });
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get());
|
||||
MS_LOG(DEBUG) << "Run PostExecute for DynamicBroadcastTo, real output shape is " << output_shape;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_shape_[0] * input_shape_[1] * input_shape_[kIndex2] * input_shape_[kIndex3] *
|
||||
sizeof(T));
|
||||
input_size_list_.push_back(shape_size_);
|
||||
output_size_list_.push_back(output_shape_[0] * output_shape_[1] * output_shape_[kIndex2] * output_shape_[kIndex3] *
|
||||
sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t shape_size_;
|
||||
size_t input_shape_[SHAPE_SIZE] = {1, 1, 1, 1};
|
||||
size_t output_shape_[SHAPE_SIZE] = {1, 1, 1, 1};
|
||||
bool is_null_input_ = false;
|
||||
std::vector<S> real_output_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_DYNAMIC_BRAODCASTTO_GPU_KERNEL_H_
|
|
@ -61,8 +61,10 @@ bool DynamicStitchKernelMod::Init(const CNodePtr &kernel_node) {
|
|||
}
|
||||
|
||||
void DynamicStitchKernelMod::UpdateOp() {
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"DynamicStitch cudaStreamSynchronized failed");
|
||||
auto output_shape = AnfAlgo::GetOutputDeviceShapeAdaptively(kernel_node_.lock(), 0);
|
||||
output_shape[0] = max_index_ + 1;
|
||||
output_shape[0] = IntToSize(max_index_) + 1;
|
||||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node_.lock(), n_);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape({data_type}, {output_shape}, kernel_node_.lock().get());
|
||||
MS_LOG(DEBUG) << "Run PostExecute for dynamicstitch, real output shape is " << output_shape;
|
||||
|
@ -79,6 +81,7 @@ void DynamicStitchKernelMod::InitSizeLists() { return; }
|
|||
bool DynamicStitchKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream) {
|
||||
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||
stream_ptr_ = stream;
|
||||
auto max_index_dev = GetDeviceAddress<int>(workspace, 0);
|
||||
auto output_addr = GetDeviceAddress<unsigned char>(outputs, 0);
|
||||
// Init output and max_index with 0
|
||||
|
@ -94,11 +97,9 @@ bool DynamicStitchKernelMod::Launch(const std::vector<AddressPtr> &inputs, const
|
|||
CallStitch(index_addr, data_addr, output_addr, index_num, one_data_ele_num_ * data_type_size_, max_index_dev,
|
||||
cuda_stream);
|
||||
}
|
||||
int temp = 0;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(&temp, max_index_dev, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"Copy max_index failed")
|
||||
max_index_ = IntToSize(temp);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(
|
||||
kernel_node_, cudaMemcpyAsync(&max_index_, max_index_dev, sizeof(int), cudaMemcpyDeviceToHost, cuda_stream),
|
||||
"Copy max_index failed")
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
|
|
|
@ -43,9 +43,10 @@ class DynamicStitchKernelMod : public NativeGpuKernelMod {
|
|||
private:
|
||||
size_t n_;
|
||||
size_t real_ele_num_;
|
||||
size_t max_index_;
|
||||
int max_index_;
|
||||
size_t one_data_ele_num_;
|
||||
size_t data_type_size_;
|
||||
void *stream_ptr_;
|
||||
};
|
||||
|
||||
MS_REG_GPU_KERNEL(DynamicStitch, DynamicStitchKernelMod)
|
||||
|
|
|
@ -70,8 +70,8 @@ abstract::ShapePtr ConcatInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
auto ret_max_shape = element0_max_shape;
|
||||
auto ret_min_shape = element0_min_shape;
|
||||
for (size_t i = 1; i < elements.size(); ++i) {
|
||||
auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMaxShape];
|
||||
auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(element0->BuildShape())[kMinShape];
|
||||
auto elementi_max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMaxShape];
|
||||
auto elementi_min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(elements[i]->BuildShape())[kMinShape];
|
||||
ret_max_shape[axis] += elementi_max_shape[axis];
|
||||
ret_min_shape[axis] += elementi_min_shape[axis];
|
||||
}
|
||||
|
|
|
@ -222,9 +222,9 @@ def test_dynamic_reshape():
|
|||
assert compare(output, output_cmp)
|
||||
|
||||
|
||||
class ReduceSumNet(nn.Cell):
|
||||
class ReduceSumInputAxisNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ReduceSumNet, self).__init__()
|
||||
super(ReduceSumInputAxisNet, self).__init__()
|
||||
self.reduce = ops.ReduceSum()
|
||||
|
||||
def construct(self, x, y):
|
||||
|
@ -234,9 +234,9 @@ class ReduceSumNet(nn.Cell):
|
|||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_reduce_sum():
|
||||
def test_dynamic_reduce_sum_input_axis():
|
||||
"""
|
||||
Feature: Test ReduceSum.
|
||||
Feature: Test ReduceSum with axis is input.
|
||||
Description: The shape of inputs is dynamic.
|
||||
Expectation: Assert that results are consistent with result of the numpy compute
|
||||
"""
|
||||
|
@ -251,7 +251,7 @@ def test_dynamic_reduce_sum():
|
|||
dataset = ds.GeneratorDataset(data_list, column_names, shuffle=False)
|
||||
dynamic_columns = {column_names[0]: [None, 256]}
|
||||
dataset.set_dynamic_columns(columns=dynamic_columns)
|
||||
net = ReduceSumNet()
|
||||
net = ReduceSumInputAxisNet()
|
||||
output = dynamic_shape_sink_process(net, dataset)
|
||||
# Currently, the parameter axis of ReduceSum operator is dynamic(tensor) is
|
||||
# not supported under the fixed shape, so numpy is used for comparison
|
||||
|
@ -290,3 +290,72 @@ def test_dynamic_nop():
|
|||
output = dynamic_shape_sink_process(net, dataset)
|
||||
output_cmp = fixed_shape_process(net, dataset)
|
||||
assert compare(output, output_cmp)
|
||||
|
||||
|
||||
class ReduceSumNet(nn.Cell):
|
||||
def __init__(self, axis=()):
|
||||
super(ReduceSumNet, self).__init__()
|
||||
self.reduce = ops.ReduceSum()
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x):
|
||||
return self.reduce(x, self.axis)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_reduce_sum():
|
||||
"""
|
||||
Feature: Test ReduceSum and its backward.
|
||||
Description: The shape of inputs is dynamic.
|
||||
Expectation: Assert that results are consistent with result of with fixed shape.
|
||||
"""
|
||||
dtype = np.float32
|
||||
data_list = []
|
||||
for i in [2, 96]:
|
||||
data = []
|
||||
data.append(np.random.rand(i, 256).astype(dtype))
|
||||
data.append(np.array(1).astype(np.float32))
|
||||
data_list.append(tuple(data))
|
||||
column_names = get_columns(len(data_list[0]))
|
||||
dataset = ds.GeneratorDataset(data_list, column_names, shuffle=False)
|
||||
dynamic_columns = {column_names[0]: [None, 256]}
|
||||
dataset.set_dynamic_columns(columns=dynamic_columns)
|
||||
net = GradNetWrtX(ReduceSumNet())
|
||||
output = dynamic_shape_sink_process(net, dataset)
|
||||
output_cmp = fixed_shape_process(net, dataset)
|
||||
assert compare(output, output_cmp)
|
||||
|
||||
|
||||
class AddNet(nn.Cell):
|
||||
def construct(self, x, y):
|
||||
return ops.add(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dynamic_add():
|
||||
"""
|
||||
Feature: Test add and its backward.
|
||||
Description: The shape of inputs is dynamic.
|
||||
Expectation: Assert that results are consistent with result of with fixed shape.
|
||||
"""
|
||||
dtype = np.float32
|
||||
data_list = []
|
||||
for i in [2, 96]:
|
||||
data = []
|
||||
data.append(np.random.rand(i, 256).astype(dtype))
|
||||
data.append(np.random.rand(i, 256).astype(dtype))
|
||||
data.append(np.random.rand(i, 256).astype(dtype))
|
||||
data_list.append(tuple(data))
|
||||
column_names = get_columns(len(data_list[0]))
|
||||
dataset = ds.GeneratorDataset(data_list, column_names, shuffle=False)
|
||||
dynamic_columns = {column_names[0]: [None, 256], column_names[1]: [
|
||||
None, 256], column_names[2]: [None, 256]}
|
||||
dataset.set_dynamic_columns(columns=dynamic_columns)
|
||||
net = GradNetWrtX(AddNet())
|
||||
output = dynamic_shape_sink_process(net, dataset)
|
||||
output_cmp = fixed_shape_process(net, dataset)
|
||||
assert compare(output, output_cmp)
|
||||
|
|
Loading…
Reference in New Issue