LU.h

00001 // -*- c++ -*-
00002 
00003 // Copyright (C) 2005,2009 Tom Drummond (twd20@cam.ac.uk),
00004 // Ed Rosten (er258@cam.ac.uk)
00005 //
00006 // This file is part of the TooN Library.   This library is free
00007 // software; you can redistribute it and/or modify it under the
00008 // terms of the GNU General Public License as published by the
00009 // Free Software Foundation; either version 2, or (at your option)
00010 // any later version.
00011 
00012 // This library is distributed in the hope that it will be useful,
00013 // but WITHOUT ANY WARRANTY; without even the implied warranty of
00014 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
00015 // GNU General Public License for more details.
00016 
00017 // You should have received a copy of the GNU General Public License along
00018 // with this library; see the file COPYING. If not, write to the Free
00019 // Software Foundation, 59 Temple Place - Suite 330, Boston, MA 02111-1307,
00020 // USA.
00021 
00022 // As a special exception, you may use this file as part of a free software
00023 // library without restriction. Specifically, if other files instantiate
00024 // templates or use macros or inline functions from this file, or you compile
00025 // this file and link it with other files to produce an executable, this
00026 // file does not by itself cause the resulting executable to be covered by
00027 // the GNU General Public License.  This exception does not however
00028 // invalidate any other reasons why the executable file might be covered by
00029 // the GNU General Public License.
00030 
00031 #ifndef TOON_INCLUDE_LU_H
00032 #define TOON_INCLUDE_LU_H
00033 
00034 #include <iostream>
00035 
00036 #include <TooN/lapack.h>
00037 
00038 #include <TooN/TooN.h>
00039 
00040 namespace TooN {
00066 template <int Size=-1, class Precision=double>
00067 class LU {
00068     public:
00069 
00072     template<int S1, int S2, class Base>
00073     LU(const Matrix<S1,S2,Precision, Base>& m)
00074     :my_lu(m.num_rows(),m.num_cols()),my_IPIV(m.num_rows()){
00075         compute(m);
00076     }
00077     
00079     template<int S1, int S2, class Base>
00080     void compute(const Matrix<S1,S2,Precision,Base>& m){
00081         //check for consistency with Size
00082         SizeMismatch<Size, S1>::test(my_lu.num_rows(),m.num_rows());
00083         SizeMismatch<Size, S2>::test(my_lu.num_rows(),m.num_cols());
00084     
00085         //Make a local copy. This is guaranteed contiguous
00086         my_lu=m;
00087         int lda = m.num_rows();
00088         int M = m.num_rows();
00089         int N = m.num_rows();
00090 
00091         getrf_(&M,&N,&my_lu[0][0],&lda,&my_IPIV[0],&my_info);
00092 
00093         if(my_info < 0){
00094             std::cerr << "error in LU, INFO was " << my_info << std::endl;
00095         }
00096     }
00097 
00100     template <int Rows, int NRHS, class Base>
00101     Matrix<Size,NRHS,Precision> backsub(const Matrix<Rows,NRHS,Precision,Base>& rhs){
00102         //Check the number of rows is OK.
00103         SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.num_rows());
00104     
00105         Matrix<Size, NRHS, Precision> result(rhs);
00106 
00107         int M=rhs.num_cols();
00108         int N=my_lu.num_rows();
00109         double alpha=1;
00110         int lda=my_lu.num_rows();
00111         int ldb=rhs.num_cols();
00112         trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
00113         trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0][0],&ldb);
00114 
00115         // now do the row swapping (lapack dlaswp.f only shuffles fortran rows = Rowmajor cols)
00116         for(int i=N-1; i>=0; i--){
00117             const int swaprow = my_IPIV[i]-1; // fortran arrays start at 1
00118             for(int j=0; j<NRHS; j++){
00119                 Precision temp = result[i][j];
00120                 result[i][j] = result[swaprow][j];
00121                 result[swaprow][j] = temp;
00122             }
00123         }
00124         return result;
00125     }
00126 
00129     template <int Rows, class Base>
00130     Vector<Size,Precision> backsub(const Vector<Rows,Precision,Base>& rhs){
00131         //Check the number of rows is OK.
00132         SizeMismatch<Size, Rows>::test(my_lu.num_rows(), rhs.size());
00133     
00134         Vector<Size, Precision> result(rhs);
00135 
00136         int M=1;
00137         int N=my_lu.num_rows();
00138         double alpha=1;
00139         int lda=my_lu.num_rows();
00140         int ldb=1;
00141         trsm_("R","U","N","N",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
00142         trsm_("R","L","N","U",&M,&N,&alpha,&my_lu[0][0],&lda,&result[0],&ldb);
00143 
00144         // now do the row swapping (lapack dlaswp.f only shuffles fortran rows = Rowmajor cols)
00145         for(int i=N-1; i>=0; i--){
00146             const int swaprow = my_IPIV[i]-1; // fortran arrays start at 1
00147             Precision temp = result[i];
00148             result[i] = result[swaprow];
00149             result[swaprow] = temp;
00150         }
00151         return result;
00152     }
00153 
00156     Matrix<Size,Size,Precision> get_inverse(){
00157         Matrix<Size,Size,Precision> Inverse(my_lu);
00158         int N = my_lu.num_rows();
00159         int lda=my_lu.num_rows();
00160         int lwork=-1;
00161         Precision size;
00162         getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], &size, &lwork, &my_info);
00163         lwork=int(size);
00164         Precision* WORK = new Precision[lwork];
00165         getri_(&N, &Inverse[0][0], &lda, &my_IPIV[0], WORK, &lwork, &my_info);
00166         delete [] WORK;
00167         return Inverse;
00168     }
00169 
00175     const Matrix<Size,Size,Precision>& get_lu()const {return my_lu;}
00176 
00177     inline int get_sign() const {
00178         int result=1;
00179         for(int i=0; i<my_lu.num_rows()-1; i++){
00180             if(my_IPIV[i] > i+1){
00181                 result=-result;
00182             }
00183         }
00184         return result;
00185     }
00186 
00188     inline Precision determinant() const {
00189         Precision result = get_sign();
00190         for (int i=0; i<my_lu.num_rows(); i++){
00191             result*=my_lu(i,i);
00192         }
00193         return result;
00194     }
00195     
00197     int get_info() const { return my_info; }
00198 
00199  private:
00200 
00201     Matrix<Size,Size,Precision> my_lu;
00202     int my_info;
00203     Vector<Size, int> my_IPIV;  //Convenient static-or-dynamic array of ints :-)
00204 
00205 };
00206 }
00207     
00208 
00209 #endif

Generated on Thu May 7 20:28:40 2009 for TooN by  doxygen 1.5.3