2010-12-02 52 views
9

PyBrain是一個python庫,提供(除其他外)易於使用的人工神經網絡。如何序列化/反序列化的pybrain網絡?

我無法正確使用pickle或cPickle對PyBrain網絡進行序列化/反序列化。

參見下面的示例:

from pybrain.datasets   import SupervisedDataSet 
from pybrain.tools.shortcuts  import buildNetwork 
from pybrain.supervised.trainers import BackpropTrainer 
import cPickle as pickle 
import numpy as np 

#generate some data 
np.random.seed(93939393) 
data = SupervisedDataSet(2, 1) 
for x in xrange(10): 
    y = x * 3 
    z = x + y + 0.2 * np.random.randn() 
    data.addSample((x, y), (z,)) 

#build a network and train it  

net1 = buildNetwork(data.indim, 2, data.outdim) 
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True) 
for i in xrange(4): 
    trainer1.trainEpochs(1) 
    print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0]) 

這是上面的代碼的輸出:

Total error: 201.501998476 
    value after 0 epochs: 2.79 
Total error: 152.487616382 
    value after 1 epochs: 5.44 
Total error: 120.48092561 
    value after 2 epochs: 7.56 
Total error: 97.9884043452 
    value after 3 epochs: 8.41 

正如你可以看到,隨着訓練的進行網絡總誤差降低。你也可以看到,預測值接近的12

預期值,現在我們會做類似的工作,但將包括序列化/反序列化:

print 'creating net2' 
net2 = buildNetwork(data.indim, 2, data.outdim) 
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True) 
trainer2.trainEpochs(1) 
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0]) 

#So far, so good. Let's test pickle 
pickle.dump(net2, open('testNetwork.dump', 'w')) 
net2 = pickle.load(open('testNetwork.dump')) 
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True) 
print 'loaded net2 using pickle, continue training' 
for i in xrange(1, 4): 
     trainer2.trainEpochs(1) 
     print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0]) 

這是此塊的輸出:

creating net2 
Total error: 176.339378639 
    value after 1 epochs: 5.45 
loaded net2 using pickle, continue training 
Total error: 123.392181859 
    value after 1 epochs: 5.45 
Total error: 94.2867637623 
    value after 2 epochs: 5.45 
Total error: 78.076711114 
    value after 3 epochs: 5.45 

正如你所看到的,似乎培訓網絡(報告的總誤差值繼續下降),但是網絡的輸出值凍結的,這是相關的第一個值上有一定的影響訓練迭代。

是否有任何需要注意的緩存機制會導致此錯誤行爲?是否有更好的方法來序列化/反序列化pybrain網絡?

相關版本號:

  • 的Python 2.6.5(R265:79096,2010年3月19日,21時48分26秒)[MSC v.1500 32位(英特爾)]
  • numpy的1.5。 1
  • cPickle的1.71
  • pybrain 0.3

PS我已經在該項目的網站上創建a bug report,將保留兩個SO和bug跟蹤系統updatedj

+0

您確定在重新加載`net2`後,不應該再次執行`trainer2 = BackpropTrainer(net2,dataset = data,verbose = True)`? – 2010-12-02 13:15:02

回答

11

原因

導致這種行爲的機制是在PyBrain參數(.params)和衍生物(.derivs)處理模塊:實際上,所有網絡參數都存儲在一個陣列中,但個別的ModuleConnection對象可以訪問「他們自己的」.params,但它們只是整個陣列的一部分的視圖。這允許在相同的數據結構上進行本地和全網絡的寫入和讀出。

很顯然,這個切片視圖鏈接會因酸洗而不知所措。

解決方案

插入

net2.sorted = False 
net2.sortModules() 

從文件(其中再現了這個共享)加載後,它應該工作。