2016-07-27 84 views
2

我編寫了for循環來枚舉包含n行28x28像素值的多維ndarray。在python中查找重複行的索引ndarray

我正在尋找重複的每一行的索引以及沒有冗餘的重複索引。

我發現此代碼here(感謝unutbu)並將其修改爲讀取ndarray,它可以工作70%的時間,但30%的時間將錯誤的圖像識別爲重複項。

如何改進以檢測正確的行?

def overlap_same(arr): 
seen = [] 
dups = collections.defaultdict(list) 
for i, item in enumerate(arr): 
    for j, orig in enumerate(seen): 
     if np.array_equal(item, orig): 
      dups[j].append(i) 
      break 
    else: 
     seen.append(item) 
return dups 

例如,返回overlap_same(火車)返回:

defaultdict(<type 'list'>, {34: [1388], 35: [1815], 583: [3045], 3208: 
[4426], 626: [824], 507: [4438], 188: [338, 431, 540, 757, 765, 806, 
808, 834, 882, 1515, 1539, 1715, 1725, 1789, 1841, 2038, 2081, 2165, 
2170, 2300, 2455, 2683, 2733, 2957, 3290, 3293, 3311, 3373, 3446, 3542, 
3565, 3890, 4110, 4197, 4206, 4364, 4371, 4734, 4851]}) 

繪製在matplotlib正確的情況下一些樣品給:

fig = plt.figure() 
a=fig.add_subplot(1,2,1) 
plt.imshow(train[35]) 
a.set_title('train[35]') 
a=fig.add_subplot(1,2,2) 
plt.imshow(train[1815]) 
a.set_title('train[1815]') 
plt.show 

train data 35 vs 1815

這是正確的

但是:

fig = plt.figure() 
a=fig.add_subplot(1,2,1) 
plt.imshow(train[3208]) 
a.set_title('train[3208]') 
a=fig.add_subplot(1,2,2) 
plt.imshow(train[4426]) 
a.set_title('train[4426]') 
plt.show 

enter image description here

是不正確的,它們不匹配

樣本數據(火車[:3])

array([[[-0.5  , -0.5  , -0.5  , ..., 0.48823529, 
     0.5  , 0.17058824], 
    [-0.5  , -0.5  , -0.5  , ..., 0.48823529, 
     0.5  , -0.0372549 ], 
    [-0.5  , -0.5  , -0.5  , ..., 0.5  , 
     0.47647059, -0.24509804], 
    ..., 
    [-0.49215686, 0.34705883, 0.5  , ..., -0.5  , 
    -0.5  , -0.5  ], 
    [-0.31176472, 0.44901961, 0.5  , ..., -0.5  , 
    -0.5  , -0.5  ], 
    [-0.11176471, 0.5  , 0.49215686, ..., -0.5  , 
    -0.5  , -0.5  ]], 

    [[-0.24509804, 0.2764706 , 0.5  , ..., 0.5  , 
     0.25294119, -0.36666667], 
    [-0.5  , -0.47254902, -0.02941176, ..., 0.20196079, 
    -0.46862745, -0.5  ], 
    [-0.49215686, -0.5  , -0.5  , ..., -0.47647059, 
    -0.5  , -0.49607843], 
    ..., 
    [-0.49215686, -0.49607843, -0.5  , ..., -0.5  , 
    -0.5  , -0.49215686], 
    [-0.5  , -0.5  , -0.26862746, ..., 0.13137256, 
    -0.46470588, -0.5  ], 
    [-0.30000001, 0.11960784, 0.48823529, ..., 0.5  , 
     0.28431374, -0.24117647]], 

    [[-0.5  , -0.5  , -0.5  , ..., -0.5  , 
    -0.5  , -0.5  ], 
    [-0.5  , -0.5  , -0.5  , ..., -0.5  , 
    -0.5  , -0.5  ], 
    [-0.5  , -0.5  , -0.5  , ..., -0.5  , 
    -0.5  , -0.5  ], 
    ..., 
    [-0.5  , -0.5  , -0.5  , ..., 0.48431373, 
     0.5  , 0.31568629], 
    [-0.5  , -0.49215686, -0.5  , ..., 0.49215686, 
     0.5  , 0.04901961], 
    [-0.5  , -0.5  , -0.5  , ..., 0.04117647, 
    -0.17450981, -0.45686275]]], dtype=float32) 
+1

您可以添加一個最小的代表性樣例和預期的輸出嗎? – Divakar

+0

當然,我會馬上添加一個! – thedlade

回答

1

numpy_indexed封裝具有很多的功能,以有效地解決這些類型的問題。

例如,(不像numpy的的內置唯一的)這個會發現你的獨特的畫面:

import numpy_indexed as npi 
unique_training_images = npi.unique(train) 

或者,如果你想找到每個唯一組的所有索引,您可以使用:

indices = npi.group_by(train).split(np.arange(len(train))) 

請注意,這些函數沒有二次時間複雜度,就像在原始文章中那樣,並且完全向量化,因此極有可能效率更高。此外,與熊貓不同,它不具備首選的數據格式,並且具有完全nd陣列功能,所以對形狀爲[n_images,28,28]的陣列進行作用「正常」。

+0

乾杯,我會試試這個 – thedlade

+0

非常感謝,它返回重複的索引,實際上只是實現了一切分組。 – thedlade

+0

是的,我認爲這個輸出與你之後的結果非常相似。但最方便的格式取決於後續步驟的細節。這是否符合您的目的? –