我正在學習LSTM網絡,並決定嘗試綜合測試。我想通過一些點(X,Y)供給LSTM網絡三個基本功能之間進行區分:如何訓練最簡單的功能識別LSTM
- 線爲:y = K * X + B
- 拋物線:Y = K * X^2 + B
- SQRT:Y = K * SQRT(X)+ b
我使用LUA +炬。
數據集是完全虛擬的 - 它是在「數據集」對象上實時創建的。當訓練週期要求樣本的另一個小樣本時,函數mt .__ index會返回動態創建的樣本。它隨機選擇三個所描述的功能併爲它們挑選一些隨機點。
想法是,LSTM網絡將學習一些功能,以識別最後點屬於哪種功能。
完整而簡單的源腳本包括:
require "torch"
require "nn"
require "rnn"
-- hyper-parameters
batchSize = 8
rho = 5 -- sequence length
hiddenSize = 100
outputSize = 3
lr = 0.001
-- Initialize synthetic dataset
-- dataset[index] returns table of the form: {inputs, targets}
-- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
-- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
local dataset = {}
dataset.size = function (self)
return 1000
end
local mt = {}
mt.__index = function (self, i)
local class = math.random(3)
local t = torch.Tensor(3):zero()
t[class] = 1
local targets = {}
for i = 1,batchSize do table.insert(targets, class) end
local inputs = {}
local k = math.random()
local b = math.random()*5
-- Line
if class == 1 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Parabola
elseif class == 2 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Sqrt
else
for i = 1,batchSize do
local x = math.random()*5 + 5
local y = k*math.sqrt(x) + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
end
return { inputs, targets }
end -- dataset.__index meta function
setmetatable(dataset, mt)
-- Initialize random number generator
math.randomseed(os.time())
-- build simple recurrent neural network
local model = nn.Sequencer(
nn.Sequential()
:add(nn.LSTM(2, hiddenSize, rho))
:add(nn.Linear(hiddenSize, outputSize))
:add(nn.LogSoftMax())
)
print(model)
-- build criterion
local criterion = nn.SequencerCriterion(nn.ClassNLLCriterion())
-- training
model:training()
local epoch = 1
while true do
print ("Epoch "..tostring(epoch).." started")
for iteration = 1, dataset:size() do
-- 1. Load minibatch of samples
local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
local inputs = sample[1]
local targets = sample[2]
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
local err = criterion:forward(outputs, targets)
print(string.format("Epoch %d Iteration %d Error = %f", epoch, iteration, err))
-- 3. Backward sequence through model(i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
-- Sequencer handles the backwardThroughTime internally
model:backward(inputs, gradOutputs)
model:updateParameters(lr)
model:zeroGradParameters()
end -- for dataset
epoch = epoch + 1
end -- while epoch
的問題是:網絡不收斂。 你能分享我做錯什麼嗎?
謝謝馬爾欽。我想我現在得到它。此外 - 所討論的代碼在每次迭代中生成隨機k和b - 這使NN無法學習任何特徵。我看到2個可能的解決方案:1.在啓動時只生成一次k和b。這意味着我們得到一些固定線,拋物線和sqrt。 2.依次生成輸入點。 雖然(2.)是可選的。我試圖實現(1.) - 它工作!經過1000次迭代後,NN能夠以99%的精度識別功能! –