From fe6e62a2b2708cf21b9ed6d8c92e605e8d0e1ce0 Mon Sep 17 00:00:00 2001 From: Ian Bell <ian.bell@nist.gov> Date: Sat, 6 Nov 2021 18:54:56 -0400 Subject: [PATCH] Expose more virial coefficients to Python --- include/teqp/derivs.hpp | 17 +++++++++++++++++ interface/pybind11_wrapper.hpp | 4 +++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/include/teqp/derivs.hpp b/include/teqp/derivs.hpp index 3bf97db..445b0a5 100644 --- a/include/teqp/derivs.hpp +++ b/include/teqp/derivs.hpp @@ -330,6 +330,23 @@ struct VirialDerivatives { return o; } + /// This version of the get_Bnvir takes the maximum number of derivatives as a runtime argument + /// and then forwards all arguments to the templated function + template <ADBackends be = ADBackends::autodiff> + static auto get_Bnvir_runtime(const int Nderiv, const Model& model, const Scalar &T, const VectorType& molefrac) { + switch(Nderiv){ + case 2: return get_Bnvir<2,be>(model, T, molefrac); + case 3: return get_Bnvir<3,be>(model, T, molefrac); + case 4: return get_Bnvir<4,be>(model, T, molefrac); + case 5: return get_Bnvir<5,be>(model, T, molefrac); + case 6: return get_Bnvir<6,be>(model, T, molefrac); + case 7: return get_Bnvir<7,be>(model, T, molefrac); + case 8: return get_Bnvir<8,be>(model, T, molefrac); + case 9: return get_Bnvir<9,be>(model, T, molefrac); + default: throw std::invalid_argument("Only Nderiv up to 9 is supported, get_Bnvir templated function allows more"); + } + } + static auto get_B12vir(const Model& model, const Scalar &T, const VectorType& molefrac) { auto B2 = get_B2vir(model, T, molefrac); // Overall B2 for mixture diff --git a/interface/pybind11_wrapper.hpp b/interface/pybind11_wrapper.hpp index cc11df5..ef65dde 100644 --- a/interface/pybind11_wrapper.hpp +++ b/interface/pybind11_wrapper.hpp @@ -37,7 +37,9 @@ void add_derivatives(py::module &m, Wrapper &cls) { using vd = VirialDerivatives<Model, double, Eigen::Array<double,Eigen::Dynamic,1>>; m.def("get_B2vir", &vd::get_B2vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert()); - cls.def("get_B2vir", [](const Model& m, const double T, const RAX molefrac) { return vd::get_B2vir(m, T, molefrac); }, py::arg("T"), py::arg("molefrac").noconvert()); + cls.def("get_B2vir", [](const Model& m, const double T, const RAX molefrac) { return vd::get_B2vir(m, T, molefrac); }, py::arg("T"), py::arg("molefrac").noconvert()); + cls.def("get_Bnvir", [](const Model& m, const int Nderiv, const double T, const RAX molefrac) { return vd::get_Bnvir_runtime(Nderiv, m, T, molefrac); }, py::arg("Nderiv"), py::arg("T"), py::arg("molefrac").noconvert()); + m.def("get_B12vir", &vd::get_B12vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert()); using ct = CriticalTracing<Model, double, Eigen::Array<double, Eigen::Dynamic, 1>>; -- GitLab