gemm-ker.cpp 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. // example code showing optimization of GEMM micro-kernel
  2. #include <iostream>
  3. #include <omp.h>
  4. #include <sctl.hpp>
  5. template <int M, int N, int K>
  6. void GEMM_ker_naive(double* C, double* A, double* B) {
  7. for (int k = 0; k < K; k++)
  8. for (int j = 0; j < N; j++)
  9. for (int i = 0; i < M; i++)
  10. C[i+j*M] += A[i+k*M] * B[k+K*j];
  11. }
  12. template <int M, int N, int K>
  13. void GEMM_ker_vec(double* C, double* A, double* B) {
  14. using Vec = sctl::Vec<double,M>;
  15. Vec C_vec[N];
  16. for (int j = 0; j < N; j++)
  17. C_vec[j] = Vec::Load(C+j*M);
  18. for (int k = 0; k < K; k++) {
  19. const Vec A_vec = Vec::Load(A+k*M);
  20. double* B_ = B + k;
  21. for (int j = 0; j < N; j++) {
  22. C_vec[j] = A_vec * B_[K*j] + C_vec[j];
  23. }
  24. }
  25. for (int j = 0; j < N; j++)
  26. C_vec[j].Store(C+j*M);
  27. }
  28. template <int M, int N, int K>
  29. void GEMM_ker_vec_unrolled(double* C, double* A, double* B) {
  30. using Vec = sctl::Vec<double,M>;
  31. Vec C_vec[N];
  32. #pragma GCC unroll (10)
  33. for (int j = 0; j < N; j++)
  34. C_vec[j] = Vec::Load(C+j*M);
  35. #pragma GCC unroll (40)
  36. for (int k = 0; k < K; k++) {
  37. const Vec A_vec = Vec::Load(A+k*M);
  38. double* B_ = B + k;
  39. #pragma GCC unroll (10)
  40. for (int j = 0; j < N; j++) {
  41. C_vec[j] = A_vec * B_[j*K] + C_vec[j];
  42. }
  43. }
  44. #pragma GCC unroll (10)
  45. for (int j = 0; j < N; j++)
  46. C_vec[j].Store(C+j*M);
  47. }
  48. int main(int argc, char** argv) {
  49. long L = 1e6;
  50. constexpr int M = 8, N = 10, K = 40;
  51. double* C = new double[M*N];
  52. double* A = new double[M*K];
  53. double* B = new double[K*N];
  54. for (long i = 0; i < M*N; i++) C[i] = 0;
  55. for (long i = 0; i < M*K; i++) A[i] = drand48();
  56. for (long i = 0; i < K*N; i++) B[i] = drand48();
  57. std::cout<<"M = "<<M<<", N = "<<N<<", K = "<<K<<"\n\n";
  58. std::cout<<"GEMM (naive)\n";
  59. double T = -omp_get_wtime();
  60. for(long i = 0; i < L; i++) GEMM_ker_naive<M,N,K>(C, A, B);
  61. T += omp_get_wtime();
  62. std::cout<<"FLOP rate = "<< 2*M*N*K*L/T/1e9 <<" GFLOP/s\n\n\n";
  63. std::cout<<"GEMM (vectorized)\n";
  64. T = -omp_get_wtime();
  65. for(long i = 0; i < L; i++) GEMM_ker_vec<M,N,K>(C, A, B);
  66. std::cout<<"FLOP rate = "<< 2*M*N*K*L/(T+omp_get_wtime())/1e9 <<" GFLOP/s\n\n\n";
  67. std::cout<<"GEMM (vectorized & unrolled)\n";
  68. T = -omp_get_wtime();
  69. for(long i = 0; i < L; i++) GEMM_ker_vec_unrolled<M,N,K>(C, A, B);
  70. std::cout<<"FLOP rate = "<< 2*M*N*K*L/(T+omp_get_wtime())/1e9 <<" GFLOP/s\n\n\n";
  71. double sum = 0;
  72. for (long i = 0; i < M*N; i++) sum += C[i];
  73. std::cout<<"result = "<<sum<<'\n';
  74. delete[] A;
  75. delete[] B;
  76. delete[] C;
  77. return 0;
  78. }