Scalapack Gemm 可以进行性能优化或调整吗?

计算科学 线性代数 矩阵 并行计算 mpi 布拉斯
2021-12-10 00:41:56

我正在比较分布式 gemm 的性能,在 OpenBLAS 上使用 Scalapack,与线程 gemm,使用 OpenBLAS。我似乎很难让 scalapack 提供比多线程 BLAS 更好的结果。我需要做一些调整/优化/配置,或者我需要以某种方式做一些不同的事情吗?

我的硬件:

  • 机箱中有 4 个刀片,我以后将其称为“节点”
  • 通过千兆以太网连接,不是我知道的最快的,也许是这里结果不佳的原因?
  • 每个节点有 12 个物理核心
  • 软件:mpich2 2.3、Ubuntu 12.04 64 位、gfortran

我正在比较三个软件配置,对于两个方阵的 GEMM,维度为n

  • 多线程 BLAS,使用 OpenBLAS,有 12 个线程。1 个单进程,在 1 个单节点上
  • scalapack over OpenBLAS,48 个 mpi 进程,每个 mpi 进程 1 个线程,在 4 个节点上
  • scalapack over OpenBLAS,4 个 mpi 进程,每个 mpi 进程 12 个线程,在 4 个节点上

这是一些结果的图表:

Scalapack 与多线程 BLAS

该图似乎显示:

  • 使用纯非分布式 BLAS 实际上超级快
  • 使用分布式 scalapack,每个节点 1 个 mpi 进程提供最佳性能,但并不比纯非分布式 BLAS 快得多,即使对于非常大的矩阵(图中最大的 n 为 30,000)
  • 使用每个 mpi 进程 1 个线程的分布式 scalapack 会产生最差的结果,总是比简单的未分布式 BLAS 更差

一些问题:

  • 这些结果有多典型?
  • 我做错了什么,或者可以做得更好,这可能会改善分布式 scalapack 结果的结果吗?
  • 每个 mpi 进程配置 1 个线程的结果比测试的其他配置差是否正常?

这是使用的测试代码:

对于 GEMM:

#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;

#include "mycblas.h"
#include "utils/NanoTimer.h"
#include "utils/stringhelper.h"
#include "args.h"

extern "C" {
    void openblas_set_num_threads(int num_threads);
}

int main( int argc, char *argv[] ) {
    int N, its, threads;
    Args( argc, argv ).arg("N", &N).arg("its", &its).arg("threads",&threads).go();

    openblas_set_num_threads( threads );    

    NanoTimer timer;
    double *A = (double*)malloc(sizeof(double)*N*N);
    double *B = (double*)malloc(sizeof(double)*N*N);
    int linsize = N * N;
    for( int i = 0; i < linsize; i++ ) {
        A[i] = i + 3;
        B[i] = i * 2;
    }
    int m = N;
    int n = N;
    int k = N;
    double alpha = 1;
    double beta = 0;
    double *C = (double*)malloc(sizeof(double)*N*N);
    timer.toc("setup input matrices");
    for( int it = 0; it < its; it++ ) {
        dgemm(false,false,N,N,N, 1, A, N, B, N, 0, C, N );
        timer.toc("it " + toString(it) );
    }
    int sum = 0;
    for( int mult = 0; mult < log(N)/log(10); mult++ ) {
        int offset = pow(10,mult);
        sum += C[offset];
    }
    cout << "sum, to prevent short-cut optimization " << sum << endl;
    return 0;
}

对于 scalapack:

#include <iostream>
#include <stdexcept>
#include <cstring>
#include <cmath>
using namespace std;

#include "mpi.h"

#include "utils/NanoTimer.h"
#include "utils/stringhelper.h"
#include "args.h"
#include "scalapack.h"

extern "C" {
    void openblas_set_num_threads(int num_threads);
}

int getRootFactor( int n ) {
    for( int t = sqrt(n); t > 0; t-- ) {
        if( n % t == 0 ) {
            return t;
        }
    }
    return 1;
}

// conventions:
// M_ by N_ matrix block-partitioned into MB_ by NB_ blocks, then
// distributed according to 2d block-cyclic scheme

// based on http://acts.nersc.gov/scalapack/hands-on/exercise3/pspblasdriver.f.html

int main( int argc, char *argv[] ) {
    int p, P;
    blacs_pinfo( &p, &P );
//    mpi_print( toString(p) + " / " + toString(P) );

    int n;
    int numthreads;
    int its;
    int blocksize;
    Args( argc, argv ).arg("N", &n ).arg("num iterations", &its ).arg("numthreads", &numthreads ).arg("blocksize", &blocksize).go();
    openblas_set_num_threads( numthreads );

    int nprows = getRootFactor(P);
    int npcols = P / nprows;
    if( p == 0 ) cout << "grid: " << nprows << " x " << npcols << endl;

    int system = blacs_get( -1, 0 );
    int grid = blacs_gridinit( system, true, nprows, npcols );
    if( p == 0 ) cout << "system context " << system << " grid context: " << grid << endl;

    int myrow, mycol;
    blacs_gridinfo( grid, nprows, npcols, &myrow, &mycol );
//    mpi_print("grid, me: " + toString(myrow) + ", " + toString(mycol) );

    if( myrow >= nprows || mycol >= npcols ) {
//        mpi_print("not needed, exiting");
        blacs_gridexit( grid );
        blacs_exit(0);
        exit(0);
    }

    // A     B       C
    // m x k k x n = m x n
    // nb: blocksize

    // nprows: process grid, number rows
    // npcols: process grid, number cols
    // myrow: process grid, our row
    // mycol: process grid, our col
    int m = n;
    int k = n;
//    int nb = min(n,128); // nb is column block size for A, and row blocks size for B
    int nb=min(n/P,128);

    int mp = numroc( m, nb, myrow, 0, nprows ); // mp number rows A owned by this process
    int kp = numroc( k, nb, myrow, 0, nprows ); // kp number rows B owned by this process
    int kq = numroc( k, nb, mycol, 0, npcols ); // kq number cols A owned by this process
    int nq = numroc( n, nb, mycol, 0, npcols ); // nq number cols B owned by this process
//    mpi_print( "mp " + toString(mp) + " kp " + toString(kp) + " kq " + toString(kq) + " nq " + toString(nq) );

    struct DESC desca, descb, descc;
    descinit( (&desca), m, k, nb, nb, 0, 0, grid, max(1, mp) );
    descinit( (&descb), k, n, nb, nb, 0, 0, grid, max(1, kp) );
    descinit( (&descc), m, n, nb, nb, 0, 0, grid, max(1, mp) );
//    mpi_print( "desca.LLD_ " + toString(desca.LLD_) + " kq " + toString(kq) );
    double *ipa = new double[desca.LLD_ * kq];
    double *ipb = new double[descb.LLD_ * nq];
    double *ipc = new double[descc.LLD_ * nq];

    for( int i = 0; i < desca.LLD_ * kq; i++ ) {
        ipa[i] = p;
    }
    for( int i = 0; i < descb.LLD_ * nq; i++ ) {
        ipb[i] = p;
    }

    if( p == 0 ) cout << "created matrices" << endl;
    double *work = new double[nb];
    if( n <=5 ) {
        pdlaprnt( n, n, ipa, 1, 1, &desca, 0, 0, "A", 6, work );
        pdlaprnt( n, n, ipb, 1, 1, &descb, 0, 0, "B", 6, work );
    }

    NanoTimer timer;
    for( int it = 0; it < its; it++ ) {
        pdgemm( false, false, m, n, k, 1,
                      ipa, 1, 1, &desca, ipb, 1, 1, &descb,
                      1, ipc, 1, 1, &descc );
        MPI_Barrier( MPI_COMM_WORLD );
        if( p == 0 ) timer.toc("it " + toString(it) + " pdgemm");
    }

    blacs_gridexit( grid );
    blacs_exit(0);

    return 0;
}

其它文件:

mycblas.h:

extern "C" {
#define ADD_
    // blas:
    #include <cblas_f77.h>

    // lapack:
    void dpotrf_( char *uplo, int *n, double *A, int *lda, int *info );
    void dtrtrs_( char *uplo, char *trans, char *diag, int *n, int *nrhs, double *A, int *lda,
         double *B, int *ldb, int *info );
    void dtrsm_( char *side, char *uplo, char *transA, char *diag, const int *m, const int *n, 
                 const double *alpha, const double *A, const int *lda, double *B, const int *ldb );
}

char boolToChar( bool value ) {
    return value ? 't' : 'n';
}

// double, general matrix multiply
void dgemm( bool transa, bool transb, int m, int n, int k, double alpha, double *A,
    int lda, double *B, int ldb, double beta, double *C, int ldc ) {
    char transachar = boolToChar( transa );
    char transbchar = boolToChar( transb );
    dgemm_(&transachar,&transbchar,&m,&n,&k, &alpha, A, &lda, B, &ldb, &beta, C, &ldc );    
}

// double, triangular, solve matrix
void dtrsm( bool XA, bool isUpper, bool transA, bool isUnitTriangular, int m, int n,
       double alpha, double *A, int lda, double *B, int ldb ) {
    char sideChar = XA ? 'R' : 'L';
    char isUpperChar = isUpper ? 'U' : 'L';
    char transAChar = transA ? 'T' : 'N';
    char isUnitTriangularChar = isUnitTriangular ? 'U' : 'N';
    dtrsm_( &sideChar, &isUpperChar, &transAChar, &isUnitTriangularChar, &m, &n,
        &alpha, A, &lda, B, &ldb );
}

// double, symmetric positive definite, triangular factorization (=cholesky)
int dpotrf( bool isUpper, int N, double *A, int lda ) {
    int info;
    char uplo = isUpper ? 'U' : 'L';
    dpotrf_( &uplo, &N, A, &lda, &info );
    return info;
}

// double, triangular, triangular solve
int dtrtrs( bool isUpper, bool transA, bool isUnitTriangular, int n, int nrhs, double *A, int lda,
             double *B, int ldb ) {
    int info;
    char isUpperChar = isUpper ? 'U' : 'L';
    char transChar = transA ? 'T' : 'N';
    char isUnitTriangularChar = isUnitTriangular ? 'U' : 'N';
    dtrtrs_( &isUpperChar, &transChar, &isUnitTriangularChar, &n, &nrhs, A, &lda, B, &ldb, &info );
    return info;
}

scalapack.h:

#pragma once

extern "C" {
    struct DESC{
        int DTYPE_;
        int CTXT_;
        int M_;
        int N_;
        int MB_;
        int NB_;
        int RSRC_;
        int CSRC_;
        int LLD_;
    } ;

    void blacs_pinfo_( int *iam, int *nprocs );
    void blacs_get_( int *icontxt, int *what, int *val );
    void blacs_gridinit_( int *icontxt, char *order, int *nprow, int *npcol );
    void blacs_gridinfo_( int *context, int *nprow, int *npcol, int *myrow, int *mycol );
    void blacs_gridexit_( int *context );
    void blacs_exit_( int *code );

    int numroc_( int *n, int *nb, int *iproc, int *isrcproc, int *nprocs );
    void descinit_( struct DESC *desc, int *m, int *n, int *mb, int *nb, int *irsrc, int *icsrc, int *ictxt, int *lld, int *info );
    void pdlaprnt_( int *m, int *n, double *a, int *ia, int *ja, struct DESC *desca, int *irprnt,
        int *icprnt, const char *cmatnm, int *nout, double *work, int cmtnmlen );
    void pdgemm_( char *transa, char *transb, int *m, int *n, int *k, double *alpha,
         double *a, int *ia, int *ja, struct DESC *desca, double *b, int *ib, int *jb,
        struct DESC *descb, double *beta, double *c, int *ic, int *jc, struct DESC *descc );
}

void blacs_pinfo( int *p, int *P ) {
    blacs_pinfo_( p, P );
}

int blacs_get( int icontxt, int what ) {
    int val;
    blacs_get_( &icontxt, &what, &val );
    return val;
}

int blacs_gridinit( int icontxt, bool isColumnMajor, int nprow, int npcol ) {
    int newcontext = icontxt;
    char order = isColumnMajor ? 'C' : 'R';
    blacs_gridinit_( &newcontext, &order, &nprow, &npcol );
    return newcontext;
}

void blacs_gridinfo( int context, int nprow, int npcol, int *myrow, int *mycol ) {
    blacs_gridinfo_( &context, &nprow, &npcol, myrow, mycol );
}

void blacs_gridexit( int context ) {
    blacs_gridexit_( &context );
}

void blacs_exit( int code ) {
    blacs_exit_( &code );
}

int numroc( int n, int nb, int iproc, int isrcproc, int nprocs ) {
    return numroc_( &n, &nb, &iproc, &isrcproc, &nprocs );
}

void descinit( struct DESC *desc, int m, int n, int mb, int nb, int irsrc, int icsrc, int ictxt, int lld ) {
    int info;
    descinit_( desc, &m, &n, &mb, &nb, &irsrc, &icsrc, &ictxt, &lld, &info );
    if( info != 0 ) {
        throw runtime_error( "non zero info: " + toString( info ) );
    }
//    return info;
}

void pdlaprnt( int m, int n, double *A, int ia, int ja, struct DESC *desc, int irprnt,
    int icprnt, const char *cmatnm, int nout, double *work ) {
    int cmatnmlen = strlen(cmatnm);
    pdlaprnt_( &m, &n, A, &ia, &ja, desc, &irprnt, &icprnt, cmatnm, &nout, work, cmatnmlen );
}

void pdgemm( bool isTransA, bool isTransB, int m, int n, int k, double alpha,
     double *a, int ia, int ja, struct DESC *desca, double *b, int ib, int jb,
    struct DESC *descb, double beta, double *c, int ic, int jc, struct DESC *descc ) {
    char transa = isTransA ? 'T' : 'N';
    char transb = isTransB ? 'T' : 'N';
    pdgemm_( &transa, &transb, &m, &n, &k, &alpha, a, &ia, &ja, desca, b, &ib, &jb,
        descb, &beta, c, &ic, &jc, descc );
}

实用程序/NanoTimer.h:

#pragma once
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <sys/time.h>

class NanoTimer {
public:
   struct timespec start;

   NanoTimer() {
      clock_gettime(CLOCK_MONOTONIC,  &start);

   }
   double elapsedSeconds() {
      struct timespec now;
      clock_gettime(CLOCK_MONOTONIC,  &now);
      double time = (now.tv_sec - start.tv_sec) + (double) (now.tv_nsec - start.tv_nsec) * 1e-9;
      start = now;
      return time;
   }
    void toc(string label) {
        double elapsed = elapsedSeconds();
        cout << label << ": " << elapsed << "s" << endl;        
    }
};

实用程序/stringhelper.h:

#pragma once

#include <vector>
#include <string>
#include <sstream>
#include <iostream>
#include <cstdlib>

template<typename T>
std::string toString(T val ) { // not terribly efficient, but works...
   std::ostringstream myostringstream;
   myostringstream << val;
   return myostringstream.str();
}

参数.h:

#pragma once

// usage:
// int N, its;
// arg( "N", &N );
// arg( "its", &its );
// args( argc, argv );

class Arg {
public:
    virtual void assign( const char *argvalue ) = 0;
    virtual void print( ostream &os ) const = 0;
};
ostream &operator<<( ostream &os, const Arg &arg ) {
    arg.print( os );
    return os;
}

class IntArg : public Arg {
public:
    int *argptr;
    IntArg( int *_argptr ) : argptr(_argptr ) {
    }
    void assign( const char *argvalue ) {
        *argptr = atoi( argvalue );
    }
    void print( ostream &os ) const {
        os << (*argptr );
    }
};

vector<string> argnames;
vector<Arg *> argptrs;

void arg_usage(string cmd) {
    cout << "Usage: " << cmd;
    for( int i = 0; i < argnames.size(); i++ ) {
        cout << " [" << argnames[i] << "]";
    }
    cout << endl;
    exit(1);
}

void arg( string name, int *p_value ) {
    argnames.push_back(name);
    argptrs.push_back( new IntArg( p_value ) );
}

void args( int argc, char *argv[] ) {
    if( argc - 1 != argnames.size() ) {
        arg_usage(argv[0]);
    }
    for( int i = 0; i < argnames.size(); i++ ) {
        argptrs[i]->assign( argv[i+1] );
        cout << argnames[i] << ": " << (*argptrs[i]) << endl;
    }
}

class Args {
public:
    int argc;
    char **argv;
    Args( int _argc, char *_argv[] ) : argc(_argc), argv(_argv) {
    }
    void go() {
        args( argc, argv );
    }
    Args &_( string name, int *pvalue ) {
        ::arg( name, pvalue );
        return *this;
    }
    Args &arg( string name, int *pvalue ) {
        ::arg( name, pvalue );
        return *this;
    }
};

编辑:在回答杰夫关于我的元素代码的问题时,这是我的 GEMM 元素代码:

#include "mpi.h"

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <sys/time.h>
using namespace std;

#include "elemental.hpp"
using namespace elem;

extern "C" {
    void openblas_set_num_threads(int num_threads);
}

class NanoTimer {
public:
   struct timespec start;

   NanoTimer() {
      clock_gettime(CLOCK_MONOTONIC,  &start);

   }
   double elapsedSeconds() {
      struct timespec now;
      clock_gettime(CLOCK_MONOTONIC,  &now);
      double time = (now.tv_sec - start.tv_sec) + (double) (now.tv_nsec - start.tv_nsec) * 1e-9;
      start = now;
      return time;
   }
    void toc(string label) {
        double elapsed = elapsedSeconds();
        cout << label << ": " << elapsed << endl;        
    }
};

int sum = 0;
void readMatrix( Matrix<double> &A ) {
    for( int i = 1; i < A.Width(); i *= 10 ) {
        sum += A.Get(i,i);
    }
}
void readMatrix( DistMatrix<double,CIRC,CIRC> &A ) {
    for( int i = 1; i < A.Width(); i *= 10 ) {
        sum += A.Get(i,i);
    }
}

int main( int argc, char *argv[] ) {
    elem::Initialize( argc, argv );
    int p = mpi::CommRank(mpi::COMM_WORLD);
    int P = mpi::CommSize(mpi::COMM_WORLD);

    if( argc < 3 ) {
        if( p == 0 ) cout << "Usage: " << argv[0] << " [N] [multithreaded: 1|0]" << endl;
        return -1;
    }
    int n = atoi( argv[1] );
    int multithreaded = atoi( argv[2] );

    Matrix<double> A(n,n);
    Matrix<double> B(n,n);
    for( int i = 0; i < n; i++ ) {
        for( int j = 0; j < n; j++ ) {
            A.Set(i,j, i*j + 2 );
            B.Set(i,j, i*j + 4 );
        }
    }

    Matrix<double> C(n,n);
    MPI_Barrier(MPI_COMM_WORLD);
    NanoTimer timer;
    if( p == 0 ) {
        Gemm<double>(NORMAL, NORMAL, 1, A, B, 0, C );  
        // read some values, to prevent being optimized out :-P
        timer.toc("blas multithreaded");
        readMatrix(C);

        if( !multithreaded ) {
            openblas_set_num_threads(1);
            Gemm<double>(NORMAL, NORMAL, 1, A, B, 0, C );    
            timer.toc("blas singlethreaded");
            readMatrix(C);
        }
    }
    if( !multithreaded ) {
        openblas_set_num_threads(1);
    }

    Grid g;
    DistMatrix<double,CIRC,CIRC> Aroot(n,n,g);
    DistMatrix<double,CIRC,CIRC> Broot(n,n,g);
    DistMatrix<double,CIRC,CIRC> Croot(n,n,g);
    Aroot.SetRoot(0);
    Broot.SetRoot(0);
    Croot.SetRoot(0);

    if( p == 0 ) {
        for( int i = 0; i < n; i++ ) {
            for( int j = 0; j < n; j++ ) {
                Aroot.Set(i,j,i*j+2);
                Broot.Set(i,j,i*j+2);
            }
        }
    }
    if( p == 0 ) timer.toc("populate root node");

//    DistMatrix<double,MC,STAR> Adist( Aroot );
//    DistMatrix<double,STAR,MR> Bdist( Broot );
//    DistMatrix<double,MC,MR> Cdist(n,n,g);
    DistMatrix<double> Adist( Aroot );
    DistMatrix<double> Bdist( Broot );
    DistMatrix<double> Cdist(n,n,g);
    MPI_Barrier(MPI_COMM_WORLD);
    if( p == 0 ) timer.toc("distributed to slaves");

    Gemm<double>(NORMAL, NORMAL, 1, Adist, Bdist, 0, Cdist );    
    MPI_Barrier(MPI_COMM_WORLD);
    if( p == 0 ) timer.toc("distmatrix gemm");

    Croot = Cdist;
    MPI_Barrier(MPI_COMM_WORLD);
    if( p == 0 ) timer.toc("gathered to master");
    readMatrix(Croot);

    if( p == 0 ) cout << "sum, to prevent optimization out: " << sum << endl;

    elem::Finalize();
    return 0;
}
3个回答

这些结果并不令人惊讶。众所周知,矩阵乘法是通信密集型的,并且您的四个节点之间的通信网络相对较慢。

在同一节点上的两个进程之间使用 MPI 肯定比在不同节点上的进程之间使用 MPI 更快,因为您没有以太网的带宽限制。但是,调用 MPI 库仍然需要付出代价。相比之下,多线程 BLAS 在线程之间的通信速度非常快,因为内存速度很快,而且您无需像使用 MPI 那样支付调用 MPI 库例程和使用操作系统在线程之间传递消息的开销。

理想情况下,对于大型n,您应该看到 4 个节点,每个节点 1 个进程,每个进程方法 12 个线程的速度是单个节点上 12 个线程的四倍。两条曲线向 4 的比率发散,但不是很快(在 n=30000 时您的速度是原来的 3 倍)——这表明千兆以太网的带宽确实阻碍了您。

这些结果看起来很正常。您的网络对于分布式 BLAS 来说很糟糕。我不能说 mpich2 共享内存实现,但它也可能不符合标准,并限制每个进程 1 个线程/每个节点 12 个进程的结果。您可以尝试 OpenMPI 或 MVAPICH2(尽管我从未仅使用以太网运行后者)。

ScaLAPACK 中的算法不是很好。查看 SUMMA ( http://www.cs.utexas.edu/ftp/techreports/tr95-13.pdf ) 及其在 Elemental (code.google.com/p/elemental/) 或 CTF ( http:// ctf.eecs.berkeley.edu/)代替。Solomonik 及其同事的论文显示,与 ScaLAPACK 相比,性能有了显着提高。

我的一位同事——物理学家,而不是计算机科学家——自己实现了 SUMMA,发现它比 ScaLAPACK 更快,因此击败后者并不