1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
|
#include <iostream>
#include <vector>
#include <cmath> // EDITED, to define abs() properly
using namespace std;
const double SMALL = 1.0E-30; // used to stop divide-by-zero
//======================================================================
bool tdma( const vector<double> &a, const vector<double> &b, const vector<double> &c, const vector<double> &d,
vector<double> &X )
//*********************************************************************************
// Solve, using the Thomas Algorithm (TDMA), the tri-diagonal system *
// a_i X_i-1 + b_i X_i + c_i X_i+1 = d_i, i = 0, n - 1 *
// *
// Effectively, this is the n x n matrix equation. *
// a[i], b[i], c[i] are the non-zero diagonals of the matrix and d[i] is the rhs. *
// a[0] and c[n-1] aren't used. *
//*********************************************************************************
{
int n = d.size();
vector<double> P( n, 0 );
vector<double> Q( n, 0 );
X = P;
// Forward pass
int i = 0;
double denominator = b[i];
P[i] = -c[i] / denominator;
Q[i] = d[i] / denominator;
for ( i = 1; i < n; i++ )
{
denominator = b[i] + a[i] * P[i-1];
if ( abs( denominator ) < SMALL ) return false;
P[i] = -c[i] / denominator;
Q[i] = ( d[i] - a[i] * Q[i-1] ) / denominator;
}
// Backward pass
X[n-1] = Q[n-1];
for ( i = n - 2; i >= 0; i-- ) X[i] = P[i] * X[i+1] + Q[i];
return true;
}
//======================================================================
int main() // Solve Ax = b, where A is tri-diagonal
{
vector<double> lower = { 0.0, -1.0, -2.0, -3.0, -4.0 }; // lower diagonal
vector<double> diagonal = { 9.0, 8.0, 7.0, 6.0, 5.0 }; // main diagonal
vector<double> upper = { -1.0, -2.0, -3.0, -4.0, 0.0 }; // upper diagonal
vector<double> rhs = { 41.0, 21.0, 7.0, -1.0, -3.0 }; // RHS
vector<double> X;
int n = rhs.size();
if ( tdma( lower, diagonal, upper, rhs, X ) )
{
cout << "X\tLHS\tRHS\n";
for ( int i = 0; i < n; i++ )
{
double lhs = diagonal[i] * X[i];
if ( i > 0 ) lhs += lower[i] * X[i-1];
if ( i < n - 1 ) lhs += upper[i] * X[i+1];
cout << X[i] << '\t' << lhs << '\t' << rhs[i] << '\n';
}
}
else
{
cout << "Unable to solve\n";
}
}
|