From 187b8c9836cb1fb065706bbdce88038877b8bcaa Mon Sep 17 00:00:00 2001
From: Ian Bell <ian.bell@nist.gov>
Date: Wed, 3 Aug 2022 14:20:14 -0400
Subject: [PATCH] Simplify and expand the Arxy method wrapping

---
 interface/pybind11_wrapper.hpp | 24 +++++++++++++++++++-----
 1 file changed, 19 insertions(+), 5 deletions(-)

diff --git a/interface/pybind11_wrapper.hpp b/interface/pybind11_wrapper.hpp
index c573128..ed2d21d 100644
--- a/interface/pybind11_wrapper.hpp
+++ b/interface/pybind11_wrapper.hpp
@@ -16,6 +16,24 @@
 namespace py = pybind11;
 using namespace teqp;
 
+template<typename Model, int iT, int iD, typename Class>
+void add_res_deriv_impl(Class& cls) {
+    using idx = TDXDerivatives<Model>;
+    using RAX = Eigen::Ref<Eigen::ArrayXd>;
+    const std::string fname = "get_Ar" + std::to_string(iT) + std::to_string(iD);
+    cls.def(fname.c_str(),
+        [](const Model& m, const double T, const double rho, const RAX molefrac) { return idx::template get_Arxy<iT, iD, ADBackends::autodiff>(m, T, rho, molefrac); },
+        py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()
+    );
+}
+
+template<typename Model, typename Class>
+void add_res_derivatives(Class& cls) {
+    add_res_deriv_impl<Model, 0, 0>(cls); add_res_deriv_impl<Model, 0, 1>(cls); add_res_deriv_impl<Model, 0, 2>(cls); add_res_deriv_impl<Model, 0, 3>(cls); add_res_deriv_impl<Model, 0, 4>(cls);
+    add_res_deriv_impl<Model, 1, 0>(cls); add_res_deriv_impl<Model, 1, 1>(cls); add_res_deriv_impl<Model, 1, 2>(cls); add_res_deriv_impl<Model, 1, 3>(cls); add_res_deriv_impl<Model, 1, 4>(cls);
+    add_res_deriv_impl<Model, 2, 0>(cls); add_res_deriv_impl<Model, 2, 1>(cls); add_res_deriv_impl<Model, 2, 2>(cls); add_res_deriv_impl<Model, 2, 3>(cls); add_res_deriv_impl<Model, 2, 4>(cls);
+}
+
 template<typename Model, typename Wrapper>
 void add_derivatives(py::module &m, Wrapper &cls) {
 
@@ -75,11 +93,7 @@ void add_derivatives(py::module &m, Wrapper &cls) {
     
     cls.def("get_R", [](const Model& m, const RAX molefrac) { return m.R(molefrac); }, py::arg("molefrac").noconvert());
     cls.def("get_Ar00", &tdx::get_Ar00, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
-    cls.def("get_Ar01", &(tdx::template get_Ar01<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
-    cls.def("get_Ar10", &(tdx::template get_Ar10<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
-    cls.def("get_Ar11", &(tdx::template get_Ar11<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
-    cls.def("get_Ar12", &(tdx::template get_Ar12<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
-    cls.def("get_Ar20", &(tdx::template get_Ar20<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
+    add_res_derivatives<Model>(cls); // All the residual derivatives
 
     cls.def("get_Ar01n", &(tdx::template get_Ar0n<1>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
     cls.def("get_Ar02n", &(tdx::template get_Ar0n<2>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert());
-- 
GitLab