ode-toolbox

ODE integration tools
git clone https://git.0xfab.ch/ode-toolbox.git
Log | Files | Refs | README | LICENSE

StepperBase.h (7064B)


      1 // File       : StepperBase.h
      2 // Date       : Thu Sep 22 13:27:19 2016
      3 // Author     : Fabian Wermelinger
      4 // Description: Time stepper base class
      5 // Copyright 2016 ETH Zurich. All Rights Reserved.
      6 #ifndef STEPPERBASE_H_FCCSM4HL
      7 #define STEPPERBASE_H_FCCSM4HL
      8 
      9 #include <ODETB/TimeStepper/KernelBase.h>
     10 
     11 #include <cassert>
     12 #include <iostream>
     13 
     14 // forward declarations
     15 template <typename Tinput, typename Trhs=Tinput>
     16 class StepperBase;
     17 template <typename Tinput, typename Trhs=Tinput>
     18 class Euler;
     19 template <typename Tinput, typename Trhs=Tinput>
     20 class RK4;
     21 template <typename Tinput, typename Trhs=Tinput>
     22 class LSRK3;
     23 template <typename Tinput, typename Trhs=Tinput>
     24 class RKF45;
     25 template <typename Tinput, typename Trhs=Tinput>
     26 class RKV56;
     27 #ifdef _USE_SUNDIALS_
     28 template <typename Tinput, typename Trhs=Tinput>
     29 class BDF;
     30 #endif /* _USE_SUNDIALS_ */
     31 
     32 #ifdef _FLOAT_PRECISION_
     33 using Real = float;
     34 #else
     35 using Real = double;
     36 #endif
     37 
     38 struct StepperSettings
     39 {
     40     StepperSettings(ArgumentParser& p)
     41     {
     42         aTol             = p("-ts_aTol").asDouble(1.0e-8);
     43         rTol             = p("-ts_rTol").asDouble(1.0e-8);
     44         minScale         = p("-ts_minScale").asDouble(0.2);
     45         maxScale         = p("-ts_maxScale").asDouble(10.0);
     46         alpha            = p("-ts_alpha").asDouble(1.0);
     47         beta             = p("-ts_beta").asDouble(0.0);
     48         safety           = p("-ts_safety").asDouble(0.9);
     49         t                = p("-ts_t0").asDouble(0.0);
     50         dt               = p("-ts_dt").asDouble(1.0e-4);
     51         dtMin            = p("-ts_dtMin").asDouble(0.0);
     52         step             = 0;
     53         nsteps           = p("-ts_nsteps").asInt(0);
     54         p.set_strict_mode();
     55         tFinal           = p("-ts_tend").asDouble();
     56         p.unset_strict_mode();
     57         dtDump           = p("-ts_dtDump").asDouble(-1.0);
     58         tDump            = t;
     59         writeGranularity = p("-ts_writeGranularity").asInt(1);
     60         writeCount       = 0;
     61         reportGranularity= p("-ts_reportGranularity").asInt(100);
     62         bFixedStep       = true;
     63 
     64         // restarts
     65         bRestart    = p("-ts_restart").asBool(false);
     66         restartstep = p("-ts_restart_step").asInt(0);
     67 
     68     }
     69 
     70     Real aTol;
     71     Real rTol;
     72     Real minScale;
     73     Real maxScale;
     74     Real alpha;
     75     Real beta;
     76     Real safety;
     77     Real t;
     78     Real dt;
     79     Real dtMin;
     80     size_t step;
     81     size_t nsteps;
     82     Real tFinal;
     83     Real dtDump;
     84     Real tDump;
     85     size_t writeGranularity;
     86     size_t writeCount;
     87     size_t reportGranularity;
     88     bool bFixedStep;
     89 
     90     // restarts
     91     bool bRestart;
     92     int restartstep;
     93 
     94 
     95     // helper
     96     void print() const
     97     {
     98         printf("Time Stepper :\n");
     99         printf("\tAbsolute tolerance      = %e\n", aTol);
    100         printf("\tRelative tolerance      = %e\n", rTol);
    101         printf("\tMinimum time step scale = %e\n", minScale);
    102         printf("\tMaximum time step scale = %e\n", maxScale);
    103         printf("\tError PI control alpha  = %e\n", alpha);
    104         printf("\tError PI control beta   = %e\n", beta);
    105     }
    106 
    107     template <typename Tinput, typename Trhs=Tinput>
    108     StepperBase<Tinput,Trhs>* stepperFactory(ArgumentParser& p, Tinput& U, KernelBase<Tinput,Trhs> * const kern=nullptr)
    109     {
    110         assert(kern != nullptr);
    111 
    112         if (p("-ts").asString("rkv56") == "euler")
    113         {
    114             std::cout << "Allocating Euler time stepper...\n";
    115             return new Euler<Tinput, Trhs>(U, *this, kern);
    116         }
    117         else if (p("-ts").asString("rkv56") == "rk4")
    118         {
    119             std::cout << "Allocating RK4 time stepper...\n";
    120             return new RK4<Tinput, Trhs>(U, *this, kern);
    121         }
    122         else if (p("-ts").asString("rkv56") == "lsrk3")
    123         {
    124             std::cout << "Allocating LSRK3 time stepper...\n";
    125             return new LSRK3<Tinput, Trhs>(U, *this, kern);
    126         }
    127         else if (p("-ts").asString("rkv56") == "rkf45")
    128         {
    129             std::cout << "Allocating RKF45 adaptive time stepper...\n";
    130             this->bFixedStep = false;
    131             this->print();
    132             return new RKF45<Tinput, Trhs>(U, *this, kern);
    133         }
    134         else if (p("-ts").asString("rkv56") == "rkv56")
    135         {
    136             std::cout << "Allocating RKV56 adaptive time stepper...\n";
    137             this->bFixedStep = false;
    138             this->print();
    139             return new RKV56<Tinput, Trhs>(U, *this, kern);
    140         }
    141 #ifdef _USE_SUNDIALS_
    142         else if (p("-ts").asString("rkv56") == "bdf")
    143         {
    144             std::cout
    145                 << "Allocating Sundials BDF/Newton implicit time stepper...\n";
    146             return new BDF<Tinput, Trhs>(U, *this, kern);
    147         }
    148 #endif /* _USE_SUNDIALS_ */
    149         else
    150             return nullptr;
    151     }
    152 };
    153 
    154 
    155 template <typename Tinput, typename Trhs>
    156 class StepperBase
    157 {
    158     StepperBase(Tinput const& rhs) = delete;
    159     StepperBase& operator=(Tinput const& rhs) = delete;
    160 
    161 protected:
    162     StepperSettings& m_settings;
    163     Tinput& m_U;
    164     KernelBase<Tinput,Trhs> * const m_kernel;
    165 
    166 public:
    167     StepperBase(Tinput& U, StepperSettings& settings, KernelBase<Tinput,Trhs> * const kern=nullptr) : m_settings(settings), m_U(U), m_kernel(kern) { }
    168     virtual ~StepperBase() { }
    169 
    170     inline Tinput& U() { return m_U; }
    171     inline Tinput const& U() const { return m_U; }
    172 
    173     virtual void step(void const* const data=nullptr) = 0;
    174 
    175     inline void write(const size_t step, const Real t, const Real dt, const Tinput& U, const void * const data=nullptr)
    176     {
    177         m_kernel->write(step, t, dt, U, data);
    178     }
    179 };
    180 
    181 // serializer
    182 struct _simulation_state_t
    183 {
    184     Real t;
    185     Real dt;
    186     Real tFinal;
    187     Real dtDump;
    188     Real tDump;
    189     size_t step;
    190     size_t nsteps;
    191     size_t writeCount;
    192 } __attribute__((packed));
    193 
    194 
    195 template <typename Tinput>
    196 void serialize(const Tinput& U, const StepperSettings& S, const std::string fname="restart.bin")
    197 {
    198     _simulation_state_t state;
    199     state.t = S.t;
    200     state.dt = S.dt;
    201     state.tFinal = S.tFinal;
    202     state.dtDump = S.dtDump;
    203     state.tDump = S.tDump;
    204     state.step = S.step;
    205     state.nsteps = S.nsteps;
    206     state.writeCount = S.writeCount;
    207 
    208     ofstream _stream(fname, ios::out | ios::binary);
    209     _stream.write((const char*)&state.t, sizeof(_simulation_state_t));
    210 
    211     const char* const _src = (const char*)U.data();
    212     _stream.write(_src, sizeof(typename Tinput::DataType)*U.size());
    213 
    214     _stream.close();
    215 }
    216 
    217 template <typename Tinput>
    218 void deserialize(Tinput& U, StepperSettings& S, const std::string fname="restart.bin")
    219 {
    220     ifstream _stream(fname, ios::in | ios::binary);
    221 
    222     _simulation_state_t state;
    223     _stream.read((char*)&state.t, sizeof(_simulation_state_t));
    224     S.t          = state.t;
    225     S.dt         = state.dt;
    226     S.tFinal     = state.tFinal;
    227     S.dtDump     = state.dtDump;
    228     S.tDump      = state.tDump;
    229     S.step       = state.step;
    230     S.nsteps     = state.nsteps;
    231     S.writeCount = state.writeCount;
    232 
    233     char* const _dst = (char*)U.data();
    234     _stream.read(_dst, sizeof(typename Tinput::DataType)*U.size());
    235 
    236     _stream.close();
    237 }
    238 
    239 #endif /* STEPPERBASE_H_FCCSM4HL */