00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef __SVD_H
00031 #define __SVD_H
00032
00033 #include <TooN/TooN.h>
00034 #include <TooN/lapack.h>
00035
00036 namespace TooN {
00037
00038
00039 static const double condition_no=1e9;
00040
00041
00042
00043
00044
00045
00046
00087 template<int Rows=Dynamic, int Cols=Rows, typename Precision=DefaultPrecision>
00088 class SVD {
00089 public:
00090
00091
00092 static const int Min_Dim = Rows<Cols?Rows:Cols;
00093
00095 SVD() {}
00096
00098 SVD(int rows, int cols)
00099 : my_copy(rows,cols),
00100 my_diagonal(std::min(rows,cols)),
00101 my_square(std::min(rows,cols), std::min(rows,cols))
00102 {}
00103
00106 template <int R2, int C2, typename P2, typename B2>
00107 SVD(const Matrix<R2,C2,P2,B2>& m)
00108 : my_copy(m),
00109 my_diagonal(std::min(m.num_rows(),m.num_cols())),
00110 my_square(std::min(m.num_rows(),m.num_cols()),std::min(m.num_rows(),m.num_cols()))
00111 {
00112 do_compute();
00113 }
00114
00116 template <int R2, int C2, typename P2, typename B2>
00117 void compute(const Matrix<R2,C2,P2,B2>& m){
00118 my_copy=m;
00119 do_compute();
00120 }
00121
00122 void do_compute(){
00123 Precision* const a = my_copy.my_data;
00124 int lda = my_copy.num_cols();
00125 int m = my_copy.num_cols();
00126 int n = my_copy.num_rows();
00127 Precision* const uorvt = my_square.my_data;
00128 Precision* const s = my_diagonal.my_data;
00129 int ldu;
00130 int ldvt = lda;
00131 int LWORK;
00132 int INFO;
00133 char JOBU;
00134 char JOBVT;
00135
00136 if(is_vertical()){
00137 JOBU='O';
00138 JOBVT='S';
00139 ldu = lda;
00140 } else {
00141 JOBU='S';
00142 JOBVT='O';
00143 ldu = my_square.num_cols();
00144 }
00145
00146 Precision* wk;
00147
00148 Precision size;
00149 LWORK = -1;
00150
00151
00152
00153 dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00154 &ldvt, uorvt, &ldu, &size, &LWORK, &INFO);
00155
00156 LWORK = (long int)(size);
00157 wk = new Precision[LWORK];
00158
00159 dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00160 &ldvt, uorvt, &ldu, wk, &LWORK, &INFO);
00161
00162 delete[] wk;
00163 }
00164
00165 bool is_vertical(){
00166 return (my_copy.num_rows() >= my_copy.num_cols());
00167 }
00168
00169 int min_dim(){ return std::min(my_copy.num_rows(), my_copy.num_cols()); }
00170
00171
00176 template <int Rows2, int Cols2, typename P2, typename B2>
00177 Matrix<Cols,Cols2, typename Internal::MultiplyType<Precision,P2>::type >
00178 backsub(const Matrix<Rows2,Cols2,P2,B2>& rhs, const Precision condition=condition_no)
00179 {
00180 Vector<Min_Dim> inv_diag(min_dim());
00181 get_inv_diag(inv_diag,condition);
00182 return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00183 }
00184
00189 template <int Size, typename P2, typename B2>
00190 Vector<Cols, typename Internal::MultiplyType<Precision,P2>::type >
00191 backsub(const Vector<Size,P2,B2>& rhs, const Precision condition=condition_no)
00192 {
00193 Vector<Min_Dim> inv_diag(min_dim());
00194 get_inv_diag(inv_diag,condition);
00195 return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00196 }
00197
00202 Matrix<Cols,Rows> get_pinv(const Precision condition = condition_no){
00203 Vector<Min_Dim> inv_diag(min_dim());
00204 get_inv_diag(inv_diag,condition);
00205 return diagmult(get_VT().T(),inv_diag) * get_U().T();
00206 }
00207
00210 Precision determinant() {
00211 Precision result = my_diagonal[0];
00212 for(int i=1; i<my_diagonal.size(); i++){
00213 result *= my_diagonal[i];
00214 }
00215 return result;
00216 }
00217
00220 int rank(const Precision condition = condition_no) {
00221 if (my_diagonal[0] == 0) return 0;
00222 int result=1;
00223 for(int i=0; i<min_dim(); i++){
00224 if(my_diagonal[i] * condition <= my_diagonal[0]){
00225 result++;
00226 }
00227 }
00228 return result;
00229 }
00230
00234 Matrix<Rows,Min_Dim,Precision,Reference::RowMajor> get_U(){
00235 if(is_vertical()){
00236 return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00237 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00238 } else {
00239 return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00240 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00241 }
00242 }
00243
00245 Vector<Min_Dim,Precision>& get_diagonal(){ return my_diagonal; }
00246
00250 Matrix<Min_Dim,Cols,Precision,Reference::RowMajor> get_VT(){
00251 if(is_vertical()){
00252 return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00253 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00254 } else {
00255 return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00256 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00257 }
00258 }
00259
00260
00261 void get_inv_diag(Vector<Min_Dim>& inv_diag, const Precision condition){
00262 for(int i=0; i<min_dim(); i++){
00263 if(my_diagonal[i] * condition <= my_diagonal[0]){
00264 inv_diag[i]=0;
00265 } else {
00266 inv_diag[i]=static_cast<Precision>(1)/my_diagonal[i];
00267 }
00268 }
00269 }
00270
00271 private:
00272 Matrix<Rows,Cols,Precision,RowMajor> my_copy;
00273 Vector<Min_Dim,Precision> my_diagonal;
00274 Matrix<Min_Dim,Min_Dim,Precision,RowMajor> my_square;
00275 };
00276
00277
00278
00279
00280
00281
00285 template<int Size, typename Precision>
00286 struct SQSVD : public SVD<Size, Size, Precision> {
00287
00288 SQSVD() {}
00289 SQSVD(int size) : SVD<Size,Size,Precision>(size, size) {}
00290
00291 template <int R2, int C2, typename P2, typename B2>
00292 SQSVD(const Matrix<R2,C2,P2,B2>& m) : SVD<Size,Size,Precision>(m) {}
00293 };
00294
00295
00296 }
00297
00298
00299 #endif