LSTM神经网络的详细推导及C++实现
来源:互联网 发布:网络布控坐火车 编辑:程序博客网 时间:2024/05/17 02:42
LSTM隐层神经元结构:
LSTM隐层神经元详细结构:
#include "iostream"#include "math.h"#include "stdlib.h"#include "time.h"#include "vector"#include "assert.h"using namespace std;#define innode 2 #define hidenode 26 #define outnode 1 #define alpha 0.1 #define binary_dim 8 #define randval(high) ( (double)rand() / RAND_MAX * high )#define uniform_plus_minus_one ( (double)( 2.0 * rand() ) / ((double)RAND_MAX + 1.0) - 1.0 ) int largest_number = ( pow(2, binary_dim) ); double sigmoid(double x) { return 1.0 / (1.0 + exp(-x));}double dsigmoid(double y){ return y * (1.0 - y); } double dtanh(double y){ y = tanh(y); return 1.0 - y * y; }void int2binary(int n, int *arr){ int i = 0; while(n) { arr[i++] = n % 2; n /= 2; } while(i < binary_dim) arr[i++] = 0;}class RNN{public: RNN(); virtual ~RNN(); void train();public: double W_I[innode][hidenode]; double U_I[hidenode][hidenode]; double W_F[innode][hidenode]; double U_F[hidenode][hidenode]; double W_O[innode][hidenode]; double U_O[hidenode][hidenode]; double W_G[innode][hidenode]; double U_G[hidenode][hidenode]; double W_out[hidenode][outnode]; double *x; double *y; };void winit(double w[], int n) { for(int i=0; i<n; i++) w[i] = uniform_plus_minus_one; }RNN::RNN(){ x = new double[innode]; y = new double[outnode]; winit((double*)W_I, innode * hidenode); winit((double*)U_I, hidenode * hidenode); winit((double*)W_F, innode * hidenode); winit((double*)U_F, hidenode * hidenode); winit((double*)W_O, innode * hidenode); winit((double*)U_O, hidenode * hidenode); winit((double*)W_G, innode * hidenode); winit((double*)U_G, hidenode * hidenode); winit((double*)W_out, hidenode * outnode);}RNN::~RNN(){ delete x; delete y;}void RNN::train(){ int epoch, i, j, k, m, p; vector<double*> I_vector; vector<double*> F_vector; vector<double*> O_vector; vector<double*> G_vector; vector<double*> S_vector; vector<double*> h_vector; vector<double> y_delta; for(epoch=0; epoch<11000; epoch++) { double e = 0.0; int predict[binary_dim]; memset(predict, 0, sizeof(predict)); int a_int = (int)randval(largest_number/2.0); int a[binary_dim]; int2binary(a_int, a); int b_int = (int)randval(largest_number/2.0); int b[binary_dim]; int2binary(b_int, b); int c_int = a_int + b_int; int c[binary_dim]; int2binary(c_int, c); double *S = new double[hidenode]; double *h = new double[hidenode]; for(i=0; i<hidenode; i++) { S[i] = 0; h[i] = 0; } S_vector.push_back(S); h_vector.push_back(h); for(p=0; p<binary_dim; p++) { x[0] = a[p]; x[1] = b[p]; double t = (double)c[p]; double *in_gate = new double[hidenode]; double *out_gate = new double[hidenode]; double *forget_gate = new double[hidenode]; double *g_gate = new double[hidenode]; double *state = new double[hidenode]; double *h = new double[hidenode]; for(j=0; j<hidenode; j++) { double inGate = 0.0; double outGate = 0.0; double forgetGate = 0.0; double gGate = 0.0; double s = 0.0; for(m=0; m<innode; m++) { inGate += x[m] * W_I[m][j]; outGate += x[m] * W_O[m][j]; forgetGate += x[m] * W_F[m][j]; gGate += x[m] * W_G[m][j]; } double *h_pre = h_vector.back(); double *state_pre = S_vector.back(); for(m=0; m<hidenode; m++) { inGate += h_pre[m] * U_I[m][j]; outGate += h_pre[m] * U_O[m][j]; forgetGate += h_pre[m] * U_F[m][j]; gGate += h_pre[m] * U_G[m][j]; } in_gate[j] = sigmoid(inGate); out_gate[j] = sigmoid(outGate); forget_gate[j] = sigmoid(forgetGate); g_gate[j] = sigmoid(gGate); double s_pre = state_pre[j]; state[j] = forget_gate[j] * s_pre + g_gate[j] * in_gate[j]; h[j] = in_gate[j] * tanh(state[j]); } for(k=0; k<outnode; k++) { double out = 0.0; for(j=0; j<hidenode; j++) out += h[j] * W_out[j][k]; y[k] = sigmoid(out); } predict[p] = (int)floor(y[0] + 0.5); I_vector.push_back(in_gate); F_vector.push_back(forget_gate); O_vector.push_back(out_gate); S_vector.push_back(state); G_vector.push_back(g_gate); h_vector.push_back(h); y_delta.push_back( (t - y[0]) * dsigmoid(y[0]) ); e += fabs(t - y[0]); } double h_delta[hidenode]; double *O_delta = new double[hidenode]; double *I_delta = new double[hidenode]; double *F_delta = new double[hidenode]; double *G_delta = new double[hidenode]; double *state_delta = new double[hidenode]; double *O_future_delta = new double[hidenode]; double *I_future_delta = new double[hidenode]; double *F_future_delta = new double[hidenode]; double *G_future_delta = new double[hidenode]; double *state_future_delta = new double[hidenode]; double *forget_gate_future = new double[hidenode]; for(j=0; j<hidenode; j++) { O_future_delta[j] = 0; I_future_delta[j] = 0; F_future_delta[j] = 0; G_future_delta[j] = 0; state_future_delta[j] = 0; forget_gate_future[j] = 0; } for(p=binary_dim-1; p>=0 ; p--) { x[0] = a[p]; x[1] = b[p]; double *in_gate = I_vector[p]; double *out_gate = O_vector[p]; double *forget_gate = F_vector[p]; double *g_gate = G_vector[p]; double *state = S_vector[p+1]; double *h = h_vector[p+1]; double *h_pre = h_vector[p]; double *state_pre = S_vector[p]; for(k=0; k<outnode; k++) { for(j=0; j<hidenode; j++) W_out[j][k] += alpha * y_delta[p] * h[j]; } for(j=0; j<hidenode; j++) { h_delta[j] = 0.0; for(k=0; k<outnode; k++) { h_delta[j] += y_delta[p] * W_out[j][k]; } for(k=0; k<hidenode; k++) { h_delta[j] += I_future_delta[k] * U_I[j][k]; h_delta[j] += F_future_delta[k] * U_F[j][k]; h_delta[j] += O_future_delta[k] * U_O[j][k]; h_delta[j] += G_future_delta[k] * U_G[j][k]; } O_delta[j] = 0.0; I_delta[j] = 0.0; F_delta[j] = 0.0; G_delta[j] = 0.0; state_delta[j] = 0.0; O_delta[j] = h_delta[j] * tanh(state[j]) * dsigmoid(out_gate[j]); state_delta[j] = h_delta[j] * out_gate[j] * dtanh(state[j]) + state_future_delta[j] * forget_gate_future[j]; F_delta[j] = state_delta[j] * state_pre[j] * dsigmoid(forget_gate[j]); I_delta[j] = state_delta[j] * g_gate[j] * dsigmoid(in_gate[j]); G_delta[j] = state_delta[j] * in_gate[j] * dsigmoid(g_gate[j]); for(k=0; k<hidenode; k++) { U_I[k][j] += alpha * I_delta[j] * h_pre[k]; U_F[k][j] += alpha * F_delta[j] * h_pre[k]; U_O[k][j] += alpha * O_delta[j] * h_pre[k]; U_G[k][j] += alpha * G_delta[j] * h_pre[k]; } for(k=0; k<innode; k++) { W_I[k][j] += alpha * I_delta[j] * x[k]; W_F[k][j] += alpha * F_delta[j] * x[k]; W_O[k][j] += alpha * O_delta[j] * x[k]; W_G[k][j] += alpha * G_delta[j] * x[k]; } } if(p == binary_dim-1) { delete O_future_delta; delete F_future_delta; delete I_future_delta; delete G_future_delta; delete state_future_delta; delete forget_gate_future; } O_future_delta = O_delta; F_future_delta = F_delta; I_future_delta = I_delta; G_future_delta = G_delta; state_future_delta = state_delta; forget_gate_future = forget_gate; } delete O_future_delta; delete F_future_delta; delete I_future_delta; delete G_future_delta; delete state_future_delta; if(epoch % 1000 == 0) { cout << "error:" << e << endl; cout << "pred:" ; for(k=binary_dim-1; k>=0; k--) cout << predict[k]; cout << endl; cout << "true:" ; for(k=binary_dim-1; k>=0; k--) cout << c[k]; cout << endl; int out = 0; for(k=binary_dim-1; k>=0; k--) out += predict[k] * pow(2, k); cout << a_int << " + " << b_int << " = " << out << endl << endl; } for(i=0; i<I_vector.size(); i++) delete I_vector[i]; for(i=0; i<F_vector.size(); i++) delete F_vector[i]; for(i=0; i<O_vector.size(); i++) delete O_vector[i]; for(i=0; i<G_vector.size(); i++) delete G_vector[i]; for(i=0; i<S_vector.size(); i++) delete S_vector[i]; for(i=0; i<h_vector.size(); i++) delete h_vector[i]; I_vector.clear(); F_vector.clear(); O_vector.clear(); G_vector.clear(); S_vector.clear(); h_vector.clear(); y_delta.clear(); }}int main(){ srand(time(NULL)); RNN rnn; rnn.train(); return 0;}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
- 266
- 267
- 268
- 269
- 270
- 271
- 272
- 273
- 274
- 275
- 276
- 277
- 278
- 279
- 280
- 281
- 282
- 283
- 284
- 285
- 286
- 287
- 288
- 289
- 290
- 291
- 292
- 293
- 294
- 295
- 296
- 297
- 298
- 299
- 300
- 301
- 302
- 303
- 304
- 305
- 306
- 307
- 308
- 309
- 310
- 311
- 312
- 313
- 314
- 315
- 316
- 317
- 318
- 319
- 320
- 321
- 322
- 323
- 324
- 325
- 326
- 327
- 328
- 329
- 330
- 331
- 332
- 333
- 334
- 335
- 336
- 337
- 338
- 339
- 340
- 341
- 342
- 343
- 344
- 345
- 346
- 347
- 348
- 349
- 350
- 351
- 352
- 353
- 354
- 355
- 356
- 357
- 358
- 359
- 360
- 361
- 362
- 363
- 364
- 365
- 366
- 367
- 368
- 369
- 370
- 371
- 372
- 373
- 374
- 375
- 376
- 377
- 378
- 379
- 380
- 381
- 382
- 383
- 384
- 385
- 386
- 387
- 388
- 389
- 390
- 391
- 392
- 393
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
- 253
- 254
- 255
- 256
- 257
- 258
- 259
- 260
- 261
- 262
- 263
- 264
- 265
- 266
- 267
- 268
- 269
- 270
- 271
- 272
- 273
- 274
- 275
- 276
- 277
- 278
- 279
- 280
- 281
- 282
- 283
- 284
- 285
- 286
- 287
- 288
- 289
- 290
- 291
- 292
- 293
- 294
- 295
- 296
- 297
- 298
- 299
- 300
- 301
- 302
- 303
- 304
- 305
- 306
- 307
- 308
- 309
- 310
- 311
- 312
- 313
- 314
- 315
- 316
- 317
- 318
- 319
- 320
- 321
- 322
- 323
- 324
- 325
- 326
- 327
- 328
- 329
- 330
- 331
- 332
- 333
- 334
- 335
- 336
- 337
- 338
- 339
- 340
- 341
- 342
- 343
- 344
- 345
- 346
- 347
- 348
- 349
- 350
- 351
- 352
- 353
- 354
- 355
- 356
- 357
- 358
- 359
- 360
- 361
- 362
- 363
- 364
- 365
- 366
- 367
- 368
- 369
- 370
- 371
- 372
- 373
- 374
- 375
- 376
- 377
- 378
- 379
- 380
- 381
- 382
- 383
- 384
- 385
- 386
- 387
- 388
- 389
- 390
- 391
- 392
- 393
0 0