dtypes.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #include <mpi.h>
  2. #include <complex>
  3. #ifndef __PVFMM_DTYPES_H_
  4. #define __PVFMM_DTYPES_H_
  5. /**
  6. * \file dtypes.h
  7. * \brief Traits to determine MPI_DATATYPE from a C++ datatype
  8. * \author Hari Sundar, hsundar@gmail.com
  9. Traits to determine MPI_DATATYPE from a C++ datatype. For non standard
  10. C++ datatypes (like classes), we will need to define additional classes. An
  11. example is given for the case of the std. complex variable. Additional
  12. classes can be added as required.
  13. */
  14. namespace pvfmm{
  15. namespace par{
  16. /**
  17. * \class Mpi_datatype
  18. * \brief An abstract class used for communicating messages using user-defined
  19. * datatypes. The user must implement the static member function "value()" that
  20. * returns the MPI_Datatype corresponding to this user-defined datatype.
  21. * \author Hari Sundar, hsundar@gmail.com
  22. * \see Mpi_datatype<bool>
  23. */
  24. template <typename T>
  25. class Mpi_datatype{
  26. public:
  27. static MPI_Datatype value() {
  28. static bool first = true;
  29. static MPI_Datatype datatype;
  30. if (first) {
  31. first = false;
  32. MPI_Type_contiguous(sizeof(T), MPI_BYTE, &datatype);
  33. MPI_Type_commit(&datatype);
  34. }
  35. return datatype;
  36. }
  37. static MPI_Op sum() {
  38. static bool first = true;
  39. static MPI_Op myop;
  40. if (first) {
  41. first = false;
  42. int commune=1;
  43. MPI_Op_create(sum_fn, commune, &myop);
  44. }
  45. return myop;
  46. }
  47. static MPI_Op max() {
  48. static bool first = true;
  49. static MPI_Op myop;
  50. if (first) {
  51. first = false;
  52. int commune=1;
  53. MPI_Op_create(max_fn, commune, &myop);
  54. }
  55. return myop;
  56. }
  57. private:
  58. static void sum_fn( void * a_, void * b_, int * len_, MPI_Datatype * datatype){
  59. T* a=(T*)a_;
  60. T* b=(T*)b_;
  61. int len=*len_;
  62. for(int i=0;i<len;i++){
  63. b[i]=a[i]+b[i];
  64. }
  65. }
  66. static void max_fn( void * a_, void * b_, int * len_, MPI_Datatype * datatype){
  67. T* a=(T*)a_;
  68. T* b=(T*)b_;
  69. int len=*len_;
  70. for(int i=0;i<len;i++){
  71. if(a[i]>b[i]) b[i]=a[i];
  72. }
  73. }
  74. };
  75. #define HS_MPIDATATYPE(CTYPE, MPITYPE) \
  76. template <> \
  77. class Mpi_datatype<CTYPE> { \
  78. public: \
  79. static MPI_Datatype value() { \
  80. return MPITYPE; \
  81. } \
  82. static MPI_Op sum() { \
  83. return MPI_SUM; \
  84. } \
  85. static MPI_Op max() { \
  86. return MPI_MAX; \
  87. } \
  88. };
  89. HS_MPIDATATYPE(short, MPI_SHORT)
  90. HS_MPIDATATYPE(int, MPI_INT)
  91. HS_MPIDATATYPE(long, MPI_LONG)
  92. HS_MPIDATATYPE(unsigned short, MPI_UNSIGNED_SHORT)
  93. HS_MPIDATATYPE(unsigned int, MPI_UNSIGNED)
  94. HS_MPIDATATYPE(unsigned long, MPI_UNSIGNED_LONG)
  95. HS_MPIDATATYPE(float, MPI_FLOAT)
  96. HS_MPIDATATYPE(double, MPI_DOUBLE)
  97. HS_MPIDATATYPE(long double, MPI_LONG_DOUBLE)
  98. HS_MPIDATATYPE(long long, MPI_LONG_LONG_INT)
  99. HS_MPIDATATYPE(char, MPI_CHAR)
  100. HS_MPIDATATYPE(unsigned char, MPI_UNSIGNED_CHAR)
  101. //PetscScalar is simply a typedef for double. Hence no need to explicitly
  102. //define an mpi_datatype for it.
  103. #undef HS_MPIDATATYPE
  104. template <typename T>
  105. class Mpi_datatype<std::complex<T> > {
  106. public:
  107. static MPI_Datatype value() {
  108. static bool first = true;
  109. static MPI_Datatype datatype;
  110. if (first) {
  111. first = false;
  112. MPI_Type_contiguous(2, Mpi_datatype<T>::value(), &datatype);
  113. MPI_Type_commit(&datatype);
  114. }
  115. return datatype;
  116. }
  117. };
  118. } //end namespace
  119. } //end namespace
  120. #endif //__PVFMM_DTYPES_H_