2016-07-24 58 views
1

我在tensorflow中有一個類,它有權重和文檔嵌入。我將用它來進行訓練和驗證。我的查詢是,它可能在tensorflow會話中用於驗證集,以僅重用來自我的訓練而不是嵌入的權重,並讓它爲有效集學習新的文檔嵌入。代碼片段。如何僅重用張量流中的一些變量?

Class NewModel(Object): 
    def __init__(self, is_training, vocabuary_size, embedding_size): 
    self.X = tf.placeholder("float", [None, 300]) 
    self.doc_int = tf.placeholder(tf.int32, shape=[None]) 

    self.embeddings=tf.get_variable("embedding", [vocabulary_size ,embedding_size],initializer=tf.random_uniform_initializer(-0.1, 0.1)) 
    self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int) 
    self.weights = tf.get_variable("weights",weight_shapeinitializer=tf.random_normal_initializer()) 
    biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0)) 
    # Some neural network with optimiser and loss that will train weight and embeddings.. 

with tf.Graph().as_default(), tf.Session() as sess: 

    initializer = tf.random_uniform_initializer() 
    with tf.variable_scope("foo", reuse=None, initializer=initializer): 
    train = NewModel(is_training=True, vocabulary_size=4000,\ 
    embedding_size =50) 
    with tf.variable_scope("foo", reuse=True, initializer=initializer): 
     valid = NewModel(is_training=False, vocabulary_size= 1000, embedding_size = 50) 
# Here is where I am confused. I want to use trained variable of weight but not embeddings and 
want new embeddings to be trained for valid set. 
    tf.initialize_all_variables().run() 
# will call some function to run epochs and stuff 

也許使用不同的作用域名稱可能會有所幫助,但仍需要一些關於它的建議。或者是否有可能在某處提到要重用的變量。

回答

0

我也許會重新組織NewModel類。

Class NewModel(Object): 
    def __init__(self, vocabuary_size, embedding_size, initializer): 
     self.X = tf.placeholder("float", [None, 300]) 
     self.doc_int = tf.placeholder(tf.int32, shape=[None]) 
     self.vocabuary_size = vocabuary_size 
     self.embedding_size = embedding_size 
     self.initializer = initializer 

    def initialize_embeddings(self): 
     with tf.variable_scope("embed",initializer=initializer) as scope: 
      self.embeddings=tf.get_variable("embedding", [self.vocabulary_size ,self.embedding_size],initializer=self.initializer) 
      self.embedval = tf.nn.embedding_lookup(self.embeddings ,self.doc_int) 
      scope.reuse_variable() 

    def initialize_weights(self, weight_shape, biase_shape, initializer=initializer): 
     with tf.variable_scope("weight", initializer=initializer) as scope: 
      self.weights = tf.get_variable("weights",weight_shapeinitializer=self.initializer) 
      biases = tf.get_variable("biases", bias_shape,initializer=tf.constant_initializer(0.0)) 
      scope.reuse_variable() 

    def train_network(self): 
     # Some neural network with optimiser and loss that will train weight and embeddings.. 

    def validate_network(self): 
     # A function for the validation process 

這樣您就可以將嵌入初始化與權重和偏置初始化分開。這種新類的使用會像...

with tf.Graph().as_default(), tf.Session() as sess: 

    initializer = tf.random_uniform_initializer() 
    model = NewModel(vocabulary_size=4000, embedding_size =50, initializer=initializer) # construct a model instance 
    model.initialize_weights(weight_shape, biase_shape) # initialize the weights and biases 
    model.initialize_embeddings() # initialize embeddings 
    model.train_network() # train the network 
    # Before start validation process, re-initialize embeddings 
    model.initialize_embeddings() 
    model.validate_network()