remove loop can unroll flag, clean some python usage

This commit is contained in:
Wei Luning 2020-07-18 16:42:50 +08:00
parent 43567f9b9f
commit 88e864a4a3
13 changed files with 21 additions and 34 deletions

View File

@ -29,7 +29,6 @@ const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__";
// flag names // flag names
const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16"; const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";
const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32"; const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
const char GRAPH_FLAG_LOOP_CAN_UNROLL[] = "loop_can_unroll";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect"; const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";

View File

@ -30,7 +30,6 @@ extern const char PYTHON_DATACLASS_FIELDS[];
extern const char GRAPH_FLAG_MIX_PRECISION_FP16[]; extern const char GRAPH_FLAG_MIX_PRECISION_FP16[];
extern const char GRAPH_FLAG_MIX_PRECISION_FP32[]; extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
extern const char GRAPH_FLAG_LOOP_CAN_UNROLL[];
extern const char GRAPH_FLAG_HAS_EFFECT[]; extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[];

View File

@ -286,7 +286,6 @@ class ClipByNorm(Cell):
self.select_ = P.Select() self.select_ = P.Select()
self.greater_ = P.Greater() self.greater_ = P.Greater()
self.cast = P.Cast() self.cast = P.Cast()
self.zero = Tensor(np.array([0.0]).astype(np.float32))
self.sqrt = P.Sqrt() self.sqrt = P.Sqrt()
self.max_op = P.Maximum() self.max_op = P.Maximum()
self.shape = P.Shape() self.shape = P.Shape()
@ -300,7 +299,7 @@ class ClipByNorm(Cell):
"""add ms_function decorator for pynative mode""" """add ms_function decorator for pynative mode"""
mul_x = F.square(x) mul_x = F.square(x)
l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32)
cond = self.greater_(l2sum, self.zero) cond = self.greater_(l2sum, 0)
ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0)
l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum))) l2sum_safe = self.select_(cond, l2sum, self.cast(ones_, self.dtype(l2sum)))
@ -407,11 +406,13 @@ class OneHot(Cell):
super(OneHot, self).__init__() super(OneHot, self).__init__()
self.onehot = P.OneHot(axis) self.onehot = P.OneHot(axis)
self.depth = depth self.depth = depth
self.on_value = Tensor(on_value, dtype) self.dtype = dtype
self.off_value = Tensor(off_value, dtype) self.on_value = on_value
self.off_value = off_value
def construct(self, indices): def construct(self, indices):
return self.onehot(indices, self.depth, self.on_value, self.off_value) return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype))
class Pad(Cell): class Pad(Cell):

View File

@ -133,7 +133,8 @@ class LSTM(Cell):
self.transpose2 = P.Transpose() self.transpose2 = P.Transpose()
num_directions = 2 if self.bidirectional else 1 num_directions = 2 if self.bidirectional else 1
self.cpu_target = False self.cpu_target = False
if context.get_context("device_target") == "CPU": enable_debug = context.get_context("enable_debug_runtime")
if context.get_context("device_target") == "CPU" and not enable_debug:
self.cpu_target = True self.cpu_target = True
if not self.cpu_target: if not self.cpu_target:
self.lstm = P.LSTM(input_size=self.input_size, self.lstm = P.LSTM(input_size=self.input_size,

View File

@ -141,7 +141,7 @@ class Optimizer(Cell):
if self.is_group_lr: if self.is_group_lr:
self.learning_rate = ParameterTuple(self.group_lr) self.learning_rate = ParameterTuple(self.group_lr)
else: else:
self.learning_rate = Parameter(learning_rate, name="learning_rate") self.learning_rate = Parameter(Tensor(learning_rate, mstype.float32), name="learning_rate")
if self.is_group: if self.is_group:
self.parameters = ParameterTuple(self.group_params) self.parameters = ParameterTuple(self.group_params)

View File

@ -1104,7 +1104,6 @@ class TransformerModel(nn.Cell):
beam_width=config.beam_width, beam_width=config.beam_width,
length_penalty_weight=config.length_penalty_weight, length_penalty_weight=config.length_penalty_weight,
max_decode_length=config.max_decode_length) max_decode_length=config.max_decode_length)
self.tfm_decoder.add_flags(loop_can_unroll=True)
self.cast = P.Cast() self.cast = P.Cast()
self.dtype = config.dtype self.dtype = config.dtype

View File

@ -277,8 +277,8 @@ class RelaPosMatrixGenerator(nn.Cell):
def __init__(self, length, max_relative_position): def __init__(self, length, max_relative_position):
super(RelaPosMatrixGenerator, self).__init__() super(RelaPosMatrixGenerator, self).__init__()
self._length = length self._length = length
self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32) self._max_relative_position = max_relative_position
self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32) self._min_relative_position = -max_relative_position
self.range_length = -length + 1 self.range_length = -length + 1
self.tile = P.Tile() self.tile = P.Tile()
@ -336,9 +336,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
max_relative_position=max_relative_position) max_relative_position=max_relative_position)
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.one_hot = P.OneHot() self.one_hot = nn.OneHot(depth=self.vocab_size)
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.shape = P.Shape() self.shape = P.Shape()
self.gather = P.GatherV2() # index_select self.gather = P.GatherV2() # index_select
self.matmul = P.BatchMatMul() self.matmul = P.BatchMatMul()
@ -350,7 +348,7 @@ class RelaPosEmbeddingsGenerator(nn.Cell):
if self.use_one_hot_embeddings: if self.use_one_hot_embeddings:
flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,)) flat_relative_positions_matrix = self.reshape(relative_positions_matrix_out, (-1,))
one_hot_relative_positions_matrix = self.one_hot( one_hot_relative_positions_matrix = self.one_hot(
flat_relative_positions_matrix, self.vocab_size, self.on_value, self.off_value) flat_relative_positions_matrix)
embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table) embeddings = self.matmul(one_hot_relative_positions_matrix, self.embeddings_table)
my_shape = self.shape(relative_positions_matrix_out) + (self.depth,) my_shape = self.shape(relative_positions_matrix_out) + (self.depth,)
embeddings = self.reshape(embeddings, my_shape) embeddings = self.reshape(embeddings, my_shape)
@ -372,11 +370,9 @@ class SaturateCast(nn.Cell):
def __init__(self, src_type=mstype.float32, dst_type=mstype.float32): def __init__(self, src_type=mstype.float32, dst_type=mstype.float32):
super(SaturateCast, self).__init__() super(SaturateCast, self).__init__()
np_type = mstype.dtype_to_nptype(dst_type) np_type = mstype.dtype_to_nptype(dst_type)
min_type = np.finfo(np_type).min
max_type = np.finfo(np_type).max
self.tensor_min_type = Tensor([min_type], dtype=src_type) self.tensor_min_type = float(np.finfo(np_type).min)
self.tensor_max_type = Tensor([max_type], dtype=src_type) self.tensor_max_type = float(np.finfo(np_type).max)
self.min_op = P.Minimum() self.min_op = P.Minimum()
self.max_op = P.Maximum() self.max_op = P.Maximum()
@ -442,7 +438,7 @@ class BertAttention(nn.Cell):
self.has_attention_mask = has_attention_mask self.has_attention_mask = has_attention_mask
self.use_relative_positions = use_relative_positions self.use_relative_positions = use_relative_positions
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type) self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape_from_2d = (-1, from_tensor_width) self.shape_from_2d = (-1, from_tensor_width)
self.shape_to_2d = (-1, to_tensor_width) self.shape_to_2d = (-1, to_tensor_width)
@ -471,7 +467,7 @@ class BertAttention(nn.Cell):
self.trans_shape = (0, 2, 1, 3) self.trans_shape = (0, 2, 1, 3)
self.trans_shape_relative = (2, 0, 1, 3) self.trans_shape_relative = (2, 0, 1, 3)
self.trans_shape_position = (1, 2, 0, 3) self.trans_shape_position = (1, 2, 0, 3)
self.multiply_data = Tensor([-10000.0,], dtype=compute_type) self.multiply_data = -10000.0
self.batch_num = batch_size * num_attention_heads self.batch_num = batch_size * num_attention_heads
self.matmul = P.BatchMatMul() self.matmul = P.BatchMatMul()

View File

@ -17,7 +17,6 @@
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.composite import add_flags
from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \
DepthwiseConv2dNative, SpaceToBatch, BatchToSpace DepthwiseConv2dNative, SpaceToBatch, BatchToSpace
@ -122,7 +121,6 @@ class ASPP(nn.Cell):
self.feature_shape = feature_shape self.feature_shape = feature_shape
self.concat = P.Concat(axis=1) self.concat = P.Concat(axis=1)
@add_flags(loop_can_unroll=True)
def construct(self, x, scale_index=0): def construct(self, x, scale_index=0):
aspp0 = self.aspp0(x) aspp0 = self.aspp0(x)
aspp1 = self.global_poolings[scale_index](x) aspp1 = self.global_poolings[scale_index](x)

View File

@ -275,8 +275,6 @@ class TransformerInferModel(nn.Cell):
length_penalty_weight=config.length_penalty_weight, length_penalty_weight=config.length_penalty_weight,
max_decode_length=config.max_decode_length) max_decode_length=config.max_decode_length)
self.decoder.add_flags(loop_can_unroll=True)
self.cast = P.Cast() self.cast = P.Cast()
self.dtype = config.dtype self.dtype = config.dtype
self.cast_compute_type = SaturateCast(dst_type=config.compute_type) self.cast_compute_type = SaturateCast(dst_type=config.compute_type)

View File

@ -108,7 +108,7 @@ class BertAttentionRelativePositionKeys(nn.Cell):
self.trans_shape_position = (1, 2, 0, 3) self.trans_shape_position = (1, 2, 0, 3)
self.trans_shape_relative = (2, 0, 1, 3) self.trans_shape_relative = (2, 0, 1, 3)
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype) self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.multiply = P.Mul() self.multiply = P.Mul()
@ -301,7 +301,7 @@ class BertAttentionRelativePositionValues(nn.Cell):
self.trans_shape_position = (1, 2, 0, 3) self.trans_shape_position = (1, 2, 0, 3)
self.trans_shape_relative = (2, 0, 1, 3) self.trans_shape_relative = (2, 0, 1, 3)
self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=dtype) self.scores_mul = 1.0 / math.sqrt(float(self.size_per_head))
self.trans_shape = (0, 2, 1, 3) self.trans_shape = (0, 2, 1, 3)
self.reshape = P.Reshape() self.reshape = P.Reshape()

View File

@ -276,7 +276,7 @@ class SingleDeepLabV3(nn.Cell):
atrous_rates=atrous_rates, atrous_rates=atrous_rates,
output_stride=output_stride, output_stride=output_stride,
fine_tune_batch_norm=fine_tune_batch_norm) fine_tune_batch_norm=fine_tune_batch_norm)
self.aspp.add_flags(loop_can_unroll=True)
atrous_rates_len = 0 atrous_rates_len = 0
if atrous_rates is not None: if atrous_rates is not None:
atrous_rates_len = len(atrous_rates) atrous_rates_len = len(atrous_rates)

View File

@ -259,7 +259,7 @@ class NormalKl(nn.Cell):
""" """
def __init__(self): def __init__(self):
super(NormalKl, self).__init__() super(NormalKl, self).__init__()
self.n = nn.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) self.n = nn.Normal(Tensor([3.0]), Tensor([4.0]), dtype=dtype.float32)
def construct(self, x_, y_): def construct(self, x_, y_):
return self.n('kl_loss', 'Normal', x_, y_) return self.n('kl_loss', 'Normal', x_, y_)

View File

@ -20,7 +20,6 @@ from numpy.random import normal
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.ops.composite import core
def setup_module(module): def setup_module(module):
@ -34,7 +33,6 @@ def test_remove_phi_and_fv():
""" test_remove_phi_and_fv """ """ test_remove_phi_and_fv """
@ms_function @ms_function
@core(loop_can_unroll=True)
def loop(x, input_data): def loop(x, input_data):
def fv_func(y): def fv_func(y):
return x * y return x * y
@ -60,7 +58,6 @@ def test_remove_multiple_phi():
""" test_remove_multiple_phi """ """ test_remove_multiple_phi """
@ms_function @ms_function
@core(loop_can_unroll=True)
def loop(x): def loop(x):
def mul(a, b): def mul(a, b):
return a * b return a * b
@ -83,7 +80,6 @@ def test_remove_multiple_phi_recursive():
""" test_remove_multiple_phi_recursive """ """ test_remove_multiple_phi_recursive """
@ms_function @ms_function
@core(loop_can_unroll=True)
def loop(x): def loop(x):
def mul(a, b): def mul(a, b):
return a * b return a * b