-
Notifications
You must be signed in to change notification settings - Fork 76
/
seq.lua
46 lines (40 loc) · 1.06 KB
/
seq.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
local seq = {}
local function repl(N)
local net = nn.ConcatTable()
for i=1,N do
net:add(nn.Identity())
end
return net
end
function seq.encoder(net,hsize)
local enc = nn.Sequential()
enc:add(net)
enc:add(nn.Sequencer(nn.LSTM(hsize, hsize)))
enc:add(nn.SelectTable(-1))
return enc
end
function seq.decoder(hsize,csize)
local dec = nn.Sequential()
:add(nn.LSTM(hsize, hsize))
:add(nn.Linear(hsize, csize))
:add(nn.LogSoftMax())
return dec
end
function seq.decoder(hsize,csize)
local dec = nn.Sequential()
:add(nn.LSTM(hsize, hsize))
:add(nn.Linear(hsize, csize))
:add(nn.LogSoftMax())
return dec
end
function seq.seq(net,hsize,csize,osize)
local enc = seq.enc(net,hsize)
local dec = seq.dec(hsize,csize)
local encdec = nn.Sequential()
:add(enc)
:add(repl(osize))
:add(dec)
local criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
return encdec,criterion
end
return seq