SVD.h

00001 // -*- c++ -*-
00002 
00003 // Copyright (C) 2005,2009 Tom Drummond (twd20@cam.ac.uk)
00004 //
00005 // This file is part of the TooN Library.  This library is free
00006 // software; you can redistribute it and/or modify it under the
00007 // terms of the GNU General Public License as published by the
00008 // Free Software Foundation; either version 2, or (at your option)
00009 // any later version.
00010 
00011 // This library is distributed in the hope that it will be useful,
00012 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
00014 // GNU General Public License for more details.
00015 
00016 // You should have received a copy of the GNU General Public License along
00017 // with this library; see the file COPYING.  If not, write to the Free
00018 // Software Foundation, 59 Temple Place - Suite 330, Boston, MA 02111-1307,
00019 // USA.
00020 
00021 // As a special exception, you may use this file as part of a free software
00022 // library without restriction.  Specifically, if other files instantiate
00023 // templates or use macros or inline functions from this file, or you compile
00024 // this file and link it with other files to produce an executable, this
00025 // file does not by itself cause the resulting executable to be covered by
00026 // the GNU General Public License.  This exception does not however
00027 // invalidate any other reasons why the executable file might be covered by
00028 // the GNU General Public License.
00029 
00030 #ifndef __SVD_H
00031 #define __SVD_H
00032 
00033 #include <TooN/TooN.h>
00034 #include <TooN/lapack.h>
00035 
00036 namespace TooN {
00037 
00038     // TODO - should this depend on precision?
00039 static const double condition_no=1e9; // GK HACK TO GLOBAL
00040 
00041 
00042 
00043 
00044 
00045 
00046 
00087 template<int Rows=Dynamic, int Cols=Rows, typename Precision=DefaultPrecision>
00088 class SVD {
00089 public:
00090     // this is the size of the diagonal
00091     // NB works for semi-dynamic sizes because -1 < +ve ints
00092     static const int Min_Dim = Rows<Cols?Rows:Cols;
00093     
00095     SVD() {}
00096 
00098     SVD(int rows, int cols)
00099         : my_copy(rows,cols),
00100           my_diagonal(std::min(rows,cols)),
00101           my_square(std::min(rows,cols), std::min(rows,cols))
00102     {}
00103 
00106     template <int R2, int C2, typename P2, typename B2>
00107     SVD(const Matrix<R2,C2,P2,B2>& m)
00108         : my_copy(m),
00109           my_diagonal(std::min(m.num_rows(),m.num_cols())),
00110           my_square(std::min(m.num_rows(),m.num_cols()),std::min(m.num_rows(),m.num_cols()))
00111     {
00112         do_compute();
00113     }
00114 
00116     template <int R2, int C2, typename P2, typename B2>
00117     void compute(const Matrix<R2,C2,P2,B2>& m){
00118         my_copy=m;
00119         do_compute();
00120     }
00121         
00122     void do_compute(){
00123         Precision* const a = my_copy.my_data;
00124         int lda = my_copy.num_cols();
00125         int m = my_copy.num_cols();
00126         int n = my_copy.num_rows();
00127         Precision* const uorvt = my_square.my_data;
00128         Precision* const s = my_diagonal.my_data;
00129         int ldu;
00130         int ldvt = lda;
00131         int LWORK;
00132         int INFO;
00133         char JOBU;
00134         char JOBVT;
00135 
00136         if(is_vertical()){ // u is a
00137             JOBU='O';
00138             JOBVT='S';
00139             ldu = lda;
00140         } else { // vt is a
00141             JOBU='S';
00142             JOBVT='O';
00143             ldu = my_square.num_cols();
00144         }
00145 
00146         Precision* wk;
00147 
00148         Precision size;
00149         LWORK = -1;
00150 
00151         // arguments are scrambled because we use rowmajor and lapack uses colmajor
00152         // thus u and vt play each other's roles.
00153         dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00154                  &ldvt, uorvt, &ldu, &size, &LWORK, &INFO);
00155     
00156         LWORK = (long int)(size);
00157         wk = new Precision[LWORK];
00158 
00159         dgesvd_( &JOBVT, &JOBU, &m, &n, a, &lda, s, uorvt,
00160                  &ldvt, uorvt, &ldu, wk, &LWORK, &INFO);
00161     
00162         delete[] wk;
00163     }
00164 
00165     bool is_vertical(){ 
00166         return (my_copy.num_rows() >= my_copy.num_cols()); 
00167     }
00168 
00169     int min_dim(){ return std::min(my_copy.num_rows(), my_copy.num_cols()); }
00170 
00171 
00176     template <int Rows2, int Cols2, typename P2, typename B2>
00177     Matrix<Cols,Cols2, typename Internal::MultiplyType<Precision,P2>::type >
00178     backsub(const Matrix<Rows2,Cols2,P2,B2>& rhs, const Precision condition=condition_no)
00179     {
00180         Vector<Min_Dim> inv_diag(min_dim());
00181         get_inv_diag(inv_diag,condition);
00182         return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00183     }
00184 
00189     template <int Size, typename P2, typename B2>
00190     Vector<Cols, typename Internal::MultiplyType<Precision,P2>::type >
00191     backsub(const Vector<Size,P2,B2>& rhs, const Precision condition=condition_no)
00192     {
00193         Vector<Min_Dim> inv_diag(min_dim());
00194         get_inv_diag(inv_diag,condition);
00195         return (get_VT().T() * diagmult(inv_diag, (get_U().T() * rhs)));
00196     }
00197 
00202     Matrix<Cols,Rows> get_pinv(const Precision condition = condition_no){
00203         Vector<Min_Dim> inv_diag(min_dim());
00204         get_inv_diag(inv_diag,condition);
00205         return diagmult(get_VT().T(),inv_diag) * get_U().T();
00206     }
00207 
00210     Precision determinant() {
00211         Precision result = my_diagonal[0];
00212         for(int i=1; i<my_diagonal.size(); i++){
00213             result *= my_diagonal[i];
00214         }
00215         return result;
00216     }
00217     
00220     int rank(const Precision condition = condition_no) {
00221         if (my_diagonal[0] == 0) return 0;
00222         int result=1;
00223         for(int i=0; i<min_dim(); i++){
00224             if(my_diagonal[i] * condition <= my_diagonal[0]){
00225                 result++;
00226             }
00227         }
00228         return result;
00229     }
00230 
00234     Matrix<Rows,Min_Dim,Precision,Reference::RowMajor> get_U(){
00235         if(is_vertical()){
00236             return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00237                 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00238         } else {
00239             return Matrix<Rows,Min_Dim,Precision,Reference::RowMajor>
00240                 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00241         }
00242     }
00243 
00245     Vector<Min_Dim,Precision>& get_diagonal(){ return my_diagonal; }
00246 
00250     Matrix<Min_Dim,Cols,Precision,Reference::RowMajor> get_VT(){
00251         if(is_vertical()){
00252             return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00253                 (my_square.my_data, my_square.num_rows(), my_square.num_cols());
00254         } else {
00255             return Matrix<Min_Dim,Cols,Precision,Reference::RowMajor>
00256                 (my_copy.my_data,my_copy.num_rows(),my_copy.num_cols());
00257         }
00258     }
00259 
00260 
00261     void get_inv_diag(Vector<Min_Dim>& inv_diag, const Precision condition){
00262         for(int i=0; i<min_dim(); i++){
00263             if(my_diagonal[i] * condition <= my_diagonal[0]){
00264                 inv_diag[i]=0;
00265             } else {
00266                 inv_diag[i]=static_cast<Precision>(1)/my_diagonal[i];
00267             }
00268         }
00269     }
00270 
00271 private:
00272     Matrix<Rows,Cols,Precision,RowMajor> my_copy;
00273     Vector<Min_Dim,Precision> my_diagonal;
00274     Matrix<Min_Dim,Min_Dim,Precision,RowMajor> my_square; // square matrix (U or V' depending on the shape of my_copy)
00275 };
00276 
00277 
00278 
00279 
00280 
00281 
00285 template<int Size, typename Precision>
00286 struct SQSVD : public SVD<Size, Size, Precision> {
00287     // forward all constructors to SVD
00288     SQSVD() {}
00289     SQSVD(int size) : SVD<Size,Size,Precision>(size, size) {}
00290     
00291     template <int R2, int C2, typename P2, typename B2>
00292     SQSVD(const Matrix<R2,C2,P2,B2>& m) : SVD<Size,Size,Precision>(m) {}
00293 };
00294 
00295 
00296 }
00297 
00298 
00299 #endif

Generated on Thu May 7 20:28:41 2009 for TooN by  doxygen 1.5.3