JavaのK-Meansクラスタリングアルゴリズム

1. 概要

クラスタリングは、教師なしアルゴリズムのクラスの包括的な用語であり、互いに密接に関連するもの、人、またはアイデアのグループを発見します*。
この一見簡単な一行定義では、いくつかの流行語を見ました。 クラスタリングとは正確には何ですか? 教師なしアルゴリズムとは何ですか?
このチュートリアルでは、まず、これらの概念に光を当てます。 次に、それらがJavaでどのように現れるかを見ていきます。

2. 教師なしアルゴリズム

ほとんどの学習アルゴリズムを使用する前に、何らかの方法でサンプルデータを提供し、アルゴリズムがそれらのデータから学習できるようにする必要があります。 Machine Learningの用語では、*このサンプルデータセットトレーニングデータを呼び出します。*また、*プロセス全体をtrainingトレーニングプロセスと呼びます。*
とにかく、**トレーニングプロセス中に必要な監督の量に基づいて学習アルゴリズムを分類できます。 **このカテゴリの学習アルゴリズムの主な2つのタイプは次のとおりです。
  • 教師あり学習:教師ありアルゴリズムでは、トレーニングデータ
    各ポイントの実際のソリューションを含める必要があります。 たとえば、スパムフィルタリングアルゴリズムをトレーニングしようとしている場合、サンプルメールとそのラベルの両方をフィードします。 アルゴリズムに対するスパムまたは非スパム。 数学的に言えば、xs と_ys._の両方を含むトレーニングセットから_f(x)_を推測します

  • 教師なし学習:トレーニングデータにラベルがない場合、
    アルゴリズムは教師なしのアルゴリズムです。 たとえば、ミュージシャンに関するデータが豊富にあり、そのデータから類似したミュージシャンのグループを発見する予定です。

3. クラスタリング

クラスタリングは、類似したもの、アイデア、または人々のグループを発見するための教師なしアルゴリズムです。 監視アルゴリズムとは異なり、既知のラベルの例を使用してクラスタリングアルゴリズムをトレーニングしているわけではありません。 代わりに、クラスタリングは、データのポイントがラベルになっていないトレーニングセット内の構造を見つけようとします。

3.1. K-Meansクラスタリング

K-Meansは、1つの基本的なプロパティを持つクラスター化アルゴリズムです:*クラスターの数は事前に定義されています*。 K-Meansに加えて、Hierarchical Clustering、Affinity Propagation、または

3.2. K-Meansの仕組み

次のようなデータセット内のいくつかの類似グループを見つけることが目標だとします。
link:/uploads/Date-6.png []
K-Meansは、ランダムに配置されたk個の重心から始まります。 *セントロイドは、その名前が示すように、クラスターの中心点です*。 たとえば、ここでは4つのランダムな重心を追加しています。
link:/uploads/Date-7.png []
次に、既存の各データポイントを最も近い重心に割り当てます。
link:/uploads/Date-8.png []
割り当て後、割り当てられたポイントの平均位置に重心を移動します。 重心はクラスターの中心点であることを忘れないでください:
link:/uploads/Date-10.png []
 
現在の反復は、重心の再配置が完了するたびに終了します。 *連続する複数の反復間の割り当てが変化しなくなるまで、これらの反復を繰り返します。*
link:/uploads/Date-copy.png []
アルゴリズムが終了すると、これらの4つのクラスターが期待どおりに見つかります。 K-Meansの仕組みがわかったので、Javaで実装してみましょう。

3.3. 機能表現

さまざまなトレーニングデータセットをモデル化する場合、モデルの属性とそれに対応する値を表すデータ構造が必要です。 たとえば、ミュージシャンはRock __.__のようなa valueを持つa genre属性を持つことができます*通常、「feature」という用語を使用して、属性とその値の組み合わせを指します。*
特定の学習アルゴリズム用のデータセットを準備するには、通常、さまざまなアイテムを比較するために使用できる数値属性の共通セットを使用します。 たとえば、ユーザーに各アーティストにジャンルのタグを付けさせた場合、一日の終わりに、各アーティストに特定のジャンルのタグを付けた回数をカウントできます。
link:/uploads/Screen-Shot-1398-04-29-at-22.30.58.png []
Linkin Parkのようなアーティストの機能ベクトルは、__ [rock-> 7890、nu-metal-> 700、alternative-> 520、pop-> 3]です。 __したがって、属性を数値として表現する方法を見つけることができれば、2つの異なる項目を単純に比較できます。 アーティスト、対応するベクトルエントリを比較します。
数値ベクトルは非常に用途の広いデータ構造であるため、それらを使用してフィーチャを表現します__。 __Javaで機能ベクトルを実装する方法は次のとおりです。
public class Record {
    private final String description;
    private final Map<String, Double> features;

    // constructor, getter, toString, equals and hashcode
}

3.4. 類似アイテムを見つける

K-Meansの各反復では、データセット内の各アイテムに最も近い重心を見つける方法が必要です。 2つの特徴ベクトル間の距離を計算する最も簡単な方法の1つは、https://en.wikipedia.org/wiki/Euclidean_distance [Euclidean Distance]を使用することです。 _ [p1、q1] _や_ [p2、q2] _のような2つのベクトル間のユークリッド距離は次と等しくなります。
link:/uploads/4febdae84cbc320c19dd13eac5060a984fd438d8.svg []
この関数をJavaで実装しましょう。 まず、抽象化:
public interface Distance {
    double calculate(Map<String, Double> f1, Map<String, Double> f2);
}
ユークリッド距離に加えて、*ピアソン相関係数のような異なるアイテム間の距離または類似性を計算する他のアプローチがあります*。 この抽象化により、異なる距離メトリックを簡単に切り替えることができます。
ユークリッド距離の実装を見てみましょう。
public class EuclideanDistance implements Distance {

    @Override
    public double calculate(Map<String, Double> f1, Map<String, Double> f2) {
        double sum = 0;
        for (String key : f1.keySet()) {
            Double v1 = f1.get(key);
            Double v2 = f2.get(key);

            if (v1 != null && v2 != null) {
                sum += Math.pow(v1 - v2, 2);
            }
        }

        return Math.sqrt(sum);
    }
}
最初に、対応するエントリ間の差の二乗和を計算します。 次に、__sqrt __関数を適用することにより、実際のユークリッド距離を計算します。

3.5. 重心表現

重心は通常のフィーチャと同じ空間にあるため、フィーチャと同様に表現できます。
public class Centroid {

    private final Map<String, Double> coordinates;

    // constructors, getter, toString, equals and hashcode
}
必要な抽象化がいくつか用意できたので、K-Means実装を作成します。 メソッドシグネチャの概要を以下に示します。
public class KMeans {

    private static final Random random = new Random();

    public static Map<Centroid, List<Record>> fit(List<Record> records,
      int k,
      Distance distance,
      int maxIterations) {
        // omitted
    }
}
このメソッドシグネチャを分類しましょう。
  • データセット_は、特徴ベクトルのセットです。 各特徴ベクトル
    Recordである場合、thenデータセットタイプは_List <Record> _です

  • k _parameterはクラスターの数を決定します。
    事前に提供する

  • distance _は、計算する方法をカプセル化します
    2つの機能の違い

  • 割り当てが数回変更を停止すると、K-Meansは終了します
    連続した反復。 この終了条件に加えて、反復回数の上限も設定できます。 maxIterations argumentは、その上限を決定します

  • K-Meansが終了すると、各重心にいくつかの割り当てが必要になります
    機能のため、戻り値の型としてaMap <Centroid、List <Record >> を使用しています。 基本的に、各マップエントリはクラスターに対応します

3.6. 重心の生成

最初のステップは、ランダムに配置された重心を生成することです。
各セントロイドには完全にランダムな座標を含めることができますが、*各属性の可能な最小値と最大値の間でランダム座標を生成することをお勧めします*。 可能な値の範囲を考慮せずにランダムな重心を生成すると、アルゴリズムの収束が遅くなります。
最初に、各属性の最小値と最大値を計算し、次にそれらの各ペア間でランダムな値を生成する必要があります。
private static List<Centroid> randomCentroids(List<Record> records, int k) {
    List<Centroid> centroids = new ArrayList<>();
    Map<String, Double> maxs = new HashMap<>();
    Map<String, Double> mins = new HashMap<>();

    for (Record record : records) {
        record.getFeatures().forEach((key, value) -> {
            // compares the value with the current max and choose the bigger value between them
            maxs.compute(key, (k1, max) -> max == null || value > max ? value : max);

            // compare the value with the current min and choose the smaller value between them
            mins.compute(key, (k1, min) -> min == null || value < min ? value : min);
        });
    }

    Set<String> attributes = records.stream()
      .flatMap(e -> e.getFeatures().keySet().stream())
      .collect(toSet());
    for (int i = 0; i < k; i++) {
        Map<String, Double> coordinates = new HashMap<>();
        for (String attribute : attributes) {
            double max = maxs.get(attribute);
            double min = mins.get(attribute);
            coordinates.put(attribute, random.nextDouble() * (max - min) + min);
        }

        centroids.add(new Centroid(coordinates));
    }

    return centroids;
}
これで、各レコードをこれらのランダムな重心のいずれかに割り当てることができます。

3.7. 割り当て

まず、, _ Record_が与えられた場合、それに最も近い重心を見つける必要があります。
private static Centroid nearestCentroid(Record record, List<Centroid> centroids, Distance distance) {
    double minimumDistance = Double.MAX_VALUE;
    Centroid nearest = null;

    for (Centroid centroid : centroids) {
        double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates());

        if (currentDistance < minimumDistance) {
            minimumDistance = currentDistance;
            nearest = centroid;
        }
    }

    return nearest;
}
各レコードは、最も近い重心クラスターに属します。
private static void assignToCluster(Map<Centroid, List<Record>> clusters,
  Record record,
  Centroid centroid) {
    clusters.compute(centroid, (key, list) -> {
        if (list == null) {
            list = new ArrayList<>();
        }

        list.add(record);
        return list;
    });
}

3.8. 重心の再配置

1回の反復の後、重心に割り当てが含まれていない場合、それを再配置しません。 そうでない場合は、各属性の重心座標を、割り当てられたすべてのレコードの平均位置に再配置する必要があります。
private static Centroid average(Centroid centroid, List<Record> records) {
    if (records == null || records.isEmpty()) {
        return centroid;
    }

    Map<String, Double> average = centroid.getCoordinates();
    records.stream().flatMap(e -> e.getFeatures().keySet().stream())
      .forEach(k -> average.put(k, 0.0));

    for (Record record : records) {
        record.getFeatures().forEach(
          (k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue)
        );
    }

    average.forEach((k, v) -> average.put(k, v / records.size()));

    return new Centroid(average);
}
1つの重心を再配置できるため、__relocateCentroids __methodを実装できるようになりました。
private static List<Centroid> relocateCentroids(Map<Centroid, List<Record>> clusters) {
    return clusters.entrySet().stream().map(e -> average(e.getKey(), e.getValue())).collect(toList());
}
このシンプルなワンライナーは、すべての重心を反復処理し、それらを再配置して、新しい重心を返します。

3.9. すべてを一緒に入れて

各反復で、すべてのレコードを最も近い重心に割り当てた後、まず、現在の割り当てを最後の反復と比較する必要があります。
割り当てが同一であった場合、アルゴリズムは終了します。 そうでなければ、次の反復にジャンプする前に、重心を再配置する必要があります。
public static Map<Centroid, List<Record>> fit(List<Record> records,
  int k,
  Distance distance,
  int maxIterations) {

    List<Centroid> centroids = randomCentroids(records, k);
    Map<Centroid, List<Record>> clusters = new HashMap<>();
    Map<Centroid, List<Record>> lastState = new HashMap<>();

    // iterate for a pre-defined number of times
    for (int i = 0; i < maxIterations; i++) {
        boolean isLastIteration = i == maxIterations - 1;

        // in each iteration we should find the nearest centroid for each record
        for (Record record : records) {
            Centroid centroid = nearestCentroid(record, centroids, distance);
            assignToCluster(clusters, record, centroid);
        }

        // if the assignments do not change, then the algorithm terminates
        boolean shouldTerminate = isLastIteration || clusters.equals(lastState);
        lastState = clusters;
        if (shouldTerminate) {
            break;
        }

        // at the end of each iteration we should relocate the centroids
        centroids = relocateCentroids(clusters);
        clusters = new HashMap<>();
    }

    return lastState;
}

4. 例:Last.fmで同様のアーティストを見つける

Last.fmは、ユーザーが聞いている内容の詳細を記録することにより、各ユーザーの音楽の好みの詳細なプロファイルを作成します。 このセクションでは、同様のアーティストのクラスターを見つけます。 このタスクに適したデータセットを構築するには、Last.fmの3つのAPIを使用します。
  1. 取得するAPI
    Last.fmのhttps://www.last.fm/api/show/chart.getTopArtists [トップアーティストのコレクション]。

  2. 検索する別のAPI
    popular tags。 各ユーザーは、アーティストに何か、たとえば 岩。 So、Last.fmは、これらのタグとその頻度のデータベースを保持しています。

  3. https://www.last.fm/api/show/artist.getTopTagsへのAPI [get the
    アーティストのトップタグ]、人気順。 このようなタグは多数あるため、上位のグローバルタグの中にあるタグのみを保持します。

4.1. Last.fmのAPI

これらのAPIを使用するには、https://www.last.fm/api/authentication [Last.fmからのAPIキー]を取得し、すべてのHTTPリクエストで送信する必要があります。 これらのAPIを呼び出すために、次のlink:/retrofit[Retrofit]サービスを使用します。
public interface LastFmService {

    @GET("/2.0/?method=chart.gettopartists&format=json&limit=50")
    Call<Artists> topArtists(@Query("page") int page);

    @GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1")
    Call<Tags> topTagsFor(@Query("artist") String artist);

    @GET("/2.0/?method=chart.gettoptags&format=json&limit=100")
    Call<TopTags> topTags();

    // A few DTOs and one interceptor
}
それでは、Last.fmで最も人気のあるアーティストを見つけましょう。
// setting up the Retrofit service

private static List<String> getTop100Artists() throws IOException {
    List<String> artists = new ArrayList<>();
    // Fetching the first two pages, each containing 50 records.
    for (int i = 1; i <= 2; i++) {
        artists.addAll(lastFm.topArtists(i).execute().body().all());
    }

    return artists;
}
同様に、トップタグを取得できます。
private static Set<String> getTop100Tags() throws IOException {
    return lastFm.topTags().execute().body().all();
}
最後に、アーティストのタグデータセットとタグの頻度を作成できます。
private static List<Record> datasetWithTaggedArtists(List<String> artists,
  Set<String> topTags) throws IOException {
    List<Record> records = new ArrayList<>();
    for (String artist : artists) {
        Map<String, Double> tags = lastFm.topTagsFor(artist).execute().body().all();

        // Only keep popular tags.
        tags.entrySet().removeIf(e -> !topTags.contains(e.getKey()));

        records.add(new Record(artist, tags));
    }

    return records;
}

4.2. アーティストクラスターの形成

これで、準備したデータセットをK-Means実装にフィードできます。
List<String> artists = getTop100Artists();
Set<String> topTags = getTop100Tags();
List<Record> records = datasetWithTaggedArtists(artists, topTags);

Map<Centroid, List<Record>> clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000);
// Printing the cluster configuration
clusters.forEach((key, value) -> {
    System.out.println("-------------------------- CLUSTER ----------------------------");

    // Sorting the coordinates to see the most significant tags first.
    System.out.println(sortedCentroid(key));
    String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet()));
    System.out.print(members);

    System.out.println();
    System.out.println();
});
このコードを実行すると、クラスターがテキスト出力として視覚化されます。
------------------------------ CLUSTER -----------------------------------
Centroid {classic rock=65.58333333333333, rock=64.41666666666667, british=20.333333333333332, ... }
David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, blink-182, The Rolling Stones, Metallica,
Fleetwood Mac, The Beatles, Elton John, The Clash

------------------------------ CLUSTER -----------------------------------
Centroid {Hip-Hop=97.21428571428571, rap=64.85714285714286, hip hop=29.285714285714285, ... }
Kanye West, Post Malone, Childish Gambino, Lil Nas X, A$AP Rocky, Lizzo, xxxtentacion,
Travi$ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj, Drake

------------------------------ CLUSTER -----------------------------------
Centroid {indie rock=54.0, rock=52.0, Psychedelic Rock=51.0, psychedelic=47.0, ... }
Tame Impala, The Black Keys

------------------------------ CLUSTER -----------------------------------
Centroid {pop=81.96428571428571, female vocalists=41.285714285714285, indie=22.785714285714285, ... }
Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, Lorde, Ellie Goulding, Bruno Mars,
Katy Perry, Khalid, Ariana Grande, Bon Iver, Dua Lipa, Beyoncé, Sia, P!nk, Sam Smith, Shawn Mendes,
Mark Ronson, Michael Jackson, Halsey, Lana Del Rey, Carly Rae Jepsen, Britney Spears, Madonna,
Adele, Lady Gaga, Jonas Brothers

------------------------------ CLUSTER -----------------------------------
Centroid {indie=95.23076923076923, alternative=70.61538461538461, indie rock=64.46153846153847, ... }
Twenty One Pilots, The Smiths, Florence + the Machine, Two Door Cinema Club, The 1975, Imagine Dragons,
The Killers, Vampire Weekend, Foster the People, The Strokes, Cage the Elephant, Arcade Fire,
Arctic Monkeys

------------------------------ CLUSTER -----------------------------------
Centroid {electronic=91.6923076923077, House=39.46153846153846, dance=38.0, ... }
Charli XCX, The Weeknd, Daft Punk, Calvin Harris, MGMT, Martin Garrix, Depeche Mode, The Chainsmokers,
Avicii, Kygo, Marshmello, David Guetta, Major Lazer

------------------------------ CLUSTER -----------------------------------
Centroid {rock=87.38888888888889, alternative=72.11111111111111, alternative rock=49.16666666, ... }
Weezer, The White Stripes, Nirvana, Foo Fighters, Maroon 5, Oasis, Panic! at the Disco, Gorillaz,
Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park,
Red Hot Chili Peppers, Muse
重心の調整は平均タグ頻度でソートされるため、各クラスターの主要なジャンルを簡単に見つけることができます。 たとえば、最後のクラスターは古き良きロックバンドのクラスターであり、2番目のクラスターはラップスターで満たされています。
このクラスタリングは理にかなっていますが、ほとんどの場合、データはユーザーの行動から収集されるだけなので、完全ではありません。

5. 可視化

少し前に、私たちのアルゴリズムは、ターミナルフレンドリーな方法でアーティストのクラスターを視覚化しました。 クラスター構成をJSONに変換し、それをD3.jsにフィードすると、数行のJavaScriptで、人に優しいhttps://observablehq.com/@d3/radial-tidy-tree? collection = @ d3 / d3-hierarchy [Radial Tidy-Tree]:
link:/uploads/Screen-Shot-1398-05-04-at-12.09.40.png []
_Map <Centroid、List <Record >> _をhttps://raw.githubusercontent.com/d3/d3-hierarchy/v1.1.8/test/data/flare.jsonのような類似のスキーマを持つJSONに変換する必要があります[このd3.jsの例]。

6. クラスターの数

K-Meansの基本的な特性の1つは、クラスターの数を事前に定義する必要があるという事実です。 これまで、_k_に静的な値を使用していましたが、この値を決定するのは難しい問題です。 *クラスターの数を計算する一般的な方法は2つあります:*
  1. 領域知識

  2. 数学的ヒューリスティック

    幸運にもドメインについて多くのことを知っていれば、正しい数字を簡単に推測できるかもしれません。 それ以外の場合は、エルボー法やシルエット法などのいくつかのヒューリスティックを適用して、クラスターの数を把握できます。
    先に進む前に、これらのヒューリスティックは有用ではありますが、単なるヒューリスティックであり、明確な答えを提供しない場合があることを知っておく必要があります。

6.1. エルボ法

エルボー法を使用するには、まず各クラスター重心とそのすべてのメンバーの差を計算する必要があります。 *クラスター内のより関連のないメンバーをグループ化すると、重心とそのメンバー間の距離が長くなるため、クラスターの品質が低下します。*
この距離の計算を実行する1つの方法は、Sum of Squared Errors__を使用することです。 __ *平方誤差の合計またはSSEは、重心とそのすべてのメンバーとの差の平方の合計に等しい*:
public static double sse(Map<Centroid, List<Record>> clustered, Distance distance) {
    double sum = 0;
    for (Map.Entry<Centroid, List<Record>> entry : clustered.entrySet()) {
        Centroid centroid = entry.getKey();
        for (Record record : entry.getValue()) {
            double d = distance.calculate(centroid.getCoordinates(), record.getFeatures());
            sum += Math.pow(d, 2);
        }
    }

    return sum;
}
その後、* _ k_ *の異なる値に対してK-Meansアルゴリズムを実行し、それぞれのSSEを計算できます。
List<Record> records = // the dataset;
Distance distance = new EuclideanDistance();
List<Double> sumOfSquaredErrors = new ArrayList<>();
for (int k = 2; k <= 16; k++) {
    Map<Centroid, List<Record>> clusters = KMeans.fit(records, k, distance, 1000);
    double sse = Errors.sse(clusters, distance);
    sumOfSquaredErrors.add(sse);
}
1日の終わりに、クラスターの数をSSEに対してプロットすることにより、適切な__k ___を見つけることができます。
link:/uploads/Screen-Shot-1398-05-04-at-17.01.36.png []
通常、クラスターの数が増えると、クラスターメンバー間の距離は短くなります。 ただし、__k、___に任意の大きな値を選択することはできません。メンバーが1つだけのクラスターが複数あるため、クラスタリングの目的全体が無効になるためです。
*エルボー法の背後にある考え方は、__k __の適切な値を見つけることで、SSEがその値を中心に劇的に減少するようになります。

7. 結論

このチュートリアルでは、まず、機械学習のいくつかの重要な概念を取り上げました。 それから、K-Meansクラスタリングアルゴリズムのメカニズムに精通しました。 最後に、K-Meansの簡単な実装を作成し、Last.fmの実世界のデータセットでアルゴリズムをテストし、クラスタリング結果をグラフィカルな方法で視覚化しました。
いつものように、サンプルコードはhttps://github.com/eugenp/tutorials/tree/master/algorithms-miscellaneous-3[GitHub]プロジェクトで入手できます。必ずチェックしてください。