From 4fe3f60d7eafbd2ccc189e0a81698b2e6605495b Mon Sep 17 00:00:00 2001 From: Ian Bell <ian.bell@nist.gov> Date: Sat, 8 May 2021 15:26:06 -0400 Subject: [PATCH] Clean up pybind11 wrapper to avoid copies where possible --- interface/pybind11_wrapper.cpp | 54 ++++++++++++++-------------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/interface/pybind11_wrapper.cpp b/interface/pybind11_wrapper.cpp index bb368bc..7b13eb2 100644 --- a/interface/pybind11_wrapper.cpp +++ b/interface/pybind11_wrapper.cpp @@ -17,18 +17,6 @@ namespace py = pybind11; -//template<typename Model> -//void add_TDx_derivatives(py::module& m) { -// using id = TDXDerivatives<Model, double, Eigen::Array<double, Eigen::Dynamic, 1> >; -// //m.def("get_Ar00", &id::get_Ar00, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac")); -// m.def("get_Ar10", &id::get_Ar10<ADBackends::autodiff>, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac")); -// m.def("get_Ar01", &id::get_Ar01<ADBackends::autodiff>, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac")); -// m.def("get_Ar11", &id::get_Ar11<ADBackends::autodiff>, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac")); -// m.def("get_Ar02", &id::get_Ar02<ADBackends::autodiff>, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac")); -// m.def("get_Ar20", &id::get_Ar20<ADBackends::autodiff>, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac")); -// m.def("get_neff", &id::get_neff<ADBackends::autodiff>, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac")); -//} - template<typename Model, typename Wrapper> void add_derivatives(py::module &m, Wrapper &cls) { using id = IsochoricDerivatives<Model, double, Eigen::Array<double,Eigen::Dynamic,1> >; @@ -43,10 +31,8 @@ void add_derivatives(py::module &m, Wrapper &cls) { m.def("build_Psir_gradient_autodiff", &id::build_Psir_gradient_autodiff, py::arg("model"), py::arg("T"), py::arg("rho")); using vd = VirialDerivatives<Model, double, Eigen::Array<double,Eigen::Dynamic,1>>; - m.def("get_B2vir", &vd::get_B2vir, py::arg("model"), py::arg("T"), py::arg("molefrac")); - m.def("get_B12vir", &vd::get_B12vir, py::arg("model"), py::arg("T"), py::arg("molefrac")); - - //add_TDx_derivatives<Model>(m); + m.def("get_B2vir", &vd::get_B2vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert()); + m.def("get_B12vir", &vd::get_B12vir, py::arg("model"), py::arg("T"), py::arg("molefrac").noconvert()); using ct = CriticalTracing<Model, double, Eigen::Array<double, Eigen::Dynamic, 1>>; m.def("trace_critical_arclength_binary", &ct::trace_critical_arclength_binary); @@ -57,19 +43,20 @@ void add_derivatives(py::module &m, Wrapper &cls) { //cls.def("get_Ar01", [](const Model& m, const double T, const Eigen::ArrayXd& rhovec) { return id::get_Ar01(m, T, rhovec); }); //cls.def("get_Ar10", [](const Model& m, const double T, const Eigen::ArrayXd& rhovec) { return id::get_Ar10(m, T, rhovec); }); using tdx = TDXDerivatives<Model, double, Eigen::Array<double, Eigen::Dynamic, 1> >; - cls.def("get_Ar00", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::get_Ar00(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar01", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar01<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar10", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar10<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar11", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar11<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar12", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar12<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - - cls.def("get_Ar01n", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar0n<1>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar02n", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar0n<2>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar03n", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar0n<3>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar04n", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar0n<4>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar05n", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar0n<5>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_Ar06n", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::template get_Ar0n<6>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); - cls.def("get_neff", [](const Model& m, const double T, const double rho, const Eigen::ArrayXd& molefrac) { return tdx::get_neff(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); + using RAX = Eigen::Ref<Eigen::ArrayXd>; + cls.def("get_Ar00", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::get_Ar00(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar01", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar01<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar10", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar10<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar11", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar11<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar12", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar12<ADBackends::autodiff>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + + cls.def("get_Ar01n", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar0n<1>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar02n", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar0n<2>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar03n", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar0n<3>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar04n", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar0n<4>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar05n", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar0n<5>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_Ar06n", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::template get_Ar0n<6>(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac").noconvert()); + cls.def("get_neff", [](const Model& m, const double T, const double rho, const RAX molefrac) { return tdx::get_neff(m, T, rho, molefrac); }, py::arg("T"), py::arg("rho"), py::arg("molefrac")); } /// Instantiate "instances" of models (really wrapped Python versions of the models), and then attach all derivative methods @@ -116,9 +103,12 @@ void init_teqp(py::module& m) { ; add_derivatives<MultiFluid>(m, wMF); - // for timing testing - m.def("mysummer", [](const double &c, const Eigen::ArrayXd &x) { return c*x.sum(); }); - m.def("myadder", [](const double& c, const double& d) { return c+d; }); + // Some functions for timing overhead of interface + m.def("___mysummer", [](const double &c, const Eigen::ArrayXd &x) { return c*x.sum(); }); + using RAX = Eigen::Ref<Eigen::ArrayXd>; + using namespace pybind11::literals; // for "arg"_a + m.def("___mysummerref", [](const double& c, const RAX x) { return c * x.sum(); }, "c"_a, "x"_a.noconvert()); + m.def("___myadder", [](const double& c, const double& d) { return c+d; }); } -- GitLab