Modify nn.Range for GPU.

This commit is contained in:
liuxiao93 2021-01-13 16:22:46 +08:00
parent 64b0a5a497
commit 7c419f2f5d
1 changed files with 1 additions and 12 deletions

View File

@ -15,7 +15,6 @@
"""math"""
import math
import numpy as np
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor
@ -128,7 +127,6 @@ class Range(Cell):
def __init__(self, start, limit=None, delta=1):
super(Range, self).__init__()
self.is_gpu = context.get_context("device_target") == "GPU"
validator.check_value_type("start", start, [int, float], self.cls_name)
validator.check_value_type("delta", delta, [int, float], self.cls_name)
if delta == 0:
@ -157,17 +155,8 @@ class Range(Cell):
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
if self.is_gpu:
self.start = Tensor(start, self.dtype)
self.limit = Tensor(limit, self.dtype)
self.delta = Tensor(delta, self.dtype)
self.range_gpu = P.Range(length_input)
def construct(self):
if self.is_gpu:
range_out = self.range_gpu(self.start, self.limit, self.delta)
else:
range_out = self.range_x(self.input_tensor)
range_out = self.range_x(self.input_tensor)
return range_out