2016-12-13 64 views
3

我正努力在TensorFlow中實現K-Nearest Neighbor。我認爲,要麼我忽略了一個錯誤,要麼做了一些可怕的錯誤。TensorFlow中KNN實現的問題

下面的代碼總是預測MNIST標籤爲0

from __future__ import print_function 

import numpy as np 
import tensorflow as tf 

# Import MNIST data 
from tensorflow.examples.tutorials.mnist import input_data 

K = 4 
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 

# In this example, we limit mnist data 
Xtr, Ytr = mnist.train.next_batch(55000) # whole training set 
Xte, Yte = mnist.test.next_batch(10000) # whole test set 

# tf Graph Input 
xtr = tf.placeholder("float", [None, 784]) 
ytr = tf.placeholder("float", [None, 10]) 
xte = tf.placeholder("float", [784]) 

# Euclidean Distance 
distance = tf.neg(tf.sqrt(tf.reduce_sum(tf.square(tf.sub(xtr, xte)), reduction_indices=1))) 
# Prediction: Get min distance neighbors 
values, indices = tf.nn.top_k(distance, k=K, sorted=False) 
nearest_neighbors = [] 
for i in range(K): 
    nearest_neighbors.append(np.argmax(ytr[indices[i]])) 

sorted_neighbors, counts = np.unique(nearest_neighbors, return_counts=True) 

pred = tf.Variable(nearest_neighbors[np.argmax(counts)]) 

# not works either 
# neighbors_tensor = tf.pack(nearest_neighbors) 
# y, idx, count = tf.unique_with_counts(neighbors_tensor) 
# pred = tf.slice(y, begin=[tf.arg_max(count, 0)], size=tf.constant([1], dtype=tf.int64))[0] 

accuracy = 0. 

# Initializing the variables 
init = tf.initialize_all_variables() 

# Launch the graph 
with tf.Session() as sess: 
    sess.run(init) 

    # loop over test data 
    for i in range(len(Xte)): 
     # Get nearest neighbor 
     nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i, :]}) 
     # Get nearest neighbor class label and compare it to its true label 
     print("Test", i, "Prediction:", nn_index, 
       "True Class:", np.argmax(Yte[i])) 
     # Calculate accuracy 
     if nn_index == np.argmax(Yte[i]): 
      accuracy += 1./len(Xte) 
    print("Done!") 
    print("Accuracy:", accuracy) 

任何幫助不勝感激。

+0

在圖中你不應該使用'numpy'功能。我在回答中更正了你的代碼 – martianwars

回答

7

所以一般來說,在定義TensorFlow模型的時候去功能numpy並不是一個好主意。這正是您的代碼無法正常工作的原因。我只對代碼做了兩處更改。我已將np.argmax替換爲tf.argmax。我也刪除了#This doesn't work either的評論。

下面是完整的工作代碼:

from __future__ import print_function 

import numpy as np 
import tensorflow as tf 

# Import MNIST data 
from tensorflow.examples.tutorials.mnist import input_data 

K = 4 
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 

# In this example, we limit mnist data 
Xtr, Ytr = mnist.train.next_batch(55000) # whole training set 
Xte, Yte = mnist.test.next_batch(10000) # whole test set 

# tf Graph Input 
xtr = tf.placeholder("float", [None, 784]) 
ytr = tf.placeholder("float", [None, 10]) 
xte = tf.placeholder("float", [784]) 

# Euclidean Distance 
distance = tf.negative(tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(xtr, xte)), reduction_indices=1))) 
# Prediction: Get min distance neighbors 
values, indices = tf.nn.top_k(distance, k=K, sorted=False) 

nearest_neighbors = [] 
for i in range(K): 
    nearest_neighbors.append(tf.argmax(ytr[indices[i]], 0)) 

neighbors_tensor = tf.stack(nearest_neighbors) 
y, idx, count = tf.unique_with_counts(neighbors_tensor) 
pred = tf.slice(y, begin=[tf.argmax(count, 0)], size=tf.constant([1], dtype=tf.int64))[0] 

accuracy = 0. 

# Initializing the variables 
init = tf.initialize_all_variables() 

# Launch the graph 
with tf.Session() as sess: 
    sess.run(init) 

    # loop over test data 
    for i in range(len(Xte)): 
     # Get nearest neighbor 
     nn_index = sess.run(pred, feed_dict={xtr: Xtr, ytr: Ytr, xte: Xte[i, :]}) 
     # Get nearest neighbor class label and compare it to its true label 
     print("Test", i, "Prediction:", nn_index, 
      "True Class:", np.argmax(Yte[i])) 
     #Calculate accuracy 
     if nn_index == np.argmax(Yte[i]): 
      accuracy += 1./len(Xte) 
    print("Done!") 
    print("Accuracy:", accuracy) 
+0

我沒有時間分解它,但我想盡快通知你:Traceback(最近呼叫的最後一個):第28行,在 nearest_neighbors。 append(tf.argmax(ytr [indices [i],::,0)) TypeError:類型爲用python 3.5和tensorflow運行r0.12 – wrecker

+0

我不記得在tensorflow 0.11中出現這個錯誤。我會稍微仔細一看,但你可以嘗試刪除':' – martianwars

+0

我已經更新了代碼,讓我知道如果這個效果更好。這個和以前的代碼都在r0.11中爲我工作。基本上,由於https://github.com/tensorflow/tensorflow/issues/206,有一些不一致。 – martianwars