-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathLnLSTM.hpp
More file actions
61 lines (48 loc) · 1.68 KB
/
LnLSTM.hpp
File metadata and controls
61 lines (48 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#pragma once
#include "LSTM.hpp"
#include "LayerNormalizer.hpp"
class LnLSTM : public LSTM{
public:
LnLSTM(){}
LnLSTM(const int inputDim, const int hiddenDim);
LnLSTM(const int inputDim, const int additionalInputDim, const int hiddenDim);
class State;
class Grad;
LayerNormalizer lnh, lnx, lnc, lna;
void init(Rand& rnd, const Real scale = 1.0);
void forward(const VecD& xt, const LSTM::State* prev, LSTM::State* cur);
void forward(const VecD& xt, LSTM::State* cur);
void backward(LSTM::State* prev, LSTM::State* cur, LSTM::Grad& grad, const VecD& xt);
void backward(LSTM::State* cur, LSTM::Grad& grad, const VecD& xt);
void sgd(const LnLSTM::Grad& grad, const Real learningRate);
void save(std::ofstream& ofs);
void load(std::ifstream& ifs);
void forward(const VecD& xt, const VecD& at, const LSTM::State* prev, LSTM::State* cur);
void backward(LSTM::State* prev, LSTM::State* cur, LSTM::Grad& grad, const VecD& xt, const VecD& at);
};
class LnLSTM::State: public LSTM::State{
public:
State(): lnsh(new LayerNormalizer::State), lnsx(new LayerNormalizer::State), lnsc(new LayerNormalizer::State), lnsa(new LayerNormalizer::State){}
~State() {
this->clear();
delete this->lnsh;
delete this->lnsx;
delete this->lnsc;
delete this->lnsa;}
LayerNormalizer::State* lnsh;
LayerNormalizer::State* lnsx;
LayerNormalizer::State* lnsc;
LayerNormalizer::State* lnsa;
VecD lnhConcat, lnxConcat, lnaConcat;
VecD delConcat;
void clear();
};
class LnLSTM::Grad: public LSTM::Grad{
public:
Grad(){}
Grad(const LnLSTM& lnlstm);
void init();
Real norm();
void operator += (const LnLSTM::Grad& grad);
LayerNormalizer::Grad lnh, lnx, lnc, lna;
};