K平均法によるクラスタリング

 今回はK平均法を紹介します。
 『はじめてのパターン認識』の表記を参考に話を進めていくと、{d}次元の{N}個のデータ{\mathcal{D}=\{x_1,\ldots,x_N\}}があり、{x_i\in\mathbb{R}^d}とします。この時、データの類似度を基に、{K}個のクラスタに分類する方法です。{K}は予め決まったクラスタ数で、データの特性から判断しモデルに与えます。そして各クラスタの代表ベクトルの集合を{\mathcal{M}=\{\mu_1,\ldots,\mu_K\}}とし、{i}番目のサンプルが{k}番目の代表ベクトルが支配するクラスタに帰属するか否かを帰属変数{q_{ik}}(帰属すれば1、そうでなければ0)と表すとすると、K平均法の評価関数を次のように定式化できます:
{J(q_{ik},\mu_k)=\sum^N_{i=1}\sum^K_{k=1} q_{ik}\|x_i-\mu_k\|^2}
これを{\mu_k}に関して最適化してみましょう。
{\frac{\partial J(q_{ik},\mu_k)}{\partial \mu_k}}=2\sum^N_{i=1}q_{ik}(x_i-\mu_k)=0
よって、
{u_k=\frac{\sum^N_{i=1}q_{ik}x_i}{\sum^N_{i=1}q_{ik}}}
クラスタの代表ベクトルは、帰属するデータの平均ベクトルになることがわかります。ただし、{q_{ik},\mu_k}を同時に最適化するのは難しいので、K平均法アルゴリズムによって逐次最適化を行うことが一般的です。

クラスタ数が3、各クラスのデータ数を100、データの次元を2とした場合について、アルゴリズムPythonにて実装すると、次のようになります。

# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt

mean1 = [-5,-5]
mean2 = [0,7]
mean3 = [5,-5]
cov = [[3.5,0],[0,3.5]]

instances = 100

#generate 2-dimensional data for three classes
x = np.random.multivariate_normal(mean1,cov,instances)
y = np.random.multivariate_normal(mean2,cov,instances)
z = np.random.multivariate_normal(mean3,cov,instances)
X = np.vstack((x,y))
X = np.vstack((X,z))

closest = np.floor(np.random.rand(instances*3)*3)
iteration = 10000

#clustering plot in the first iteration
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex='col', sharey='row')
ax1.plot(X[closest==0,0],X[closest==0,1],'bo')
ax1.plot(X[closest==1,0],X[closest==1,1],'ro')
ax1.plot(X[closest==2,0],X[closest==2,1],'go')
ax1.set_title('1st iteration')

for i in range(iteration):
    mu1 = np.mean(X[closest==0,:],axis=0)
    mu2 = np.mean(X[closest==1,:],axis=0)
    mu3 = np.mean(X[closest==2,:],axis=0)
    dist = np.zeros((instances*3,3))
    for j in range(instances*3):
        dist[j,0] = np.linalg.norm(X[j,:] - mu1)
        dist[j,1] = np.linalg.norm(X[j,:] - mu2)
        dist[j,2] = np.linalg.norm(X[j,:] - mu3)
    if np.linalg.norm(closest - np.argmin(dist,axis=1)) == 0:
        break
    closest = np.argmin(dist,axis=1)
    if i ==1:
        ax2.plot(X[closest==0,0],X[closest==0,1],'bo')
        ax2.plot(X[closest==1,0],X[closest==1,1],'ro')
        ax2.plot(X[closest==2,0],X[closest==2,1],'go')
        ax2.set_title('2nd iteration')
    if i ==2:
        ax3.plot(X[closest==0,0],X[closest==0,1],'bo')
        ax3.plot(X[closest==1,0],X[closest==1,1],'ro')
        ax3.plot(X[closest==2,0],X[closest==2,1],'go')
        ax3.set_title('3rd iteration')        

ax4.plot(x[:,0],x[:,1],'bo')
ax4.plot(y[:,0],y[:,1],'ro')
ax4.plot(z[:,0],z[:,1],'go')
ax4.set_title('After convergence')

図示をすると、以下のようになります。
f:id:decompose:20160526001019p:plain