TensorflowforJavaの概要
1. 概要
TensorFlow は、データフロープログラミング用のオープンソースライブラリです。 これは元々Googleによって開発されたもので、さまざまなプラットフォームで利用できます。 TensorFlowはシングルコアで動作できますが、利用可能な複数のCPU、GPU、またはTPUの恩恵を受けることができます。
このチュートリアルでは、TensorFlowの基本と、Javaでの使用方法について説明します。 TensorFlow Java APIは実験的なAPIであるため、安定性の保証は受けられないことに注意してください。 チュートリアルの後半で、TensorFlowJavaAPIを使用するための可能なユースケースについて説明します。
2. 基本
TensorFlowの計算は、基本的に2つの基本的な概念であるグラフとセッションを中心に展開されます。 チュートリアルの残りの部分を実行するために必要な背景を取得するために、それらをすばやく確認してみましょう。
2.1. TensorFlowグラフ
まず、TensorFlowプログラムの基本的な構成要素を理解しましょう。 計算はTensorFlowでグラフとして表されます。 グラフは通常、操作とデータの有向非巡回グラフです。次に例を示します。
上の図は、次の方程式の計算グラフを表しています。
f(x, y) = z = a*x + b*y
TensorFlow計算グラフは、次の2つの要素で構成されています。
- Tensor:これらはTensorFlowのデータのコアユニットです。これらは計算グラフのエッジとして表され、グラフを通過するデータの流れを示します。 テンソルは、任意の数の次元を持つ形状を持つことができます。 テンソルの次元数は通常、そのランクと呼ばれます。 したがって、スカラーはランク0のテンソル、ベクトルはランク1のテンソル、行列はランク2のテンソルなどです。
- 操作:これらは計算グラフのノードです。これらは、操作にフィードするテンソルで発生する可能性のあるさまざまな計算を指します。 多くの場合、計算グラフの演算から発生するテンソルも発生します。
2.2. TensorFlowセッション
現在、TensorFlowグラフは、実際には値を保持しない計算の単なる概略図です。 このようなグラフは、評価されるグラフのテンソルに対して、いわゆるTensorFlowセッション内で実行する必要があります。 セッションは、入力パラメーターとしてグラフから評価するためにテンソルの束を取ることができます。 次に、グラフ内で逆方向に実行され、それらのテンソルを評価するために必要なすべてのノードが実行されます。
この知識があれば、これをJavaAPIに適用する準備が整います。
3. Mavenのセットアップ
JavaでTensorFlowグラフを作成して実行するための簡単なMavenプロジェクトを設定します。 tensorflow依存関係が必要です。
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.12.0</version>
</dependency>
4. グラフの作成
ここで、TensorFlow Java APIを使用して、前のセクションで説明したグラフを作成してみましょう。 より正確には、このチュートリアルでは、TensorFlow Java APIを使用して、次の方程式で表される関数を解きます。
z = 3*x + 2*y
最初のステップは、グラフを宣言して初期化することです。
Graph graph = new Graph()
ここで、必要なすべての操作を定義する必要があります。 TensorFlowの操作は、0個以上のテンソルを消費および生成することに注意してください。 さらに、グラフ内のすべてのノードは、定数とプレースホルダーを含む操作です。 これは直感に反するように思えるかもしれませんが、しばらくは我慢してください。
クラスGraphには、TensorFlowであらゆる種類の操作を構築するための opBuilder()と呼ばれる汎用関数があります。
4.1. 定数の定義
まず、上のグラフで定数演算を定義しましょう。 定数操作では、その値に対してテンソルが必要になることに注意してください。
Operation a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(3.0, Double.class))
.build();
Operation b = graph.opBuilder("Const", "b")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.<Double>create(2.0, Double.class))
.build();
ここでは、定数タイプの Operation を定義し、[X97X]TensorにDouble値2.0および3.0を供給します。 そもそも少し圧倒されるように思えるかもしれませんが、それが今のところJavaAPIにある方法です。 これらの構造は、Pythonなどの言語でははるかに簡潔です。
4.2. プレースホルダーの定義
定数に値を指定する必要がありますが、プレースホルダーはdefinition-timeでは値を必要としません。 グラフをセッション内で実行する場合は、プレースホルダーに値を指定する必要があります。 その部分については、チュートリアルの後半で説明します。
今のところ、プレースホルダーをどのように定義できるか見てみましょう。
Operation x = graph.opBuilder("Placeholder", "x")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Operation y = graph.opBuilder("Placeholder", "y")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
プレースホルダーに値を提供する必要がないことに注意してください。 これらの値は、実行時にTensorsとして提供されます。
4.3. 関数の定義
最後に、方程式の数学演算、つまり結果を得るための乗算と加算を定義する必要があります。
これらもTensorFlowのOperationに過ぎず、 Graph.opBuilder()はもう一度便利です。
Operation ax = graph.opBuilder("Mul", "ax")
.addInput(a.output(0))
.addInput(x.output(0))
.build();
Operation by = graph.opBuilder("Mul", "by")
.addInput(b.output(0))
.addInput(y.output(0))
.build();
Operation z = graph.opBuilder("Add", "z")
.addInput(ax.output(0))
.addInput(by.output(0))
.build();
ここでは、 Operation を定義しました。2つは入力を乗算するためのもので、最後の1つは中間結果を合計するためのものです。 ここでの操作は、以前の操作の出力に他ならないテンソルを受け取ることに注意してください。
インデックス「0」を使用して、Operationから出力Tensorを取得していることに注意してください。 前に説明したように、操作は1つ以上のTensor をもたらす可能性があるため、そのハンドルを取得するときに、インデックスについて言及する必要があります。 操作が1つのTensorのみを返すことがわかっているので、「0」は問題なく機能します。
5. グラフの視覚化
グラフのサイズが大きくなるにつれて、グラフのタブを維持することは困難です。 これにより、何らかの方法で視覚化することが重要になります。 以前に作成した小さなグラフのようにいつでも手描きを作成できますが、大きなグラフには実用的ではありません。 TensorFlowは、これを容易にするTensorBoardと呼ばれるユーティリティを提供します。
残念ながら、Java APIには、TensorBoardによって消費されるイベントファイルを生成する機能がありません。 ただし、PythonでAPIを使用すると、次のようなイベントファイルを生成できます。
writer = tf.summary.FileWriter('.')
......
writer.add_graph(tf.get_default_graph())
writer.flush()
これがJavaのコンテキストで意味をなさない場合でも、気にしないでください。これは、完全を期すためにここに追加されており、チュートリアルの残りの部分を続行する必要はありません。
これで、次のようにTensorBoardにイベントファイルをロードして視覚化できます。
tensorboard --logdir .
TensorBoardは、TensorFlowインストールの一部として提供されます。
これと以前に手動で描画したグラフとの類似性に注意してください。
6. セッションでの作業
これで、TensorFlowJavaAPIで簡単な方程式の計算グラフが作成されました。 しかし、どのように実行しますか? これに対処する前に、この時点で作成したGraphの状態を見てみましょう。 最終的な操作「z」の出力を印刷しようとすると、次のようになります。
System.out.println(z.output(0));
これにより、次のようになります。
<Add 'z:0' shape=<unknown> dtype=DOUBLE>
これは私たちが期待したものではありません! しかし、前に説明したことを思い出すと、これは実際には理にかなっています。 定義したグラフはまだ実行されていないため、その中のテンソルは実際には実際の値を保持していません。上記の出力は、これがタイプテンソルになることを示しています。 X195X]ダブル。
次に、セッションを定義して、グラフを実行します。
Session sess = new Session(graph)
これで、グラフを実行して、期待していた出力を取得する準備が整いました。
Tensor<Double> tensor = sess.runner().fetch("z")
.feed("x", Tensor.<Double>create(3.0, Double.class))
.feed("y", Tensor.<Double>create(6.0, Double.class))
.run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());
では、ここで何をしているのでしょうか。 かなり直感的である必要があります。
- セッションからランナーを取得します
- 名前「z」でフェッチするOperationを定義します
- プレースホルダー「x」と「y」のテンソルをフィードします
- セッションでグラフを実行します
そして今、スカラー出力が表示されます。
21.0
これは私たちが期待していたことですよね!
7. JavaAPIのユースケース
この時点で、TensorFlowは基本的な操作を実行するにはやり過ぎのように聞こえるかもしれません。 ただし、もちろん、 TensorFlowは、これよりもはるかに大きいグラフを実行することを目的としています。
さらに、実世界のモデルで扱うテンソルは、サイズとランクがはるかに大きくなります。 これらは、TensorFlowが実際に使用される実際の機械学習モデルです。
グラフのサイズが大きくなると、TensorFlowでコアAPIを操作するのが非常に面倒になる可能性があることを理解するのは難しくありません。 この目的のために、 TensorFlowは、複雑なモデルで動作するKerasのような高レベルのAPIを提供します。 残念ながら、JavaでのKerasの公式サポートはまだほとんどまたはまったくありません。
ただし、 Pythonを使用して、TensorFlowで直接、またはKerasなどの高レベルAPIを使用して、複雑なモデルを定義およびトレーニングできます。 その後、トレーニング済みモデルをエクスポートし、TensorFlowJavaAPIを使用してJavaで使用できます。
さて、なぜ私たちはそのようなことをしたいのですか? これは、Javaで実行されている既存のクライアントで機械学習対応の機能を使用する場合に特に便利です。 たとえば、Androidデバイスのユーザー画像にキャプションを推奨します。 それでも、機械学習モデルの出力に関心があるが、必ずしもJavaでそのモデルを作成してトレーニングしたくない場合がいくつかあります。
これは、TensorFlowJavaAPIがその使用の大部分を見つける場所です。 次のセクションでは、これをどのように実現できるかについて説明します。
8. 保存されたモデルの使用
これで、TensorFlowのモデルをファイルシステムに保存し、それを完全に異なる言語とプラットフォームでロードする方法を理解できます。 TensorFlowは、ProtocolBufferと呼ばれる言語およびプラットフォームに依存しない構造でモデルファイルを生成するためのAPIを提供します。
8.1. モデルをファイルシステムに保存する
まず、Pythonで以前に作成したものと同じグラフを定義し、それをファイルシステムに保存します。
Pythonでこれを実行できることを見てみましょう。
import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
with graph.as_default():
a = tf.constant(2, name='a')
b = tf.constant(3, name='b')
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
z = tf.math.add(a*x, b*y, name='z')
sess = tf.Session()
sess.run(z, feed_dict = {x: 2, y: 3})
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save()
Javaでのこのチュートリアルの焦点として、「saved_model.pb」というファイルを生成するという事実を除いて、Pythonでのこのコードの詳細にはあまり注意を払わないでください。 Javaと比較して同様のグラフを定義する際に簡潔にすることに注意してください!
8.2. ファイルシステムからのモデルのロード
ここで、「saved_model.pb」をJavaにロードします。 Java TensorFlow APIには、保存されたモデルを操作するためのSavedModelBundleがあります。
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
Tensor<Integer> tensor = model.session().runner().fetch("z")
.feed("x", Tensor.<Integer>create(3, Integer.class))
.feed("y", Tensor.<Integer>create(3, Integer.class))
.run().get(0).expect(Integer.class);
System.out.println(tensor.intValue());
これで、上記のコードが何をしているのかを理解するのはかなり直感的になるはずです。 プロトコルバッファからモデルグラフをロードし、その中のセッションを利用できるようにするだけです。 そこから先は、ローカルで定義されたグラフの場合と同じように、このグラフでほとんど何でもできます。
9. 結論
要約すると、このチュートリアルでは、TensorFlow計算グラフに関連する基本的な概念について説明しました。 TensorFlow Java APIを使用して、このようなグラフを作成して実行する方法を見てきました。 次に、TensorFlowに関するJavaAPIのユースケースについて説明しました。
その過程で、TensorBoardを使用してグラフを視覚化し、ProtocolBufferを使用してモデルを保存および再ロードする方法も理解しました。
いつものように、例のコードはGitHubでから入手できます。