1
功能common_precision
需要兩個numpy陣列,如x
和y
。我想確保它們處於同一個最高精確度。看來,dtypes的關係比較做一些事情,以我想要的東西,而是:更高精度的選擇類型
- 我不知道它實際上是比較
- 它認爲
numpy.int64
<numpy.float16
,這我不知道,如果我同意
def common_precision(x, y):
if x.dtype > y.dtype:
y = y.astype(x.dtype)
else:
x = x.astype(y.dtype)
return (x, y)
編輯: 由於kennytm的答案,我發現與NumPy的find_common_type
不正是我想要的。
def common_precision(self, x, y):
dtype = np.find_common_type([x.dtype, y.dtype], [])
if x.dtype != dtype: x = x.astype(dtype)
if y.dtype != dtype: y = y.astype(dtype)
return x, y
這並沒有完全解決我的問題,但您的答案和您的鏈接讓我走上了正確的道路。謝謝! –