diff --git a/include/teqp/derivs.hpp b/include/teqp/derivs.hpp index 3149bd376f3a81d33bda765cf065aa95542a0b50..41df6906e943797ea51a404b607a7c2d69df5f68 100644 --- a/include/teqp/derivs.hpp +++ b/include/teqp/derivs.hpp @@ -143,16 +143,33 @@ typename ContainerType::value_type get_B2vir(const Model& model, const TType T, * \param molefrac The mole fractions */ -template <typename Model, typename TType, typename ContainerType> -auto get_Bnvir(const Model& model, int Nderiv, const TType T, const ContainerType& molefrac) { - - using namespace mcx; - using fcn_t = std::function<MultiComplex<double>(const MultiComplex<double>&)>; - fcn_t f = [&model, &T, &molefrac](const auto& rho_) { return model.alphar(T, rho_, molefrac); }; +template <int Nderiv, ADBackends be = ADBackends::autodiff, typename Model, typename TType, typename ContainerType> +auto get_Bnvir(const Model& model, const TType T, const ContainerType& molefrac) +{ + std::map<int, double> dnalphardrhon; + if constexpr(be == ADBackends::multicomplex){ + using namespace mcx; + using fcn_t = std::function<MultiComplex<double>(const MultiComplex<double>&)>; + fcn_t f = [&model, &T, &molefrac](const auto& rho_) { return model.alphar(T, rho_, molefrac); }; + auto derivs = diff_mcx1(f, 0.0, Nderiv+1, true /* and_val */); + for (auto n = 1; n <= Nderiv; ++n){ + dnalphardrhon[n] = derivs[n]; + } + } + else if constexpr(be == ADBackends::autodiff){ + autodiff::HigherOrderDual<Nderiv+1, double> rhodual = 0.0; + auto f = [&model, &T, &molefrac](const auto& rho_) { return model.alphar(T, rho_, molefrac); }; + auto derivs = derivatives(f, wrt(rhodual), at(rhodual)); + for (auto n = 1; n <= Nderiv; ++n){ + dnalphardrhon[n] = derivs[n]; + } + } + else{ + static_assert("algorithmic differentiation backend is invalid"); + } std::map<int, TType> o; - auto dalphardrhon = diff_mcx1(f, 0.0, Nderiv+1, true /* and_val */); for (int n = 2; n < Nderiv+1; ++n) { - o[n] = dalphardrhon[n-1]; + o[n] = dnalphardrhon[n-1]; // 0!=1, 1!=1, so only n>3 terms need factorial correction if (n > 3) { auto factorial = [](int N) {return tgamma(N + 1); }; diff --git a/src/tests/catch_tests.cxx b/src/tests/catch_tests.cxx index c0e00a91f7cc862911e597271ed920a33230b6a9..240d88548855888170b58d2e8dbe2a743bf2f0a8 100644 --- a/src/tests/catch_tests.cxx +++ b/src/tests/catch_tests.cxx @@ -46,10 +46,11 @@ TEST_CASE("Check virial coefficients for vdW", "[virial]") auto T = 300.0; std::valarray<double> molefrac = { 1.0 }; - auto Nvir = 8; + constexpr int Nvir = 8; // Numerical solutions from alphar - auto Bn = get_Bnvir(vdW, Nvir, T, molefrac); + auto Bn = get_Bnvir<Nvir, ADBackends::autodiff>(vdW, T, molefrac); + auto Bnmcx = get_Bnvir<Nvir, ADBackends::multicomplex>(vdW, T, molefrac); // Exact solutions for virial coefficients for van der Waals auto get_vdW_exacts = [a, b, R, T](int Nmax) { @@ -116,8 +117,8 @@ TEST_CASE("Check p three ways for vdW", "[virial][p]") auto pfromderiv = rho*model.R*T + get_pr(model, T, rhovec); // Numerical solution from virial expansion - auto Nvir = 8; - auto Bn = get_Bnvir(model, 8, T, molefrac); + constexpr int Nvir = 8; + auto Bn = get_Bnvir<Nvir>(model, T, molefrac); auto Z = 1.0; for (auto i = 2; i <= Nvir; ++i){ Z += Bn[i]*pow(rho, i-1);