ompUtils.txx 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #include <cstdlib>
  2. #include <omp.h>
  3. #include <iterator>
  4. #include <vector>
  5. namespace pvfmm{
  6. template <class T,class StrictWeakOrdering>
  7. void omp_par::merge(T A_,T A_last,T B_,T B_last,T C_,int p,StrictWeakOrdering comp){
  8. typedef typename std::iterator_traits<T>::difference_type _DiffType;
  9. typedef typename std::iterator_traits<T>::value_type _ValType;
  10. _DiffType N1=A_last-A_;
  11. _DiffType N2=B_last-B_;
  12. if(N1==0 && N2==0) return;
  13. if(N1==0 || N2==0){
  14. _ValType* A=(N1==0? &B_[0]: &A_[0]);
  15. _DiffType N=(N1==0? N2 : N1 );
  16. #pragma omp parallel for
  17. for(int i=0;i<p;i++){
  18. _DiffType indx1=( i *N)/p;
  19. _DiffType indx2=((i+1)*N)/p;
  20. memcpy(&C_[indx1], &A[indx1], (indx2-indx1)*sizeof(_ValType));
  21. }
  22. return;
  23. }
  24. //Split both arrays ( A and B ) into n equal parts.
  25. //Find the position of each split in the final merged array.
  26. int n=10;
  27. _ValType* split=new _ValType[p*n*2];
  28. _DiffType* split_size=new _DiffType[p*n*2];
  29. #pragma omp parallel for
  30. for(int i=0;i<p;i++){
  31. for(int j=0;j<n;j++){
  32. int indx=i*n+j;
  33. _DiffType indx1=(indx*N1)/(p*n);
  34. split [indx]=A_[indx1];
  35. split_size[indx]=indx1+(std::lower_bound(B_,B_last,split[indx],comp)-B_);
  36. indx1=(indx*N2)/(p*n);
  37. indx+=p*n;
  38. split [indx]=B_[indx1];
  39. split_size[indx]=indx1+(std::lower_bound(A_,A_last,split[indx],comp)-A_);
  40. }
  41. }
  42. //Find the closest split position for each thread that will
  43. //divide the final array equally between the threads.
  44. _DiffType* split_indx_A=new _DiffType[p+1];
  45. _DiffType* split_indx_B=new _DiffType[p+1];
  46. split_indx_A[0]=0;
  47. split_indx_B[0]=0;
  48. split_indx_A[p]=N1;
  49. split_indx_B[p]=N2;
  50. #pragma omp parallel for
  51. for(int i=1;i<p;i++){
  52. _DiffType req_size=(i*(N1+N2))/p;
  53. int j=std::lower_bound(&split_size[0],&split_size[p*n],req_size,std::less<_DiffType>())-&split_size[0];
  54. if(j>=p*n)
  55. j=p*n-1;
  56. _ValType split1 =split [j];
  57. _DiffType split_size1=split_size[j];
  58. j=(std::lower_bound(&split_size[p*n],&split_size[p*n*2],req_size,std::less<_DiffType>())-&split_size[p*n])+p*n;
  59. if(j>=2*p*n)
  60. j=2*p*n-1;
  61. if(abs(split_size[j]-req_size)<abs(split_size1-req_size)){
  62. split1 =split [j];
  63. split_size1=split_size[j];
  64. }
  65. split_indx_A[i]=std::lower_bound(A_,A_last,split1,comp)-A_;
  66. split_indx_B[i]=std::lower_bound(B_,B_last,split1,comp)-B_;
  67. }
  68. delete[] split;
  69. delete[] split_size;
  70. //Merge for each thread independently.
  71. #pragma omp parallel for
  72. for(int i=0;i<p;i++){
  73. T C=C_+split_indx_A[i]+split_indx_B[i];
  74. std::merge(A_+split_indx_A[i],A_+split_indx_A[i+1],B_+split_indx_B[i],B_+split_indx_B[i+1],C,comp);
  75. }
  76. delete[] split_indx_A;
  77. delete[] split_indx_B;
  78. }
  79. template <class T,class StrictWeakOrdering>
  80. void omp_par::merge_sort(T A,T A_last,StrictWeakOrdering comp){
  81. typedef typename std::iterator_traits<T>::difference_type _DiffType;
  82. typedef typename std::iterator_traits<T>::value_type _ValType;
  83. int p=omp_get_max_threads();
  84. _DiffType N=A_last-A;
  85. if(N<2*p){
  86. std::sort(A,A_last,comp);
  87. return;
  88. }
  89. //Split the array A into p equal parts.
  90. _DiffType* split=new _DiffType[p+1];
  91. split[p]=N;
  92. #pragma omp parallel for
  93. for(int id=0;id<p;id++){
  94. split[id]=(id*N)/p;
  95. }
  96. //Sort each part independently.
  97. #pragma omp parallel for
  98. for(int id=0;id<p;id++){
  99. std::sort(A+split[id],A+split[id+1],comp);
  100. }
  101. //Merge two parts at a time.
  102. _ValType* B=new _ValType[N];
  103. _ValType* A_=&A[0];
  104. _ValType* B_=&B[0];
  105. for(int j=1;j<p;j=j*2){
  106. for(int i=0;i<p;i=i+2*j){
  107. if(i+j<p){
  108. omp_par::merge(A_+split[i],A_+split[i+j],A_+split[i+j],A_+split[(i+2*j<=p?i+2*j:p)],B_+split[i],p,comp);
  109. }else{
  110. #pragma omp parallel for
  111. for(int k=split[i];k<split[p];k++)
  112. B_[k]=A_[k];
  113. }
  114. }
  115. _ValType* tmp_swap=A_;
  116. A_=B_;
  117. B_=tmp_swap;
  118. }
  119. //The final result should be in A.
  120. if(A_!=&A[0]){
  121. #pragma omp parallel for
  122. for(int i=0;i<N;i++)
  123. A[i]=A_[i];
  124. }
  125. //Free memory.
  126. delete[] split;
  127. delete[] B;
  128. }
  129. template <class T>
  130. void omp_par::merge_sort(T A,T A_last){
  131. typedef typename std::iterator_traits<T>::value_type _ValType;
  132. omp_par::merge_sort(A,A_last,std::less<_ValType>());
  133. }
  134. template <class T, class I>
  135. T omp_par::reduce(T* A, I cnt){
  136. T sum=0;
  137. #pragma omp parallel for reduction(+:sum)
  138. for(I i = 0; i < cnt; i++)
  139. sum+=A[i];
  140. return sum;
  141. }
  142. template <class T, class I>
  143. void omp_par::scan(T* A, T* B,I cnt){
  144. int p=omp_get_max_threads();
  145. if(cnt<(I)100*p){
  146. for(I i=1;i<cnt;i++)
  147. B[i]=B[i-1]+A[i-1];
  148. return;
  149. }
  150. I step_size=cnt/p;
  151. #pragma omp parallel for
  152. for(int i=0; i<p; i++){
  153. int start=i*step_size;
  154. int end=start+step_size;
  155. if(i==p-1) end=cnt;
  156. if(i!=0)B[start]=0;
  157. for(I j=(I)start+1; j<(I)end; j++)
  158. B[j]=B[j-1]+A[j-1];
  159. }
  160. T* sum=new T[p];
  161. sum[0]=0;
  162. for(int i=1;i<p;i++)
  163. sum[i]=sum[i-1]+B[i*step_size-1]+A[i*step_size-1];
  164. #pragma omp parallel for
  165. for(int i=1; i<p; i++){
  166. int start=i*step_size;
  167. int end=start+step_size;
  168. if(i==p-1) end=cnt;
  169. T sum_=sum[i];
  170. for(I j=(I)start; j<(I)end; j++)
  171. B[j]+=sum_;
  172. }
  173. delete[] sum;
  174. }
  175. }//end namespace