From 187b8c9836cb1fb065706bbdce88038877b8bcaa Mon Sep 17 00:00:00 2001 From: Ian Bell <ian.bell@nist.gov> Date: Wed, 3 Aug 2022 14:20:14 -0400 Subject: [PATCH] Simplify and expand the Arxy method wrapping --- interface/pybind11_wrapper.hpp | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/interface/pybind11_wrapper.hpp b/interface/pybind11_wrapper.hpp index c573128..ed2d21d 100644 --- a/interface/pybind11_wrapper.hpp +++ b/interface/pybind11_wrapper.hpp @@ -16,6 +16,24 @@ namespace py = pybind11; using namespace teqp; +template<typename Model, int iT, int iD, typename Class> +void add_res_deriv_impl(Class& cls) { + using idx = TDXDerivatives<Model>; + using RAX = Eigen::Ref<Eigen::ArrayXd>; + const std::string fname = "get_Ar" + std::to_string(iT) + std::to_string(iD); + cls.def(fname.c_str(), + [](const Model& m, const double T, const double rho, const RAX molefrac) { return idx::template get_Arxy<iT, iD, ADBackends::autodiff>(m, T, rho, molefrac); }, + py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert() + ); +} + +template<typename Model, typename Class> +void add_res_derivatives(Class& cls) { + add_res_deriv_impl<Model, 0, 0>(cls); add_res_deriv_impl<Model, 0, 1>(cls); add_res_deriv_impl<Model, 0, 2>(cls); add_res_deriv_impl<Model, 0, 3>(cls); add_res_deriv_impl<Model, 0, 4>(cls); + add_res_deriv_impl<Model, 1, 0>(cls); add_res_deriv_impl<Model, 1, 1>(cls); add_res_deriv_impl<Model, 1, 2>(cls); add_res_deriv_impl<Model, 1, 3>(cls); add_res_deriv_impl<Model, 1, 4>(cls); + add_res_deriv_impl<Model, 2, 0>(cls); add_res_deriv_impl<Model, 2, 1>(cls); add_res_deriv_impl<Model, 2, 2>(cls); add_res_deriv_impl<Model, 2, 3>(cls); add_res_deriv_impl<Model, 2, 4>(cls); +} + template<typename Model, typename Wrapper> void add_derivatives(py::module &m, Wrapper &cls) { @@ -75,11 +93,7 @@ void add_derivatives(py::module &m, Wrapper &cls) { cls.def("get_R", [](const Model& m, const RAX molefrac) { return m.R(molefrac); }, py::arg("molefrac").noconvert()); cls.def("get_Ar00", &tdx::get_Ar00, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); - cls.def("get_Ar01", &(tdx::template get_Ar01<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); - cls.def("get_Ar10", &(tdx::template get_Ar10<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); - cls.def("get_Ar11", &(tdx::template get_Ar11<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); - cls.def("get_Ar12", &(tdx::template get_Ar12<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); - cls.def("get_Ar20", &(tdx::template get_Ar20<ADBackends::autodiff>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + add_res_derivatives<Model>(cls); // All the residual derivatives cls.def("get_Ar01n", &(tdx::template get_Ar0n<1>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); cls.def("get_Ar02n", &(tdx::template get_Ar0n<2>), py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); -- GitLab