TooN 2.1
QR_Lapack.h
00001 #ifndef TOON_INCLUDE_QR_LAPACK_H
00002 #define TOON_INCLUDE_QR_LAPACK_H
00003 
00004 
00005 #include <TooN/TooN.h>
00006 #include <TooN/lapack.h>
00007 #include <utility>
00008 
00009 namespace TooN{
00010 
00011 /**
00012 Performs %QR decomposition.
00013 
00014 @warning this will only work if the number of columns is greater than 
00015 the number of rows!
00016 
00017 The QR decomposition operates on a matrix A. It can be performed with
00018 or without column pivoting. In general:
00019 \f[
00020 AP = QR
00021 \f]
00022 Where \f$P\f$ is a permutation matrix constructed to permute the columns
00023 of A. In practise, \f$P\f$ is stored as a vector of integer elements.
00024 
00025 With column pivoting, the elements of the leading diagonal of \f$R\f$ will
00026 be sorted from largest in magnitude to smallest in magnitude.
00027 
00028 @ingroup gDecomps
00029 */
00030 template<int Rows=Dynamic, int Cols=Rows, class Precision=double>
00031 class QR_Lapack{
00032 
00033     private:
00034         static const int square_Size = (Rows>=0 && Cols>=0)?(Rows<Cols?Rows:Cols):Dynamic;
00035 
00036     public: 
00037         /// Construct the %QR decomposition of a matrix. This initialises the class, and
00038         /// performs the decomposition immediately.
00039         /// @param m The matrix to decompose
00040         /// @param p Whether or not to perform pivoting
00041         template<int R, int C, class P, class B> 
00042         QR_Lapack(const Matrix<R,C,P,B>& m, bool p=0)
00043         :copy(m),tau(square_size()), 
00044          Q(square_size(), square_size()), 
00045          do_pivoting(p), 
00046          pivot(Zeros(m.num_cols()))
00047         {
00048             //pivot is set to all zeros, which means all columns are free columns
00049             //and can take part in column pivoting.
00050 
00051             compute();
00052         }
00053         
00054         ///Return R
00055         const Matrix<Rows, Cols, Precision, ColMajor>& get_R()
00056         {
00057             return copy;
00058         }
00059         
00060         ///Return Q
00061         const Matrix<square_Size, square_Size, Precision, ColMajor>& get_Q()
00062         {
00063             return Q;
00064         }   
00065 
00066         ///Return the permutation vector. The definition is that column \f$i\f$ of A is
00067         ///column \f$P(i)\f$ of \f$QR\f$.
00068         const Vector<Cols, int>& get_P()
00069         {
00070             return pivot;
00071         }
00072 
00073     private:
00074 
00075         void compute()
00076         {   
00077             FortranInteger M = copy.num_rows();
00078             FortranInteger N = copy.num_cols();
00079             
00080             FortranInteger LWORK=-1;
00081             FortranInteger INFO;
00082             FortranInteger lda = M;
00083 
00084             Precision size;
00085             
00086             //Set up the pivot vector
00087             if(do_pivoting)
00088                 pivot = Zeros;
00089             else
00090                 for(int i=0; i < pivot.size(); i++)
00091                     pivot[i] = i+1;
00092 
00093             
00094             //Compute the working space
00095             geqp3_(&M, &N, copy.get_data_ptr(), &lda, pivot.get_data_ptr(), tau.get_data_ptr(), &size, &LWORK, &INFO);
00096 
00097             LWORK = (FortranInteger) size;
00098 
00099             Precision* work = new Precision[LWORK];
00100             
00101             geqp3_(&M, &N, copy.get_data_ptr(), &lda, pivot.get_data_ptr(), tau.get_data_ptr(), work, &LWORK, &INFO);
00102 
00103 
00104             if(INFO < 0)
00105                 std::cerr << "error in QR, INFO was " << INFO << std::endl;
00106 
00107             //The upper "triangle+" of copy is R
00108             //The lower right and tau contain enough information to reconstruct Q
00109             
00110             //LAPACK provides a handy function to do the reconstruction
00111             Q = copy.template slice<0,0,square_Size, square_Size>(0,0,square_size(), square_size());
00112             
00113             FortranInteger K = square_size();
00114             M=K;
00115             N=K;
00116             lda = K;
00117             orgqr_(&M, &N, &K, Q.get_data_ptr(), &lda, tau.get_data_ptr(), work, &LWORK, &INFO);
00118 
00119             if(INFO < 0)
00120                 std::cerr << "error in QR, INFO was " << INFO << std::endl;
00121 
00122             delete [] work;
00123             
00124             //Now zero out the lower triangle
00125             for(int r=1; r < square_size(); r++)
00126                 for(int c=0; c<r; c++)
00127                     copy[r][c] = 0;
00128 
00129             //Now fix the pivot matrix.
00130             //We need to go from FORTRAN to C numbering. 
00131             for(int i=0; i < pivot.size(); i++)
00132                 pivot[i]--;
00133         }
00134 
00135         Matrix<Rows, Cols, Precision, ColMajor> copy;
00136         Vector<square_Size, Precision> tau;
00137         Matrix<square_Size, square_Size, Precision, ColMajor> Q;
00138         bool do_pivoting;
00139         Vector<Cols, FortranInteger> pivot;
00140         
00141 
00142         int square_size()
00143         {
00144             return std::min(copy.num_rows(), copy.num_cols());  
00145         }
00146 };
00147 
00148 }
00149 
00150 
00151 #endif