1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
|
#include <cassert>
#include <iostream>
namespace expr
{
typedef unsigned int uint;
// globals to count constructor and destructor calls
// of a vec
uint ctor_calls = 0;
uint dtor_calls = 0;
// print the numbers
void printctor()
{
std::cout << "ctor/dtor = " << ctor_calls
<< "/" << dtor_calls << "\n";
}
// an example vec container of three successive floats
class vec
{
public:
float v[3];
// note: the {...} is c++0x initializer syntax
vec () : v({0,0,0}) { ++ctor_calls; }
vec (const vec& o) : v({o.v[0], o.v[1], o.v[2]}) { ++ctor_calls; }
vec (float x,float y,float z) : v({x,y,z}) { ++ctor_calls; }
~vec() { ++dtor_calls; }
// return indexed value
float operator[](uint index) const { return v[index]; }
// return indexed reference
float& operator[](uint index) { return v[index]; }
void print() { std::cout << "<"<<v[0]<<","<<v[1]<<","<<v[2]<<">\n"; }
// assignment to an expression
// E must have operator[](uint)
template <class E>
vec& operator= (const E& x)
{
for (uint i=0; i!=3; ++i) (*this)[i] = x[i];
return *this;
}
};
// basic catch-all expression node
// L and R must provide operator[](uint)
// O must provide static function float eval(float,float)
template <class L, class O, class R>
struct expression
{
expression(const L& l, const R& r)
: l(l), r(r) { }
float operator[](const uint index) const
{
return O::eval(l[index], r[index]);
}
const L& l;
const R& r;
};
// wraps a reference to float into an operator[](uint) entity
class scalar
{
public:
scalar(const float& t) : t(t) { }
// act like an endless vector of ts
float operator[](uint) const { return t; }
const float& t;
};
// an operation function object
struct plus
{
static float eval(const float a, const float b) { return a + b; }
};
// anything + anything
template <class L, class R>
expression<L,plus,R> operator+(const L& l, const R& r)
{
return expression<L,plus,R>(l, r);
}
// anything + scalar
template <class L>
expression<L,plus,scalar> operator+(const L& l, const float& r)
{
return expression<L,plus,scalar>(l, r);
}
}
void do_some()
{
using namespace expr;
vec a(1,2,3), b(2,3,4), c(3,4,5);
a.print(); b.print(); c.print();
// works
a = b + c;
a.print();
assert( a.v[0] == 5 && a.v[1] == 7 && a.v[2] == 9 );
// does not work -> segfault
a = b + 1.f;
a.print();
assert( a.v[0] == 3 && a.v[1] == 4 && a.v[2] == 5 );
}
int main()
{
do_some();
// check ctor calls
expr::printctor();
return 0;
}
|