Просмотр исходного кода

Add optimized sctl::pow(Real b, Integer e)

Dhairya Malhotra 6 лет назад
Родитель
Сommit
83df10a9f9
3 измененных файлов с 21 добавлено и 3 удалено
  1. 4 2
      include/sctl/math_utils.hpp
  2. 16 0
      include/sctl/math_utils.txx
  3. 1 1
      include/sctl/vec.hpp

+ 4 - 2
include/sctl/math_utils.hpp

@@ -26,9 +26,11 @@ template <class Real> inline Real exp(const Real a) { return (Real)::exp(a); }
 
 template <class Real> inline Real log(const Real a) { return (Real)::log(a); }
 
-template <class Real> inline Real pow(const Real b, const Real e) { return (Real)::pow(b, e); }
+template <class Real, class ExpType> inline constexpr Real pow(const Real b, const ExpType e) { return (Real)std::pow(b, e); }
 
-template <Integer N, class T> constexpr T pow(const T& x) { return N > 1 ? x * pow<(N - 1) * (N > 1)>(x) : N < 0 ? T(1) / pow<(-N) * (N < 0)>(x) : N == 1 ? x : T(1); }
+template <Integer e, class Real> inline constexpr Real pow(Real b);
+
+template <class Real> inline constexpr Real pow(Real b, Integer e);
 
 }  // end namespace
 

+ 16 - 0
include/sctl/math_utils.txx

@@ -262,6 +262,22 @@ template <class Real> inline std::ostream& ostream_insertion_generic(std::ostrea
   return output;
 }
 
+template <Integer e, class Real> static inline constexpr Real pow_helper(Real b) {
+  return (e > 0) ? ((e & 1) ? b : Real(1)) * pow_helper<(e>>1),Real>(b*b) : Real(1);
+}
+
+template <Integer e, class Real> inline constexpr Real pow(Real b) {
+  return (e > 0) ? pow_helper<e,Real>(b) : 1/pow_helper<-e,Real>(b);
+}
+
+template <class Real> static inline constexpr Real pow_helper(Real b, Integer e) {
+  return (e > 0) ? ((e & 1) ? b : Real(1)) * pow_helper(b*b, e>>1) : Real(1);
+}
+
+template <class Real> inline constexpr Real pow(Real b, Integer e) {
+  return (e > 0) ? pow_helper(b, e) : 1/pow_helper(b, -e);
+}
+
 }  // end namespace
 
 namespace SCTL_NAMESPACE {

+ 1 - 1
include/sctl/vec.hpp

@@ -247,7 +247,7 @@ namespace SCTL_NAMESPACE {
       }
       friend Vec approx_rsqrt(const Vec& x) {
         Vec r;
-        for (int i = 0; i < N; i++) r.v[i] = 1.0 / sqrt(x.v[i]);
+        for (int i = 0; i < N; i++) r.v[i] = 1 / sqrt<ValueType>(x.v[i]);
         return r;
       }