我想在mex文件中使用cublasSgemmBatched從matlab中乘以多個矩陣。錯誤在mex中使用cublasSgemmBatched
我MATLAB代碼非常簡單:
gpuDevice(1);
a = single(rand(400,10,1500,'gpuArray'));
b = single(rand(10,12,1500,'gpuArray'));
c = MatCuda(a,b)
我得到以下錯誤:使用gpuArray /的subsref 意外錯誤
錯誤CUDA執行過程中發生。 CUDA的錯誤是: 未知錯誤
和這裏的mexFunction代碼:
void mexFunction(int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[]){
char const * const errId = "parallel:gpu:mexGPUExample:InvalidInput";
char const * const errMsg = "Invalid input to MEX file.";
/* Declare all variables.*/
mxGPUArray const *A;
mxGPUArray const *B;
mxGPUArray *C;
/* Initialize the MathWorks GPU API. */
mxInitGPU();
/* Throw an error if the input is not a GPU array. */
if ((nrhs != 2) || !(mxIsGPUArray(prhs[0])) || !(mxIsGPUArray(prhs[1]))) {
mexErrMsgIdAndTxt(errId, errMsg);
}
A = mxGPUCreateFromMxArray(prhs[0]);
B = mxGPUCreateFromMxArray(prhs[1]);
if ((mxGPUGetClassID(A) != mxSINGLE_CLASS) || (mxGPUGetClassID(B) != mxSINGLE_CLASS)) {
mexErrMsgIdAndTxt(errId, errMsg);
}
float const *d_A;
float const *d_B;
d_A = (float const *)(mxGPUGetDataReadOnly(A));
d_B = (float const *)(mxGPUGetDataReadOnly(B));
const mwSize *dimsA = mxGPUGetDimensions(A);
size_t nrowsA = dimsA[0];
size_t ncolsA = dimsA[1];
size_t nMatricesA = dimsA[2];
mxFree((void*) dimsA);
const mwSize *dimsB = mxGPUGetDimensions(B);
size_t nrowsB = dimsB[0];
size_t ncolsB = dimsB[1];
size_t nMatricesB = dimsB[2];
mxFree((void*)dimsB);
size_t nrowsC = nrowsA;
size_t ncolsC = ncolsB;
mwSize dimsC[3] = { nrowsA, ncolsB, nMatricesB };
C = mxGPUCreateGPUArray(mxGPUGetNumberOfDimensions(A),
dimsC,
mxGPUGetClassID(A),
mxGPUGetComplexity(A),
MX_GPU_DO_NOT_INITIALIZE);
float *d_C;
d_C = (float *)(mxGPUGetData(C));
cublasHandle_t handle;
cublasStatus_t ret;
ret = cublasCreate(&handle);
if (ret != CUBLAS_STATUS_SUCCESS)
{
printf("cublasCreate returned error code %d, line(%d)\n", ret, __LINE__);
exit(EXIT_FAILURE);
}
const float alpha = 1.0f;
const float beta = 0.0f;
ret = cublasSgemmBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N, nrowsA, ncolsB, ncolsA, &alpha, &d_A, nrowsA, &d_B, nrowsB, &beta, &d_C, nrowsC, nMatricesA);
if (ret != CUBLAS_STATUS_SUCCESS)
{
printf("cublasSgemm returned error code %d, line(%d)\n", ret, __LINE__);
exit(EXIT_FAILURE);
}
ret = cublasDestroy(handle);
if (ret != CUBLAS_STATUS_SUCCESS)
{
printf("cublasCreate returned error code %d, line(%d)\n", ret, __LINE__);
exit(EXIT_FAILURE);
}
plhs[0] = mxGPUCreateMxArrayOnGPU(C);
mxGPUDestroyGPUArray(A);
mxGPUDestroyGPUArray(B);
mxGPUDestroyGPUArray(C);
}
我懷疑這是有關功能cublasSgemmBatched,因爲當我從代碼中刪除它,然後我沒有得到這個錯誤。
幫助將非常感謝! 謝謝!
gemmBatched比大多數cublas函數要複雜得多。您不僅需要複製矩陣數組以進行乘法運算,還必須將指針數組複製到這些矩陣。你可以通過編寫一個正確使用該函數的普通C/C++代碼來測試你的理解,或者看看[this](http://stackoverflow.com/questions/23743384/how-performing-multiple-matrix-multiplications -in-CUDA/23743838#23743838)。 –
您尚未掌握gemmBatched函數的要求,因此您的調用肯定是不正確的。我們不會將具有'd_A','d_B'和'd_C'的參數的(主機)地址傳遞給參數。這些是指針指針參數,必須將它們正確設置爲一組設備指針,然後將其複製到設備。我認爲你的方法完全忽略了這一點。 –
謝謝羅伯特! – nonobrez