!2106 fix BasicLSTMCell reg and bp error
Merge pull request !2106 from zhaozhenlong/fix-issue-lstm-reg-error
This commit is contained in:
commit
a0b346690f
|
@ -714,7 +714,7 @@ def get_bprop_basic_lstm_cell(self):
|
||||||
def bprop(x, h, c, w, b, out, dout):
|
def bprop(x, h, c, w, b, out, dout):
|
||||||
_, _, it, jt, ft, ot, tanhct = out
|
_, _, it, jt, ft, ot, tanhct = out
|
||||||
dct, dht, _, _, _, _, _ = dout
|
dct, dht, _, _, _, _, _ = dout
|
||||||
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
|
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct)
|
||||||
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
|
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
|
||||||
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
|
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
|
||||||
return dxt, dht, dct_1, dw, db
|
return dxt, dht, dct_1, dw, db
|
||||||
|
|
|
@ -37,10 +37,10 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \
|
||||||
.output(1, "dct_1", False, "required", "all") \
|
.output(1, "dct_1", False, "required", "all") \
|
||||||
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
.dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
DataType.F16_FracNZ, DataType.F32_FracNZ) \
|
||||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||||
DataType.F32_FracNZ, DataType.F16_FracNZ) \
|
DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue