Javaでのロジスティック回帰
1. 序章
ロジスティック回帰は、機械学習(ML)実践者ツールボックスの重要な手段です。
このチュートリアルでは、ロジスティック回帰の背後にある主なアイデアを探ります。
まず、MLのパラダイムとアルゴリズムの概要から始めましょう。
2. 概要
MLを使用すると、人間にわかりやすい言葉で定式化できる問題を解決できます。 ただし、この事実は、ソフトウェア開発者にとって課題となる可能性があります。 私たちは、コンピューターに優しい用語で定式化できる問題に対処することに慣れています。 たとえば、人間として、写真上のオブジェクトを簡単に検出したり、フレーズのムードを確立したりできます。 このような問題をコンピューターでどのように定式化できるでしょうか。
解決策を考え出すために、MLにはトレーニングと呼ばれる特別なステージがあります。 この段階では、入力データをアルゴリズムにフィードして、最適なパラメーターのセット(いわゆる重み)を考え出そうとします。 アルゴリズムに入力するデータが多ければ多いほど、アルゴリズムから期待できる予測はより正確になります。
トレーニングは、反復的なMLワークフローの一部です。
まず、データの取得から始めます。 多くの場合、データはさまざまなソースから取得されます。 したがって、同じ形式にする必要があります。 データセットが研究領域を公正に表すことも制御する必要があります。 モデルが赤いリンゴで訓練されたことがない場合、それを予測することはほとんどできません。
次に、データを消費し、予測を行うことができるモデルを構築する必要があります。 MLには、すべての状況で適切に機能する事前定義されたモデルはありません。
正しいモデルを検索するとき、モデルを作成してトレーニングし、その予測を確認して、モデルの予測に満足できないためにモデルを破棄することが簡単に発生する可能性があります。 この場合、一歩下がって別のモデルを作成し、プロセスをもう一度繰り返す必要があります。
3. MLパラダイム
MLでは、自由に使用できる入力データの種類に基づいて、次の3つの主要なパラダイムを選択できます。
- 教師あり学習(画像分類、オブジェクト認識、感情分析)
- 教師なし学習(異常検出)
- 強化学習(ゲーム戦略)
このチュートリアルでについて説明するケースは、教師あり学習に属します。
4. MLツールボックス
MLには、モデルを構築するときに適用できる一連のツールがあります。 それらのいくつかに言及しましょう:
- 線形回帰
- ロジスティック回帰
- ニューラルネットワーク
- サポートベクターマシン
- k-最近傍
予測性の高いモデルを構築する際に、いくつかのツールを組み合わせる場合があります。実際、このチュートリアルでは、モデルはロジスティック回帰とニューラルネットワークを使用します。
5. MLライブラリ
JavaはMLモデルのプロトタイピングに最も人気のある言語ではありませんが、 は、MLを含む多くの分野で堅牢なソフトウェアを作成するための信頼できるツールとしての評判があります。 したがって、Javaで記述されたMLライブラリが見つかる場合があります。
これに関連して、JavaバージョンもあるデファクトスタンダードライブラリTensorflowについて言及する場合があります。 もう1つ言及する価値があるのは、Deeplearning4jと呼ばれるディープラーニングライブラリです。 これは非常に強力なツールであり、このチュートリアルでも使用します。
6. 数字認識のロジスティック回帰
ロジスティック回帰の主なアイデアは、入力データのラベルを可能な限り正確に予測するモデルを構築することです。
いわゆる損失関数または目的関数が最小値に達するまで、モデルをトレーニングします。 損失関数は、実際のモデルの予測と予想される予測(入力データのラベル)によって異なります。 私たちの目標は、実際のモデル予測と予想される予測の相違を最小限に抑えることです。
その最小値に満足できない場合は、別のモデルを作成してトレーニングを再実行する必要があります。
ロジスティック回帰の動作を確認するために、手書き数字の認識について説明します。 この問題はすでに古典的な問題になっています。 Deeplearning4jライブラリには、APIの使用方法を示す一連の現実的な例があります。 このチュートリアルのコード関連の部分は、MNIST分類子に大きく基づいています。
6.1. 入力データ
入力データとして、よく知られている手書き数字のMNISTデータベースを使用します。 入力データとして、28×28ピクセルのグレースケール画像があります。 各画像には、画像が表す数字である自然なラベルがあります。
構築するモデルの効率を推定するために、入力データをトレーニングセットとテストセットに分割します。
DataSetIterator train = new RecordReaderDataSetIterator(...);
DataSetIterator test = new RecordReaderDataSetIterator(...);
入力画像にラベルを付けて2つのセットに分割すると、「データの作成」段階が終了し、「モデル構築」に進むことができます。
6.2. モデル構築
すでに述べたように、あらゆる状況でうまく機能するモデルはありません。 それにもかかわらず、MLでの長年の研究の後、科学者は手書きの数字を認識するのに非常にうまく機能するモデルを発見しました。 ここでは、いわゆるLeNet-5モデルを使用します。
LeNet-5は、28×28ピクセルの画像を10次元のベクトルに変換する一連のレイヤーで構成されるニューラルネットワークです。
10次元の出力ベクトルには、入力画像のラベルが0、1、または2のいずれかである確率が含まれています。
たとえば、出力ベクトルの形式が次の場合:
{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}
これは、入力画像がゼロになる確率が0.1、1が0、2になる確率が0.3などであることを意味します。 最大確率(0.3)がラベル3に対応していることがわかります。
モデル構築の詳細を詳しく見ていきましょう。 Java固有の詳細は省略し、MLの概念に集中します。
MultiLayerNetwork オブジェクトを作成して、モデルを設定します。
MultiLayerNetwork model = new MultiLayerNetwork(config);
そのコンストラクターで、MultiLayerConfigurationオブジェクトを渡す必要があります。 これは、ニューラルネットワークのジオメトリを記述するまさにオブジェクトです。 ネットワークジオメトリを定義するには、すべてのレイヤーを定義する必要があります。
1つ目と2つ目でこれを行う方法を示しましょう。
ConvolutionLayer layer1 = new ConvolutionLayer
.Builder(5, 5).nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build();
SubsamplingLayer layer2 = new SubsamplingLayer
.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build();
レイヤーの定義には、ネットワーク全体のパフォーマンスに大きな影響を与えるかなりの量のアドホックパラメーターが含まれていることがわかります。 これこそが、すべての人の風景の中で良いモデルを見つける能力が重要になるところです。
これで、MultiLayerConfigurationオブジェクトを作成する準備が整いました。
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
// preparation steps
.list()
.layer(layer1)
.layer(layer2)
// other layers and final steps
.build();
MultiLayerNetworkコンストラクターに渡します。
6.3. トレーニング
構築したモデルには、431080のパラメーターまたは重みが含まれています。 ここでは、この数値の正確な計算については説明しませんが、最初のレイヤーだけで24x24x20=11520以上の重みがあることに注意してください。
トレーニング段階は次のように簡単です。
model.fit(train);
当初、431080パラメータにはいくつかのランダムな値がありますが、トレーニング後に、モデルのパフォーマンスを決定するいくつかの値を取得します。 モデルの予測性を評価する場合があります。
Evaluation eval = model.evaluate(test);
logger.info(eval.stats());
LeNet-5モデルは、1回のトレーニング反復(エポック)でほぼ99% eの非常に高い精度を達成します。 より高い精度を実現したい場合は、プレーンなfor-loopを使用してより多くの反復を行う必要があります。
for (int i = 0; i < epochs; i++) {
model.fit(train);
train.reset();
test.reset();
}
6.4. 予測
これで、モデルをトレーニングし、テストデータでの予測に満足したので、まったく新しい入力でモデルを試すことができます。 この目的のために、ファイルシステムから選択したファイルから画像をロードする新しいクラスMnistPredictionを作成しましょう。
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
new ImagePreProcessingScaler(0, 1).transform(image);
変数imageには、28×28グレースケールに縮小された画像が含まれています。 モデルにフィードできます。
INDArray output = model.output(image);
変数outputには、画像が0、1、2などになる確率が含まれます。
少し遊んで、数字2を書いて、この画像をデジタル化してモデルにフィードしてみましょう。 次のようなものが得られる可能性があります。
ご覧のとおり、最大値が0.99のコンポーネントのインデックスは2です。 これは、モデルが手書きの数字を正しく認識したことを意味します。
7. 結論
このチュートリアルでは、機械学習の一般的な概念について説明しました。 これらの概念を、手書き数字認識に適用したロジスティック回帰の例で説明しました。
いつものように、対応するコードスニペットはGitHubリポジトリにあります。