ode-toolbox

ODE integration tools
git clone https://git.0xfab.ch/ode-toolbox.git
Log | Files | Refs | README | LICENSE

RKF45.h (7164B)


      1 /* File:   RKF45.h */
      2 /* Date:   Wed Feb 10 17:33:42 2016 */
      3 /* Author: Fabian Wermelinger */
      4 /* Tag:    Runge-Kutta-Fehlberg 4th-5th order variable step method */
      5 /* Copyright 2016 ETH Zurich. All Rights Reserved. */
      6 #ifndef RKF45_H_UWN6WM1V
      7 #define RKF45_H_UWN6WM1V
      8 
      9 #include <ODETB/TimeStepper/explicit/Euler.h>
     10 
     11 #include <cassert>
     12 #include <cmath>
     13 #include <cstdlib>
     14 #include <limits>
     15 
     16 template <typename Tinput, typename Trhs>
     17 class RKF45 : public Euler<Tinput, Trhs>
     18 {
     19     using Euler<Tinput, Trhs>::m_settings;
     20     using Euler<Tinput, Trhs>::m_U;
     21     using Euler<Tinput, Trhs>::m_rhs;
     22     using Euler<Tinput, Trhs>::m_kernel;
     23 
     24     Real m_errold;
     25     const Real m_alpha;
     26     const Real m_beta;
     27     bool m_reject;
     28     bool m_firstPass;
     29 
     30     Trhs m_rhs2;
     31     Trhs m_rhs3;
     32     Trhs m_rhs4;
     33     Trhs m_rhs5;
     34     Trhs m_rhs6;
     35 
     36     Tinput m_output;
     37 
     38     static constexpr Real b1  = 16./135., b3 = 6656./12825., b4 = 28561./56430., b5 = -9./50., b6 = 2./55.;
     39     static constexpr Real c21 = 0.25, ct2 = 0.25;
     40     static constexpr Real c31 = 3./32., c32 = 9./32., ct3 = 3./8.;
     41     static constexpr Real c41 = 1932./2197., c42 = -7200./2197., c43 = 7296./2197., ct4 = 12./13.;
     42     static constexpr Real c51 = 439./216., c52 = -8., c53 = 3680./513., c54 = -845./4104., ct5 = 1.0;
     43     static constexpr Real c61 = -8./27., c62 = 2., c63 = -3544./2565., c64 = 1859./4104., c65 = -11./40., ct6 = 0.5;
     44     static constexpr Real e1 = 1./360., e3 = -128./4275., e4 = -2197./75240., e5 = 1./50., e6 = 2./55.;
     45 
     46     void _computeRHS(const Real t, const Real dt, void const* const data)
     47     {
     48         if (!m_reject) m_kernel->compute(m_U, m_rhs, t, data);
     49         m_kernel->compute(m_U + dt*(c21*m_rhs), m_rhs2, t + ct2*dt, data);
     50         m_kernel->compute(m_U + dt*(c31*m_rhs + c32*m_rhs2), m_rhs3, t + ct3*dt, data);
     51         m_kernel->compute(m_U + dt*(c41*m_rhs + c42*m_rhs2 + c43*m_rhs3), m_rhs4, t + ct4*dt, data);
     52         m_kernel->compute(m_U + dt*(c51*m_rhs + c52*m_rhs2 + c53*m_rhs3 + c54*m_rhs4), m_rhs5, t + ct5*dt, data);
     53         m_kernel->compute(m_U + dt*(c61*m_rhs + c62*m_rhs2 + c63*m_rhs3 + c64*m_rhs4 + c65*m_rhs5), m_rhs6, t + ct6*dt, data);
     54     }
     55 
     56     inline Real _error() const
     57     {
     58         assert(m_U.size() == m_rhs.size());
     59         Real maxerr = 0.0;
     60         // using inf-norm (could also do L2-norm)
     61         for (size_t i = 0; i < m_U.size(); ++i)
     62         {
     63             const typename Trhs::DataType E = e1*m_rhs[i] + e3*m_rhs3[i] + e4*m_rhs4[i] + e5*m_rhs5[i] + e6*m_rhs6[i];
     64             for (int j = 0; j < Trhs::DataType::SIZE; ++j)
     65             {
     66                 const Real IepsI = std::abs(E[j]);
     67                 const Real scale = m_settings.aTol + std::abs(m_U[i][j])*m_settings.rTol;
     68                 maxerr = std::max(maxerr, IepsI/scale);
     69             }
     70         }
     71         assert(!isnan(maxerr));
     72         return maxerr;
     73     }
     74 
     75     bool _inBound(const Real t, Real& dt, void const* const data, Real& dt_next)
     76     {
     77         const Real err = _error();
     78 
     79         const Real dtMin = m_settings.dtMin;
     80         if (dt <= dtMin) {
     81             dt_next = 1.1 * dtMin;
     82             m_errold = std::max(err, static_cast<Real>(1.0e-4));
     83             m_reject = false;
     84             m_firstPass = true;
     85             return true;
     86         }
     87 
     88         if (dt < std::numeric_limits<Real>::epsilon())
     89         {
     90             // if you request the impossible
     91             dt_next = 1.5*std::numeric_limits<Real>::epsilon();
     92             m_errold = std::max(err, static_cast<Real>(1.0e-4));
     93             m_reject = false;
     94             m_firstPass = true;
     95             return true;
     96         }
     97 
     98         const Real safety = m_settings.safety;
     99         const Real minScale = m_settings.minScale;
    100         const Real maxScale = m_settings.maxScale;
    101         Real scale;
    102         if (err <= 1.0)
    103         {
    104             if (err == 0.0)
    105                 scale = maxScale;
    106             else
    107             {
    108                 scale = safety*std::pow(err,-m_alpha)*std::pow(m_errold,m_beta);
    109                 scale = (scale < minScale) ? minScale : ((scale > maxScale) ? maxScale : scale);
    110             }
    111             if (m_reject)
    112                 dt_next = dt*std::min(scale, static_cast<Real>(1.0));
    113             else
    114                 dt_next = dt*scale;
    115             m_errold = std::max(err, static_cast<Real>(1.0e-4));
    116             m_reject = false;
    117             m_firstPass = true;
    118             return true;
    119         }
    120         else
    121         {
    122             scale = safety*std::pow(err,-m_alpha);
    123             scale = std::max(scale, minScale);
    124             dt *= scale;
    125             m_reject = true;
    126             m_firstPass = false;
    127             _computeRHS(t, dt, data);
    128             return false;
    129         }
    130     }
    131 
    132     void _dumpOutput(const void * const data)
    133     {
    134         const Real t  = m_settings.t;
    135         const Real dt = m_settings.dt;
    136         Real& tDump   = m_settings.tDump;
    137 
    138         const Tinput Uold = m_output;
    139         const Trhs rhsOld = m_rhs;
    140 
    141         m_kernel->compute(m_U, m_rhs, t+dt, data);
    142         m_firstPass = false;
    143 
    144         while(tDump <= (t+dt))
    145         {
    146             const Real phi = (tDump - t)/dt;
    147             // Hermite 3rd order
    148             m_output = (1.0-phi)*Uold + phi*m_U + phi*(phi-1.0)*((1.0-2.0*phi)*(m_U - Uold) + (phi-1.0)*dt*rhsOld + phi*dt*m_rhs);
    149             m_kernel->write(m_settings.step, t+phi*dt, phi*dt, m_output, data);
    150             tDump += m_settings.dtDump;
    151         }
    152     }
    153 
    154 public:
    155     RKF45(Tinput& U, StepperSettings& S, KernelBase<Tinput,Trhs> * const kern) :
    156         Euler<Tinput,Trhs>(U,S,kern), m_errold(1.0e-4), m_alpha(0.25*(S.alpha-0.75*S.beta)), m_beta(0.25*S.beta), m_reject(false), m_firstPass(true),
    157         m_rhs2(U.size()), m_rhs3(U.size()), m_rhs4(U.size()), m_rhs5(U.size()), m_rhs6(U.size()), m_output(U.size()) {}
    158     virtual ~RKF45() {}
    159 
    160     virtual void step(void const* const data=nullptr)
    161     {
    162         size_t& step       = m_settings.step;
    163         Real& t            = m_settings.t;
    164         Real& dt           = m_settings.dt;
    165         const Real& tFinal = m_settings.tFinal;
    166         Real& tDump        = m_settings.tDump;
    167         if (!m_settings.bRestart && m_settings.dtDump > 0.0) tDump = t + m_settings.dtDump;
    168 
    169         while(t < tFinal)
    170         {
    171             Real dt_next;
    172             dt = (tFinal-t) < dt ? (tFinal-t) : dt;
    173             _computeRHS(t, dt, data);
    174             while (!_inBound(t, dt, data, dt_next)) {}
    175             m_output = m_U;
    176             m_U += dt*(b1*m_rhs + b3*m_rhs3 + b4*m_rhs4 + b5*m_rhs5 + b6*m_rhs6);
    177             if (m_settings.dtDump > 0.0)
    178             {
    179                 if (tDump <= (t+dt)) _dumpOutput(data);
    180             }
    181             else
    182                 if ((step+1) % m_settings.writeGranularity == 0)
    183                     m_kernel->write((step+1), t+dt, dt, m_U, data);
    184             t += dt;
    185             dt = dt_next;
    186             ++step;
    187 
    188             if (m_settings.step % m_settings.reportGranularity == 0)
    189                 std::printf(
    190                     "Time = %e;\tStep = %zu\n", m_settings.t, m_settings.step);
    191 
    192             if (m_settings.restartstep > 0 && step % m_settings.restartstep == 0)
    193                 serialize<Tinput>(m_U, m_settings);
    194         }
    195     }
    196 };
    197 
    198 #endif /* RKF45_H_UWN6WM1V */