forked from mindspore-Ecosystem/mindspore
Modify nn.Range for GPU.
This commit is contained in:
parent
64b0a5a497
commit
7c419f2f5d
|
@ -15,7 +15,6 @@
|
|||
"""math"""
|
||||
import math
|
||||
import numpy as np
|
||||
import mindspore.context as context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
@ -128,7 +127,6 @@ class Range(Cell):
|
|||
|
||||
def __init__(self, start, limit=None, delta=1):
|
||||
super(Range, self).__init__()
|
||||
self.is_gpu = context.get_context("device_target") == "GPU"
|
||||
validator.check_value_type("start", start, [int, float], self.cls_name)
|
||||
validator.check_value_type("delta", delta, [int, float], self.cls_name)
|
||||
if delta == 0:
|
||||
|
@ -157,17 +155,8 @@ class Range(Cell):
|
|||
length_input = math.ceil((limit - start) / delta)
|
||||
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
|
||||
|
||||
if self.is_gpu:
|
||||
self.start = Tensor(start, self.dtype)
|
||||
self.limit = Tensor(limit, self.dtype)
|
||||
self.delta = Tensor(delta, self.dtype)
|
||||
self.range_gpu = P.Range(length_input)
|
||||
|
||||
def construct(self):
|
||||
if self.is_gpu:
|
||||
range_out = self.range_gpu(self.start, self.limit, self.delta)
|
||||
else:
|
||||
range_out = self.range_x(self.input_tensor)
|
||||
range_out = self.range_x(self.input_tensor)
|
||||
return range_out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue