thread-comm.hpp 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #ifndef _SCTL_THREAD_COMM_HPP_
  2. #define _SCTL_THREAD_COMM_HPP_
  3. #include SCTL_INCLUDE(common.hpp)
  4. #include <thread>
  5. #include <mutex>
  6. #include <atomic>
  7. #include <condition_variable>
  8. namespace SCTL_NAMESPACE {
  9. class ThreadComm;
  10. class ShMem {
  11. public:
  12. ShMem(Integer count) {
  13. size = count;
  14. thread_counter = 0;
  15. thread_data = aligned_new<ThreadData>(size * BlockSize); // TODO: on stack
  16. for (Integer i = 0; i < count; i++) {
  17. thread_data[i*BlockSize].ptr = nullptr;
  18. thread_data[i*BlockSize].init_flag = 0;
  19. thread_data[i*BlockSize].sync_flag = 0;
  20. }
  21. sync_counter0=0;
  22. sync_counter1=0;
  23. }
  24. ~ShMem() {
  25. // TODO: uncomment // SCTL_ASSERT(thread_counter == 0);
  26. aligned_delete<ThreadData>(thread_data);
  27. }
  28. ShMem(ShMem const&) = delete;
  29. void operator=(ShMem const &) = delete;
  30. private:
  31. friend class ThreadComm;
  32. struct ThreadData {
  33. void* ptr;
  34. Integer init_flag;
  35. std::atomic<Integer> sync_flag;
  36. };
  37. static constexpr Long BlockSize = SCTL_CACHE_LINE_SIZE / sizeof(ThreadData) + 2;
  38. ThreadData& GetThreadData(Integer i) const {
  39. SCTL_ASSERT_MSG(i < size, "invalid thread id");
  40. return thread_data[i * BlockSize];
  41. }
  42. Integer InitThread(Integer id) const {
  43. lck.lock();
  44. if (id == -1) id = thread_counter;
  45. auto& tdata = GetThreadData(id);
  46. SCTL_ASSERT_MSG(!tdata.init_flag, "duplicate thread id.");
  47. tdata.init_flag = 1;
  48. thread_counter++;
  49. lck.unlock();
  50. SyncThreads(id);
  51. return id;
  52. }
  53. void SyncThreads0(Integer rank) const {
  54. #pragma omp barrier
  55. }
  56. void SyncThreads1(Integer rank) const {
  57. auto& mydata = GetThreadData(rank);
  58. Integer mask0, mask;
  59. mask0 = 1;
  60. while (mask0 < size) mask0 = mask0 << 1;
  61. mask = 1;
  62. while (mask < size) {
  63. Integer partner = rank ^ mask;
  64. const auto& partner_data = GetThreadData(partner);
  65. mydata.sync_flag = mask0*0 + mask;
  66. while (partner_data.sync_flag < mask0*0 + mask || partner_data.sync_flag > mask0*2);
  67. mask = mask << 1;
  68. }
  69. mask = 1;
  70. while (mask < size) {
  71. Integer partner = rank ^ mask;
  72. const auto& partner_data = GetThreadData(partner);
  73. mydata.sync_flag = mask0*1 + mask;
  74. while (partner_data.sync_flag < mask0*1 + mask);
  75. mask = mask << 1;
  76. }
  77. mask = 1;
  78. while (mask < size) {
  79. Integer partner = rank ^ mask;
  80. const auto& partner_data = GetThreadData(partner);
  81. mydata.sync_flag = mask0*2 + mask;
  82. while (partner_data.sync_flag < mask0*2 + mask && !(partner_data.sync_flag <= mask0*1));
  83. mask = mask << 1;
  84. }
  85. mydata.sync_flag = 0;
  86. }
  87. void SyncThreads2(Integer rank) const {
  88. auto& mydata = GetThreadData(rank);
  89. for (Long i = 0; i < size; i++) {
  90. while (GetThreadData(i).sync_flag == 2);
  91. }
  92. mydata.sync_flag = 1;
  93. for (Long i = 0; i < size; i++) {
  94. while (GetThreadData(i).sync_flag == 0);
  95. }
  96. mydata.sync_flag = 2;
  97. for (Long i = 0; i < size; i++) {
  98. while (GetThreadData(i).sync_flag == 1);
  99. }
  100. mydata.sync_flag = 0;
  101. // TODO: hypercube
  102. }
  103. void SyncThreads3(Integer rank) const {
  104. sync_counter0++;
  105. while(sync_counter0 != 0 && sync_counter0 != size) {}
  106. if (!rank) sync_counter0 = 0;
  107. sync_counter1++;
  108. while(sync_counter1 != 0 && sync_counter1 != size) {}
  109. if (!rank) sync_counter1 = 0;
  110. }
  111. void SyncThreads4(Integer rank) const {
  112. Integer sync_counter1_ = sync_counter1.load(std::memory_order_relaxed);
  113. if (sync_counter0++ == size - 1) {
  114. sync_counter0 = 0;
  115. sync_counter1.store(1 - sync_counter1_, std::memory_order_release);
  116. } else {
  117. while (sync_counter1 == sync_counter1_) {}
  118. }
  119. //while(sync_counter0 != 0 && sync_counter0 != size) {}
  120. //if (!rank) sync_counter0 = 0;
  121. //sync_counter1++;
  122. //while(sync_counter1 != 0 && sync_counter1 != size) {}
  123. //if (!rank) sync_counter1 = 0;
  124. }
  125. void SyncThreads5(Integer rank) const {
  126. //std::atomic_thread_fence(std::memory_order_seq_cst);
  127. //std::atomic_thread_fence(std::memory_order_seq_cst);
  128. while (sync_counter0 != rank) {}
  129. sync_counter0 = rank+1;
  130. if (rank==0) {
  131. while (sync_counter0 < size) {}
  132. sync_counter0 = 0;
  133. }
  134. while (sync_counter1 != rank) {}
  135. sync_counter1 = rank+1;
  136. if (rank==0) {
  137. while (sync_counter1 < size) {}
  138. sync_counter1 = 0;
  139. }
  140. }
  141. void SyncThreads(Integer rank) const {
  142. Integer sync_counter1_ = sync_counter1.load(std::memory_order_relaxed);
  143. if(sync_counter0.fetch_add(1) == (size - 1)) {
  144. sync_counter0 = 0;
  145. sync_counter1.store(sync_counter1_+1, std::memory_order_release);
  146. } else {
  147. while(sync_counter1.load(std::memory_order_relaxed) == sync_counter1_) {};
  148. }
  149. std::atomic_thread_fence(std::memory_order_acq_rel);
  150. //std::atomic_thread_fence(std::memory_order_seq_cst);
  151. }
  152. void FinalizeThread(Integer id) const {
  153. SyncThreads(id);
  154. lck.lock();
  155. auto& tdata = GetThreadData(id);
  156. SCTL_ASSERT(tdata.init_flag);
  157. tdata.init_flag = 0;
  158. thread_counter--;
  159. lck.unlock();
  160. }
  161. mutable Iterator<ThreadData> thread_data;
  162. mutable Integer thread_counter;
  163. mutable std::mutex lck;
  164. Integer size;
  165. mutable std::atomic<Integer> sync_counter0;
  166. mutable std::atomic<Integer> sync_counter1;
  167. };
  168. class ThreadComm {
  169. public:
  170. ThreadComm(const ShMem& m, Integer id = -1) {
  171. smem = Ptr2ConstItr<ShMem>(&m, 1);
  172. rank = smem->InitThread(id);
  173. size = smem->size;
  174. }
  175. ~ThreadComm() {
  176. smem->FinalizeThread(rank);
  177. }
  178. ThreadComm(ThreadComm const&) = delete;
  179. void operator=(ThreadComm const&) = delete;
  180. Integer Rank() const { return rank; }
  181. Integer Size() const { return size; }
  182. void Sync() const { smem->SyncThreads(rank); }
  183. //template <class DataType> void Gather(Iterator<DataType> array, const DataType& a);
  184. //template <class DataType> void Broadcast(Iterator<DataType> array, const DataType& a);
  185. //Comm Split(Integer clr) const;
  186. //template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
  187. //template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
  188. //template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  189. //template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  190. //template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
  191. //template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
  192. //template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
  193. //template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
  194. //template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
  195. static void test() {
  196. auto fn = [](const ShMem& m) {
  197. ThreadComm c(m);
  198. Long i=0;
  199. for (Long i=0; i< c.Size(); i++){
  200. std::cout<<i;
  201. c.Sync();
  202. if (!c.Rank()) std::cout<<'\n';
  203. c.Sync();
  204. }
  205. double tt[2]={0,0};
  206. while (1) {
  207. c.Sync();
  208. i++;
  209. if (c.Rank() ==0 && i%10000000 == 0) {
  210. tt[1] = tt[0];
  211. tt[0] = omp_get_wtime();
  212. std::cout<<tt[0]-tt[1]<<'\n';
  213. }
  214. }
  215. };
  216. Long np = 4;
  217. ShMem m(np);
  218. if (1) {
  219. std::vector<std::thread> threads;
  220. for (Integer i = 0; i < np; i++) threads.push_back(std::thread(fn, std::ref(m)));
  221. for (auto& t : threads) t.join();
  222. } else {
  223. omp_set_num_threads(np);
  224. #pragma omp parallel
  225. {
  226. fn(m);
  227. }
  228. }
  229. }
  230. private:
  231. Integer rank, size;
  232. ConstIterator<ShMem> smem;
  233. };
  234. } // end namespace
  235. //#include SCTL_INCLUDE(thread-comm.txx)
  236. #endif //_SCTL_THREAD_COMM_HPP_