JavaのK-Meansクラスタリングアルゴリズム
1. 概要
クラスタリングは、互いに密接に関連しているもの、人、またはアイデアのグループを発見するための教師なしアルゴリズムのクラスの総称です。
この一見単純なワンライナーの定義では、いくつかの流行語が見られました。 クラスタリングとは正確には何ですか? 教師なしアルゴリズムとは何ですか?
このチュートリアルでは、最初に、これらの概念にいくつかの光を当てます。 次に、それらがJavaでどのように現れるかを見ていきます。
2. 教師なしアルゴリズム
ほとんどの学習アルゴリズムを使用する前に、何らかのサンプルデータをそれらにフィードし、アルゴリズムがそれらのデータから学習できるようにする必要があります。 機械学習の用語では、
ともかく、
- 教師あり学習:教師ありアルゴリズムでは、トレーニングデータに各ポイントの実際のソリューションを含める必要があります。 たとえば、スパムフィルタリングアルゴリズムをトレーニングしようとしている場合は、サンプルメールとそのラベルの両方をフィードします。 アルゴリズムにスパムまたは非スパム。 数学的に言えば、 xsとysの両方を含むトレーニングセットからf(x)を推測します。
- 教師なし学習:トレーニングデータにラベルがない場合、アルゴリズムは教師なし学習です。 たとえば、ミュージシャンに関するデータがたくさんあり、データ内に類似したミュージシャンのグループを見つけます。
3. クラスタリング
クラスタリングは、類似したもの、アイデア、または人々のグループを発見するための教師なしアルゴリズムです。 教師ありアルゴリズムとは異なり、既知のラベルの例を使用してクラスタリングアルゴリズムをトレーニングしていません。 代わりに、クラスタリングは、データのポイントがラベルではないトレーニングセット内の構造を見つけようとします。
3.1. K-Meansクラスタリング
K-Meansは、1つの基本的なプロパティを持つクラスタリングアルゴリズムです。クラスターの数は事前に定義されています。 K-Meansに加えて、階層的クラスタリング、アフィニティ伝搬、スペクトルクラスタリングなどの他のタイプのクラスタリングアルゴリズムがあります。
3.2. K-Meansのしくみ
私たちの目標が、次のようなデータセット内のいくつかの類似したグループを見つけることであると仮定します。
K-Meansは、ランダムに配置されたk個の重心から始まります。 重心は、その名前が示すように、クラスターの中心点です。 たとえば、ここでは4つのランダムな重心を追加しています。
次に、既存の各データポイントを最も近い重心に割り当てます。
割り当て後、重心を割り当てられたポイントの平均位置に移動します。 重心はクラスターの中心点であると想定されていることを忘れないでください。
現在の反復は、図心の再配置が完了するたびに終了します。 複数の連続する反復間の割り当ての変更が停止するまで、これらの反復を繰り返します。
アルゴリズムが終了すると、これらの4つのクラスターが期待どおりに検出されます。 K-Meansがどのように機能するかがわかったので、Javaで実装してみましょう。
3.3. 機能表現
さまざまなトレーニングデータセットをモデル化する場合、モデルの属性とそれに対応する値を表すデータ構造が必要です。 たとえば、ミュージシャンはRockのような値を持つジャンル属性を持つことができます。通常、属性とその値の組み合わせを指すために機能という用語を使用します。
特定の学習アルゴリズムのデータセットを準備するために、通常、さまざまな項目を比較するために使用できる数値属性の共通セットを使用します。 たとえば、ユーザーに各アーティストにジャンルのタグを付けるようにすると、1日の終わりに、各アーティストに特定のジャンルのタグが付けられた回数を数えることができます。
リンキンパークのようなアーティストの特徴ベクトルは
数値ベクトルは非常に用途の広いデータ構造であるため、それらを使用して特徴を表現します。
public class Record {
private final String description;
private final Map<String, Double> features;
// constructor, getter, toString, equals and hashcode
}
3.4. 類似アイテムの検索
K-Meansの各反復では、データセット内の各アイテムに最も近い重心を見つける方法が必要です。 2つの特徴ベクトル間の距離を計算する最も簡単な方法の1つは、ユークリッド距離を使用することです。 [p1、q1]と[p2、q2]のような2つのベクトル間のユークリッド距離は次のようになります。
この関数をJavaに実装しましょう。 まず、抽象化:
public interface Distance {
double calculate(Map<String, Double> f1, Map<String, Double> f2);
}
ユークリッド距離に加えて、ピアソン相関係数のような異なるアイテム間の距離または類似性を計算する他のアプローチがあります。 この抽象化により、異なる距離メトリックを簡単に切り替えることができます。
ユークリッド距離の実装を見てみましょう。
public class EuclideanDistance implements Distance {
@Override
public double calculate(Map<String, Double> f1, Map<String, Double> f2) {
double sum = 0;
for (String key : f1.keySet()) {
Double v1 = f1.get(key);
Double v2 = f2.get(key);
if (v1 != null && v2 != null) {
sum += Math.pow(v1 - v2, 2);
}
}
return Math.sqrt(sum);
}
}
まず、対応するエントリ間の差の2乗の合計を計算します。 次に、 sqrt 関数を適用して、実際のユークリッド距離を計算します。
3.5. 図心表現
図心は通常のフィーチャと同じスペースにあるため、フィーチャと同様に表すことができます。
public class Centroid {
private final Map<String, Double> coordinates;
// constructors, getter, toString, equals and hashcode
}
必要な抽象化がいくつか整ったので、K-Means実装を作成します。 メソッドシグネチャを簡単に見てみましょう。
public class KMeans {
private static final Random random = new Random();
public static Map<Centroid, List<Record>> fit(List<Record> records,
int k,
Distance distance,
int maxIterations) {
// omitted
}
}
このメソッドシグネチャを分解してみましょう。
- データセットは、特徴ベクトルのセットです。 各特徴ベクトルは記録、 その場合、データセットタイプはリスト
- k パラメーターは、クラスターの数を決定します。クラスターの数は、事前に提供する必要があります。
- 距離は、2つの機能の差を計算する方法をカプセル化します
- K-Meansは、割り当てが数回の連続した反復で変更を停止すると終了します。 この終了条件に加えて、反復回数の上限を設定することもできます。 maxIterations 引数は、上限を決定します
- K-Meansが終了すると、各重心にいくつかの機能が割り当てられるはずなので、 地図
>> リターンタイプとして。 基本的に、各マップエントリはクラスターに対応します
3.6. セントロイド生成
最初のステップは、kランダムに配置された重心を生成することです。
各重心には完全にランダムな座標を含めることができますが、各属性の可能な最小値と最大値の間でランダムな座標を生成することをお勧めします。 可能な値の範囲を考慮せずにランダムな重心を生成すると、アルゴリズムの収束が遅くなります。
まず、各属性の最小値と最大値を計算してから、それらの各ペア間でランダムな値を生成する必要があります。
private static List<Centroid> randomCentroids(List<Record> records, int k) {
List<Centroid> centroids = new ArrayList<>();
Map<String, Double> maxs = new HashMap<>();
Map<String, Double> mins = new HashMap<>();
for (Record record : records) {
record.getFeatures().forEach((key, value) -> {
// compares the value with the current max and choose the bigger value between them
maxs.compute(key, (k1, max) -> max == null || value > max ? value : max);
// compare the value with the current min and choose the smaller value between them
mins.compute(key, (k1, min) -> min == null || value < min ? value : min);
});
}
Set<String> attributes = records.stream()
.flatMap(e -> e.getFeatures().keySet().stream())
.collect(toSet());
for (int i = 0; i < k; i++) {
Map<String, Double> coordinates = new HashMap<>();
for (String attribute : attributes) {
double max = maxs.get(attribute);
double min = mins.get(attribute);
coordinates.put(attribute, random.nextDouble() * (max - min) + min);
}
centroids.add(new Centroid(coordinates));
}
return centroids;
}
これで、各レコードをこれらのランダムな重心の1つに割り当てることができます。
3.7. 割り当て
まず、 Record が与えられた場合、それに最も近い重心を見つける必要があります。
private static Centroid nearestCentroid(Record record, List<Centroid> centroids, Distance distance) {
double minimumDistance = Double.MAX_VALUE;
Centroid nearest = null;
for (Centroid centroid : centroids) {
double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates());
if (currentDistance < minimumDistance) {
minimumDistance = currentDistance;
nearest = centroid;
}
}
return nearest;
}
各レコードは、最も近い重心クラスターに属します。
private static void assignToCluster(Map<Centroid, List<Record>> clusters,
Record record,
Centroid centroid) {
clusters.compute(centroid, (key, list) -> {
if (list == null) {
list = new ArrayList<>();
}
list.add(record);
return list;
});
}
3.8. セントロイドの再配置
1回の反復の後、図心に割り当てが含まれていない場合、図心は再配置されません。 それ以外の場合は、各属性の重心座標を、割り当てられたすべてのレコードの平均位置に再配置する必要があります。
private static Centroid average(Centroid centroid, List<Record> records) {
if (records == null || records.isEmpty()) {
return centroid;
}
Map<String, Double> average = centroid.getCoordinates();
records.stream().flatMap(e -> e.getFeatures().keySet().stream())
.forEach(k -> average.put(k, 0.0));
for (Record record : records) {
record.getFeatures().forEach(
(k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue)
);
}
average.forEach((k, v) -> average.put(k, v / records.size()));
return new Centroid(average);
}
単一の図心を再配置できるため、relocateCentroidsメソッドを実装できるようになりました。
private static List<Centroid> relocateCentroids(Map<Centroid, List<Record>> clusters) {
return clusters.entrySet().stream().map(e -> average(e.getKey(), e.getValue())).collect(toList());
}
この単純なワンライナーは、すべての図心を繰り返し処理し、それらを再配置して、新しい図心を返します。
3.9. すべてを一緒に入れて
各反復では、すべてのレコードを最も近い重心に割り当てた後、最初に、現在の割り当てを最後の反復と比較する必要があります。
割り当てが同一の場合、アルゴリズムは終了します。 それ以外の場合は、次の反復にジャンプする前に、図心を再配置する必要があります。
public static Map<Centroid, List<Record>> fit(List<Record> records,
int k,
Distance distance,
int maxIterations) {
List<Centroid> centroids = randomCentroids(records, k);
Map<Centroid, List<Record>> clusters = new HashMap<>();
Map<Centroid, List<Record>> lastState = new HashMap<>();
// iterate for a pre-defined number of times
for (int i = 0; i < maxIterations; i++) {
boolean isLastIteration = i == maxIterations - 1;
// in each iteration we should find the nearest centroid for each record
for (Record record : records) {
Centroid centroid = nearestCentroid(record, centroids, distance);
assignToCluster(clusters, record, centroid);
}
// if the assignments do not change, then the algorithm terminates
boolean shouldTerminate = isLastIteration || clusters.equals(lastState);
lastState = clusters;
if (shouldTerminate) {
break;
}
// at the end of each iteration we should relocate the centroids
centroids = relocateCentroids(clusters);
clusters = new HashMap<>();
}
return lastState;
}
4. 例:Last.fmで類似のアーティストを見つける
Last.fmは、ユーザーが聴いているものの詳細を記録することにより、各ユーザーの音楽の好みの詳細なプロファイルを作成します。 このセクションでは、類似したアーティストのクラスターを見つけます。 このタスクに適したデータセットを構築するために、Last.fmの3つのAPIを使用します。
- Last.fmでトップアーティストのコレクションを取得するためのAPI。
- 人気のタグを見つけるための別のAPI。 各ユーザーは、アーティストに何かをタグ付けできます。
石。 したがって、Last.fmはそれらのタグとその頻度のデータベースを維持しています。 - また、アーティストのトップタグを人気順に取得するためのAPI。 このようなタグは多数あるため、上位のグローバルタグに含まれるタグのみを保持します。
4.1. Last.fmのAPI
これらのAPIを使用するには、Last.fmからAPIキーを取得し、すべてのHTTPリクエストで送信する必要があります。 これらのAPIを呼び出すために、次のRetrofitサービスを使用します。
public interface LastFmService {
@GET("/2.0/?method=chart.gettopartists&format=json&limit=50")
Call<Artists> topArtists(@Query("page") int page);
@GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1")
Call<Tags> topTagsFor(@Query("artist") String artist);
@GET("/2.0/?method=chart.gettoptags&format=json&limit=100")
Call<TopTags> topTags();
// A few DTOs and one interceptor
}
それでは、Last.fmで最も人気のあるアーティストを見つけましょう:
// setting up the Retrofit service
private static List<String> getTop100Artists() throws IOException {
List<String> artists = new ArrayList<>();
// Fetching the first two pages, each containing 50 records.
for (int i = 1; i <= 2; i++) {
artists.addAll(lastFm.topArtists(i).execute().body().all());
}
return artists;
}
同様に、上位のタグを取得できます。
private static Set<String> getTop100Tags() throws IOException {
return lastFm.topTags().execute().body().all();
}
最後に、アーティストのデータセットとそのタグ頻度を作成できます。
private static List<Record> datasetWithTaggedArtists(List<String> artists,
Set<String> topTags) throws IOException {
List<Record> records = new ArrayList<>();
for (String artist : artists) {
Map<String, Double> tags = lastFm.topTagsFor(artist).execute().body().all();
// Only keep popular tags.
tags.entrySet().removeIf(e -> !topTags.contains(e.getKey()));
records.add(new Record(artist, tags));
}
return records;
}
4.2. アーティストクラスターの形成
これで、準備したデータセットをK-Means実装にフィードできます。
List<String> artists = getTop100Artists();
Set<String> topTags = getTop100Tags();
List<Record> records = datasetWithTaggedArtists(artists, topTags);
Map<Centroid, List<Record>> clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000);
// Printing the cluster configuration
clusters.forEach((key, value) -> {
System.out.println("-------------------------- CLUSTER ----------------------------");
// Sorting the coordinates to see the most significant tags first.
System.out.println(sortedCentroid(key));
String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet()));
System.out.print(members);
System.out.println();
System.out.println();
});
このコードを実行すると、クラスターがテキスト出力として視覚化されます。
------------------------------ CLUSTER -----------------------------------
Centroid {classic rock=65.58333333333333, rock=64.41666666666667, british=20.333333333333332, ... }
David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, blink-182, The Rolling Stones, Metallica,
Fleetwood Mac, The Beatles, Elton John, The Clash
------------------------------ CLUSTER -----------------------------------
Centroid {Hip-Hop=97.21428571428571, rap=64.85714285714286, hip hop=29.285714285714285, ... }
Kanye West, Post Malone, Childish Gambino, Lil Nas X, A$AP Rocky, Lizzo, xxxtentacion,
Travi$ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj, Drake
------------------------------ CLUSTER -----------------------------------
Centroid {indie rock=54.0, rock=52.0, Psychedelic Rock=51.0, psychedelic=47.0, ... }
Tame Impala, The Black Keys
------------------------------ CLUSTER -----------------------------------
Centroid {pop=81.96428571428571, female vocalists=41.285714285714285, indie=22.785714285714285, ... }
Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, Lorde, Ellie Goulding, Bruno Mars,
Katy Perry, Khalid, Ariana Grande, Bon Iver, Dua Lipa, Beyoncé, Sia, P!nk, Sam Smith, Shawn Mendes,
Mark Ronson, Michael Jackson, Halsey, Lana Del Rey, Carly Rae Jepsen, Britney Spears, Madonna,
Adele, Lady Gaga, Jonas Brothers
------------------------------ CLUSTER -----------------------------------
Centroid {indie=95.23076923076923, alternative=70.61538461538461, indie rock=64.46153846153847, ... }
Twenty One Pilots, The Smiths, Florence + the Machine, Two Door Cinema Club, The 1975, Imagine Dragons,
The Killers, Vampire Weekend, Foster the People, The Strokes, Cage the Elephant, Arcade Fire,
Arctic Monkeys
------------------------------ CLUSTER -----------------------------------
Centroid {electronic=91.6923076923077, House=39.46153846153846, dance=38.0, ... }
Charli XCX, The Weeknd, Daft Punk, Calvin Harris, MGMT, Martin Garrix, Depeche Mode, The Chainsmokers,
Avicii, Kygo, Marshmello, David Guetta, Major Lazer
------------------------------ CLUSTER -----------------------------------
Centroid {rock=87.38888888888889, alternative=72.11111111111111, alternative rock=49.16666666, ... }
Weezer, The White Stripes, Nirvana, Foo Fighters, Maroon 5, Oasis, Panic! at the Disco, Gorillaz,
Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park,
Red Hot Chili Peppers, Muse
重心の調整は平均タグ頻度でソートされるため、各クラスターの主要なジャンルを簡単に見つけることができます。 たとえば、最後のクラスターは古き良きロックバンドのクラスターであるか、2番目のクラスターはラップスターで満たされています。
このクラスタリングは理にかなっていますが、データはユーザーの行動から収集されるだけなので、ほとんどの場合、完全ではありません。
5. 視覚化
少し前に、私たちのアルゴリズムは、ターミナルフレンドリーな方法でアーティストのクラスターを視覚化しました。 クラスター構成をJSONに変換してD3.jsにフィードすると、JavaScriptを数行使用するだけで、人間に優しい RadialTidy-Treeが得られます。
変換する必要があります地図
6. クラスターの数
K-Meansの基本的な特性の1つは、クラスターの数を事前に定義する必要があるという事実です。 これまで、 k に静的な値を使用しましたが、この値を決定することは困難な問題になる可能性があります。 クラスターの数を計算する一般的な方法は2つあります。
- 領域知識
- 数学的ヒューリスティック
運が良ければ、ドメインについてよく知っているので、正しい数を簡単に推測できるかもしれません。 それ以外の場合は、ElbowメソッドやSilhouetteメソッドなどのいくつかのヒューリスティックを適用して、クラスターの数を把握できます。
先に進む前に、これらのヒューリスティックは有用ではありますが、単なるヒューリスティックであり、明確な答えを提供しない可能性があることを知っておく必要があります。
6.1. エルボー法
エルボー法を使用するには、最初に各クラスター重心とそのすべてのメンバーの差を計算する必要があります。クラスター内の無関係なメンバーをグループ化すると、重心とそのメンバー間の距離が長くなり、クラスターの品質が低下します。 。
この距離計算を実行する1つの方法は、二乗誤差の合計を使用することです。
public static double sse(Map<Centroid, List<Record>> clustered, Distance distance) {
double sum = 0;
for (Map.Entry<Centroid, List<Record>> entry : clustered.entrySet()) {
Centroid centroid = entry.getKey();
for (Record record : entry.getValue()) {
double d = distance.calculate(centroid.getCoordinates(), record.getFeatures());
sum += Math.pow(d, 2);
}
}
return sum;
}
次に、 k のさまざまな値に対してK-Meansアルゴリズムを実行し、それぞれのSSEを計算できます。
List<Record> records = // the dataset;
Distance distance = new EuclideanDistance();
List<Double> sumOfSquaredErrors = new ArrayList<>();
for (int k = 2; k <= 16; k++) {
Map<Centroid, List<Record>> clusters = KMeans.fit(records, k, distance, 1000);
double sse = Errors.sse(clusters, distance);
sumOfSquaredErrors.add(sse);
}
1日の終わりに、SSEに対してクラスターの数をプロットすることにより、適切なkを見つけることができます。
通常、クラスターの数が増えると、クラスターメンバー間の距離は短くなります。 ただし、
7. 結論
このチュートリアルでは、最初に、機械学習のいくつかの重要な概念について説明しました。 次に、K-Meansクラスタリングアルゴリズムの仕組みに精通しました。 最後に、K-Meansの簡単な実装を作成し、Last.fmの実際のデータセットを使用してアルゴリズムをテストし、クラスタリングの結果を優れたグラフィカルな方法で視覚化しました。
いつものように、サンプルコードは GitHub プロジェクトで入手できるので、ぜひチェックしてください。