ompUtils.txx 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. #include <omp.h>
  2. #include <cstring>
  3. #include <algorithm>
  4. namespace pvfmm{
  5. template <class T,class StrictWeakOrdering>
  6. void omp_par::merge(T A_,T A_last,T B_,T B_last,T C_,int p,StrictWeakOrdering comp){
  7. typedef typename std::iterator_traits<T>::difference_type _DiffType;
  8. typedef typename std::iterator_traits<T>::value_type _ValType;
  9. _DiffType N1=A_last-A_;
  10. _DiffType N2=B_last-B_;
  11. if(N1==0 && N2==0) return;
  12. if(N1==0 || N2==0){
  13. _ValType* A=(N1==0? &B_[0]: &A_[0]);
  14. _DiffType N=(N1==0? N2 : N1 );
  15. #pragma omp parallel for
  16. for(int i=0;i<p;i++){
  17. _DiffType indx1=( i *N)/p;
  18. _DiffType indx2=((i+1)*N)/p;
  19. memcpy(&C_[indx1], &A[indx1], (indx2-indx1)*sizeof(_ValType));
  20. }
  21. return;
  22. }
  23. //Split both arrays ( A and B ) into n equal parts.
  24. //Find the position of each split in the final merged array.
  25. int n=10;
  26. _ValType* split=new _ValType[p*n*2];
  27. _DiffType* split_size=new _DiffType[p*n*2];
  28. #pragma omp parallel for
  29. for(int i=0;i<p;i++){
  30. for(int j=0;j<n;j++){
  31. int indx=i*n+j;
  32. _DiffType indx1=(indx*N1)/(p*n);
  33. split [indx]=A_[indx1];
  34. split_size[indx]=indx1+(std::lower_bound(B_,B_last,split[indx],comp)-B_);
  35. indx1=(indx*N2)/(p*n);
  36. indx+=p*n;
  37. split [indx]=B_[indx1];
  38. split_size[indx]=indx1+(std::lower_bound(A_,A_last,split[indx],comp)-A_);
  39. }
  40. }
  41. //Find the closest split position for each thread that will
  42. //divide the final array equally between the threads.
  43. _DiffType* split_indx_A=new _DiffType[p+1];
  44. _DiffType* split_indx_B=new _DiffType[p+1];
  45. split_indx_A[0]=0;
  46. split_indx_B[0]=0;
  47. split_indx_A[p]=N1;
  48. split_indx_B[p]=N2;
  49. #pragma omp parallel for
  50. for(int i=1;i<p;i++){
  51. _DiffType req_size=(i*(N1+N2))/p;
  52. int j=std::lower_bound(&split_size[0],&split_size[p*n],req_size,std::less<_DiffType>())-&split_size[0];
  53. if(j>=p*n)
  54. j=p*n-1;
  55. _ValType split1 =split [j];
  56. _DiffType split_size1=split_size[j];
  57. j=(std::lower_bound(&split_size[p*n],&split_size[p*n*2],req_size,std::less<_DiffType>())-&split_size[p*n])+p*n;
  58. if(j>=2*p*n)
  59. j=2*p*n-1;
  60. if(abs(split_size[j]-req_size)<abs(split_size1-req_size)){
  61. split1 =split [j];
  62. split_size1=split_size[j];
  63. }
  64. split_indx_A[i]=std::lower_bound(A_,A_last,split1,comp)-A_;
  65. split_indx_B[i]=std::lower_bound(B_,B_last,split1,comp)-B_;
  66. }
  67. delete[] split;
  68. delete[] split_size;
  69. //Merge for each thread independently.
  70. #pragma omp parallel for
  71. for(int i=0;i<p;i++){
  72. T C=C_+split_indx_A[i]+split_indx_B[i];
  73. std::merge(A_+split_indx_A[i],A_+split_indx_A[i+1],B_+split_indx_B[i],B_+split_indx_B[i+1],C,comp);
  74. }
  75. delete[] split_indx_A;
  76. delete[] split_indx_B;
  77. }
  78. template <class T,class StrictWeakOrdering>
  79. void omp_par::merge_sort(T A,T A_last,StrictWeakOrdering comp){
  80. typedef typename std::iterator_traits<T>::difference_type _DiffType;
  81. typedef typename std::iterator_traits<T>::value_type _ValType;
  82. int p=omp_get_max_threads();
  83. _DiffType N=A_last-A;
  84. if(N<2*p){
  85. std::sort(A,A_last,comp);
  86. return;
  87. }
  88. //Split the array A into p equal parts.
  89. _DiffType* split=new _DiffType[p+1];
  90. split[p]=N;
  91. #pragma omp parallel for
  92. for(int id=0;id<p;id++){
  93. split[id]=(id*N)/p;
  94. }
  95. //Sort each part independently.
  96. #pragma omp parallel for
  97. for(int id=0;id<p;id++){
  98. std::sort(A+split[id],A+split[id+1],comp);
  99. }
  100. //Merge two parts at a time.
  101. _ValType* B=new _ValType[N];
  102. _ValType* A_=&A[0];
  103. _ValType* B_=&B[0];
  104. for(int j=1;j<p;j=j*2){
  105. for(int i=0;i<p;i=i+2*j){
  106. if(i+j<p){
  107. omp_par::merge(A_+split[i],A_+split[i+j],A_+split[i+j],A_+split[(i+2*j<=p?i+2*j:p)],B_+split[i],p,comp);
  108. }else{
  109. #pragma omp parallel for
  110. for(int k=split[i];k<split[p];k++)
  111. B_[k]=A_[k];
  112. }
  113. }
  114. _ValType* tmp_swap=A_;
  115. A_=B_;
  116. B_=tmp_swap;
  117. }
  118. //The final result should be in A.
  119. if(A_!=&A[0]){
  120. #pragma omp parallel for
  121. for(int i=0;i<N;i++)
  122. A[i]=A_[i];
  123. }
  124. //Free memory.
  125. delete[] split;
  126. delete[] B;
  127. }
  128. template <class T>
  129. void omp_par::merge_sort(T A,T A_last){
  130. typedef typename std::iterator_traits<T>::value_type _ValType;
  131. omp_par::merge_sort(A,A_last,std::less<_ValType>());
  132. }
  133. template <class T, class I>
  134. T omp_par::reduce(T* A, I cnt){
  135. T sum=0;
  136. #pragma omp parallel for reduction(+:sum)
  137. for(I i = 0; i < cnt; i++)
  138. sum+=A[i];
  139. return sum;
  140. }
  141. template <class T, class I>
  142. void omp_par::scan(T* A, T* B,I cnt){
  143. int p=omp_get_max_threads();
  144. if(cnt<(I)100*p){
  145. for(I i=1;i<cnt;i++)
  146. B[i]=B[i-1]+A[i-1];
  147. return;
  148. }
  149. I step_size=cnt/p;
  150. #pragma omp parallel for
  151. for(int i=0; i<p; i++){
  152. int start=i*step_size;
  153. int end=start+step_size;
  154. if(i==p-1) end=cnt;
  155. if(i!=0)B[start]=0;
  156. for(I j=(I)start+1; j<(I)end; j++)
  157. B[j]=B[j-1]+A[j-1];
  158. }
  159. T* sum=new T[p];
  160. sum[0]=0;
  161. for(int i=1;i<p;i++)
  162. sum[i]=sum[i-1]+B[i*step_size-1]+A[i*step_size-1];
  163. #pragma omp parallel for
  164. for(int i=1; i<p; i++){
  165. int start=i*step_size;
  166. int end=start+step_size;
  167. if(i==p-1) end=cnt;
  168. T sum_=sum[i];
  169. for(I j=(I)start; j<(I)end; j++)
  170. B[j]+=sum_;
  171. }
  172. delete[] sum;
  173. }
  174. }//end namespace