Dhairya Malhotra 7 年之前
父節點
當前提交
a764e60ea8
共有 1 個文件被更改,包括 6 次插入3 次删除
  1. 6 3
      include/sctl/comm.txx

+ 6 - 3
include/sctl/comm.txx

@@ -1012,6 +1012,10 @@ template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector
           splt_count = (100 * nelem) / totSize;
           if (npes > 100) splt_count = (drand48() * totSize) < (100 * nelem) ? 1 : 0;
           if (splt_count > nelem) splt_count = nelem;
+          MPI_Allreduce  (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), mpi_comm_);
+          if (!glb_splt_count) splt_count = std::min<Long>(1, nelem);
+          MPI_Allreduce  (&splt_count, &glb_splt_count, 1, CommDatatype<Integer>::value(), CommDatatype<Integer>::sum(), mpi_comm_);
+          SCTL_ASSERT(glb_splt_count);
         }
 
         Vector<Type> splitters(splt_count);
@@ -1020,12 +1024,11 @@ template <class Type> void Comm::HyperQuickSort(const Vector<Type>& arr_, Vector
         }
 
         Vector<Integer> glb_splt_cnts(npes), glb_splt_disp(npes);
-        {  // Set glb_splt_count, glb_splt_cnts, glb_splt_disp
+        {  // Set glb_splt_cnts, glb_splt_disp
           MPI_Allgather(&splt_count, 1, CommDatatype<Integer>::value(), &glb_splt_cnts[0], 1, CommDatatype<Integer>::value(), comm);
           glb_splt_disp[0] = 0;
           omp_par::scan(glb_splt_cnts.begin(), glb_splt_disp.begin(), npes);
-          glb_splt_count = glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1];
-          SCTL_ASSERT(glb_splt_count);
+          SCTL_ASSERT(glb_splt_count == glb_splt_cnts[npes - 1] + glb_splt_disp[npes - 1]);
         }
 
         {  // Gather all splitters. O( log(p) )