#ifndef _SCTL_THREAD_COMM_HPP_ #define _SCTL_THREAD_COMM_HPP_ #include SCTL_INCLUDE(common.hpp) #include #include #include #include namespace SCTL_NAMESPACE { class ThreadComm; class ShMem { public: ShMem(Integer count) { size = count; thread_counter = 0; thread_data = aligned_new(size * BlockSize); // TODO: on stack for (Integer i = 0; i < count; i++) { thread_data[i*BlockSize].ptr = nullptr; thread_data[i*BlockSize].init_flag = 0; thread_data[i*BlockSize].sync_flag = 0; } sync_counter0=0; sync_counter1=0; } ~ShMem() { // TODO: uncomment // SCTL_ASSERT(thread_counter == 0); aligned_delete(thread_data); } ShMem(ShMem const&) = delete; void operator=(ShMem const &) = delete; private: friend class ThreadComm; struct ThreadData { void* ptr; Integer init_flag; std::atomic sync_flag; }; static constexpr Long BlockSize = SCTL_CACHE_LINE_SIZE / sizeof(ThreadData) + 2; ThreadData& GetThreadData(Integer i) const { SCTL_ASSERT_MSG(i < size, "invalid thread id"); return thread_data[i * BlockSize]; } Integer InitThread(Integer id) const { lck.lock(); if (id == -1) id = thread_counter; auto& tdata = GetThreadData(id); SCTL_ASSERT_MSG(!tdata.init_flag, "duplicate thread id."); tdata.init_flag = 1; thread_counter++; lck.unlock(); SyncThreads(id); return id; } void SyncThreads0(Integer rank) const { #pragma omp barrier } void SyncThreads1(Integer rank) const { auto& mydata = GetThreadData(rank); Integer mask0, mask; mask0 = 1; while (mask0 < size) mask0 = mask0 << 1; mask = 1; while (mask < size) { Integer partner = rank ^ mask; const auto& partner_data = GetThreadData(partner); mydata.sync_flag = mask0*0 + mask; while (partner_data.sync_flag < mask0*0 + mask || partner_data.sync_flag > mask0*2); mask = mask << 1; } mask = 1; while (mask < size) { Integer partner = rank ^ mask; const auto& partner_data = GetThreadData(partner); mydata.sync_flag = mask0*1 + mask; while (partner_data.sync_flag < mask0*1 + mask); mask = mask << 1; } mask = 1; while (mask < size) { Integer partner = rank ^ mask; const auto& partner_data = GetThreadData(partner); mydata.sync_flag = mask0*2 + mask; while (partner_data.sync_flag < mask0*2 + mask && !(partner_data.sync_flag <= mask0*1)); mask = mask << 1; } mydata.sync_flag = 0; } void SyncThreads2(Integer rank) const { auto& mydata = GetThreadData(rank); for (Long i = 0; i < size; i++) { while (GetThreadData(i).sync_flag == 2); } mydata.sync_flag = 1; for (Long i = 0; i < size; i++) { while (GetThreadData(i).sync_flag == 0); } mydata.sync_flag = 2; for (Long i = 0; i < size; i++) { while (GetThreadData(i).sync_flag == 1); } mydata.sync_flag = 0; // TODO: hypercube } void SyncThreads3(Integer rank) const { sync_counter0++; while(sync_counter0 != 0 && sync_counter0 != size) {} if (!rank) sync_counter0 = 0; sync_counter1++; while(sync_counter1 != 0 && sync_counter1 != size) {} if (!rank) sync_counter1 = 0; } void SyncThreads4(Integer rank) const { Integer sync_counter1_ = sync_counter1.load(std::memory_order_relaxed); if (sync_counter0++ == size - 1) { sync_counter0 = 0; sync_counter1.store(1 - sync_counter1_, std::memory_order_release); } else { while (sync_counter1 == sync_counter1_) {} } //while(sync_counter0 != 0 && sync_counter0 != size) {} //if (!rank) sync_counter0 = 0; //sync_counter1++; //while(sync_counter1 != 0 && sync_counter1 != size) {} //if (!rank) sync_counter1 = 0; } void SyncThreads5(Integer rank) const { //std::atomic_thread_fence(std::memory_order_seq_cst); //std::atomic_thread_fence(std::memory_order_seq_cst); while (sync_counter0 != rank) {} sync_counter0 = rank+1; if (rank==0) { while (sync_counter0 < size) {} sync_counter0 = 0; } while (sync_counter1 != rank) {} sync_counter1 = rank+1; if (rank==0) { while (sync_counter1 < size) {} sync_counter1 = 0; } } void SyncThreads(Integer rank) const { Integer sync_counter1_ = sync_counter1.load(std::memory_order_relaxed); if(sync_counter0.fetch_add(1) == (size - 1)) { sync_counter0 = 0; sync_counter1.store(sync_counter1_+1, std::memory_order_release); } else { while(sync_counter1.load(std::memory_order_relaxed) == sync_counter1_) {}; } std::atomic_thread_fence(std::memory_order_acq_rel); //std::atomic_thread_fence(std::memory_order_seq_cst); } void FinalizeThread(Integer id) const { SyncThreads(id); lck.lock(); auto& tdata = GetThreadData(id); SCTL_ASSERT(tdata.init_flag); tdata.init_flag = 0; thread_counter--; lck.unlock(); } mutable Iterator thread_data; mutable Integer thread_counter; mutable std::mutex lck; Integer size; mutable std::atomic sync_counter0; mutable std::atomic sync_counter1; }; class ThreadComm { public: ThreadComm(const ShMem& m, Integer id = -1) { smem = Ptr2ConstItr(&m, 1); rank = smem->InitThread(id); size = smem->size; } ~ThreadComm() { smem->FinalizeThread(rank); } ThreadComm(ThreadComm const&) = delete; void operator=(ThreadComm const&) = delete; Integer Rank() const { return rank; } Integer Size() const { return size; } void Sync() const { smem->SyncThreads(rank); } //template void Gather(Iterator array, const DataType& a); //template void Broadcast(Iterator array, const DataType& a); //Comm Split(Integer clr) const; //template void* Isend(ConstIterator sbuf, Long scount, Integer dest, Integer tag = 0) const; //template void* Irecv(Iterator rbuf, Long rcount, Integer source, Integer tag = 0) const; //template void Allgather(ConstIterator sbuf, Long scount, Iterator rbuf, Long rcount) const; //template void Allgatherv(ConstIterator sbuf, Long scount, Iterator rbuf, ConstIterator rcounts, ConstIterator rdispls) const; //template void Alltoall(ConstIterator sbuf, Long scount, Iterator rbuf, Long rcount) const; //template void* Ialltoallv_sparse(ConstIterator sbuf, ConstIterator scounts, ConstIterator sdispls, Iterator rbuf, ConstIterator rcounts, ConstIterator rdispls, Integer tag = 0) const; //template void Alltoallv(ConstIterator sbuf, ConstIterator scounts, ConstIterator sdispls, Iterator rbuf, ConstIterator rcounts, ConstIterator rdispls) const; //template void Allreduce(ConstIterator sbuf, Iterator rbuf, Long count, CommOp op) const; //template void Scan(ConstIterator sbuf, Iterator rbuf, int count, CommOp op) const; static void test() { auto fn = [](const ShMem& m) { ThreadComm c(m); Long i=0; for (Long i=0; i< c.Size(); i++){ std::cout< threads; for (Integer i = 0; i < np; i++) threads.push_back(std::thread(fn, std::ref(m))); for (auto& t : threads) t.join(); } else { omp_set_num_threads(np); #pragma omp parallel { fn(m); } } } private: Integer rank, size; ConstIterator smem; }; } // end namespace //#include SCTL_INCLUDE(thread-comm.txx) #endif //_SCTL_THREAD_COMM_HPP_