00001 #ifndef __MATRIX_H 00002 #define __MATRIX_H 00003 00004 #include <valarray> 00005 #include<numeric> 00006 00011 namespace Teem 00012 { 00013 template<typename T> class SliceIter; 00014 00017 template<typename T> bool operator==(const SliceIter<T>& p, const SliceIter<T>& q) 00018 { 00019 return p.curr == q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start(); 00020 } 00021 00024 template<typename T> bool operator!=(const SliceIter<T>& p, const SliceIter<T>& q) 00025 { 00026 return !(p==q); 00027 } 00028 00031 template<typename T> bool operator<(const SliceIter<T>& p, const SliceIter<T>& q) 00032 { 00033 return p.curr < q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start(); 00034 } 00035 00038 template<typename T> class SliceIter 00039 { 00040 public: 00042 SliceIter(std::valarray<T> *vv, const std::slice &ss) : v(vv), s(ss), curr(0) { } 00043 00045 SliceIter end() const 00046 { 00047 SliceIter t = *this; 00048 t.curr = s.size(); 00049 return t; 00050 } 00051 00053 SliceIter& operator++() { curr++; return *this; } 00055 SliceIter operator++(int) { SliceIter t = *this; curr++; return t; } 00056 00058 T& operator[](size_t i) { return ref(i); } 00060 T& operator()(size_t i) { return ref(i); } 00062 T& operator*() { return ref(curr); } 00063 00065 friend bool operator== <>(const SliceIter& p, const SliceIter& q); 00067 friend bool operator!= <>(const SliceIter& p, const SliceIter& q); 00069 friend bool operator< <>(const SliceIter& p, const SliceIter& q); 00070 00071 protected: 00072 std::valarray<T> *v; 00073 const std::slice s; 00074 size_t curr; 00075 00077 T& ref(size_t i) const { return (*v)[s.start() + i * s.stride()]; } 00078 }; 00079 00080 00081 template<typename T> class ConstSliceIter; 00082 00085 template<typename T> bool operator==(const ConstSliceIter<T>& p, const ConstSliceIter<T>& q) 00086 { 00087 return p.curr == q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start(); 00088 } 00089 00092 template<typename T> bool operator!=(const ConstSliceIter<T>& p, const ConstSliceIter<T>& q) 00093 { 00094 return !(p==q); 00095 } 00096 00099 template<typename T> bool operator<(const ConstSliceIter<T>& p, const ConstSliceIter<T>& q) 00100 { 00101 return p.curr < q.curr && p.s.stride() == q.s.stride() && p.s.start() == q.s.start(); 00102 } 00103 00106 template<typename T> class ConstSliceIter 00107 { 00108 public: 00110 ConstSliceIter(const std::valarray<T> *vv, const std::slice &ss) : v(vv), s(ss), curr(0) { } 00111 00113 ConstSliceIter end() const 00114 { 00115 ConstSliceIter t = *this; 00116 t.curr = s.size(); 00117 return t; 00118 } 00119 00121 ConstSliceIter& operator++() { curr++; return *this; } 00123 ConstSliceIter operator++(int) { ConstSliceIter t = *this; curr++; return t; } 00124 00126 const T& operator[](size_t i) const { return ref(i); } 00128 const T& operator()(size_t i) const { return ref(i); } 00130 const T& operator*() const { return ref(curr); } 00131 00133 friend bool operator== <>(const ConstSliceIter& p, const ConstSliceIter& q); 00135 friend bool operator!= <>(const ConstSliceIter& p, const ConstSliceIter& q); 00137 friend bool operator< <>(const ConstSliceIter& p, const ConstSliceIter& q); 00138 00139 protected: 00140 const std::valarray<T> *v; 00141 const std::slice s; 00142 size_t curr; 00143 00145 const T& ref(size_t i) const { return (*v)[s.start() + i * s.stride()]; } 00146 }; 00147 00150 template<typename T> class Matrix 00151 { 00152 public: 00154 Matrix(size_t nx = 0, size_t ny = 0, T c = T()) : items(c, nx*ny), xDim(nx), yDim(ny) { } 00156 ~Matrix() { } 00157 00159 size_t size() const { return xDim * yDim; } 00161 size_t columnNum() const { return xDim; } 00163 size_t rowNum() const { return yDim; } 00164 00166 std::valarray<T>& flat() { return items; } 00167 00169 const std::valarray<T>& const_flat() const { return items; } 00170 00172 void resize(size_t nx, size_t ny, T c = T()) { items.resize(nx*ny, c); xDim = nx; yDim = ny; } 00173 00175 SliceIter<T> row(size_t i) { return SliceIter<T>(&items, std::slice(i, xDim, yDim)); } 00177 ConstSliceIter<T> row(size_t i) const { return ConstSliceIter<T>(&items, std::slice(i, xDim, yDim)); } 00178 00180 SliceIter<T> column(size_t i) { return SliceIter<T>(&items, std::slice(i*yDim, yDim, 1)); } 00182 ConstSliceIter<T> column(size_t i) const { return ConstSliceIter<T>(&items, std::slice(i*yDim, yDim, 1)); } 00183 00185 T& operator() (size_t x, size_t y) { return column(x)[y]; } 00187 const T& operator() (size_t x, size_t y) const { return column(x)[y]; } 00188 00190 SliceIter<T> operator() (size_t i) { return column(i); } 00192 ConstSliceIter<T> operator() (size_t i) const { return column(i); } 00193 00194 protected: 00195 std::valarray<T> items; 00196 size_t xDim; 00197 size_t yDim; 00198 }; 00199 00202 template<typename T> T operator*(const ConstSliceIter<T> &v1, const std::valarray<T> &v2) 00203 { 00204 T res = 0; 00205 for(size_t i = 0; i < v2.size(); i++) 00206 res += v1[i] * v2[i]; 00207 return res; 00208 } 00209 00215 template<typename T> std::valarray<T> operator*(const Matrix<T>& m, const std::valarray<T>& v) 00216 { 00217 assert(m.columnNum() == v.size()); 00218 00219 std::valarray<T> res(m.rowNum()); 00220 for(size_t i = 0; i < m.rowNum(); i++) 00221 res[i] = m.row(i) * v; 00222 return res; 00223 } 00224 00230 template<typename T> std::valarray<T> operator*(const std::valarray<T>& v, const Matrix<T>& m) 00231 { 00232 assert(m.rowNum() == v.size()); 00233 00234 std::valarray<T> res(m.columnNum()); 00235 for(size_t i = 0; i < m.columnNum(); i++) 00236 res[i] = m.column(i) * v; 00237 return res; 00238 } 00239 } 00240 00241 #endif