1. 概要

このチュートリアルでは、Javaで2つの行列を乗算する方法を見ていきます。

行列の概念は言語にネイティブに存在しないため、自分で実装します。また、いくつかのライブラリを使用して、行列の乗算をどのように処理するかを確認します。

最後に、最速のソリューションを決定するために、調査したさまざまなソリューションのベンチマークを少し行います。

2. 例

このチュートリアル全体で参照できる例を設定することから始めましょう。

まず、3×2の行列を想像します。

今回は2行4列の2番目の行列を想像してみましょう。

次に、最初の行列に2番目の行列を乗算すると、3×4の行列になります。

注意として、この結果は、次の式を使用して結果の行列の各セルを計算することによって取得されます。

ここで、rは行列Aの行数、cは行列Bおよびn[の列数です。 X127X]は、行列 A の列数であり、行列Bの行数と一致する必要があります。

3. 行列の乗算

3.1. 独自の実装

行列の独自の実装から始めましょう。

シンプルに保ち、2次元の二重配列を使用します。

double[][] firstMatrix = {
  new double[]{1d, 5d},
  new double[]{2d, 3d},
  new double[]{1d, 7d}
};

double[][] secondMatrix = {
  new double[]{1d, 2d, 3d, 7d},
  new double[]{5d, 2d, 8d, 1d}
};

これらは、この例の2つのマトリックスです。 それらの乗算の結果として期待されるものを作成しましょう:

double[][] expected = {
  new double[]{26d, 12d, 43d, 12d},
  new double[]{17d, 10d, 30d, 17d},
  new double[]{36d, 16d, 59d, 14d}
};

すべての設定が完了したので、乗算アルゴリズムを実装しましょう。 最初に空の結果配列を作成し、そのセルを反復処理して、各セルに期待値を格納します。

double[][] multiplyMatrices(double[][] firstMatrix, double[][] secondMatrix) {
    double[][] result = new double[firstMatrix.length][secondMatrix[0].length];

    for (int row = 0; row < result.length; row++) {
        for (int col = 0; col < result[row].length; col++) {
            result[row][col] = multiplyMatricesCell(firstMatrix, secondMatrix, row, col);
        }
    }

    return result;
}

最後に、単一セルの計算を実装しましょう。 これを実現するために、例のプレゼンテーションで前に示した式を使用します

double multiplyMatricesCell(double[][] firstMatrix, double[][] secondMatrix, int row, int col) {
    double cell = 0;
    for (int i = 0; i < secondMatrix.length; i++) {
        cell += firstMatrix[row][i] * secondMatrix[i][col];
    }
    return cell;
}

最後に、アルゴリズムの結果が期待される結果と一致することを確認しましょう。

double[][] actual = multiplyMatrices(firstMatrix, secondMatrix);
assertThat(actual).isEqualTo(expected);

3.2. EJML

最初に見るライブラリはEJMLで、これは Efficient Java MatrixLibraryの略です。 このチュートリアルを書いている時点では、最近更新されたJavaマトリックスライブラリの1つです。 その目的は、計算とメモリ使用量に関して可能な限り効率的にすることです。

pom.xmlのライブラリ依存関係を追加する必要があります。

<dependency>
    <groupId>org.ejml</groupId>
    <artifactId>ejml-all</artifactId>
    <version>0.38</version>
</dependency>

以前とほぼ同じパターンを使用します。例に従って2つの行列を作成し、それらの乗算の結果が前に計算したものであることを確認します。

それでは、EJMLを使用してマトリックスを作成しましょう。 これを実現するために、ライブラリが提供するSimpleMatrixクラスを使用します。

コンストラクターの入力として、2次元のdouble配列を取ることができます。

SimpleMatrix firstMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 5d},
    new double[] {2d, 3d},
    new double[] {1d ,7d}
  }
);

SimpleMatrix secondMatrix = new SimpleMatrix(
  new double[][] {
    new double[] {1d, 2d, 3d, 7d},
    new double[] {5d, 2d, 8d, 1d}
  }
);

それでは、乗算に期待される行列を定義しましょう。

SimpleMatrix expected = new SimpleMatrix(
  new double[][] {
    new double[] {26d, 12d, 43d, 12d},
    new double[] {17d, 10d, 30d, 17d},
    new double[] {36d, 16d, 59d, 14d}
  }
);

これですべての設定が完了したので、2つの行列を乗算する方法を見てみましょう。 SimpleMatrixクラスは、別の SimpleMatrix をパラメーターとして受け取り、2つの行列の乗算を返すmult()メソッドを提供します。

SimpleMatrix actual = firstMatrix.mult(secondMatrix);

得られた結果が期待した結果と一致するかどうかを確認しましょう。

SimpleMatrixequals()メソッドをオーバーライドしないため、検証を行うためにこれに依存することはできません。 ただし、は代替手段を提供します。isIdentical()メソッドは、別のマトリックスパラメーターだけでなく、 double フォールトトレランスを使用して、倍精度による小さな違いを無視します。

assertThat(actual).matches(m -> m.isIdentical(expected, 0d));

これで、EJMLライブラリとの行列の乗算は終了です。 他のものが何を提供しているか見てみましょう。

3.3. ND4J

ND4Jライブラリを試してみましょう。 ND4Jは計算ライブラリであり、deeplearning4jプロジェクトの一部です。 とりわけ、ND4Jは行列計算機能を提供します。

まず、ライブラリの依存関係を取得する必要があります。

<dependency>
    <groupId>org.nd4j</groupId>
    <artifactId>nd4j-native</artifactId>
    <version>1.0.0-beta4</version>
</dependency>

GAリリースにはいくつかのバグがあるように思われるため、ここではベータ版を使用していることに注意してください。

簡潔にするために、2次元の double 配列を書き直さず、各ライブラリでの使用方法に焦点を当てます。 したがって、ND4Jでは、INDArrayを作成する必要があります。 これを行うために、 Nd4j.create()ファクトリメソッドを呼び出して、行列を表すdouble配列を渡します。

INDArray matrix = Nd4j.create(/* a two dimensions double array */);

前のセクションと同様に、3つの行列を作成します。2つは一緒に乗算し、1つは期待される結果です。

その後、 INDArray.mmul()メソッドを使用して、最初の2つの行列間の乗算を実際に実行します。

INDArray actual = firstMatrix.mmul(secondMatrix);

次に、実際の結果が期待される結果と一致することを再度確認します。 今回は、同等性チェックに頼ることができます。

assertThat(actual).isEqualTo(expected);

これは、ND4Jライブラリを使用して行列計算を行う方法を示しています。

3.4. Apache Commons

次に、 Apache Commons Math3モジュールについて説明します。このモジュールは、行列操作を含む数学的な計算を提供します。

ここでも、pom.xml依存関係を指定する必要があります。

<dependency>
    <groupId>org.apache.commons</groupId>
    <artifactId>commons-math3</artifactId>
    <version>3.6.1</version>
</dependency>

セットアップが完了すると、RealMatrixインターフェイスとそのArray2DRowRealMatrix実装を使用して通常のマトリックスを作成できます。 実装クラスのコンストラクターは、パラメーターとして2次元の doublearrayを取ります。

RealMatrix matrix = new Array2DRowRealMatrix(/* a two dimensions double array */);

行列の乗算に関しては、 RealMatrixインターフェイスは、別の RealMatrix パラメーターを使用するmultiply()メソッドを提供します。

RealMatrix actual = firstMatrix.multiply(secondMatrix);

最終的に、結果が期待したものと等しいことを確認できます。

assertThat(actual).isEqualTo(expected);

次の図書館を見てみましょう!

3.5. LA4J

これはLA4Jという名前で、Java線形代数の略です。

これにも依存関係を追加しましょう。

<dependency>
    <groupId>org.la4j</groupId>
    <artifactId>la4j</artifactId>
    <version>0.6.0</version>
</dependency>

現在、LA4Jは他のライブラリとほとんど同じように機能します。 2次元double配列を入力として受け取るBasic2DMatrix実装を備えたMatrixインターフェースを提供します:

Matrix matrix = new Basic2DMatrix(/* a two dimensions double array */);

Apache Commons Math3モジュールと同様に、乗算メソッドはmultiply()であり、パラメーターとして別のMatrixを取ります。

Matrix actual = firstMatrix.multiply(secondMatrix);

もう一度、結果が期待と一致することを確認できます。

assertThat(actual).isEqualTo(expected);

最後のライブラリであるColtを見てみましょう。

3.6. コルト

Colt は、CERNによって開発されたライブラリです。 高性能の科学技術コンピューティングを可能にする機能を提供します。

以前のライブラリと同様に、正しい依存関係を取得する必要があります。

<dependency>
    <groupId>colt</groupId>
    <artifactId>colt</artifactId>
    <version>1.2.0</version>
</dependency>

Coltで行列を作成するには、DoubleFactory2Dクラスを使用する必要があります。 高密度、スパースrowCompressedの3つのファクトリインスタンスが付属しています。 それぞれが最適化されて、一致する種類の行列が作成されます。

ここでは、高密度インスタンスを使用します。 今回は、呼び出すメソッドはmake()であり、2次元の double配列を再度取り、DoubleMatrix2Dオブジェクトを生成します。

DoubleMatrix2D matrix = doubleFactory2D.make(/* a two dimensions double array */);

行列がインスタンス化されたら、それらを乗算する必要があります。 今回は、マトリックスオブジェクトにそれを行うメソッドはありません。 パラメータに2つの行列をとるmult()メソッドを持つAlgebraクラスのインスタンスを作成する必要があります。

Algebra algebra = new Algebra();
DoubleMatrix2D actual = algebra.mult(firstMatrix, secondMatrix);

次に、実際の結果を期待される結果と比較できます。

assertThat(actual).isEqualTo(expected);

4. ベンチマーク

行列の乗算のさまざまな可能性の調査が完了したので、どれが最もパフォーマンスが高いかを確認しましょう。

4.1. 小さな行列

小さな行列から始めましょう。 ここでは、3×2と2×4の行列です。

パフォーマンステストを実装するために、JMHベンチマークライブラリを使用します。 次のオプションを使用してベンチマーククラスを構成しましょう。

public static void main(String[] args) throws Exception {
    Options opt = new OptionsBuilder()
      .include(MatrixMultiplicationBenchmarking.class.getSimpleName())
      .mode(Mode.AverageTime)
      .forks(2)
      .warmupIterations(5)
      .measurementIterations(10)
      .timeUnit(TimeUnit.MICROSECONDS)
      .build();

    new Runner(opt).run();
}

このように、JMHは、 @Benchmark で注釈が付けられたメソッドごとに2回の完全実行を行い、それぞれ5回のウォームアップ反復(平均計算には含まれません)と10回の測定反復を行います。 測定値については、さまざまなライブラリの平均実行時間をマイクロ秒単位で収集します。

次に、配列を含む状態オブジェクトを作成する必要があります。

@State(Scope.Benchmark)
public class MatrixProvider {
    private double[][] firstMatrix;
    private double[][] secondMatrix;

    public MatrixProvider() {
        firstMatrix =
          new double[][] {
            new double[] {1d, 5d},
            new double[] {2d, 3d},
            new double[] {1d ,7d}
          };

        secondMatrix =
          new double[][] {
            new double[] {1d, 2d, 3d, 7d},
            new double[] {5d, 2d, 8d, 1d}
          };
    }
}

このようにして、アレイの初期化がベンチマークの一部ではないことを確認します。 その後も、 MatrixProvider オブジェクトをデータソースとして使用して、行列の乗算を行うメソッドを作成する必要があります。 以前に各ライブラリを見たので、ここではコードを繰り返しません。

最後に、mainメソッドを使用してベンチマークプロセスを実行します。 これにより、次の結果が得られます。

Benchmark                                                           Mode  Cnt   Score   Error  Units
MatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication  avgt   20   1,008 ± 0,032  us/op
MatrixMultiplicationBenchmarking.coltMatrixMultiplication           avgt   20   0,219 ± 0,014  us/op
MatrixMultiplicationBenchmarking.ejmlMatrixMultiplication           avgt   20   0,226 ± 0,013  us/op
MatrixMultiplicationBenchmarking.homemadeMatrixMultiplication       avgt   20   0,389 ± 0,045  us/op
MatrixMultiplicationBenchmarking.la4jMatrixMultiplication           avgt   20   0,427 ± 0,016  us/op
MatrixMultiplicationBenchmarking.nd4jMatrixMultiplication           avgt   20  12,670 ± 2,582  us/op

ご覧のとおり、 EJMLとColtは、操作あたり約5分の1マイクロ秒で非常に良好に機能していますが、ND4jは、操作あたり10マイクロ秒を少し超えるとパフォーマンスが低下します。 他の図書館はその間に公演があります。

また、ウォームアップの反復回数を5から10に増やすと、すべてのライブラリのパフォーマンスが向上することにも注意してください。

4.2. 大きな行列

さて、3000×3000のようなより大きな行列を取るとどうなりますか? 何が起こるかを確認するために、まず、そのサイズの生成された行列を提供する別の状態クラスを作成しましょう。

@State(Scope.Benchmark)
public class BigMatrixProvider {
    private double[][] firstMatrix;
    private double[][] secondMatrix;

    public BigMatrixProvider() {}

    @Setup
    public void setup(BenchmarkParams parameters) {
        firstMatrix = createMatrix();
        secondMatrix = createMatrix();
    }

    private double[][] createMatrix() {
        Random random = new Random();

        double[][] result = new double[3000][3000];
        for (int row = 0; row < result.length; row++) {
            for (int col = 0; col < result[row].length; col++) {
                result[row][col] = random.nextDouble();
            }
        }
        return result;
    }
}

ご覧のとおり、ランダムな実数で満たされた3000×3000の2次元二重配列を作成します。

次に、ベンチマーククラスを作成しましょう。

public class BigMatrixMultiplicationBenchmarking {
    public static void main(String[] args) throws Exception {
        Map<String, String> parameters = parseParameters(args);

        ChainedOptionsBuilder builder = new OptionsBuilder()
          .include(BigMatrixMultiplicationBenchmarking.class.getSimpleName())
          .mode(Mode.AverageTime)
          .forks(2)
          .warmupIterations(10)
          .measurementIterations(10)
          .timeUnit(TimeUnit.SECONDS);

        new Runner(builder.build()).run();
    }

    @Benchmark
    public Object homemadeMatrixMultiplication(BigMatrixProvider matrixProvider) {
        return HomemadeMatrix
          .multiplyMatrices(matrixProvider.getFirstMatrix(), matrixProvider.getSecondMatrix());
    }

    @Benchmark
    public Object ejmlMatrixMultiplication(BigMatrixProvider matrixProvider) {
        SimpleMatrix firstMatrix = new SimpleMatrix(matrixProvider.getFirstMatrix());
        SimpleMatrix secondMatrix = new SimpleMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.mult(secondMatrix);
    }

    @Benchmark
    public Object apacheCommonsMatrixMultiplication(BigMatrixProvider matrixProvider) {
        RealMatrix firstMatrix = new Array2DRowRealMatrix(matrixProvider.getFirstMatrix());
        RealMatrix secondMatrix = new Array2DRowRealMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.multiply(secondMatrix);
    }

    @Benchmark
    public Object la4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
        Matrix firstMatrix = new Basic2DMatrix(matrixProvider.getFirstMatrix());
        Matrix secondMatrix = new Basic2DMatrix(matrixProvider.getSecondMatrix());

        return firstMatrix.multiply(secondMatrix);
    }

    @Benchmark
    public Object nd4jMatrixMultiplication(BigMatrixProvider matrixProvider) {
        INDArray firstMatrix = Nd4j.create(matrixProvider.getFirstMatrix());
        INDArray secondMatrix = Nd4j.create(matrixProvider.getSecondMatrix());

        return firstMatrix.mmul(secondMatrix);
    }

    @Benchmark
    public Object coltMatrixMultiplication(BigMatrixProvider matrixProvider) {
        DoubleFactory2D doubleFactory2D = DoubleFactory2D.dense;

        DoubleMatrix2D firstMatrix = doubleFactory2D.make(matrixProvider.getFirstMatrix());
        DoubleMatrix2D secondMatrix = doubleFactory2D.make(matrixProvider.getSecondMatrix());

        Algebra algebra = new Algebra();
        return algebra.mult(firstMatrix, secondMatrix);
    }
}

このベンチマークを実行すると、まったく異なる結果が得られます。

Benchmark                                                              Mode  Cnt    Score    Error  Units
BigMatrixMultiplicationBenchmarking.apacheCommonsMatrixMultiplication  avgt   20  511.140 ± 13.535   s/op
BigMatrixMultiplicationBenchmarking.coltMatrixMultiplication           avgt   20  197.914 ±  2.453   s/op
BigMatrixMultiplicationBenchmarking.ejmlMatrixMultiplication           avgt   20   25.830 ±  0.059   s/op
BigMatrixMultiplicationBenchmarking.homemadeMatrixMultiplication       avgt   20  497.493 ±  2.121   s/op
BigMatrixMultiplicationBenchmarking.la4jMatrixMultiplication           avgt   20   35.523 ±  0.102   s/op
BigMatrixMultiplicationBenchmarking.nd4jMatrixMultiplication           avgt   20    0.548 ±  0.006   s/op

ご覧のとおり、自家製の実装とApacheライブラリは以前よりもはるかに悪くなり、2つの行列の乗算を実行するのに10分近くかかります。

コルトは3分より少し長くかかっています。これはより良いですが、それでも非常に長いです。 EJMLとLA4Jは、30秒近くで実行されるため、かなり良好に機能しています。 しかし、CPUバックエンドで1秒未満で実行されるこのベンチマークに勝つのはND4Jです。

4.3. 分析

これは、ベンチマークの結果が実際にはマトリックスの特性に依存していることを示しているため、1人の勝者を指摘するのは難しいことです。

5. 結論

この記事では、Javaで、自分自身または外部ライブラリを使用して行列を乗算する方法を学習しました。 すべてのソリューションを調査した後、それらすべてのベンチマークを実行し、ND4Jを除いて、すべてが小さなマトリックスで非常に良好に機能することを確認しました。 一方、より大きなマトリックスでは、ND4Jが主導権を握っています。

いつものように、この記事の完全なコードはGitHubにあります。