1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
|
template <class T>
void gmmSetup(const Fmatrix<T>& A, const Fmatrix<T>& B, const Fmatrix<T>& C,
char& TRANSA, char& TRANSB,
integer& M, integer& N, integer& K,
integer& LDA, integer& LDB, integer& LDC)
{
if(TRANSA == 'T'){ M = A.numcols(); K = A.numrows(); LDA = K; }
else{ M = A.numrows(); K = A.numcols(); LDA = M; }
if(TRANSB == 'T'){ N = B.numrows(); LDB = N; }
else{ N = B.numcols(); LDB = K;
}
LDC=C.numrows();
}
void gmm(Fmatrix<float>& A, Fmatrix<float>& B, Fmatrix<float>& C,
char TRANSA='N', char TRANSB='N', float alpha=1, float beta=0)
{
integer M, N, K, LDA, LDB, LDC;
gmmSetup(A, B, C, TRANSA, TRANSB, M, N, K, LDA, LDB, LDC);
sgemm_(&TRANSA, &TRANSB, &M, &N, &K,
&alpha, A.begin(), &LDA, B.begin(), &LDB,
&beta, C.begin(), &LDC );
}
void gmm(Fmatrix<double>& A, Fmatrix<double>& B, Fmatrix<double>& C,
char TRANSA='N', char TRANSB='N', double alpha=1, double beta=0)
{
integer M, N, K, LDA, LDB, LDC;
gmmSetup(A, B, C, TRANSA, TRANSB, M, N, K, LDA, LDB, LDC);
dgemm_(&TRANSA, &TRANSB, &M, &N, &K,
&alpha, A.begin(), &LDA, B.begin(), &LDB,
&beta, C.begin(), &LDC );
}
void gmm(Fmatrix<complex>& A, Fmatrix<complex>& B, Fmatrix<complex>& C,
char TRANSA='N', char TRANSB='N', complex alpha, complex beta)
{
integer M, N, K, LDA, LDB, LDC;
gmmSetup(A, B, C, TRANSA, TRANSB, M, N, K, LDA, LDB, LDC);
cgemm_(&TRANSA, &TRANSB, &M, &N, &K,
&alpha, A.begin(), &LDA, B.begin(), &LDB,
&beta, C.begin(), &LDC );
}
void gmm(Fmatrix<doublecomplex>& A, Fmatrix<doublecomplex>& B, Fmatrix<doublecomplex>& C,
char TRANSA='N', char TRANSB='N', doublecomplex alpha, doublecomplex beta)
{
integer M, N, K, LDA, LDB, LDC;
gmmSetup(A, B, C, TRANSA, TRANSB, M, N, K, LDA, LDB, LDC);
zgemm_(&TRANSA, &TRANSB, &M, &N, &K,
&alpha, A.begin(), &LDA, B.begin(), &LDB,
&beta, C.begin(), &LDC );
}
template <class T>
void gmm(Fmatrix<T>& A, Fmatrix<T>& B, Fmatrix<T>& C,
char TRANSA='N', char TRANSB='N', T alpha=1, T beta=0)
{
std::cout << "\nERROR: Unrecognized data Type.\n\n";
}
|