From 38b602a5791b0847f2a6c4599de7cdb8ad59da42 Mon Sep 17 00:00:00 2001
From: Ian Bell <ian.bell@nist.gov>
Date: Mon, 29 Mar 2021 15:43:50 -0400
Subject: [PATCH] Begin transition to compile-time selection of algorithmic
 differentiation backend

---
 include/teqp/derivs.hpp | 42 ++++++++++++++++++++---------------------
 src/multifluid.cpp      |  6 +++---
 2 files changed, 24 insertions(+), 24 deletions(-)

diff --git a/include/teqp/derivs.hpp b/include/teqp/derivs.hpp
index 17a3394..3149bd3 100644
--- a/include/teqp/derivs.hpp
+++ b/include/teqp/derivs.hpp
@@ -82,28 +82,28 @@ typename ContainerType::value_type get_Ar10(const Model& model, const TType T, c
     return -T*model.alphar(std::complex<TType>(T, h), rho, molefrac); // Complex step derivative
 }
 
-template <typename Model, typename TType, typename RhoType, typename MoleFracType>
-auto get_Ar01(const Model& model, const TType &T, const RhoType &rho, const MoleFracType& molefrac) {
-    double h = 1e-100;
-    auto der = model.alphar(T, std::complex<double>(rho, h), molefrac).imag() / h;
-    return der*rho;
-}
+enum class ADBackends { autodiff, multicomplex, complex_step } ;
 
-template <typename Model, typename TType, typename RhoType, typename MoleFracType>
-auto get_Ar01mcx(const Model& model, const TType& T, const RhoType& rho, const MoleFracType& molefrac) {
-    using fcn_t = std::function<mcx::MultiComplex<double>(const mcx::MultiComplex<double>&)>;
-    bool and_val = true;
-    fcn_t f = [&model, &T, &molefrac](const auto& rho_) { return model.alphar(T, rho_, molefrac); };
-    auto ders = diff_mcx1(f, rho, 1, and_val);
-    return ders[1] * rho;
-}
-
-template <typename Model, typename TType, typename RhoType, typename MoleFracType>
-auto get_Ar01ad(const Model& model, const TType& T, const RhoType& rho, const MoleFracType& molefrac) {
-    autodiff::dual rhodual = rho;
-    auto f = [&model, &T, &molefrac](const auto& rho_) { return eval(model.alphar(T, rho_, molefrac)); };
-    auto der = derivative(f, wrt(rhodual), at(rhodual));
-    return der * rho;
+template <ADBackends be = ADBackends::autodiff, typename Model, typename TType, typename RhoType, typename MoleFracType>
+auto get_Ar01(const Model& model, const TType &T, const RhoType &rho, const MoleFracType& molefrac) {
+    if constexpr(be == ADBackends::complex_step){
+        double h = 1e-100;
+        auto der = model.alphar(T, std::complex<double>(rho, h), molefrac).imag() / h;
+        return der*rho;
+    }
+    else if constexpr(be == ADBackends::multicomplex){
+        using fcn_t = std::function<mcx::MultiComplex<double>(const mcx::MultiComplex<double>&)>;
+        bool and_val = true;
+        fcn_t f = [&model, &T, &molefrac](const auto& rho_) { return model.alphar(T, rho_, molefrac); };
+        auto ders = diff_mcx1(f, rho, 1, and_val);
+        return ders[1] * rho;
+    }
+    else if constexpr(be == ADBackends::autodiff){
+        autodiff::dual rhodual = rho;
+        auto f = [&model, &T, &molefrac](const auto& rho_) { return eval(model.alphar(T, rho_, molefrac)); };
+        auto der = derivative(f, wrt(rhodual), at(rhodual));
+        return der * rho;
+    }
 }
 
 template <typename Model, typename TType, typename RhoType, typename MoleFracType>
diff --git a/src/multifluid.cpp b/src/multifluid.cpp
index 03b9f56..db3d212 100644
--- a/src/multifluid.cpp
+++ b/src/multifluid.cpp
@@ -259,21 +259,21 @@ int main(){
         {
             Timer t(N);
             for (auto i = 0; i < N; ++i) {
-                alphar = get_Ar01(model, T, rho, molefrac);
+                alphar = get_Ar01<ADBackends::complex_step>(model, T, rho, molefrac);
             }
             std::cout << alphar << "; 1st CSD" << std::endl;
         }
         {
             Timer t(N);
             for (auto i = 0; i < N; ++i) {
-                alphar = get_Ar01ad(model, T, rho, molefrac);
+                alphar = get_Ar01<ADBackends::autodiff>(model, T, rho, molefrac);
             }
             std::cout << alphar << "; 1st autodiff::autodiff" << std::endl;
         }
         {
             Timer t(N);
             for (auto i = 0; i < N; ++i) {
-                alphar = get_Ar01mcx(model, T, rho, molefrac);
+                alphar = get_Ar01<ADBackends::multicomplex>(model, T, rho, molefrac);
             }
             std::cout << alphar << "; 1st MCX" << std::endl;
         }
-- 
GitLab