1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
|
#include <iostream>
#include <type_traits>
#include <valarray>
#include <initializer_list>
template < typename T > struct matrix
{
static_assert( std::is_arithmetic<T>::value, "non-arithmetic types are not supported" ) ;
static_assert( !std::is_reference<T>::value, "can't have a matrix of references" ) ;
static_assert( !std::is_const<T>::value &&
!std::is_volatile<T>::value, "can't have a matrix of cv-qualified types" ) ;
matrix( std::size_t nrows, std::size_t ncols ) : nrows(nrows), ncols(ncols), data(nrows*ncols) {}
matrix( std::size_t nrows, std::size_t ncols, std::initializer_list<T> ilist )
: nrows(nrows), ncols(ncols), data(nrows*ncols)
{
std::size_t i = 0 ;
for( auto& v : ilist ) { if( i < data.size() ) data[i++] = v ; else break ; }
}
// rule of zero: http://en.cppreference.com/w/cpp/language/rule_of_three
// row and column access
std::slice_array<T> operator[] ( std::size_t row_num ) { return row(row_num) ; }
std::slice_array<T> row( std::size_t row_num ) { return data[ std::slice( row_num*ncols, ncols, 1 ) ] ; ; }
std::slice_array<T> col( std::size_t col_num ) { return data[ std::slice( col_num, nrows, ncols ) ] ; }
std::valarray<T> operator[] ( std::size_t row_num ) const { return row(row_num) ; }
std::valarray<T> row( std::size_t row_num ) const { return data[ std::slice( row_num*ncols, ncols, 1 ) ] ; ; }
std::valarray<T> col( std::size_t col_num ) const { return data[ std::slice( col_num, nrows, ncols ) ] ; }
// compound arithmetic operators
matrix& operator+= ( const matrix& that ) { data += that.data ; return *this ; }
matrix& operator+= ( const T& v ) { data += v ; return *this ; }
matrix& operator-= ( const matrix& that ) { data -= that.data ; return *this ; }
matrix& operator-= ( const T& v ) { data -= v ; return *this ; }
// TO DO: likewise for other compound arithmetic operators
// makes a whole lot more functionality available
std::valarray<T>& view_as_valarray() { return data ; }
const std::valarray<T>& view_as_valarray() const { return data ; }
private:
std::size_t nrows ;
std::size_t ncols ;
std::valarray<T> data ;
// arithmetic operators
friend matrix operator+ ( matrix a, const matrix& b ) { return a += b ; }
friend matrix operator+ ( matrix a, const T& b ) { return a += b ; }
friend matrix operator+ ( const T& a, matrix b ) { return b += a ; }
friend matrix operator- ( matrix a, const matrix& b ) { return a -= b ; }
friend matrix operator- ( matrix a, const T& b ) { return a -= b ; }
// TO DO: other arithmetic operators
// TO DO: classic operations transpose, matrix multiply etc.
friend std::ostream& operator<< ( std::ostream& stm, const matrix& mtx )
{
std::size_t cnt = 0 ;
for( const auto& v : mtx.data ) stm << v << ( ++cnt % mtx.ncols == 0 ? '\n' : ' ' ) ;
return stm ;
}
// TO DO: stream insertion operator
};
static_assert( std::is_nothrow_move_constructible< matrix<double> >::value &&
std::is_nothrow_move_assignable< matrix<int> >::value,
"broken compiler and/or library" ) ; // sanity check
int main()
{
matrix<int> m1( 3, 5, { 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 } ) ;
const matrix<int> m2( 3, 5, { 33, 38, 32, 35, 39, 37, 34, 36, 31, 40, 41, 42, 43, 44, 45 } ) ;
std::cout << "m1:\n" << m1 << "\nm2:\n" << m2 << "\nm1+m2:\n" << m1+m2
<< "\nm1-m2:\n" << m1-m2 << "\nm1-1:\n" << m1-1 << "\n100+m1:\n" << 100+m1 << '\n' ;
m1.row(1) = 17 ;
std::cout << "row(1) = 17 ;\n" << m1 << '\n' ;
m1.col(3) = -678 ;
std::cout << "col(3) = -678 ;\n"<< m1 << '\n' ;
m1.col(2) *= std::valarray<int>{ 20, 30, 40 } ;
std::cout << "col(2) *= std::valarray<int>{ 20, 30, 40 } ;\n"<< m1 << '\n' ;
auto& va = m1.view_as_valarray() ;
va[ va>21 ] = 99 ; // assign 99 to all elements greater than 21
std::cout << "va[ va>21 ] = 99 ;\n" << m1 << '\n' ;
}
|