forked from mindspore-Ecosystem/mindspore
range past segfault addressed comments
changed default max_output_length to 1000000 change docstring fix ci change max_output_length to maxlen
This commit is contained in:
parent
03e655f14a
commit
507cc4ab15
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cstdint>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/dynamic_range_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
DynamicRangeGpuKernel, float)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
DynamicRangeGpuKernel, double)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeInt32),
|
||||
DynamicRangeGpuKernel, int32_t)
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Range,
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeInt64),
|
||||
DynamicRangeGpuKernel, int64_t)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,121 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/dynamic_range_impl.cuh"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class DynamicRangeGpuKernel : public GpuKernel {
|
||||
public:
|
||||
DynamicRangeGpuKernel() { ResetResource(); }
|
||||
~DynamicRangeGpuKernel() = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *range_start = GetDeviceAddress<T>(inputs, 0);
|
||||
T *range_end = GetDeviceAddress<T>(inputs, 1);
|
||||
T *range_delta = GetDeviceAddress<T>(inputs, 2);
|
||||
T *output_device_address = GetDeviceAddress<T>(outputs, 0);
|
||||
int64_t *output_shape_device_address = GetDeviceAddress<int64_t>(workspace, 0);
|
||||
|
||||
stream_ptr_ = stream_ptr;
|
||||
|
||||
CalRange(range_start, range_end, range_delta, output_device_address, output_shape_device_address,
|
||||
max_output_length_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
|
||||
// use workspace[0] for actual output shape, we know it must be 1d
|
||||
CHECK_CUDA_RET_WITH_ERROR(c_node_ptr_,
|
||||
cudaMemcpyAsync(&output_shape_, output_shape_device_address, sizeof(int64_t),
|
||||
cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"Failed to copy gpu memory.");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaDeviceSynchronize(), "cudaDeviceSyncFailed");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void PostExecute() override {
|
||||
// required synchronize for PostExecute
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(c_node_ptr_, cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaStreamSynchronize failed");
|
||||
|
||||
std::vector<TypeId> output_type = {AnfAlgo::GetOutputInferDataType(c_node_ptr_, 0)};
|
||||
std::vector<std::vector<size_t>> output_shape = {{(size_t)output_shape_}};
|
||||
AnfAlgo::SetOutputInferTypeAndShape(output_type, output_shape, c_node_ptr_.get());
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
stream_ptr_ = nullptr;
|
||||
c_node_ptr_ = nullptr;
|
||||
output_shape_ = 0;
|
||||
max_output_length_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
size_t input_count = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_count != 3) {
|
||||
MS_LOG(ERROR) << input_count << " inputs were provided, but DynamicRangeGpuKernel expects 3.";
|
||||
return false;
|
||||
}
|
||||
|
||||
max_output_length_ = GetAttr<int64_t>(kernel_node, "maxlen");
|
||||
c_node_ptr_ = kernel_node;
|
||||
InitSizeLists();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
input_size_list_.push_back(sizeof(T));
|
||||
output_size_list_.push_back(max_output_length_ * sizeof(T));
|
||||
|
||||
// this op outputs a 1d tensor, size of one int64_t is enough space to hold the shape.
|
||||
workspace_size_list_.push_back(sizeof(int64_t));
|
||||
return;
|
||||
}
|
||||
|
||||
private:
|
||||
void *stream_ptr_;
|
||||
CNodePtr c_node_ptr_;
|
||||
int64_t output_shape_;
|
||||
int64_t max_output_length_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_DYNAMIC_RANGE_GPU_KERNEL_H_
|
|
@ -0,0 +1,76 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dynamic_range_impl.cuh"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__device__ void CheckInputs(const T &start, const T &end, const T &delta) {
|
||||
if (delta == 0) {
|
||||
asm("trap;");
|
||||
}
|
||||
|
||||
if (start < end && delta < 0) {
|
||||
asm("trap;");
|
||||
}
|
||||
|
||||
if (start > end && delta > 0) {
|
||||
asm("trap;");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Range(const T *range_start, const T *range_end, const T *range_delta, T *output,
|
||||
int64_t *output_shape, const int64_t max_output_size) {
|
||||
T start = range_start[0];
|
||||
T end = range_end[0];
|
||||
T delta = range_delta[0];
|
||||
|
||||
CheckInputs(start, end, delta);
|
||||
|
||||
int64_t real_output_shape = static_cast<int64_t>(ceil(static_cast<double>(end - start) / delta));
|
||||
if (real_output_shape > max_output_size) {
|
||||
asm("trap;");
|
||||
}
|
||||
*output_shape = real_output_shape;
|
||||
|
||||
size_t gt_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (; gt_id < real_output_shape; gt_id += blockDim.x * gridDim.x) {
|
||||
output[gt_id] = gt_id * delta + start;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream) {
|
||||
Range<<<GET_BLOCKS(max_output_size), GET_THREADS, 0, cuda_stream>>>(range_start, range_end, range_delta,
|
||||
output, output_shape, max_output_size);
|
||||
}
|
||||
|
||||
template void CalRange<int>(const int *range_start, const int *range_end, const int *range_delta, int *output,
|
||||
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
||||
template void CalRange<int64_t>(const int64_t *range_start, const int64_t *range_end, const int64_t *range_delta,
|
||||
int64_t *output, int64_t *output_shape, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template void CalRange<float>(const float *range_start, const float *range_end, const float *range_delta, float *output,
|
||||
int64_t *output_shape, const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
template void CalRange<double>(const double *range_start, const double *range_end, const double *range_delta,
|
||||
double *output, int64_t *output_shape, const int64_t max_output_size,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
template <typename T>
|
||||
void CalRange(const T *range_start, const T *range_end, const T *range_delta, T *output, int64_t *output_shape,
|
||||
const int64_t max_output_size, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_DYNAMIC_RANGE_CUH_
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -285,6 +285,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
|
||||
template <typename T>
|
||||
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -168,6 +168,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
if (max_shape.empty()) {
|
||||
max_shape = shape->shape();
|
||||
}
|
||||
|
||||
auto ids =
|
||||
std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(ids_shape, min_shape, max_shape));
|
||||
// Currently we choose the same data type as input for the idx.
|
||||
|
@ -186,6 +187,7 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p
|
|||
if (idx_max_shape.empty()) {
|
||||
idx_max_shape = shape->shape();
|
||||
}
|
||||
|
||||
auto ids_idx = std::make_shared<AbstractTensor>(ids_idx_type, idx_shape);
|
||||
ids_idx->set_shape(std::make_shared<Shape>(idx_shape, idx_min_shape, idx_max_shape));
|
||||
// outputs: ids, ids_idx
|
||||
|
@ -951,5 +953,36 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive
|
|||
ShapePtr output_shape = std::make_shared<Shape>(lengths_shape, lengths_shape_min, lengths_shape_max);
|
||||
return std::make_shared<AbstractTensor>(kBool, output_shape);
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
const std::string &op_name = primitive->name();
|
||||
CheckArgsSize(op_name, args_spec_list, 3);
|
||||
AbstractTensorPtr range_start = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
AbstractTensorPtr range_end = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
|
||||
AbstractTensorPtr range_delta = CheckArg<AbstractTensor>(op_name, args_spec_list, 2);
|
||||
|
||||
TypePtrList supported_types = {kInt64, kInt32, kFloat32, kFloat64};
|
||||
TypePtr range_start_type = CheckTensorDType(range_start, supported_types, "range_start input of Range should be %s");
|
||||
TypePtr range_end_type = CheckTensorDType(range_end, supported_types, "range_start input of Range should be %s");
|
||||
TypePtr range_delta_type = CheckTensorDType(range_delta, supported_types, "range_start input of Range should be %s");
|
||||
|
||||
// check all 3 inputs are same type
|
||||
if (!IsIdentidityOrSubclass(range_start_type, range_end_type) ||
|
||||
!IsIdentidityOrSubclass(range_end_type, range_delta_type)) {
|
||||
MS_LOG(EXCEPTION) << "All inputs must have same type, but got: " << args_spec_list[0]->type_name() << ", "
|
||||
<< args_spec_list[1]->type_name() << ", and " << args_spec_list[2]->type_name();
|
||||
}
|
||||
|
||||
int64_t max_output_length = -1;
|
||||
ValuePtr max_output_length_ptr = primitive->GetAttr("maxlen");
|
||||
max_output_length = GetValue<int64_t>(max_output_length_ptr);
|
||||
ShapeVector output_shape = {Shape::SHP_ANY};
|
||||
ShapeVector min_shape = {1};
|
||||
ShapeVector max_shape = {max_output_length};
|
||||
ShapePtr shape = std::make_shared<Shape>(output_shape, min_shape, max_shape);
|
||||
|
||||
return std::make_shared<AbstractTensor>(range_start_type, shape);
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -81,6 +81,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimMapUniform, {InferImplMapUniform, true}},
|
||||
{prim::kPrimSplit, {InferImplSplit, true}},
|
||||
{prim::kPrimSequenceMask, {InferImplSequenceMask, true}},
|
||||
{prim::kPrimRange, {InferImplRange, true}},
|
||||
// Structure
|
||||
{prim::kPrimMakeTuple, {InferImplMakeTuple, true}},
|
||||
{prim::kPrimMakeList, {InferImplMakeList, true}},
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -125,6 +125,7 @@ inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("Scat
|
|||
inline const PrimitivePtr kPrimMapUniform = std::make_shared<Primitive>("MapUniform");
|
||||
inline const PrimitivePtr kPrimSplit = std::make_shared<Primitive>("Split");
|
||||
inline const PrimitivePtr kPrimSequenceMask = std::make_shared<Primitive>("SequenceMask");
|
||||
inline const PrimitivePtr kPrimRange = std::make_shared<Primitive>("Range");
|
||||
|
||||
// NN
|
||||
inline const PrimitivePtr kPrimFlatten = std::make_shared<Primitive>("Flatten");
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -33,7 +33,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentMax,
|
||||
UnsortedSegmentProd, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
|
||||
Unique, GatherD, Identity)
|
||||
Unique, GatherD, Identity, Range)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice,
|
||||
|
@ -402,6 +402,7 @@ __all__ = [
|
|||
"ReLUV2",
|
||||
"SparseToDense",
|
||||
"MatrixInverse",
|
||||
"Range",
|
||||
]
|
||||
|
||||
__all__.sort()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# coding: utf-8
|
||||
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -4722,3 +4722,63 @@ class Identity(PrimitiveWithInfer):
|
|||
'dtype': x['dtype'],
|
||||
'value': None}
|
||||
return out
|
||||
|
||||
|
||||
class Range(PrimitiveWithCheck):
|
||||
r"""
|
||||
Creates a sequence of numbers that begins at `start` and extends by increments of
|
||||
`delta` up to but not including `limit`.
|
||||
|
||||
The types of all 3 inputs must be the same. The type of the resulting tensor is
|
||||
the same as the type of the inputs.
|
||||
|
||||
Args:
|
||||
maxlen (int): Memory that can fit `maxlen` many elements
|
||||
will be allocated for the output. Optional, must be positive, defaults to 1000000.
|
||||
If the output has more than `maxlen` elements, a runtime error
|
||||
will occur.
|
||||
|
||||
Inputs:
|
||||
- **start** (Tensor) - A scalar Tensor. The first number in the sequence. Must have
|
||||
type: int32 or float32
|
||||
- **limit** (Tensor) - A scalar Tensor. Upper limit of the sequence, exclusive. Must
|
||||
have type: int32 or float32
|
||||
- **delta** (Tensor) - A scalar Tensor. Number that increments `start`. Must have
|
||||
type: int32 or float32
|
||||
|
||||
Outputs:
|
||||
A 1-D Tensor, with the same type as the inputs.
|
||||
|
||||
Examples:
|
||||
>>> start = Tensor(0)
|
||||
>>> limit = Tensor(10)
|
||||
>>> delta = Tensor(4)
|
||||
>>> output = ops.Range()(start, limit, delta)
|
||||
>>> print(output)
|
||||
[0, 4, 8]
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, maxlen=1000000):
|
||||
self.init_prim_io_names(inputs=['start', 'limit', 'delta'], outputs=['output'])
|
||||
validator.check_value_type("maxlen", maxlen, [int], self.name)
|
||||
validator.check_positive_int(maxlen, "maxlen", self.name)
|
||||
self.maxlen = maxlen
|
||||
self.add_prim_attr('maxlen', maxlen)
|
||||
|
||||
self.add_prim_attr("dynamic_shape_depends", [0])
|
||||
self.add_prim_attr("dynamic_shape_depends", [1])
|
||||
self.add_prim_attr("dynamic_shape_depends", [2])
|
||||
|
||||
def check_shape(self, start_shape, limit_shape, delta_shape):
|
||||
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
|
||||
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
|
||||
validator.check("delta_shape", len(delta_shape), "", 0, Rel.EQ, self.name)
|
||||
|
||||
def check_dtype(self, start_dtype, limit_dtype, delta_dtype):
|
||||
valid_dtypes = [mstype.int32, mstype.float32]
|
||||
inputs = {"start": start_dtype, "limit": limit_dtype, "delta": delta_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(inputs, valid_dtypes, self.name)
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class RangeNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(RangeNet, self).__init__()
|
||||
self.range = P.Range()
|
||||
|
||||
def construct(self, s, e, d):
|
||||
return self.range(s, e, d)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_int():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(2, mstype.int32), Tensor(5, mstype.int32), Tensor(1, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([2, 3, 4])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(-24, mstype.int32), Tensor(1, mstype.int32), Tensor(4, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([-24, -20, -16, -12, -8, -4, 0])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(8, mstype.int32), Tensor(1, mstype.int32), Tensor(-1, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([8, 7, 6, 5, 4, 3, 2])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(3, mstype.int32), Tensor(-11, mstype.int32), Tensor(-5, mstype.int32)).asnumpy()
|
||||
np_expected = np.array([3, -2, -7])
|
||||
np.testing.assert_array_equal(ms_out, np_expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_float():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(2.3, mstype.float32), Tensor(5.5, mstype.float32), Tensor(1.2, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([2.3, 3.5, 4.7])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(-4, mstype.float32), Tensor(-1, mstype.float32), Tensor(1.5, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([-4.0, -2.5])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(8.0, mstype.float32), Tensor(1.0, mstype.float32), Tensor(-1.0, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
range_net = RangeNet()
|
||||
ms_out = range_net(Tensor(1.5, mstype.float32), Tensor(-1, mstype.float32), Tensor(-18.9, mstype.float32)).asnumpy()
|
||||
np_expected = np.array([1.5])
|
||||
np.testing.assert_array_almost_equal(ms_out, np_expected)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_range_invalid_max_output_length():
|
||||
with pytest.raises(ValueError):
|
||||
_ = P.Range(0)
|
||||
_ = P.Range(-1)
|
||||
_ = P.Range(None)
|
||||
_ = P.Range('5')
|
Loading…
Reference in New Issue