!6374 LSTM API optimization

Merge pull request !6374 from caojian05/ms_master_lstm_api_optimation
This commit is contained in:
mindspore-ci-bot 2020-09-17 17:29:15 +08:00 committed by Gitee
commit 63ae2c3639
3 changed files with 229 additions and 131 deletions

View File

@ -14,12 +14,12 @@
# ============================================================================ # ============================================================================
"""lstm""" """lstm"""
import math import math
import numpy as np import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
@ -118,83 +118,41 @@ class LSTM(Cell):
dropout=0, dropout=0,
bidirectional=False): bidirectional=False):
super(LSTM, self).__init__() super(LSTM, self).__init__()
self.input_size = input_size validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
self.hidden_size = hidden_size validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name)
self.num_layers = num_layers validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name)
self.has_bias = has_bias
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) self.batch_first = batch_first
self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.cls_name) self.transpose = P.Transpose()
self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.cls_name) self.lstm = P.LSTM(input_size=input_size,
self.dropout = float(dropout) hidden_size=hidden_size,
self.bidirectional = bidirectional num_layers=num_layers,
if self.batch_first: has_bias=has_bias,
self.transpose1 = P.Transpose() bidirectional=bidirectional,
self.transpose2 = P.Transpose() dropout=float(dropout))
num_directions = 2 if self.bidirectional else 1
self.cpu_target = False weight_size = 0
enable_debug = context.get_context("enable_debug_runtime") gate_size = 4 * hidden_size
if context.get_context("device_target") == "CPU" and not enable_debug: num_directions = 2 if bidirectional else 1
self.cpu_target = True for layer in range(num_layers):
if not self.cpu_target: input_layer_size = input_size if layer == 0 else hidden_size * num_directions
self.lstm = P.LSTM(input_size=self.input_size, increment_size = gate_size * input_layer_size
hidden_size=self.hidden_size, increment_size += gate_size * hidden_size
num_layers=self.num_layers, if has_bias:
has_bias=self.has_bias, increment_size += 2 * gate_size
bidirectional=self.bidirectional, weight_size += increment_size * num_directions
dropout=self.dropout) stdv = 1 / math.sqrt(hidden_size)
weight_size = 0 w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
gate_size = 4 * self.hidden_size self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
for layer in range(self.num_layers):
input_layer_size = self.input_size if layer == 0 else self.hidden_size * num_directions
increment_size = gate_size * input_layer_size
increment_size += gate_size * self.hidden_size
if self.has_bias:
increment_size += 2 * gate_size
weight_size += increment_size * num_directions
stdv = 1 / math.sqrt(hidden_size)
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
else:
input_size_list = []
input_size_list.append(self.input_size)
for i in range(self.num_layers - 1):
input_size_list.append(self.hidden_size * num_directions)
weights = []
layers = []
bias_size = 0 if not self.has_bias else num_directions * self.hidden_size * 4
stdv = 1 / math.sqrt(hidden_size)
for i in range(num_layers):
weight_size = (input_size_list[i] + self.hidden_size) * num_directions * self.hidden_size * 4
if has_bias:
weight_size = weight_size + bias_size
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name='weight' + str(i)))
layers.append(nn.LSTMCell(input_size=input_size_list[i],
hidden_size=self.hidden_size,
has_bias=self.has_bias,
bidirectional=self.bidirectional,
dropout=self.dropout))
self.lstms = layers
self.weight = ParameterTuple(tuple(weights))
self.fill = P.Fill()
self.shape = P.Shape()
def construct(self, x, hx): def construct(self, x, hx):
if self.batch_first: if self.batch_first:
x = self.transpose1(x, (1, 0, 2)) x = self.transpose(x, (1, 0, 2))
if not self.cpu_target:
h, c = hx
output, h, c, _, _ = self.lstm(x, h, c, self.weight)
if self.batch_first:
output = self.transpose2(output, (1, 0, 2))
return (output, (h, c))
h, c = hx h, c = hx
output, hn, cn, _, _ = self.lstms[0](x, h[0], c[0], self.weight[0]) x, h, c, _, _ = self.lstm(x, h, c, self.weight)
for i in range(1, self.num_layers):
output, hn, cn, _, _ = self.lstms[i](output, h[i], c[i], self.weight[i])
if self.batch_first: if self.batch_first:
output = self.transpose2(output, (1, 0, 2)) x = self.transpose(x, (1, 0, 2))
return (output, (hn, cn)) return x, (h, c)
class LSTMCell(Cell): class LSTMCell(Cell):
@ -291,30 +249,19 @@ class LSTMCell(Cell):
dropout=0, dropout=0,
bidirectional=False): bidirectional=False):
super(LSTMCell, self).__init__() super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.has_bias = has_bias
self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name)
self.dropout = float(dropout) self.transpose = P.Transpose()
self.bidirectional = bidirectional self.lstm = P.LSTM(input_size=input_size,
self.num_directions = 1 hidden_size=hidden_size,
if self.bidirectional:
self.num_directions = 2
if self.batch_first:
self.transpose1 = P.Transpose()
self.transpose2 = P.Transpose()
self.lstm = P.LSTM(input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=1, num_layers=1,
has_bias=self.has_bias, has_bias=has_bias,
bidirectional=self.bidirectional, bidirectional=bidirectional,
dropout=self.dropout) dropout=float(dropout))
def construct(self, x, h, c, w): def construct(self, x, h, c, w):
if self.batch_first: if self.batch_first:
x = self.transpose1(x, (1, 0, 2)) x = self.transpose(x, (1, 0, 2))
output, hn, cn, _, _ = self.lstm(x, h, c, w) x, h, c, _, _ = self.lstm(x, h, c, w)
if self.batch_first: if self.batch_first:
output = self.transpose2(output, (1, 0, 2)) x = self.transpose(x, (1, 0, 2))
return output, hn, cn, _, _ return x, h, c, _, _

View File

@ -13,40 +13,108 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""LSTM.""" """LSTM."""
import math
import numpy as np import numpy as np
from mindspore import Tensor, nn, context from mindspore import Tensor, nn, context, Parameter, ParameterTuple
from mindspore.common.initializer import initializer
from mindspore.ops import operations as P from mindspore.ops import operations as P
STACK_LSTM_DEVICE = ["CPU"]
# Initialize short-term memory (h) and long-term memory (c) to 0 # Initialize short-term memory (h) and long-term memory (c) to 0
def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional): def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
"""init default input.""" """init default input."""
num_directions = 1 num_directions = 2 if bidirectional else 1
if bidirectional: h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
num_directions = 2 c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
if context.get_context("device_target") == "CPU":
h_list = []
c_list = []
i = 0
while i < num_layers:
hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))
h_list.append(hi)
ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))
c_list.append(ci)
i = i + 1
h = tuple(h_list)
c = tuple(c_list)
return h, c
h = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor(
np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
return h, c return h, c
def stack_lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
"""init default input."""
num_directions = 2 if bidirectional else 1
h_list = c_list = []
for _ in range(num_layers):
h_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
c_list.append(Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32)))
h, c = tuple(h_list), tuple(c_list)
return h, c
class StackLSTM(nn.Cell):
"""
Stack multi-layers LSTM together.
"""
def __init__(self,
input_size,
hidden_size,
num_layers=1,
has_bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False):
super(StackLSTM, self).__init__()
self.num_layers = num_layers
self.batch_first = batch_first
self.transpose = P.Transpose()
# direction number
num_directions = 2 if bidirectional else 1
# input_size list
input_size_list = [input_size]
for i in range(num_layers - 1):
input_size_list.append(hidden_size * num_directions)
# layers
layers = []
for i in range(num_layers):
layers.append(nn.LSTMCell(input_size=input_size_list[i],
hidden_size=hidden_size,
has_bias=has_bias,
batch_first=batch_first,
bidirectional=bidirectional,
dropout=dropout))
# weights
weights = []
for i in range(num_layers):
# weight size
weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
if has_bias:
bias_size = num_directions * hidden_size * 4
weight_size = weight_size + bias_size
# numpy weight
stdv = 1 / math.sqrt(hidden_size)
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
# lstm weight
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
#
self.lstms = layers
self.weight = ParameterTuple(tuple(weights))
def construct(self, x, hx):
"""construct"""
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
# stack lstm
h, c = hx
hn = cn = None
for i in range(self.num_layers):
x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i])
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
return x, (hn, cn)
class SentimentNet(nn.Cell): class SentimentNet(nn.Cell):
"""Sentiment network structure.""" """Sentiment network structure."""
@ -67,14 +135,25 @@ class SentimentNet(nn.Cell):
self.embedding.embedding_table.requires_grad = False self.embedding.embedding_table.requires_grad = False
self.trans = P.Transpose() self.trans = P.Transpose()
self.perm = (1, 0, 2) self.perm = (1, 0, 2)
self.encoder = nn.LSTM(input_size=embed_size,
hidden_size=num_hiddens,
num_layers=num_layers,
has_bias=True,
bidirectional=bidirectional,
dropout=0.0)
self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) if context.get_context("device_target") in STACK_LSTM_DEVICE:
# stack lstm by user
self.encoder = StackLSTM(input_size=embed_size,
hidden_size=num_hiddens,
num_layers=num_layers,
has_bias=True,
bidirectional=bidirectional,
dropout=0.0)
self.h, self.c = stack_lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
else:
# standard lstm
self.encoder = nn.LSTM(input_size=embed_size,
hidden_size=num_hiddens,
num_layers=num_layers,
has_bias=True,
bidirectional=bidirectional,
dropout=0.0)
self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
self.concat = P.Concat(1) self.concat = P.Concat(1)
if bidirectional: if bidirectional:

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
import math
import pytest import pytest
import numpy as np import numpy as np
@ -20,12 +21,83 @@ import mindspore.context as context
from mindspore.common.api import ms_function from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import ParameterTuple, Parameter from mindspore.common.parameter import ParameterTuple, Parameter
context.set_context(mode=context.GRAPH_MODE, device_target='CPU') context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class StackLSTM(nn.Cell):
"""
Stack multi-layers LSTM together.
"""
def __init__(self,
input_size,
hidden_size,
num_layers=1,
has_bias=True,
batch_first=False,
dropout=0.0,
bidirectional=False):
super(StackLSTM, self).__init__()
self.num_layers = num_layers
self.batch_first = batch_first
self.transpose = P.Transpose()
# direction number
num_directions = 2 if bidirectional else 1
# input_size list
input_size_list = [input_size]
for i in range(num_layers - 1):
input_size_list.append(hidden_size * num_directions)
# layers
layers = []
for i in range(num_layers):
layers.append(nn.LSTMCell(input_size=input_size_list[i],
hidden_size=hidden_size,
has_bias=has_bias,
batch_first=batch_first,
bidirectional=bidirectional,
dropout=dropout))
# weights
weights = []
for i in range(num_layers):
# weight size
weight_size = (input_size_list[i] + hidden_size) * num_directions * hidden_size * 4
if has_bias:
bias_size = num_directions * hidden_size * 4
weight_size = weight_size + bias_size
# numpy weight
stdv = 1 / math.sqrt(hidden_size)
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
# lstm weight
weights.append(Parameter(initializer(Tensor(w_np), w_np.shape), name="weight" + str(i)))
#
self.lstms = layers
self.weight = ParameterTuple(tuple(weights))
def construct(self, x, hx):
"""construct"""
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
# stack lstm
h, c = hx
hn = cn = None
for i in range(self.num_layers):
x, hn, cn, _, _ = self.lstms[i](x, h[i], c[i], self.weight[i])
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
return x, (hn, cn)
class LstmNet(nn.Cell): class LstmNet(nn.Cell):
def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): def __init__(self, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
super(LstmNet, self).__init__() super(LstmNet, self).__init__()
@ -34,7 +106,7 @@ class LstmNet(nn.Cell):
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout) self.lstm = StackLSTM(input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)
input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]], input_np = np.array([[[0.6755, -1.6607, 0.1367], [0.4276, -0.7850, -0.3758]],
[[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]], [[-0.6424, -0.6095, 0.6639], [0.7918, 0.4147, -0.5089]],
[[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]], [[-1.5612, 0.0120, -0.7289], [-0.6656, -0.6626, -0.5883]],
@ -137,8 +209,8 @@ class MultiLayerBiLstmNet(nn.Cell):
if bidirectional: if bidirectional:
num_directions = 2 num_directions = 2
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias, self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=has_bias,
bidirectional=bidirectional, dropout=dropout) bidirectional=bidirectional, dropout=dropout)
input_np = np.array([[[-0.1887, -0.4144, -0.0235, 0.7489, 0.7522, 0.5969, 0.3342, 1.2198, 0.6786, -0.9404], input_np = np.array([[[-0.1887, -0.4144, -0.0235, 0.7489, 0.7522, 0.5969, 0.3342, 1.2198, 0.6786, -0.9404],
[-0.8643, -1.6835, -2.4965, 2.8093, 0.1741, 0.2707, 0.7387, -0.0939, -1.7990, 0.4765]], [-0.8643, -1.6835, -2.4965, 2.8093, 0.1741, 0.2707, 0.7387, -0.0939, -1.7990, 0.4765]],
@ -264,8 +336,8 @@ class Net(nn.Cell):
bih = np.zeros((1, 8)).astype(np.float32) bih = np.zeros((1, 8)).astype(np.float32)
w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1]) w_np = np.concatenate((wih, whh, bih), axis=1).reshape([-1, 1, 1])
self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0') self.w = Parameter(initializer(Tensor(w_np), w_np.shape), name='weight0')
self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, self.lstm = StackLSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
has_bias=has_bias, bidirectional=bidirectional, dropout=dropout) has_bias=has_bias, bidirectional=bidirectional, dropout=dropout)
self.lstm.weight = ParameterTuple(tuple([self.w])) self.lstm.weight = ParameterTuple(tuple([self.w]))
@ms_function @ms_function