2017-04-26 70 views
1

我需要在一個繪圖中繪製多個數據集。數據集的數量有所不同,所以我不知道會有多少。設置圖例中的最大行數

如果我只畫了傳說,我得到這個(以下MCVE):

我怎麼能告訴plt.legend()只繪製說出的第10個傳說?我環顧了plt.legends()類,但似乎沒有任何理由來設置這樣的值。

MCVE

import numpy as np 
import matplotlib.pyplot as plt 

dataset = [] 
for _ in range(20): 
    dataset.append(np.random.uniform(0, 1, 2)) 

lbl = ['adfg', 'dfgb', 'cgfg', 'rtbd', 'etryt', 'frty', 'jklg', 'jklh', 
     'ijkl', 'dfgj', 'kbnm', 'bnmbl', 'qweqw', 'fghfn', 'dfg', 'hjt', 'dfb', 
     'sdgdas', 'werwe', 'dghfg'] 

for i, xy in enumerate(dataset): 
    plt.scatter(xy[0], xy[1], label=lbl[i]) 
plt.legend() 
plt.savefig('test.png') 
+1

我不知道這樣做的正確方法,但是可能的解決方法可能是將if語句放在for循環中,並且只在'_'小於10時使用'label ='? – DavidG

+0

謝謝@DavidG。這確實是一個解決方法,但我想知道是否可能有一種「駭人」的方式來做到這一點。 – Gabriel

回答

2

你可以只限制顯示標籤的數量。

import matplotlib.pyplot as plt 

maxn = 16 
for i in range(25): 
    plt.scatter(.5, .5, label=(i//maxn)*"_"+str(i)) 
plt.legend() 
plt.show() 

enter image description here

此方法也是當然的文本標籤:

import numpy as np 
import matplotlib.pyplot as plt 

labels = ["".join(np.random.choice(list("ABCDEFGHIJK"), size=8)) for k in range(25)] 
maxn = 16 
for i,l in enumerate(labels): 
    plt.scatter(.5, .5, label=(i//maxn)*"_"+l) 
plt.legend() 
plt.show() 

enter image description here

之所以這樣工作原理是,首先是"_"標籤圖例中被忽略。這在內部用於爲對象提供標籤而不在圖例中顯示它們,但當然也可以用於限制圖例中元素的數量。

+0

只有標籤是數字「i」,這似乎才起作用,事實並非如此。這是我的錯,我簡化了我的MCVE。我現在要修復它。 – Gabriel

+1

@Gabriel即使你的標籤不是數字,你也可以保留一個動態構建的標籤列表,並使用這種方法來選擇該標籤列表的索引以顯示輸出 – goofd

+1

非常好,它確實有效!你能否解釋爲什麼在標籤上加上'_'來防止它被顯示?我不知道這個功能。 – Gabriel

1

我想建議一種替代方法來獲得您想要的輸出,我覺得這種方法對圖例標籤的「破解」依賴較少。

您可以使用function Axes.get_legend_handles_labels()來獲取將放入圖例中的對象的手柄和標籤列表。 但是,在將它們傳遞給plt.legend()之前,您可以截斷這些列表,但您感覺像。例如:

import numpy as np 
import matplotlib.pyplot as plt 

dataset = [] 
for _ in range(20): 
    dataset.append(np.random.uniform(0, 1, 2)) 

lbl = ['adfg', 'dfgb', 'cgfg', 'rtbd', 'etryt', 'frty', 'jklg', 'jklh', 
     'ijkl', 'dfgj', 'kbnm', 'bnmbl', 'qweqw', 'fghfn', 'dfg', 'hjt', 'dfb', 
     'sdgdas', 'werwe', 'dghfg'] 

fig, ax = plt.subplots() 
for i, xy in enumerate(dataset): 
    ax.scatter(xy[0], xy[1], label=lbl[i]) 

h,l = ax.get_legend_handles_labels() 
plt.legend(h[:3], l[:3]) # <<<<<<<< This is where the magic happens 
plt.show() 

你甚至可以顯示其他任何標籤plt.legend(h[::2], l[::2])或任何你想要的。

+0

這也是一個非常好的答案,謝謝Diziet! – Gabriel