range past segfault addressed comments

changed default max_output_length to 1000000

change docstring

fix ci

change max_output_length to maxlen
This commit is contained in:
Peilin Wang 2021-01-05 23:49:05 -05:00
parent 03e655f14a
commit 507cc4ab15
11 changed files with 475 additions and 7 deletions

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include "backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Range,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
DynamicRangeGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Range,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64),
DynamicRangeGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Range,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
DynamicRangeGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(Range,
KernelAttr()
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeInt64),
DynamicRangeGpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,121 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
#include <cuda_runtime.h>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class DynamicRangeGpuKernel : public GpuKernel {
public:
DynamicRangeGpuKernel() { ResetResource(); }
~DynamicRangeGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *range_start = GetDeviceAddress<T>(inputs, 0);
T *range_end = GetDeviceAddress<T>(inputs, 1);
T *range_delta = GetDeviceAddress<T>(inputs, 2);
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0);
stream_ptr_ = stream_ptr;
CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address,
max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr));
// use workspace[0] for actual output shape, we know it must be 1d
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t),
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Failed to copy gpu memory.");
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
return true;
}
void PostExecute() override {
// required synchronize for PostExecute
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaStreamSynchronize failed");
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)};
std::vector<std::vector<size_t>> output_shape = {{(size_t)output_shape_}};
AnfAlgo::SetOutputInferTypeAndShape(output_type, output_shape, c_node_ptr_.get());
}
void ResetResource() noexcept override {
stream_ptr_ = nullptr;
c_node_ptr_ = nullptr;
output_shape_ = 0;
max_output_length_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_count != 3) {
MS_LOG(ERROR) << input_count << " inputs were provided, but DynamicRangeGpuKernel expects 3.";
return false;
}
max_output_length_ = GetAttr<int64_t>(kernel_node, "maxlen");
c_node_ptr_ = kernel_node;
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
input_size_list_.push_back(sizeof(T));
output_size_list_.push_back(max_output_length_ * sizeof(T));
// this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape.
workspace_size_list_.push_back(sizeof(int64_t));
return;
}
private:
void *stream_ptr_;
CNodePtr c_node_ptr_;
int64_t output_shape_;
int64_t max_output_length_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_

View File

@ -0,0 +1,76 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "dynamic_range_impl.cuh"
#include <cuda_runtime.h>
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__device__ void CheckInputs(const T &start, const T &end, const T &delta) {
if (delta == 0) {
asm("trap;");
}
if (start < end && delta < 0) {
asm("trap;");
}
if (start > end && delta > 0) {
asm("trap;");
}
}
template <typename T>
__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output,
int64_t *output_shape, const int64_t max_output_size) {
T start = range_start[0];
T end = range_end[0];
T delta = range_delta[0];
CheckInputs(start, end, delta);
int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta));
if (real_output_shape > max_output_size) {
asm("trap;");
}
*output_shape = real_output_shape;
size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x;
for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) {
output[gt_id] = gt_id * delta + start;
}
}
template <typename T>
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
const int64_t max_output_size, cudaStream_t cuda_stream) {
Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta,
output, output_shape, max_output_size);
}
template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output,
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta,
int64_t *output, int64_t *output_shape, const int64_t max_output_size,
cudaStream_t cuda_stream);
template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output,
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta,
double *output, int64_t *output_shape, const int64_t max_output_size,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
#include <cuda_runtime.h>
template <typename T>
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
const int64_t max_output_size, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -285,6 +285,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
template <typename T>
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -168,6 +168,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
if (max_shape.empty()) {
max_shape = shape->shape();
}
auto ids =
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
// Currently we choose the same data type as input for the idx.
@ -186,6 +187,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
if (idx_max_shape.empty()) {
idx_max_shape = shape->shape();
}
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape);
ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape));
// outputs: ids, ids_idx
@ -951,5 +953,36 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max);
return std::make_shared<AbstractTensor>(kBool, output_shape);
}
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string &op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
AbstractTensorPtr range_delta = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
TypePtrList supported_types = {kInt64, kInt32, kFloat32, kFloat64};
TypePtr range_start_type = CheckTensorDType(range_start, supported_types, "range_start input of Range should be %s");
TypePtr range_end_type = CheckTensorDType(range_end, supported_types, "range_start input of Range should be %s");
TypePtr range_delta_type = CheckTensorDType(range_delta, supported_types, "range_start input of Range should be %s");
// check all 3 inputs are same type
if (!IsIdentidityOrSubclass(range_start_type, range_end_type) ||
!IsIdentidityOrSubclass(range_end_type, range_delta_type)) {
MS_LOG(EXCEPTION) << "All inputs must have same type, but got: " << args_spec_list[0]->type_name() << ", "
<< args_spec_list[1]->type_name() << ", and " << args_spec_list[2]->type_name();
}
int64_t max_output_length = -1;
ValuePtr max_output_length_ptr = primitive->GetAttr("maxlen");
max_output_length = GetValue<int64_t>(max_output_length_ptr);
ShapeVector output_shape = {Shape::SHP_ANY};
ShapeVector min_shape = {1};
ShapeVector max_shape = {max_output_length};
ShapePtr shape = std::make_shared<Shape>(output_shape, min_shape, max_shape);
return std::make_shared<AbstractTensor>(range_start_type, shape);
}
} // namespace abstract
} // namespace mindspore

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -81,6 +81,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
{prim::kPrimSplit, {InferImplSplit, true}},
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
{prim::kPrimRange, {InferImplRange, true}},
// Structure
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
{prim::kPrimMakeList, {InferImplMakeList, true}},

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -125,6 +125,7 @@ inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("Scat
inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform");
inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split");
inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask");
inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range");
// NN
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity)
Unique, GatherD, Identity, Range)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
@ -402,6 +402,7 @@ __all__ = [
"ReLUV2",
"SparseToDense",
"MatrixInverse",
"Range",
]
__all__.sort()

View File

@ -1,6 +1,6 @@
# coding: utf-8
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -4722,3 +4722,63 @@ class Identity(PrimitiveWithInfer):
'dtype': x['dtype'],
'value': None}
return out
class Range(PrimitiveWithCheck):
r"""
Creates a sequence of numbers that begins at `start` and extends by increments of
`delta` up to but not including `limit`.
The types of all 3 inputs must be the same. The type of the resulting tensor is
the same as the type of the inputs.
Args:
maxlen (int): Memory that can fit `maxlen` many elements
will be allocated for the output. Optional, must be positive, defaults to 1000000.
If the output has more than `maxlen` elements, a runtime error
will occur.
Inputs:
- **start** (Tensor) - A scalar Tensor. The first number in the sequence. Must have
type: int32 or float32
- **limit** (Tensor) - A scalar Tensor. Upper limit of the sequence, exclusive. Must
have type: int32 or float32
- **delta** (Tensor) - A scalar Tensor. Number that increments `start`. Must have
type: int32 or float32
Outputs:
A 1-D Tensor, with the same type as the inputs.
Examples:
>>> start = Tensor(0)
>>> limit = Tensor(10)
>>> delta = Tensor(4)
>>> output = ops.Range()(start, limit, delta)
>>> print(output)
[0, 4, 8]
Supported Platforms:
``Ascend`` ``GPU``
"""
@prim_attr_register
def __init__(self, maxlen=1000000):
self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output'])
validator.check_value_type("maxlen", maxlen, [int], self.name)
validator.check_positive_int(maxlen, "maxlen", self.name)
self.maxlen = maxlen
self.add_prim_attr('maxlen', maxlen)
self.add_prim_attr("dynamic_shape_depends", [0])
self.add_prim_attr("dynamic_shape_depends", [1])
self.add_prim_attr("dynamic_shape_depends", [2])
def check_shape(self, start_shape, limit_shape, delta_shape):
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
def check_dtype(self, start_dtype, limit_dtype, delta_dtype):
valid_dtypes = [mstype.int32, mstype.float32]
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)

View File

@ -0,0 +1,93 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
class RangeNet(nn.Cell):
def __init__(self):
super(RangeNet, self).__init__()
self.range = P.Range()
def construct(self, s, e, d):
return self.range(s, e, d)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_range_int():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
range_net = RangeNet()
ms_out = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
np_expected = np.array([2, 3, 4])
np.testing.assert_array_equal(ms_out, np_expected)
range_net = RangeNet()
ms_out = range_net(Tensor(-24, mstype.int32), Tensor(1, mstype.int32), Tensor(4, mstype.int32)).asnumpy()
np_expected = np.array([-24, -20, -16, -12, -8, -4, 0])
np.testing.assert_array_equal(ms_out, np_expected)
range_net = RangeNet()
ms_out = range_net(Tensor(8, mstype.int32), Tensor(1, mstype.int32), Tensor(-1, mstype.int32)).asnumpy()
np_expected = np.array([8, 7, 6, 5, 4, 3, 2])
np.testing.assert_array_equal(ms_out, np_expected)
range_net = RangeNet()
ms_out = range_net(Tensor(3, mstype.int32), Tensor(-11, mstype.int32), Tensor(-5, mstype.int32)).asnumpy()
np_expected = np.array([3, -2, -7])
np.testing.assert_array_equal(ms_out, np_expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_range_float():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
range_net = RangeNet()
ms_out = range_net(Tensor(2.3, mstype.float32), Tensor(5.5, mstype.float32), Tensor(1.2, mstype.float32)).asnumpy()
np_expected = np.array([2.3, 3.5, 4.7])
np.testing.assert_array_almost_equal(ms_out, np_expected)
range_net = RangeNet()
ms_out = range_net(Tensor(-4, mstype.float32), Tensor(-1, mstype.float32), Tensor(1.5, mstype.float32)).asnumpy()
np_expected = np.array([-4.0, -2.5])
np.testing.assert_array_almost_equal(ms_out, np_expected)
range_net = RangeNet()
ms_out = range_net(Tensor(8.0, mstype.float32), Tensor(1.0, mstype.float32), Tensor(-1.0, mstype.float32)).asnumpy()
np_expected = np.array([8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0])
np.testing.assert_array_almost_equal(ms_out, np_expected)
range_net = RangeNet()
ms_out = range_net(Tensor(1.5, mstype.float32), Tensor(-1, mstype.float32), Tensor(-18.9, mstype.float32)).asnumpy()
np_expected = np.array([1.5])
np.testing.assert_array_almost_equal(ms_out, np_expected)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_range_invalid_max_output_length():
with pytest.raises(ValueError):
_ = P.Range(0)
_ = P.Range(-1)
_ = P.Range(None)
_ = P.Range('5')