首先,你沒有使用正確的算法。如果兩點大於boxWidth,會怎麼樣?其次,如果你有多個粒子,調用一個完成所有距離計算的單個函數,並將結果放到輸出緩衝區中將會顯着提高效率。內聯有助於減少一些,但不是全部。任何預先計算 - 例如在算法中將盒子長度除以2 - 將在不需要時重複。
這裏是一些SIMD代碼來做計算。您需要使用-msse4進行編譯。使用-O3,在我的機器上(macbook pro,llvm-gcc-4.2),我的速度提高了約2倍。這確實需要使用32位浮點數而不是雙精度算術。
上證所真的不是那麼複雜,它只是看起來可怕。例如而不是寫一個* b,你必須編寫笨重的_mm_mul_ps(a,b)。
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <smmintrin.h>
// you can compile this code with -DDOUBLE to try using doubles vs. floats
// in the unoptimized code. The SSE code uses only floats.
#ifdef DOUBLE
typedef double real;
#else
typedef float real;
#endif
static inline __m128 loadFloat3(const float const* value) {
// Load (x,y,z) into a SSE register, leaving the last entry
// set to zero.
__m128 x = _mm_load_ss(&value[0]);
__m128 y = _mm_load_ss(&value[1]);
__m128 z = _mm_load_ss(&value[2]);
__m128 xy = _mm_movelh_ps(x, y);
return _mm_shuffle_ps(xy, z, _MM_SHUFFLE(2, 0, 2, 0));
}
int fdistanceSqPeriodic(float* position1, float* position2, const float boxWidth,
float* out, const int n_points) {
int i;
__m128 r1, r2, r12, s12, r12_2, s, box, invBox;
box = _mm_set1_ps(boxWidth);
invBox = _mm_div_ps(_mm_set1_ps(1.0f), box);
for (i = 0; i < n_points; i++) {
r1 = loadFloat3(position1);
r2 = loadFloat3(position1);
r12 = _mm_sub_ps(r1, r2);
s12 = _mm_mul_ps(r12, invBox);
s12 = _mm_sub_ps(s12, _mm_round_ps(s12, _MM_FROUND_TO_NEAREST_INT));
r12 = _mm_mul_ps(box, s12);
r12_2 = _mm_mul_ps(r12, r12);
// double horizontal add instruction accumulates the sum of
// all four elements into each of the elements
// (e.g. s.x = s.y = s.z = s.w = r12_2.x + r12_2.y + r12_2.z + r12_2.w)
s = _mm_hadd_ps(r12_2, r12_2);
s = _mm_hadd_ps(s, s);
_mm_store_ss(out++, s);
position1 += 3;
position2 += 3;
}
return 1;
}
inline real distanceSqPeriodic(real const * const position1, real const * const position2, real boxWidth) {
real xhw, yhw, zhw, x, y, z;
xhw = boxWidth/2.0;
yhw = xhw;
zhw = xhw;
x = position2[0] - position1[0];
if (x > xhw)
x -= boxWidth;
else if (x < -xhw)
x += boxWidth;
y = position2[1] - position1[1];
if (y > yhw)
y -= boxWidth;
else if (y < -yhw)
y += boxWidth;
z = position2[2] - position1[2];
if (z > zhw)
z -= boxWidth;
else if (z < -zhw)
z += boxWidth;
return x * x + y * y + z * z;
}
int main(void) {
real* position1;
real* position2;
real* output;
int n_runs = 10000000;
posix_memalign((void**) &position1, 16, n_runs*3*sizeof(real));
posix_memalign((void**) &position2, 16, n_runs*3*sizeof(real));
posix_memalign((void**) &output, 16, n_runs*sizeof(real));
real boxWidth = 1.8;
real result = 0;
int i;
clock_t t;
#ifdef OPT
printf("Timing optimized SSE implementation\n");
#else
printf("Timinig original implementation\n");
#endif
#ifdef DOUBLE
printf("Using double precision\n");
#else
printf("Using single precision\n");
#endif
t = clock();
#ifdef OPT
fdistanceSqPeriodic(position1, position2, boxWidth, output, n_runs);
#else
for (i = 0; i < n_runs; i++) {
*output = distanceSqPeriodic(position1, position2, boxWidth);
position1 += 3;
position2 += 3;
output++;
}
#endif
t = clock() - t;
printf("It took me %d clicks (%f seconds).\n", (int) t, ((float)t)/CLOCKS_PER_SEC);
}
這將有助於看到呼叫代碼,特別是如果你想要去SIMD路線。理解參數也會有幫助,例如每個呼叫的'boxWidths []'中的值是不同的? – 2013-02-25 15:01:15
調用方法次數更少(即算法改進)。當然,不知道它是否適用;) – 2013-02-25 15:02:12
@Paul R No ... d'oh!讓我解決這個問題 – Nick 2013-02-25 15:02:26