123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278 |
- #ifndef _SCTL_THREAD_COMM_HPP_
- #define _SCTL_THREAD_COMM_HPP_
- #include SCTL_INCLUDE(common.hpp)
- #include <thread>
- #include <mutex>
- #include <atomic>
- #include <condition_variable>
- namespace SCTL_NAMESPACE {
- class ThreadComm;
- class ShMem {
- public:
- ShMem(Integer count) {
- size = count;
- thread_counter = 0;
- thread_data = aligned_new<ThreadData>(size * BlockSize); // TODO: on stack
- for (Integer i = 0; i < count; i++) {
- thread_data[i*BlockSize].ptr = nullptr;
- thread_data[i*BlockSize].init_flag = 0;
- thread_data[i*BlockSize].sync_flag = 0;
- }
- sync_counter0=0;
- sync_counter1=0;
- }
- ~ShMem() {
- // TODO: uncomment // SCTL_ASSERT(thread_counter == 0);
- aligned_delete<ThreadData>(thread_data);
- }
- ShMem(ShMem const&) = delete;
- void operator=(ShMem const &) = delete;
- private:
- friend class ThreadComm;
- struct ThreadData {
- void* ptr;
- Integer init_flag;
- std::atomic<Integer> sync_flag;
- };
- static constexpr Long BlockSize = SCTL_CACHE_LINE_SIZE / sizeof(ThreadData) + 2;
- ThreadData& GetThreadData(Integer i) const {
- SCTL_ASSERT_MSG(i < size, "invalid thread id");
- return thread_data[i * BlockSize];
- }
- Integer InitThread(Integer id) const {
- lck.lock();
- if (id == -1) id = thread_counter;
- auto& tdata = GetThreadData(id);
- SCTL_ASSERT_MSG(!tdata.init_flag, "duplicate thread id.");
- tdata.init_flag = 1;
- thread_counter++;
- lck.unlock();
- SyncThreads(id);
- return id;
- }
- void SyncThreads0(Integer rank) const {
- #pragma omp barrier
- }
- void SyncThreads1(Integer rank) const {
- auto& mydata = GetThreadData(rank);
- Integer mask0, mask;
- mask0 = 1;
- while (mask0 < size) mask0 = mask0 << 1;
- mask = 1;
- while (mask < size) {
- Integer partner = rank ^ mask;
- const auto& partner_data = GetThreadData(partner);
- mydata.sync_flag = mask0*0 + mask;
- while (partner_data.sync_flag < mask0*0 + mask || partner_data.sync_flag > mask0*2);
- mask = mask << 1;
- }
- mask = 1;
- while (mask < size) {
- Integer partner = rank ^ mask;
- const auto& partner_data = GetThreadData(partner);
- mydata.sync_flag = mask0*1 + mask;
- while (partner_data.sync_flag < mask0*1 + mask);
- mask = mask << 1;
- }
- mask = 1;
- while (mask < size) {
- Integer partner = rank ^ mask;
- const auto& partner_data = GetThreadData(partner);
- mydata.sync_flag = mask0*2 + mask;
- while (partner_data.sync_flag < mask0*2 + mask && !(partner_data.sync_flag <= mask0*1));
- mask = mask << 1;
- }
- mydata.sync_flag = 0;
- }
- void SyncThreads2(Integer rank) const {
- auto& mydata = GetThreadData(rank);
- for (Long i = 0; i < size; i++) {
- while (GetThreadData(i).sync_flag == 2);
- }
- mydata.sync_flag = 1;
- for (Long i = 0; i < size; i++) {
- while (GetThreadData(i).sync_flag == 0);
- }
- mydata.sync_flag = 2;
- for (Long i = 0; i < size; i++) {
- while (GetThreadData(i).sync_flag == 1);
- }
- mydata.sync_flag = 0;
- // TODO: hypercube
- }
- void SyncThreads3(Integer rank) const {
- sync_counter0++;
- while(sync_counter0 != 0 && sync_counter0 != size) {}
- if (!rank) sync_counter0 = 0;
- sync_counter1++;
- while(sync_counter1 != 0 && sync_counter1 != size) {}
- if (!rank) sync_counter1 = 0;
- }
- void SyncThreads4(Integer rank) const {
- Integer sync_counter1_ = sync_counter1.load(std::memory_order_relaxed);
- if (sync_counter0++ == size - 1) {
- sync_counter0 = 0;
- sync_counter1.store(1 - sync_counter1_, std::memory_order_release);
- } else {
- while (sync_counter1 == sync_counter1_) {}
- }
- //while(sync_counter0 != 0 && sync_counter0 != size) {}
- //if (!rank) sync_counter0 = 0;
- //sync_counter1++;
- //while(sync_counter1 != 0 && sync_counter1 != size) {}
- //if (!rank) sync_counter1 = 0;
- }
- void SyncThreads5(Integer rank) const {
- //std::atomic_thread_fence(std::memory_order_seq_cst);
- //std::atomic_thread_fence(std::memory_order_seq_cst);
- while (sync_counter0 != rank) {}
- sync_counter0 = rank+1;
- if (rank==0) {
- while (sync_counter0 < size) {}
- sync_counter0 = 0;
- }
- while (sync_counter1 != rank) {}
- sync_counter1 = rank+1;
- if (rank==0) {
- while (sync_counter1 < size) {}
- sync_counter1 = 0;
- }
- }
- void SyncThreads(Integer rank) const {
- Integer sync_counter1_ = sync_counter1.load(std::memory_order_relaxed);
- if(sync_counter0.fetch_add(1) == (size - 1)) {
- sync_counter0 = 0;
- sync_counter1.store(sync_counter1_+1, std::memory_order_release);
- } else {
- while(sync_counter1.load(std::memory_order_relaxed) == sync_counter1_) {};
- }
- std::atomic_thread_fence(std::memory_order_acq_rel);
- //std::atomic_thread_fence(std::memory_order_seq_cst);
- }
- void FinalizeThread(Integer id) const {
- SyncThreads(id);
- lck.lock();
- auto& tdata = GetThreadData(id);
- SCTL_ASSERT(tdata.init_flag);
- tdata.init_flag = 0;
- thread_counter--;
- lck.unlock();
- }
- mutable Iterator<ThreadData> thread_data;
- mutable Integer thread_counter;
- mutable std::mutex lck;
- Integer size;
- mutable std::atomic<Integer> sync_counter0;
- mutable std::atomic<Integer> sync_counter1;
- };
- class ThreadComm {
- public:
- ThreadComm(const ShMem& m, Integer id = -1) {
- smem = Ptr2ConstItr<ShMem>(&m, 1);
- rank = smem->InitThread(id);
- size = smem->size;
- }
- ~ThreadComm() {
- smem->FinalizeThread(rank);
- }
- ThreadComm(ThreadComm const&) = delete;
- void operator=(ThreadComm const&) = delete;
- Integer Rank() const { return rank; }
- Integer Size() const { return size; }
- void Sync() const { smem->SyncThreads(rank); }
- //template <class DataType> void Gather(Iterator<DataType> array, const DataType& a);
- //template <class DataType> void Broadcast(Iterator<DataType> array, const DataType& a);
- //Comm Split(Integer clr) const;
- //template <class SType> void* Isend(ConstIterator<SType> sbuf, Long scount, Integer dest, Integer tag = 0) const;
- //template <class RType> void* Irecv(Iterator<RType> rbuf, Long rcount, Integer source, Integer tag = 0) const;
- //template <class SType, class RType> void Allgather(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
- //template <class SType, class RType> void Allgatherv(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
- //template <class SType, class RType> void Alltoall(ConstIterator<SType> sbuf, Long scount, Iterator<RType> rbuf, Long rcount) const;
- //template <class SType, class RType> void* Ialltoallv_sparse(ConstIterator<SType> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<RType> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls, Integer tag = 0) const;
- //template <class Type> void Alltoallv(ConstIterator<Type> sbuf, ConstIterator<Long> scounts, ConstIterator<Long> sdispls, Iterator<Type> rbuf, ConstIterator<Long> rcounts, ConstIterator<Long> rdispls) const;
- //template <class Type> void Allreduce(ConstIterator<Type> sbuf, Iterator<Type> rbuf, Long count, CommOp op) const;
- //template <class Type> void Scan(ConstIterator<Type> sbuf, Iterator<Type> rbuf, int count, CommOp op) const;
- static void test() {
- auto fn = [](const ShMem& m) {
- ThreadComm c(m);
- Long i=0;
- for (Long i=0; i< c.Size(); i++){
- std::cout<<i;
- c.Sync();
- if (!c.Rank()) std::cout<<'\n';
- c.Sync();
- }
- double tt[2]={0,0};
- while (1) {
- c.Sync();
- i++;
- if (c.Rank() ==0 && i%10000000 == 0) {
- tt[1] = tt[0];
- tt[0] = omp_get_wtime();
- std::cout<<tt[0]-tt[1]<<'\n';
- }
- }
- };
- Long np = 4;
- ShMem m(np);
- if (1) {
- std::vector<std::thread> threads;
- for (Integer i = 0; i < np; i++) threads.push_back(std::thread(fn, std::ref(m)));
- for (auto& t : threads) t.join();
- } else {
- omp_set_num_threads(np);
- #pragma omp parallel
- {
- fn(m);
- }
- }
- }
- private:
- Integer rank, size;
- ConstIterator<ShMem> smem;
- };
- } // end namespace
- //#include SCTL_INCLUDE(thread-comm.txx)
- #endif //_SCTL_THREAD_COMM_HPP_
|