From fe6e62a2b2708cf21b9ed6d8c92e605e8d0e1ce0 Mon Sep 17 00:00:00 2001
From: Ian Bell <ian.bell@nist.gov>
Date: Sat, 6 Nov 2021 18:54:56 -0400
Subject: [PATCH] Expose more virial coefficients to Python

---
 include/teqp/derivs.hpp        | 17 +++++++++++++++++
 interface/pybind11_wrapper.hpp |  4 +++-
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/include/teqp/derivs.hpp b/include/teqp/derivs.hpp
index 3bf97db..445b0a5 100644
--- a/include/teqp/derivs.hpp
+++ b/include/teqp/derivs.hpp
@@ -330,6 +330,23 @@ struct VirialDerivatives {
         return o;
     }
 
+    /// This version of the get_Bnvir takes the maximum number of derivatives as a runtime argument
+    /// and then forwards all arguments to the templated function
+    template <ADBackends be = ADBackends::autodiff>
+    static auto get_Bnvir_runtime(const int Nderiv, const Model& model, const Scalar &T, const VectorType& molefrac) {
+        switch(Nderiv){
+            case 2: return get_Bnvir<2,be>(model, T, molefrac);
+            case 3: return get_Bnvir<3,be>(model, T, molefrac);
+            case 4: return get_Bnvir<4,be>(model, T, molefrac);
+            case 5: return get_Bnvir<5,be>(model, T, molefrac);
+            case 6: return get_Bnvir<6,be>(model, T, molefrac);
+            case 7: return get_Bnvir<7,be>(model, T, molefrac);
+            case 8: return get_Bnvir<8,be>(model, T, molefrac);
+            case 9: return get_Bnvir<9,be>(model, T, molefrac);
+            default: throw std::invalid_argument("Only Nderiv up to 9 is supported, get_Bnvir templated function allows more");
+        }
+    }
+
     static auto get_B12vir(const Model& model, const Scalar &T, const VectorType& molefrac) {
     
         auto B2 = get_B2vir(model, T, molefrac); // Overall B2 for mixture
diff --git a/interface/pybind11_wrapper.hpp b/interface/pybind11_wrapper.hpp
index cc11df5..ef65dde 100644
--- a/interface/pybind11_wrapper.hpp
+++ b/interface/pybind11_wrapper.hpp
@@ -37,7 +37,9 @@ void add_derivatives(py::module &m, Wrapper &cls) {
 
     using vd = VirialDerivatives<Model, double, Eigen::Array<double,Eigen::Dynamic,1>>;
     m.def("get_B2vir", &vd::get_B2vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert());
-    cls.def("get_B2vir", [](const Model& m, const double T, const RAX molefrac) { return vd::get_B2vir(m, T, molefrac); }, py::arg("T"), py::arg("molefrac").noconvert());
+    cls.def("get_B2vir", [](const Model& m, const double T, const RAX molefrac) { return vd::get_B2vir(m, T, molefrac); }, py::arg("T"), py::arg("molefrac").noconvert());    
+    cls.def("get_Bnvir", [](const Model& m, const int Nderiv, const double T, const RAX molefrac) { return vd::get_Bnvir_runtime(Nderiv, m, T, molefrac); }, py::arg("Nderiv"), py::arg("T"), py::arg("molefrac").noconvert());
+    
     m.def("get_B12vir", &vd::get_B12vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert());
 
     using ct = CriticalTracing<Model, double, Eigen::Array<double, Eigen::Dynamic, 1>>;
-- 
GitLab