2016-09-06 281 views
1

我讀過Shuffle Tips and Tricks紙,但我不知道究竟是如何將其應用到一些狡猾的代碼,我繼承:瞭解CUDA SHFL指令

extern __shared__ unsigned int lpSharedMem[]; 
int tid = threadIdx.x; 
lpSharedMem[tid] = startValue; 
volatile unsigned int *srt = lpSharedMem; 

// ...various stuff 
srt[tid] = min(srt[tid], srt[tid+32]); 
srt[tid] = min(srt[tid], srt[tid+16]); 
srt[tid] = min(srt[tid], srt[tid+8]); 
srt[tid] = min(srt[tid], srt[tid+4]); 
srt[tid] = min(srt[tid], srt[tid+2]); 
srt[tid] = min(srt[tid], srt[tid+1]); 
__syncthreads(); 

即使沒有CUDA,這個代碼是模模糊糊,但看着this implementation我看到:

__device__ inline int min_warp(int val) { 
    val = min(val, __shfl_xor(val, 16)); 
    val = min(val, __shfl_xor(val, 8)); 
    val = min(val, __shfl_xor(val, 4)); 
    val = min(val, __shfl_xor(val, 2)); 
    val = min(val, __shfl_xor(val, 1)); 
    return __shfl(val, 0); 
} 

此代碼可能是調用與:

int minVal = min_warp(startValue); 

因此,我可以用上面的代碼替換我相當不利的volatile。但是,我無法真正理解正在發生的事情;有人可以解釋我是否正確,以及min_warp()函數中究竟發生了什麼。

+2

看看這個https://devblogs.nvidia.com/parallelforall/faster-平行削減-開普勒/ – Hopobcn

回答

6

int __shfl_xor(int var, int laneMask, int width=warpSize);的描述:()

__shfl_xor通過與laneMask執行呼叫者的車道ID的按位XOR來計算源極線ID:返回通過將得到的車道ID保持var值。 (...)

車道ID是線程的索引的經內,從0到31因此,硬件執行用於每個線程一個按位XOR:sourceLaneId XOR laneMask => destinationLaneId

例如,對於線程0和:

__shfl_xor(val, 16) 

laneMask = 0b00000000000000000000000000010000 = 16(十進制)

srclaneID = 0b00000000000000000000000000000000 = 0(十進制)

XOR ------------------------------------ ----------------------

dstLaneID = 0b00000000000000000000000000010000 = 16(十進制)

然後線程0得到螺紋16的值。

螺紋4

現在laneMask = 0b00000000000000000000000000010000 = 16(十進制)

srclaneID = 0b00000000000000000000000000000100 = 4(十進制)

XOR ------------------------- ---------------------------------

dstLaneID = 0b00000000000000000000000000010100 = 20(十進制)

因此線程4獲得線程20的值。等等...

如果我們回到實際的算法米,我們看到這是一個並行減少,其中應用了min運算符。在步驟:

  1. 32個線程將它們的值累加到較低的16個線程中。
  2. 16個線程累積到較低的8個線程中。 (其他線程對於實際算法無關緊要)
  3. 8個線程累積到較低的4個線程中。
  4. 4線程acumulate進入下2個線程...

PD:請注意,這兩個代碼是不完全一樣的。這個'32'的偏移告訴我們你的共享內存數組是2 * WARP長。 (你正在減少2個* WARP值到1)

srt[tid] = min(srt[tid], srt[tid+32]); 

而洗牌一個降低WARP值到1