forked from mindspore-Ecosystem/mindspore
remove loop can unroll flag, clean some python usage
This commit is contained in:
parent
43567f9b9f
commit
88e864a4a3
|
@ -29,7 +29,6 @@ const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__";
|
|||
// flag names
|
||||
const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";
|
||||
const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
|
||||
const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
|
||||
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
|
||||
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
|
||||
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
|
||||
|
|
|
@ -30,7 +30,6 @@ extern const char PYTHON_DATACLASS_FIELDS[];
|
|||
|
||||
extern const char GRAPH_FLAG_MIX_PRECISION_FP16[];
|
||||
extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
|
||||
extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
|
||||
extern const char GRAPH_FLAG_HAS_EFFECT[];
|
||||
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
|
||||
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
|
||||
|
|
|
@ -286,7 +286,6 @@ class ClipByNorm(Cell):
|
|||
self.select_ = P.Select()
|
||||
self.greater_ = P.Greater()
|
||||
self.cast = P.Cast()
|
||||
self.zero = Tensor(np.array([0.0]).astype(np.float32))
|
||||
self.sqrt = P.Sqrt()
|
||||
self.max_op = P.Maximum()
|
||||
self.shape = P.Shape()
|
||||
|
@ -300,7 +299,7 @@ class ClipByNorm(Cell):
|
|||
"""add ms_function decorator for pynative mode"""
|
||||
mul_x = F.square(x)
|
||||
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
|
||||
cond = self.greater_(l2sum, self.zero)
|
||||
cond = self.greater_(l2sum, 0)
|
||||
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
|
||||
|
||||
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
|
||||
|
@ -407,11 +406,13 @@ class OneHot(Cell):
|
|||
super(OneHot, self).__init__()
|
||||
self.onehot = P.OneHot(axis)
|
||||
self.depth = depth
|
||||
self.on_value = Tensor(on_value, dtype)
|
||||
self.off_value = Tensor(off_value, dtype)
|
||||
self.dtype = dtype
|
||||
self.on_value = on_value
|
||||
self.off_value = off_value
|
||||
|
||||
def construct(self, indices):
|
||||
return self.onehot(indices, self.depth, self.on_value, self.off_value)
|
||||
return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
|
||||
|
||||
|
||||
|
||||
class Pad(Cell):
|
||||
|
|
|
@ -133,7 +133,8 @@ class LSTM(Cell):
|
|||
self.transpose2 = P.Transpose()
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
self.cpu_target = False
|
||||
if context.get_context("device_target") == "CPU":
|
||||
enable_debug = context.get_context("enable_debug_runtime")
|
||||
if context.get_context("device_target") == "CPU" and not enable_debug:
|
||||
self.cpu_target = True
|
||||
if not self.cpu_target:
|
||||
self.lstm = P.LSTM(input_size=self.input_size,
|
||||
|
|
|
@ -141,7 +141,7 @@ class Optimizer(Cell):
|
|||
if self.is_group_lr:
|
||||
self.learning_rate = ParameterTuple(self.group_lr)
|
||||
else:
|
||||
self.learning_rate = Parameter(learning_rate, name="learning_rate")
|
||||
self.learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name="learning_rate")
|
||||
|
||||
if self.is_group:
|
||||
self.parameters = ParameterTuple(self.group_params)
|
||||
|
|
|
@ -1104,7 +1104,6 @@ class TransformerModel(nn.Cell):
|
|||
beam_width=config.beam_width,
|
||||
length_penalty_weight=config.length_penalty_weight,
|
||||
max_decode_length=config.max_decode_length)
|
||||
self.tfm_decoder.add_flags(loop_can_unroll=True)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
|
|
|
@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
|
|||
def __init__(self, length, max_relative_position):
|
||||
super(RelaPosMatrixGenerator, self).__init__()
|
||||
self._length = length
|
||||
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
|
||||
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
|
||||
self._max_relative_position = max_relative_position
|
||||
self._min_relative_position = -max_relative_position
|
||||
self.range_length = -length + 1
|
||||
|
||||
self.tile = P.Tile()
|
||||
|
@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
|
||||
max_relative_position=max_relative_position)
|
||||
self.reshape = P.Reshape()
|
||||
self.one_hot = P.OneHot()
|
||||
self.on_value = Tensor(1.0, mstype.float32)
|
||||
self.off_value = Tensor(0.0, mstype.float32)
|
||||
self.one_hot = nn.OneHot(depth=self.vocab_size)
|
||||
self.shape = P.Shape()
|
||||
self.gather = P.GatherV2() # index_select
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
|
|||
if self.use_one_hot_embeddings:
|
||||
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
|
||||
one_hot_relative_positions_matrix = self.one_hot(
|
||||
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value)
|
||||
flat_relative_positions_matrix)
|
||||
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
|
||||
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
|
||||
embeddings = self.reshape(embeddings, my_shape)
|
||||
|
@ -372,11 +370,9 @@ class SaturateCast(nn.Cell):
|
|||
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
|
||||
super(SaturateCast, self).__init__()
|
||||
np_type = mstype.dtype_to_nptype(dst_type)
|
||||
min_type = np.finfo(np_type).min
|
||||
max_type = np.finfo(np_type).max
|
||||
|
||||
self.tensor_min_type = Tensor([min_type], dtype=src_type)
|
||||
self.tensor_max_type = Tensor([max_type], dtype=src_type)
|
||||
self.tensor_min_type = float(np.finfo(np_type).min)
|
||||
self.tensor_max_type = float(np.finfo(np_type).max)
|
||||
|
||||
self.min_op = P.Minimum()
|
||||
self.max_op = P.Maximum()
|
||||
|
@ -442,7 +438,7 @@ class BertAttention(nn.Cell):
|
|||
self.has_attention_mask = has_attention_mask
|
||||
self.use_relative_positions = use_relative_positions
|
||||
|
||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
|
||||
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
||||
self.reshape = P.Reshape()
|
||||
self.shape_from_2d = (-1, from_tensor_width)
|
||||
self.shape_to_2d = (-1, to_tensor_width)
|
||||
|
@ -471,7 +467,7 @@ class BertAttention(nn.Cell):
|
|||
self.trans_shape = (0, 2, 1, 3)
|
||||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
|
||||
self.multiply_data = -10000.0
|
||||
self.batch_num = batch_size * num_attention_heads
|
||||
self.matmul = P.BatchMatMul()
|
||||
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.composite import add_flags
|
||||
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
|
||||
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
|
||||
|
||||
|
@ -122,7 +121,6 @@ class ASPP(nn.Cell):
|
|||
self.feature_shape = feature_shape
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
@add_flags(loop_can_unroll=True)
|
||||
def construct(self, x, scale_index=0):
|
||||
aspp0 = self.aspp0(x)
|
||||
aspp1 = self.global_poolings[scale_index](x)
|
||||
|
|
|
@ -275,8 +275,6 @@ class TransformerInferModel(nn.Cell):
|
|||
length_penalty_weight=config.length_penalty_weight,
|
||||
max_decode_length=config.max_decode_length)
|
||||
|
||||
self.decoder.add_flags(loop_can_unroll=True)
|
||||
|
||||
self.cast = P.Cast()
|
||||
self.dtype = config.dtype
|
||||
self.cast_compute_type = SaturateCast(dst_type=config.compute_type)
|
||||
|
|
|
@ -108,7 +108,7 @@ class BertAttentionRelativePositionKeys(nn.Cell):
|
|||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
|
||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype)
|
||||
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
self.multiply = P.Mul()
|
||||
|
@ -301,7 +301,7 @@ class BertAttentionRelativePositionValues(nn.Cell):
|
|||
self.trans_shape_position = (1, 2, 0, 3)
|
||||
self.trans_shape_relative = (2, 0, 1, 3)
|
||||
|
||||
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype)
|
||||
self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
|
||||
self.trans_shape = (0, 2, 1, 3)
|
||||
|
||||
self.reshape = P.Reshape()
|
||||
|
|
|
@ -276,7 +276,7 @@ class SingleDeepLabV3(nn.Cell):
|
|||
atrous_rates=atrous_rates,
|
||||
output_stride=output_stride,
|
||||
fine_tune_batch_norm=fine_tune_batch_norm)
|
||||
self.aspp.add_flags(loop_can_unroll=True)
|
||||
|
||||
atrous_rates_len = 0
|
||||
if atrous_rates is not None:
|
||||
atrous_rates_len = len(atrous_rates)
|
||||
|
|
|
@ -259,7 +259,7 @@ class NormalKl(nn.Cell):
|
|||
"""
|
||||
def __init__(self):
|
||||
super(NormalKl, self).__init__()
|
||||
self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
|
||||
self.n = nn.Normal(Tensor([3.0]), Tensor([4.0]), dtype=dtype.float32)
|
||||
|
||||
def construct(self, x_, y_):
|
||||
return self.n('kl_loss', 'Normal', x_, y_)
|
||||
|
|
|
@ -20,7 +20,6 @@ from numpy.random import normal
|
|||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.ops.composite import core
|
||||
|
||||
|
||||
def setup_module(module):
|
||||
|
@ -34,7 +33,6 @@ def test_remove_phi_and_fv():
|
|||
""" test_remove_phi_and_fv """
|
||||
|
||||
@ms_function
|
||||
@core(loop_can_unroll=True)
|
||||
def loop(x, input_data):
|
||||
def fv_func(y):
|
||||
return x * y
|
||||
|
@ -60,7 +58,6 @@ def test_remove_multiple_phi():
|
|||
""" test_remove_multiple_phi """
|
||||
|
||||
@ms_function
|
||||
@core(loop_can_unroll=True)
|
||||
def loop(x):
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
|
@ -83,7 +80,6 @@ def test_remove_multiple_phi_recursive():
|
|||
""" test_remove_multiple_phi_recursive """
|
||||
|
||||
@ms_function
|
||||
@core(loop_can_unroll=True)
|
||||
def loop(x):
|
||||
def mul(a, b):
|
||||
return a * b
|
||||
|
|
Loading…
Reference in New Issue