关于Pytorch中LSTM前向传播的一个验证
简要记录Pytorch中LSTM前向传播的验证。
1 前言
在使用Pytorch中LSTM的使用,一直会存在几点疑问:
- 分批训练的时候,LSTM中的序列逐渐输入依次输出与自己手动一次次给LSTM做输入是否一致?(答案:一致)
- 多层LSTM的设置是不是Pytorch中的cell?(答案:是的)
2 LSTM序列直接输入和依次输入的验证
1 | import torch |
结果:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85RNN(
(rnn): LSTM(4, 5, batch_first=True)
)
rnn param:
name: rnn.weight_ih_l0 torch.Size([20, 4])
param: Parameter containing:
tensor([[ 0.0449, 0.0707, -0.0393, 0.0395],
[ 0.0645, -0.0325, 0.0488, -0.0596],
[ 0.0624, -0.0478, -0.0269, 0.0331],
[-0.0517, -0.0195, 0.0451, -0.0473],
[-0.0484, -0.0571, -0.0639, -0.0441],
[-0.0117, 0.0692, -0.0243, 0.0735],
[-0.0595, 0.0038, -0.0086, 0.0466],
[ 0.0254, 0.0347, 0.0495, 0.0562],
[ 0.0456, 0.0009, -0.0051, -0.0497],
[-0.0673, -0.0705, -0.0513, 0.0625],
[ 0.0251, 0.0018, -0.0748, 0.0668],
[ 0.0722, 0.0215, 0.0752, 0.0411],
[ 0.0584, 0.0039, 0.0678, 0.0541],
[-0.0537, 0.0217, 0.0074, -0.0665],
[ 0.0614, -0.0069, 0.0511, 0.0226],
[-0.0080, 0.0327, -0.0524, 0.0015],
[ 0.0308, -0.0742, -0.0100, -0.0072],
[-0.0191, -0.0472, 0.0721, 0.0302],
[-0.0441, 0.0121, -0.0346, 0.0012],
[ 0.0420, 0.0477, -0.0551, 0.0526]], requires_grad=True)
name: rnn.weight_hh_l0 torch.Size([20, 5])
param: Parameter containing:
tensor([[-0.0148, -0.0516, 0.0692, 0.0709, -0.0271],
[-0.0097, 0.0695, -0.0254, 0.0593, 0.0047],
[ 0.0134, -0.0035, -0.0379, -0.0302, 0.0038],
[-0.0277, 0.0703, 0.0463, 0.0604, -0.0431],
[ 0.0381, 0.0209, -0.0193, 0.0672, 0.0152],
[ 0.0761, -0.0743, -0.0116, 0.0341, 0.0542],
[-0.0195, -0.0670, -0.0298, -0.0319, -0.0263],
[-0.0087, 0.0492, 0.0289, -0.0410, -0.0386],
[ 0.0496, -0.0090, -0.0324, 0.0696, -0.0223],
[-0.0294, -0.0665, 0.0589, -0.0306, -0.0638],
[ 0.0753, -0.0558, 0.0510, 0.0738, 0.0045],
[-0.0210, -0.0420, -0.0287, 0.0511, -0.0410],
[ 0.0017, -0.0286, -0.0728, 0.0357, 0.0564],
[-0.0152, 0.0389, -0.0335, -0.0051, 0.0425],
[ 0.0673, 0.0345, -0.0648, -0.0781, 0.0205],
[-0.0342, -0.0217, 0.0453, 0.0501, -0.0570],
[-0.0381, 0.0469, -0.0716, -0.0330, -0.0564],
[-0.0537, 0.0374, 0.0703, -0.0208, 0.0104],
[ 0.0738, 0.0166, 0.0372, -0.0009, 0.0514],
[-0.0605, -0.0516, 0.0641, -0.0779, 0.0699]], requires_grad=True)
name: rnn.bias_ih_l0 torch.Size([20])
param: Parameter containing:
tensor([ 0.0719, 0.0595, 0.0036, 0.0694, 0.0378, 0.0128, -0.0564, -0.0279,
0.0115, 0.0693, 0.0534, -0.0731, -0.0012, 0.0284, -0.0694, 0.0135,
0.0451, -0.0575, 0.0220, -0.0532], requires_grad=True)
name: rnn.bias_hh_l0 torch.Size([20])
param: Parameter containing:
tensor([ 0.0154, 0.0417, -0.0557, -0.0170, -0.0160, 0.0391, 0.0685, -0.0552,
-0.0166, 0.0263, 0.0050, -0.0282, -0.0162, 0.0333, -0.0444, 0.0584,
0.0008, 0.0734, -0.0651, 0.0419], requires_grad=True)
out: tensor([[[ 0.0808, -0.0553, -0.0342, 0.0148, -0.0718],
[ 0.0570, -0.0712, -0.0422, 0.0514, -0.0905],
[ 0.0671, -0.0729, -0.0326, 0.0359, -0.0875]],
[[ 0.0761, -0.0196, -0.0030, 0.0017, -0.0420],
[ 0.0764, -0.0191, 0.0171, -0.0125, -0.0333],
[ 0.0501, -0.0520, -0.0162, 0.0257, -0.0549]]],
grad_fn=<TransposeBackward0>)
r_out: tensor([[[ 0.0808, -0.0553, -0.0342, 0.0148, -0.0718]],
[[ 0.0761, -0.0196, -0.0030, 0.0017, -0.0420]]],
grad_fn=<TransposeBackward0>)
r_out: tensor([[[ 0.0570, -0.0712, -0.0422, 0.0514, -0.0905]],
[[ 0.0764, -0.0191, 0.0171, -0.0125, -0.0333]]],
grad_fn=<TransposeBackward0>)
r_out: tensor([[[ 0.0671, -0.0729, -0.0326, 0.0359, -0.0875]],
[[ 0.0501, -0.0520, -0.0162, 0.0257, -0.0549]]],
grad_fn=<TransposeBackward0>)
out: tensor([[[ 0.0671, -0.0729, -0.0326, 0.0359, -0.0875]],
[[ 0.0501, -0.0520, -0.0162, 0.0257, -0.0549]]],
grad_fn=<TransposeBackward0>)
3 关于多层LSTM和cell的讨论
为了省略之后再点链接,这里复制了 [link]中的代码,这个链接详细讨论了多层LSTM的设置问题,本文直接给出链接中的最后设置正确的验证代码与结果。
3.1 方法一:设置cell=2
1 | import torch |
结果:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68weight weight_ih_l0 after init: Parameter containing:
tensor([[ 0.6025, -0.1577, -0.0990],
[-0.5255, 0.4554, 0.4651],
[ 0.1428, 0.1414, -0.0291],
[ 0.1248, 0.3465, -0.5053],
[ 0.6295, -0.8635, -0.3394],
[ 0.1072, 0.0786, 0.3427],
[ 0.5352, -0.2032, 0.8816],
[ 0.3727, -0.1608, -0.6332],
[-0.3745, 0.1903, -0.1654],
[-0.0460, -0.2148, 0.7737],
[-0.1980, -0.8980, -0.3470],
[-0.1130, 0.6074, 0.1844]], requires_grad=True)
weight weight_hh_l0 after init: Parameter containing:
tensor([[-0.0719, -0.0122, 0.2626],
[ 0.3887, -0.3044, -0.4356],
[-0.8422, 0.2204, 0.1151],
[ 0.4171, 0.1116, -0.2114],
[ 0.2061, -0.3204, -0.0983],
[ 0.4791, -0.5683, -0.3928],
[-0.3196, -0.1726, -0.0732],
[-0.3058, -0.5667, -0.0211],
[-0.0832, -0.3168, 0.1241],
[-0.4197, 0.0525, 0.0741],
[ 0.3849, 0.0481, -0.3130],
[ 0.5788, 0.6312, -0.3627]], requires_grad=True)
weight weight_ih_l1 after init: Parameter containing:
tensor([[ 3.6955e-02, 7.1276e-02, -4.3073e-01],
[-5.2666e-01, 2.7323e-02, 1.2894e-01],
[ 3.7136e-01, 3.3969e-01, 1.9601e-01],
[ 3.5802e-01, -4.3600e-01, -1.7962e-01],
[ 8.3209e-01, 1.7189e-01, 2.2195e-01],
[-2.1302e-02, -1.6867e-01, -1.3460e-01],
[ 1.3446e-01, 1.7708e-01, -5.6676e-01],
[-2.3697e-01, -2.8254e-02, -2.2063e-01],
[-2.0928e-01, 3.4973e-01, 3.5858e-04],
[-5.0565e-01, -6.8619e-02, 3.7702e-01],
[-9.0796e-02, -1.7238e-01, 4.7868e-01],
[-1.1565e-01, -6.7956e-02, -2.1049e-01]], requires_grad=True)
weight weight_hh_l1 after init: Parameter containing:
tensor([[-0.3017, -0.0811, -0.6554],
[ 0.2665, -0.2052, -0.0577],
[ 0.5493, -0.5094, 0.2167],
[ 0.1210, -0.3868, -0.2293],
[-0.0991, 0.6744, -0.0114],
[-0.0343, -0.6136, 0.4856],
[ 0.0505, 0.3920, -0.1662],
[ 0.1163, -0.1296, 0.2505],
[-0.1373, -0.8803, -0.4666],
[-0.0230, -0.0346, -0.8451],
[ 0.2032, 0.1847, -0.0758],
[ 0.2533, 0.1532, 0.8224]], requires_grad=True)
inputs: [tensor([[1.5381, 1.4673, 1.5951]]), tensor([[-1.5279, 1.0156, -0.2020]]), tensor([[-1.2865, 0.8231, -0.6101]]), tensor([[-1.2960, -0.9434, 0.6684]]), tensor([[ 1.1628, -0.3229, 1.8782]])]
idx: 0
tensor([[[ 0.0374, -0.0085, -0.0240]]], grad_fn=<StackBackward>)
==========
idx: 1
tensor([[[ 0.0073, -0.0110, -0.0296]]], grad_fn=<StackBackward>)
==========
idx: 2
tensor([[[-0.0314, -0.0147, -0.0136]]], grad_fn=<StackBackward>)
==========
idx: 3
tensor([[[-0.0458, -0.0118, -0.0254]]], grad_fn=<StackBackward>)
==========
idx: 4
tensor([[[-0.0096, -0.0281, -0.0440]]], grad_fn=<StackBackward>)
==========
3.2 方法二:手动设置两层LSTM
1 | torch.manual_seed(1) |
结果:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94inputs: [tensor([[1.5381, 1.4673, 1.5951]]), tensor([[-1.5279, 1.0156, -0.2020]]), tensor([[-1.2865, 0.8231, -0.6101]]), tensor([[-1.2960, -0.9434, 0.6684]]), tensor([[ 1.1628, -0.3229, 1.8782]])]
lstm weight_ih_l0 after init: Parameter containing:
tensor([[ 0.6025, -0.1577, -0.0990],
[-0.5255, 0.4554, 0.4651],
[ 0.1428, 0.1414, -0.0291],
[ 0.1248, 0.3465, -0.5053],
[ 0.6295, -0.8635, -0.3394],
[ 0.1072, 0.0786, 0.3427],
[ 0.5352, -0.2032, 0.8816],
[ 0.3727, -0.1608, -0.6332],
[-0.3745, 0.1903, -0.1654],
[-0.0460, -0.2148, 0.7737],
[-0.1980, -0.8980, -0.3470],
[-0.1130, 0.6074, 0.1844]], requires_grad=True)
lstm weight_hh_l0 after init: Parameter containing:
tensor([[-0.0719, -0.0122, 0.2626],
[ 0.3887, -0.3044, -0.4356],
[-0.8422, 0.2204, 0.1151],
[ 0.4171, 0.1116, -0.2114],
[ 0.2061, -0.3204, -0.0983],
[ 0.4791, -0.5683, -0.3928],
[-0.3196, -0.1726, -0.0732],
[-0.3058, -0.5667, -0.0211],
[-0.0832, -0.3168, 0.1241],
[-0.4197, 0.0525, 0.0741],
[ 0.3849, 0.0481, -0.3130],
[ 0.5788, 0.6312, -0.3627]], requires_grad=True)
lstm2 weight_ih_l0 after init: Parameter containing:
tensor([[ 3.6955e-02, 7.1276e-02, -4.3073e-01],
[-5.2666e-01, 2.7323e-02, 1.2894e-01],
[ 3.7136e-01, 3.3969e-01, 1.9601e-01],
[ 3.5802e-01, -4.3600e-01, -1.7962e-01],
[ 8.3209e-01, 1.7189e-01, 2.2195e-01],
[-2.1302e-02, -1.6867e-01, -1.3460e-01],
[ 1.3446e-01, 1.7708e-01, -5.6676e-01],
[-2.3697e-01, -2.8254e-02, -2.2063e-01],
[-2.0928e-01, 3.4973e-01, 3.5858e-04],
[-5.0565e-01, -6.8619e-02, 3.7702e-01],
[-9.0796e-02, -1.7238e-01, 4.7868e-01],
[-1.1565e-01, -6.7956e-02, -2.1049e-01]], requires_grad=True)
lstm2 weight_hh_l0 after init: Parameter containing:
tensor([[-0.3017, -0.0811, -0.6554],
[ 0.2665, -0.2052, -0.0577],
[ 0.5493, -0.5094, 0.2167],
[ 0.1210, -0.3868, -0.2293],
[-0.0991, 0.6744, -0.0114],
[-0.0343, -0.6136, 0.4856],
[ 0.0505, 0.3920, -0.1662],
[ 0.1163, -0.1296, 0.2505],
[-0.1373, -0.8803, -0.4666],
[-0.0230, -0.0346, -0.8451],
[ 0.2032, 0.1847, -0.0758],
[ 0.2533, 0.1532, 0.8224]], requires_grad=True)
lstm2 weight_ih_l0 after init: Parameter containing:
tensor([[ 3.6955e-02, 7.1276e-02, -4.3073e-01],
[-5.2666e-01, 2.7323e-02, 1.2894e-01],
[ 3.7136e-01, 3.3969e-01, 1.9601e-01],
[ 3.5802e-01, -4.3600e-01, -1.7962e-01],
[ 8.3209e-01, 1.7189e-01, 2.2195e-01],
[-2.1302e-02, -1.6867e-01, -1.3460e-01],
[ 1.3446e-01, 1.7708e-01, -5.6676e-01],
[-2.3697e-01, -2.8254e-02, -2.2063e-01],
[-2.0928e-01, 3.4973e-01, 3.5858e-04],
[-5.0565e-01, -6.8619e-02, 3.7702e-01],
[-9.0796e-02, -1.7238e-01, 4.7868e-01],
[-1.1565e-01, -6.7956e-02, -2.1049e-01]], requires_grad=True)
lstm2 weight_hh_l0 after init: Parameter containing:
tensor([[-0.3017, -0.0811, -0.6554],
[ 0.2665, -0.2052, -0.0577],
[ 0.5493, -0.5094, 0.2167],
[ 0.1210, -0.3868, -0.2293],
[-0.0991, 0.6744, -0.0114],
[-0.0343, -0.6136, 0.4856],
[ 0.0505, 0.3920, -0.1662],
[ 0.1163, -0.1296, 0.2505],
[-0.1373, -0.8803, -0.4666],
[-0.0230, -0.0346, -0.8451],
[ 0.2032, 0.1847, -0.0758],
[ 0.2533, 0.1532, 0.8224]], requires_grad=True)
idx: 0
tensor([[[ 0.0374, -0.0085, -0.0240]]], grad_fn=<StackBackward>)
==========
idx: 1
tensor([[[ 0.0073, -0.0110, -0.0296]]], grad_fn=<StackBackward>)
==========
idx: 2
tensor([[[-0.0314, -0.0147, -0.0136]]], grad_fn=<StackBackward>)
==========
idx: 3
tensor([[[-0.0458, -0.0118, -0.0254]]], grad_fn=<StackBackward>)
==========
idx: 4
tensor([[[-0.0096, -0.0281, -0.0440]]], grad_fn=<StackBackward>)
==========