MatrixMultiplication.java


Below is the syntax highlighted version of MatrixMultiplication.java from §9.4 Numerical Solutions to Differential Equations.


/*************************************************************************
 *  Compilation:  javac MatrixMultiplication.java
 *  Execution:    java MatrixMultiplication
 * 
 *  6 different ways to multiply two N-by-N matrices.
 *  Illustrates importance of row-major vs. column-major ordering.
 *
 *  % java MatrixMultiplication 500
 *  Generating input:  0.048 seconds
 *  Order ijk:   3.562 seconds
 *  Order ikj:   1.348 seconds
 *  Order jik:   2.368 seconds
 *  Order jki:   4.846 seconds
 *  Order kij:   1.407 seconds
 *  Order kji:   4.91 seconds
 *  Order jik JAMA optimized:   0.571 seconds
 *  Order ikj pure row:   0.483 seconds
 *
 *  These timings are on a SUN-FIRE-X4100 running Linux.
 *
 *************************************************************************/

public class MatrixMultiplication {
    public static void show(double[][] a) {
        int N = a.length;
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                System.out.printf("%6.4f ", a[i][j]);
            }
            System.out.println();
        }
        System.out.println();
    }


    public static void main(String[] args) {
        int N = Integer.parseInt(args[0]);
        long start, stop;
        double elapsed;


        // generate input
        start = System.currentTimeMillis(); 

        double[][] A = new double[N][N];
        double[][] B = new double[N][N];
        double[][] C;

        for (int i = 0; i < N; i++)
            for (int j = 0; j < N; j++)
                A[i][j] = Math.random();

        for (int i = 0; i < N; i++)
            for (int j = 0; j < N; j++)
                B[i][j] = Math.random();

        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Generating input:  " + elapsed + " seconds");

        // order 1: ijk = dot product version
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        for (int i = 0; i < N; i++)
            for (int j = 0; j < N; j++)
                for (int k = 0; k < N; k++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order ijk:   " + elapsed + " seconds");
        if (N < 10) show(C);

        // order 2: ikj
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        for (int i = 0; i < N; i++)
            for (int k = 0; k < N; k++)
                for (int j = 0; j < N; j++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order ikj:   " + elapsed + " seconds");
        if (N < 10) show(C);

        // order 3: jik
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        for (int j = 0; j < N; j++)
            for (int i = 0; i < N; i++)
                for (int k = 0; k < N; k++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order jik:   " + elapsed + " seconds");
        if (N < 10) show(C);

        // order 4: jki = GAXPY version
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        for (int j = 0; j < N; j++)
            for (int k = 0; k < N; k++)
                for (int i = 0; i < N; i++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order jki:   " + elapsed + " seconds");
        if (N < 10) show(C);

        // order 5: kij
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        for (int k = 0; k < N; k++)
            for (int i = 0; i < N; i++)
                for (int j = 0; j < N; j++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order kij:   " + elapsed + " seconds");
        if (N < 10) show(C);

        // order 6: kji = outer product version
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        for (int k = 0; k < N; k++)
            for (int j = 0; j < N; j++)
                for (int i = 0; i < N; i++)
                    C[i][j] += A[i][k] * B[k][j];
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order kji:   " + elapsed + " seconds");
        if (N < 10) show(C);


        // order 7: jik optimized ala JAMA 
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        double[] bcolj = new double[N];
        for (int j = 0; j < N; j++) {
            for (int k = 0; k < N; k++) bcolj[k] = B[k][j];
            for (int i = 0; i < N; i++) {
                double[] arowi = A[i];
                double sum = 0.0;
                for (int k = 0; k < N; k++) {
                    sum += arowi[k] * bcolj[k];
                }
                C[i][j] = sum;
            }
        }
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order jik JAMA optimized:   " + elapsed + " seconds");
        if (N < 10) show(C);

        // order 8: ikj pure row
        C = new double[N][N];
        start = System.currentTimeMillis(); 
        for (int i = 0; i < N; i++) {
            double[] arowi = A[i];
            double[] crowi = C[i];
            for (int k = 0; k < N; k++) {
                double[] browk = B[k];
                double aik = arowi[k];
                for (int j = 0; j < N; j++) {
                    crowi[j] += aik * browk[j];
                }
            }
        }
        stop = System.currentTimeMillis();
        elapsed = (stop - start) / 1000.0;
        System.out.println("Order ikj pure row:   " + elapsed + " seconds");
        if (N < 10) show(C);

    }

}


Copyright © 2007, Robert Sedgewick and Kevin Wayne.
Last updated: Tue Sep 29 16:17:41 EDT 2009.