![]() |
Prev | Next | ode_evaluate.hpp | Headings |
# ifndef CPPAD_ODE_EVALUATE_INCLUDED
# define CPPAD_ODE_EVALUATE_INCLUDED
# include <cppad/vector.hpp>
# include <cppad/runge_45.hpp>
namespace CppAD { // BEGIN CppAD namespace
template <class Float>
class ode_evaluate_fun {
private:
const size_t m_;
const CppAD::vector<Float> x_;
public:
ode_evaluate_fun(size_t m, const CppAD::vector<Float> &x)
: m_(m), x_(x)
{ }
void Ode(
const Float &t,
const CppAD::vector<Float> &z,
CppAD::vector<Float> &h)
{
if( m_ == 0 )
ode_y(t, z, h);
if( m_ == 1 )
ode_z(t, z, h);
}
void ode_y(
const Float &t,
const CppAD::vector<Float> &y,
CppAD::vector<Float> &g)
{ // y_t = g(t, x, y)
CPPAD_ASSERT_UNKNOWN( y.size() == x_.size() );
size_t i, n = x_.size();
Float yi1 = Float(1);
for(i = 0; i < n; i++)
{ g[i] = Float(int(i+1)) * x_[i] * yi1;
yi1 = y[i];
}
// solution for this equation is
// y_0 (t) = x_0 * t
// y_1 (t) = x_1 * x_0 * t^2
// y_2 (t) = x_2 * x_1 * x_0 * t^3
// ...
}
void ode_z(
const Float &t ,
const CppAD::vector<Float> &z ,
CppAD::vector<Float> &h )
{ // z = [ y ; y_x ]
// z_t = h(t, x, z) = [ y_t , y_x_t ]
size_t i, j, n = x_.size();
CPPAD_ASSERT_UNKNOWN( z.size() == n + n * n );
// y_t
Float zi1 = Float(1);
for(i = 0; i < n; i++)
{ h[i] = Float(int(i+1)) * x_[i] * zi1;
for(j = 0; j < n; j++)
h[n + i * n + j] = 0.;
zi1 = z[i];
}
size_t ij;
Float gi_xi, gi_yi1, yi1_xj;
// y0_x0_t
h[n] += 1.;
// yi_xj_t
for(i = 1; i < n; i++)
{ // partial g[i] w.r.t. x[i]
gi_xi = Float(int(i+1)) * z[i-1];
ij = n + i * n + i;
h[ij] += gi_xi;
// partial g[i] w.r.t y[i-1]
gi_yi1 = Float(int(i+1)) * x_[i];
// multiply by partial y[i-1] w.r.t x[j];
for(j = 0; j < n; j++)
{ ij = n + (i-1) * n + j;
yi1_xj = z[ij];
ij = n + i * n + j;
h[ij] += gi_yi1 * yi1_xj;
}
}
}
};
template <class Float>
void ode_evaluate(
CppAD::vector<Float> &x ,
size_t m ,
CppAD::vector<Float> &fm )
{
typedef CppAD::vector<Float> Vector;
size_t n = x.size();
size_t ell;
CPPAD_ASSERT_KNOWN( m == 0 || m == 1,
"ode_evaluate: m is not zero or one"
);
CPPAD_ASSERT_KNOWN(
((m==0) & (fm.size()==1) ) || ((m==1) & (fm.size()==n)),
"ode_evaluate: the size of fm is not correct"
);
if( m == 0 )
ell = n;
else ell = n + n * n;
// set up the case we are integrating
size_t M = 10;
Float ti = 0.;
Float tf = 1.;
Vector yi(ell);
Vector yf(ell);
size_t i;
for(i = 0; i < ell; i++)
yi[i] = Float(0);
// construct ode equation
ode_evaluate_fun<Float> f(m, x);
// solve differential equation
yf = Runge45(f, M, ti, tf, yi);
if( m == 0 )
fm[0] = yf[n-1];
else
{ for(i = 0; i < n; i++)
fm[i] = yf[n + (n-1) * n + i];
}
return;
}
} // END CppAD namespace
# endif