diff --git a/include/teqp/derivs.hpp b/include/teqp/derivs.hpp index 3bf97dba14d6c79ae36c7f43fc6fde98d5ccce55..445b0a5f386672252c45d184f0124d6da5d9a867 100644 --- a/include/teqp/derivs.hpp +++ b/include/teqp/derivs.hpp @@ -330,6 +330,23 @@ struct VirialDerivatives { return o; } + /// This version of the get_Bnvir takes the maximum number of derivatives as a runtime argument + /// and then forwards all arguments to the templated function + template <ADBackends be = ADBackends::autodiff> + static auto get_Bnvir_runtime(const int Nderiv, const Model& model, const Scalar &T, const VectorType& molefrac) { + switch(Nderiv){ + case 2: return get_Bnvir<2,be>(model, T, molefrac); + case 3: return get_Bnvir<3,be>(model, T, molefrac); + case 4: return get_Bnvir<4,be>(model, T, molefrac); + case 5: return get_Bnvir<5,be>(model, T, molefrac); + case 6: return get_Bnvir<6,be>(model, T, molefrac); + case 7: return get_Bnvir<7,be>(model, T, molefrac); + case 8: return get_Bnvir<8,be>(model, T, molefrac); + case 9: return get_Bnvir<9,be>(model, T, molefrac); + default: throw std::invalid_argument("Only Nderiv up to 9 is supported, get_Bnvir templated function allows more"); + } + } + static auto get_B12vir(const Model& model, const Scalar &T, const VectorType& molefrac) { auto B2 = get_B2vir(model, T, molefrac); // Overall B2 for mixture diff --git a/interface/pybind11_wrapper.hpp b/interface/pybind11_wrapper.hpp index cc11df5f76b4daa8073e245d208626b98774c421..ef65ddea967aca38f1136c59925134b48936055a 100644 --- a/interface/pybind11_wrapper.hpp +++ b/interface/pybind11_wrapper.hpp @@ -37,7 +37,9 @@ void add_derivatives(py::module &m, Wrapper &cls) { using vd = VirialDerivatives<Model, double, Eigen::Array<double,Eigen::Dynamic,1>>; m.def("get_B2vir", &vd::get_B2vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert()); - cls.def("get_B2vir", [](const Model& m, const double T, const RAX molefrac) { return vd::get_B2vir(m, T, molefrac); }, py::arg("T"), py::arg("molefrac").noconvert()); + cls.def("get_B2vir", [](const Model& m, const double T, const RAX molefrac) { return vd::get_B2vir(m, T, molefrac); }, py::arg("T"), py::arg("molefrac").noconvert()); + cls.def("get_Bnvir", [](const Model& m, const int Nderiv, const double T, const RAX molefrac) { return vd::get_Bnvir_runtime(Nderiv, m, T, molefrac); }, py::arg("Nderiv"), py::arg("T"), py::arg("molefrac").noconvert()); + m.def("get_B12vir", &vd::get_B12vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert()); using ct = CriticalTracing<Model, double, Eigen::Array<double, Eigen::Dynamic, 1>>;