Skip to content
Snippets Groups Projects
Commit 9502e341 authored by Ian Bell's avatar Ian Bell
Browse files

Move more methods into their own namespace

parent 85f57bd3
No related branches found
No related tags found
No related merge requests found
......@@ -62,48 +62,45 @@ typename ContainerType::value_type derivrhoi(const FuncType& f, TType T, const C
return f(T, rhocom).imag() / h;
}
enum class ADBackends { autodiff, multicomplex, complex_step };
template<typename Model, typename Scalar = double, typename VectorType = Eigen::ArrayXd>
struct TDXDerivatives {
template <typename Model, typename TType, typename RhoType, typename ContainerType>
typename ContainerType::value_type get_Ar10(const Model& model, const TType T, const RhoType &rho, const ContainerType& molefrac) {
double h = 1e-100;
return -T*model.alphar(std::complex<TType>(T, h), rho, molefrac).imag()/h; // Complex step derivative
}
enum class ADBackends { autodiff, multicomplex, complex_step } ;
template <ADBackends be = ADBackends::autodiff, typename Model, typename TType, typename RhoType, typename MoleFracType>
auto get_Ar01(const Model& model, const TType &T, const RhoType &rho, const MoleFracType& molefrac) {
if constexpr(be == ADBackends::complex_step){
static auto get_Ar10(const Model& model, const Scalar &T, const Scalar &rho, const VectorType& molefrac) {
double h = 1e-100;
auto der = model.alphar(T, std::complex<double>(rho, h), molefrac).imag() / h;
return der*rho;
return -T*model.alphar(std::complex<Scalar>(T, h), rho, molefrac).imag()/h; // Complex step derivative
}
else if constexpr(be == ADBackends::multicomplex){
using fcn_t = std::function<mcx::MultiComplex<double>(const mcx::MultiComplex<double>&)>;
bool and_val = true;
fcn_t f = [&model, &T, &molefrac](const auto& rho_) { return model.alphar(T, rho_, molefrac); };
auto ders = diff_mcx1(f, rho, 1, and_val);
return ders[1] * rho;
template<ADBackends be = ADBackends::autodiff>
static auto get_Ar01(const Model& model, const Scalar&T, const Scalar &rho, const VectorType& molefrac) {
if constexpr(be == ADBackends::complex_step){
double h = 1e-100;
auto der = model.alphar(T, std::complex<Scalar>(rho, h), molefrac).imag() / h;
return der*rho;
}
else if constexpr(be == ADBackends::multicomplex){
using fcn_t = std::function<mcx::MultiComplex<Scalar>(const mcx::MultiComplex<Scalar>&)>;
bool and_val = true;
fcn_t f = [&model, &T, &molefrac](const auto& rho_) { return model.alphar(T, rho_, molefrac); };
auto ders = diff_mcx1(f, rho, 1, and_val);
return ders[1] * rho;
}
else if constexpr(be == ADBackends::autodiff){
autodiff::dual rhodual = rho;
auto f = [&model, &T, &molefrac](const auto& rho_) { return eval(model.alphar(T, rho_, molefrac)); };
auto der = derivative(f, wrt(rhodual), at(rhodual));
return der * rho;
}
}
else if constexpr(be == ADBackends::autodiff){
autodiff::dual rhodual = rho;
static auto get_Ar02(const Model& model, const Scalar& T, const Scalar& rho, const VectorType& molefrac) {
autodiff::dual2nd rhodual = rho;
auto f = [&model, &T, &molefrac](const auto& rho_) { return eval(model.alphar(T, rho_, molefrac)); };
auto der = derivative(f, wrt(rhodual), at(rhodual));
return der * rho;
auto ders = derivatives(f, wrt(rhodual), at(rhodual));
return ders[2]*rho*rho;
}
}
template <typename Model, typename TType, typename RhoType, typename MoleFracType>
auto get_Ar02(const Model& model, const TType& T, const RhoType& rho, const MoleFracType& molefrac) {
autodiff::dual2nd rhodual = rho;
auto f = [&model, &T, &molefrac](const auto& rho_) { return eval(model.alphar(T, rho_, molefrac)); };
auto ders = derivatives(f, wrt(rhodual), at(rhodual));
return ders[2]*rho*rho;
}
};
template<typename Model, typename Scalar = double, typename VectorType = Eigen::ArrayXd>
struct VirialDerivatives {
......
......@@ -9,6 +9,15 @@
namespace py = pybind11;
template<typename Model>
void add_TDx_derivatives(py::module& m) {
using id = TDXDerivatives<Model, double, Eigen::Array<double, Eigen::Dynamic, 1> >;
//m.def("get_Ar00", &id::get_Ar00, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac"));
m.def("get_Ar10", &id::get_Ar10, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac"));
m.def("get_Ar01", &id::get_Ar01<ADBackends::autodiff>, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac"));
m.def("get_Ar02", &id::get_Ar02, py::arg("model"), py::arg("T"), py::arg("rho"), py::arg("molefrac"));
}
template<typename Model>
void add_virials(py::module& m) {
using vd = VirialDerivatives<Model>;
......@@ -30,6 +39,7 @@ void add_derivatives(py::module &m) {
m.def("build_Psir_gradient_autodiff", &id::build_Psir_gradient_autodiff, py::arg("model"), py::arg("T"), py::arg("rho"));
add_virials<Model>(m);
add_TDx_derivatives<Model>(m);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment