2017-08-09 147 views
3

我看到在tensorflow contrib庫中有一個Kmeans聚類的實現。但是,我無法做簡單的估算2D點聚類中心的操作。Kmeans聚類如何在tensorflow中工作?

代碼:

## Generate synthetic data 
N,D = 1000, 2 # number of points and dimenstinality 

means = np.array([[0.5, 0.0], 
        [0, 0], 
        [-0.5, -0.5], 
        [-0.8, 0.3]]) 
covs = np.array([np.diag([0.01, 0.01]), 
       np.diag([0.01, 0.01]), 
       np.diag([0.01, 0.01]), 
       np.diag([0.01, 0.01])]) 
n_clusters = means.shape[0] 

points = [] 
for i in range(n_clusters): 
    x = np.random.multivariate_normal(means[i], covs[i], N) 
    points.append(x) 
points = np.concatenate(points) 

## construct model 
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters) 
kmeans.fit(points.astype(np.float32)) 

我得到以下錯誤:

InvalidArgumentError (see above for traceback): Shape [-1,2] has negative dimensions 
    [[Node: input = Placeholder[dtype=DT_FLOAT, shape=[?,2], _device="/job:localhost/replica:0/task:0/cpu:0"]()]] 

我想我做錯了什麼,但不能從文檔什麼弄清楚。

編輯

我解決它使用input_fn但它實在是太慢了(我不得不在每個集羣,以減少點的數量到10看到的結果)。爲什麼是這樣,我怎樣才能讓它更快?

def input_fn(): 
    return tf.constant(points, dtype=tf.float32), None 

## construct model 
kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001) 
kmeans.fit(input_fn=input_fn) 
centers = kmeans.clusters() 
print(centers) 

解決:

似乎相對寬容應設置。所以我只更改了一行,它工作正常。 kmeans = tf.contrib.learn.KMeansClustering(num_clusters = n_clusters, relative_tolerance=0.0001)

+0

你正在運行什麼版本的TF? –

回答

0

你原來的代碼返回下面的錯誤與Tensorflow 1.2:

WARNING:tensorflow:From <stdin>:1: calling BaseEstimator.fit (from   
    tensorflow.contrib.learn.python.learn.estimators.estimator) with x 
    is deprecated and will be removed after 2016-12-01. 
    Instructions for updating: 
    Estimator is decoupled from Scikit Learn interface by moving into 
    separate class SKCompat. Arguments x, y and batch_size are only 
    available in the SKCompat class, Estimator will only accept input_fn. 

根據您的編輯,看來你想通了,input_fn是唯一可接受的輸入。如果您真的想使用TF,我會升級到r1.2並將錯誤消息所示的Estimator包裝到SKCompat類中。否則,我只會使用SKLearn包。您也可以手動在TF中實現您自己的聚類算法,如this blog中所示。

+0

謝謝。我想到了。一個問題,但如果我的觀點是在一個變量?它的工作原理是否相同,還是我需要做一些不同的事情? (比如在輸入kmeans聚類之前對其進行評估) –

+0

包裝器不包含TF張量作爲輸入的估算器,因此排除佔位符和變量。因此,在輸入之前評估它應該工作! –