Dhairya Malhotra hace 2 años
padre
commit
ce33cd1f84
Se han modificado 3 ficheros con 107 adiciones y 7 borrados
  1. 5 6
      code/Makefile
  2. 102 0
      code/src/gemm-blocking.cpp
  3. 0 1
      code/src/gemm-ker.cpp

+ 5 - 6
code/Makefile

@@ -3,11 +3,6 @@ CXXFLAGS = -O3 -march=native -std=c++11 -fopenmp # need C++11 and OpenMP
 
 #CXXFLAGS += -DSCTL_HAVE_MPI #use MPI
 
-#CXXFLAGS += -lblas -DSCTL_HAVE_BLAS # use BLAS
-#CXXFLAGS += -llapack -DSCTL_HAVE_LAPACK # use LAPACK
-#CXXFLAGS += -qmkl -DSCTL_HAVE_BLAS -DSCTL_HAVE_LAPACK -DSCTL_HAVE_FFTW3_MKL # use MKL BLAS, LAPACK and FFTW (Intel compiler)
-#CXXFLAGS += -lmkl_intel_lp64 -lmkl_sequential -lmkl_core -lpthread -DSCTL_HAVE_BLAS -DSCTL_HAVE_LAPACK # use MKL BLAS and LAPACK (non-Intel compiler)
-
 RM = rm -f
 MKDIRS = mkdir -p
 
@@ -19,13 +14,17 @@ INCDIR = ./SCTL/include
 TARGET_BIN = \
        $(BINDIR)/instruction \
        $(BINDIR)/poly-eval \
-       $(BINDIR)/gemm-ker
+       $(BINDIR)/gemm-ker \
+       $(BINDIR)/bandwidth \
+       $(BINDIR)/gemm
 
 all : $(TARGET_BIN)
 
 $(BINDIR)/%: $(OBJDIR)/%.o
 	-@$(MKDIRS) $(dir $@)
 	$(CXX) $^ $(CXXFLAGS) $(LDLIBS) -o $@
+#perf stat -e L1-dcache-load-misses -e L1-dcache-loads -e l2_rqsts.miss -e l2_rqsts.references -e LLC-load-misses -e LLC-loads mpiexec -n 1 --map-by slot:pe=16 ./$@
+
 
 $(OBJDIR)/%.o: $(SRCDIR)/%.cpp
 	-@$(MKDIRS) $(dir $@)

+ 102 - 0
code/src/gemm-blocking.cpp

@@ -0,0 +1,102 @@
+// example code showing blocking of GEMM to optimize memory access
+
+#include <iostream>
+#include <omp.h>
+#include <sctl.hpp>
+
+void GEMM_naive(int M, int N, int K, double* A, int LDA, double* B, int LDB, double* C, int LDC) {
+  for (int j = 0; j < N; j++)
+    for (int k = 0; k < K; k++)
+      for (int i = 0; i < M; i++)
+        C[i+j*LDC] += A[i+k*LDA] * B[k+j*LDB];
+}
+
+template <int M, int N, int K>
+void GEMM_ker_vec_unrolled(double* A, int LDA, double* B, int LDB, double* C, int LDC) {
+  using Vec = sctl::Vec<double,M>;
+
+  Vec C_vec[N];
+  #pragma GCC unroll (10)
+  for (int j = 0; j < N; j++)
+    C_vec[j] = Vec::Load(C+j*LDC);
+
+  #pragma GCC unroll (40)
+  for (int k = 0; k < K; k++) {
+    const Vec A_vec = Vec::Load(A+k*LDA);
+    double* B_ = B + k;
+    #pragma GCC unroll (10)
+    for (int j = 0; j < N; j++) {
+      C_vec[j] = A_vec * B_[j*LDB] + C_vec[j];
+    }
+  }
+
+  #pragma GCC unroll (10)
+  for (int j = 0; j < N; j++)
+    C_vec[j].Store(C+j*LDC);
+}
+
+template <int M, int N, int K>
+void GEMM_blocked(double* A, int LDA, double* B, int LDB, double* C, int LDC) {
+  if (M == sctl::DefaultVecLen<double>()) {
+    GEMM_ker_vec_unrolled<M,N,K>(A,LDA, B,LDB, C,LDC);
+    return;
+  }
+
+  for (int j = 0; j < N; j++)
+    for (int k = 0; k < K; k++)
+      for (int i = 0; i < M; i++)
+        C[i+j*LDC] += A[i+k*LDA] * B[k+j*LDB];
+}
+
+template <int M, int N, int K, int Mb, int Nb, int Kb, int... NN>
+void GEMM_blocked(double* A, int LDA, double* B, int LDB, double* C, int LDC) {
+  static_assert(M % Mb == 0);
+  static_assert(N % Nb == 0);
+  static_assert(K % Kb == 0);
+  for (int j = 0; j < N; j+=Nb)
+    for (int i = 0; i < M; i+=Mb)
+      for (int k = 0; k < K; k+=Kb)
+        GEMM_blocked<Mb,Nb,Kb, NN...>(A+i+k*LDA,LDA, B+k+j*LDB,LDB, C+i+j*LDC,LDC);
+}
+
+int main(int argc, char** argv) {
+  constexpr long M = 2000, N = 2000, K = 2000, iter = 10;
+  double* C_ref = new double[M*N];
+  double* C = new double[M*N];
+  double* A = new double[M*K];
+  double* B = new double[K*N];
+  double T = 0;
+
+  for (long i = 0; i < M*N; i++) C[i] = 0;
+  for (long i = 0; i < M*N; i++) C_ref[i] = 0;
+  for (long i = 0; i < M*K; i++) A[i] = drand48();
+  for (long i = 0; i < K*N; i++) B[i] = drand48();
+
+  T = -omp_get_wtime();
+  for (long i = 0; i < iter; i++)
+    //GEMM_naive(M,N,K, A,M, B,K, C,M);
+    GEMM_blocked<M,N,K, 200,200,200, 40,40,40, 8,10,40>(A,M, B,K, C,M);
+  T += omp_get_wtime();
+  std::cout<<"T = "<<T<<"    GFLOPS = "<<2*M*N*K*iter/T/1e9<<'\n';
+
+  if (0) { // check
+    T = -omp_get_wtime();
+    for (long i = 0; i < iter; i++)
+      GEMM_naive(M,N,K, A,M, B,K, C_ref,M);
+    T += omp_get_wtime();
+    std::cout<<"T = "<<T<<"    GFLOPS = "<<2*M*N*K*iter/T/1e9<<'\n';
+
+    double max_err = 0, max_val = 0;
+    for (long i = 0; i < M*N; i++) {
+      max_err = std::max(max_err, fabs(C[i]-C_ref[i]));
+      max_val = std::max(max_val, fabs(C_ref[i]));
+    }
+    std::cout<<"Error = "<<max_err/max_val<<'\n';
+  }
+
+  delete[] A;
+  delete[] B;
+  delete[] C;
+  delete[] C_ref;
+  return 0;
+}

+ 0 - 1
code/src/gemm-ker.cpp

@@ -56,7 +56,6 @@ void GEMM_ker_vec_unrolled(double* C, double* A, double* B) {
     C_vec[j].Store(C+j*M);
 }
 
-
 int main(int argc, char** argv) {
   long L = 1e6;
   constexpr int M = 8, N = 10, K = 40;