forked from mindspore-Ecosystem/mindspore
!8861 [MS][GPU][DynamicShapeUpdate] Converting UnsortedSegmentMax from normal to dynamic shape op
From: @danishnxt Reviewed-by: Signed-off-by:
This commit is contained in:
commit
442217314d
|
@ -22,15 +22,35 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
UnsortedSegmentMax,
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
// Dynamic Mode
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
UnsortedSegmentMaxGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat16)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeFloat16),
|
||||
UnsortedSegmentMaxGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(UnsortedSegmentMax,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
UnsortedSegmentMaxGpuKernel, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,14 +28,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
||||
public:
|
||||
UnsortedSegmentMaxGpuKernel()
|
||||
: num_segments_(1),
|
||||
inner_size_(1),
|
||||
outer_size_(1),
|
||||
input_size_(1),
|
||||
segment_ids_size_(1),
|
||||
output_size_(1),
|
||||
is_null_input_(false) {}
|
||||
UnsortedSegmentMaxGpuKernel() { ResetResource(); }
|
||||
~UnsortedSegmentMaxGpuKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
|
@ -60,18 +53,24 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
auto input_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
is_null_input_ = CHECK_NULL_INPUT(input_shapes);
|
||||
if (is_null_input_) {
|
||||
MS_LOG(WARNING) << "UnsortedSegmentMax input is null";
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
auto segment_ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); // we get that from computation
|
||||
auto segment_ids_shapes = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
|
||||
auto output_shapes = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num == 3) {
|
||||
MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 3 - dynamic mode";
|
||||
} else {
|
||||
MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2";
|
||||
}
|
||||
|
||||
num_segments_ = output_shapes[0];
|
||||
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||
input_size_ *= input_shapes[i];
|
||||
|
@ -97,6 +96,19 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
|
|||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
num_segments_ = 1;
|
||||
inner_size_ = 1;
|
||||
outer_size_ = 1;
|
||||
input_size_ = 1;
|
||||
segment_ids_size_ = 1;
|
||||
output_size_ = 1;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
|
|
|
@ -113,6 +113,8 @@ AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -273,6 +273,74 @@ AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const Pri
|
|||
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplUnsortedSegmentMax(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
auto x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
MS_EXCEPTION_IF_NULL(x->shape());
|
||||
auto segment_ids = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
MS_EXCEPTION_IF_NULL(segment_ids);
|
||||
MS_EXCEPTION_IF_NULL(segment_ids->shape());
|
||||
auto segment_ids_shape = segment_ids->shape()->shape();
|
||||
(void)CheckTensorDType(x, {kFloat16, kFloat32, kInt32}, "Input 0 (x) for UnsortedSegmentMax should be %s");
|
||||
(void)CheckTensorDType(segment_ids, {kInt32}, "Input 1 (segment_ids) for UnsortedSegmentMax should be %s");
|
||||
// check if dynamic shape
|
||||
bool x_is_dyn = (!x->shape()->min_shape().empty() && !x->shape()->max_shape().empty());
|
||||
bool ids_is_dyn = (!segment_ids->shape()->min_shape().empty() && !segment_ids->shape()->max_shape().empty());
|
||||
bool op_is_dynamic = x_is_dyn && ids_is_dyn;
|
||||
auto x_shape = x->shape()->shape();
|
||||
ShapeVector shape;
|
||||
int64_t num_segments_value = 0;
|
||||
if (args_spec_list[2]->isa<AbstractTensor>()) { // num_segments is Tensor
|
||||
auto num_segments = args_spec_list[2]->cast<AbstractTensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments);
|
||||
auto num_segments_value_ptr = num_segments->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_value_ptr);
|
||||
auto num_segments_tensor = num_segments_value_ptr->cast<tensor::TensorPtr>();
|
||||
MS_EXCEPTION_IF_NULL(num_segments_tensor);
|
||||
num_segments_value = *static_cast<int64_t *>(num_segments_tensor->data_c());
|
||||
} else if (args_spec_list[2]->isa<AbstractScalar>()) { // num_segments is Scalar
|
||||
auto num_segments = CheckArg<AbstractScalar>(op_name, args_spec_list, 2);
|
||||
num_segments_value = GetValue<int64_t>(num_segments->BuildValue());
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "num_segments incorrect type in UnsortedSegmentMax";
|
||||
}
|
||||
if (num_segments_value <= 0) {
|
||||
MS_LOG(EXCEPTION) << "num_segments must be > 0 in UnsortedSegmentMax";
|
||||
}
|
||||
shape.emplace_back(num_segments_value);
|
||||
shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end());
|
||||
if (!op_is_dynamic) {
|
||||
if (x_shape[0] != segment_ids_shape[0]) {
|
||||
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
|
||||
}
|
||||
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape));
|
||||
}
|
||||
// is dynamic
|
||||
ShapeVector min_shape;
|
||||
ShapeVector max_shape;
|
||||
min_shape.emplace_back(num_segments_value);
|
||||
max_shape.emplace_back(num_segments_value);
|
||||
// only run validation if shape values are known
|
||||
bool x_any_shape = std::any_of(x_shape.begin(), x_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
|
||||
bool ids_any_shape =
|
||||
std::any_of(segment_ids_shape.begin(), segment_ids_shape.end(), [](int64_t dim) { return dim == Shape::SHP_ANY; });
|
||||
if (!x_any_shape && !ids_any_shape) {
|
||||
if (x_shape[0] != segment_ids_shape[0]) {
|
||||
MS_LOG(EXCEPTION) << "Length of segment_ids must match first value of x shape UnsortedSegmentMax";
|
||||
}
|
||||
}
|
||||
ShapeVector x_shape_min;
|
||||
ShapeVector x_shape_max;
|
||||
x_shape_min = (x_is_dyn) ? x->shape()->min_shape() : x->shape()->shape();
|
||||
x_shape_max = (x_is_dyn) ? x->shape()->max_shape() : x->shape()->shape();
|
||||
min_shape.insert(min_shape.end(), x_shape_min.begin() + segment_ids_shape.size(), x_shape_min.end());
|
||||
max_shape.insert(max_shape.end(), x_shape_max.begin() + segment_ids_shape.size(), x_shape_max.end());
|
||||
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string op_name = primitive->name();
|
||||
|
|
|
@ -58,6 +58,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimSparseGatherV2, {InferImplGatherV2, true}},
|
||||
{prim::kPrimEmbeddingLookup, {InferImplEmbeddingLookup, true}},
|
||||
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
|
||||
{prim::kPrimUnsortedSegmentMax, {InferImplUnsortedSegmentMax, true}},
|
||||
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
|
||||
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
|
||||
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
|
||||
|
|
|
@ -1981,7 +1981,7 @@ class UnsortedSegmentMin(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class UnsortedSegmentMax(PrimitiveWithInfer):
|
||||
class UnsortedSegmentMax(PrimitiveWithCheck):
|
||||
"""
|
||||
Computes the maximum along segments of a tensor.
|
||||
|
||||
|
@ -2017,27 +2017,21 @@ class UnsortedSegmentMax(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""Initialize UnsortedSegmentMax"""
|
||||
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
|
||||
self.add_prim_attr("dynamic_shape_depends", [2])
|
||||
|
||||
def __infer__(self, x, segment_ids, num_segments):
|
||||
x_type = x['dtype']
|
||||
x_shape = x['shape']
|
||||
def __check__(self, x, segment_ids, num_segments):
|
||||
segment_ids_shape = segment_ids['shape']
|
||||
valid_type = [mstype.float16, mstype.float32, mstype.int32]
|
||||
validator.check_tensor_dtype_valid("x", x['dtype'], valid_type, self.name)
|
||||
validator.check_tensors_dtypes_same_and_valid({"segment_ids": segment_ids['dtype']}, [mstype.int32], self.name)
|
||||
validator.check_equal_int(len(segment_ids_shape), 1, "rank of segment_ids_shape", self.name)
|
||||
validator.check(f'first shape of input_x', x_shape[0],
|
||||
'length of segments_id', segment_ids_shape[0], Rel.EQ, self.name)
|
||||
num_segments_v = num_segments['value']
|
||||
validator.check_value_type('num_segments', num_segments_v, [int], self.name)
|
||||
validator.check_positive_int(num_segments_v, "num_segments", self.name)
|
||||
segment_ids_shape_len = len(segment_ids_shape)
|
||||
out_shape = [num_segments_v]
|
||||
out_shape += x_shape[segment_ids_shape_len:]
|
||||
out = {'shape': out_shape,
|
||||
'dtype': x_type,
|
||||
'value': None}
|
||||
return out
|
||||
num_segments_type = num_segments['dtype']
|
||||
validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name)
|
||||
if isinstance(num_segments_type, type(mstype.tensor)):
|
||||
validator.check_tensor_dtype_valid("num_segments", num_segments_type, [mstype.int64],
|
||||
self.name)
|
||||
else:
|
||||
validator.check_value_type('num_segments', num_segments['value'], [int], self.name)
|
||||
|
||||
|
||||
class UnsortedSegmentProd(PrimitiveWithInfer):
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore
|
|||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
@ -139,7 +140,7 @@ def test_3d_float32():
|
|||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_single_init():
|
||||
context.set_context(device_target='GPU')
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
|
||||
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
|
||||
|
@ -202,3 +203,112 @@ def test_3d_single_init():
|
|||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
||||
|
||||
# For testing Dynamic Shape operation
|
||||
class UnsortedSegmentMaxDynNet(nn.Cell):
|
||||
def __init__(self, num_segments):
|
||||
super(UnsortedSegmentMaxDynNet, self).__init__()
|
||||
self.unsorted_segment_max = P.UnsortedSegmentMax()
|
||||
self.gpu_convert_to_dynamic_shape = inner.GpuConvertToDynamicShape()
|
||||
self.num_segments = num_segments
|
||||
|
||||
def construct(self, data, ids):
|
||||
dyn_data = self.gpu_convert_to_dynamic_shape(data)
|
||||
dyn_ids = self.gpu_convert_to_dynamic_shape(ids)
|
||||
return self.unsorted_segment_max(dyn_data, dyn_ids, self.num_segments)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_float32_dyn():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
|
||||
segment_ids = Tensor([2, 1, 1, -1], mstype.int32)
|
||||
num_segments = 3
|
||||
net = UnsortedSegmentMaxDynNet(num_segments)
|
||||
output = net(input_x, segment_ids).asnumpy()
|
||||
expect = np.array([[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
|
||||
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
|
||||
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
|
||||
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
|
||||
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
|
||||
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
|
||||
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
|
||||
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
|
||||
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
|
||||
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_3d_single_init_dyn():
|
||||
context.set_context(device_target='GPU')
|
||||
input_x = Tensor(np.arange(
|
||||
4 * 5 * 3, dtype=np.float32).reshape(4, 5, 3), dtype=mindspore.float32)
|
||||
segment_ids = Tensor([3, 0, 1, -1], mstype.int32)
|
||||
num_segments = 4
|
||||
net = UnsortedSegmentMaxDynNet(num_segments)
|
||||
output = net(input_x, segment_ids).asnumpy()
|
||||
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
|
||||
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
|
||||
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
|
||||
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
|
||||
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
|
||||
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
|
||||
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
|
||||
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
|
||||
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
|
||||
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
|
||||
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
|
||||
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
|
||||
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
|
||||
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]]]).astype(np.float32)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
||||
|
||||
num_segments = 6
|
||||
net = UnsortedSegmentMaxDynNet(num_segments)
|
||||
output = net(input_x, segment_ids).asnumpy()
|
||||
expect = np.array([[[1.5000000e+01, 1.6000000e+01, 1.7000000e+01],
|
||||
[1.8000000e+01, 1.9000000e+01, 2.0000000e+01],
|
||||
[2.1000000e+01, 2.2000000e+01, 2.3000000e+01],
|
||||
[2.4000000e+01, 2.5000000e+01, 2.6000000e+01],
|
||||
[2.7000000e+01, 2.8000000e+01, 2.9000000e+01]],
|
||||
[[3.0000000e+01, 3.1000000e+01, 3.2000000e+01],
|
||||
[3.3000000e+01, 3.4000000e+01, 3.5000000e+01],
|
||||
[3.6000000e+01, 3.7000000e+01, 3.8000000e+01],
|
||||
[3.9000000e+01, 4.0000000e+01, 4.1000000e+01],
|
||||
[4.2000000e+01, 4.3000000e+01, 4.4000000e+01]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[0.0000000e+00, 1.0000000e+00, 2.0000000e+00],
|
||||
[3.0000000e+00, 4.0000000e+00, 5.0000000e+00],
|
||||
[6.0000000e+00, 7.0000000e+00, 8.0000000e+00],
|
||||
[9.0000000e+00, 1.0000000e+01, 1.1000000e+01],
|
||||
[1.2000000e+01, 1.3000000e+01, 1.4000000e+01]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]],
|
||||
[[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38],
|
||||
[-3.4028235e+38, -3.4028235e+38, -3.4028235e+38]]]).astype(np.float32)
|
||||
np.testing.assert_array_almost_equal(output, expect)
|
||||
|
|
Loading…
Reference in New Issue