【Python】機械学習におけるkNN(最近傍法)を解説

  • URLをコピーしました!

今回は、機械学習アルゴリズムのなかでもシンプルで、直感的に理解しやすいkNNを紹介します。

kNNの特徴と実際にPythonを使用し、kNNを実装する方法を紹介します。

目次

kNNとは

kNN(k Nearest Neighbor)とは、近くにいるデータは仲間という考え方で生まれたアルゴリズムです。

教師あり学習のひとつで、主に分類問題に使用します。

周辺にあるk個のデータから対象のデータがどのクラスに属するのかを多数決で決定していきます。

このように直感的にも理解しやすいアルゴリズムであることが、kNNの特徴です。

kNNのアルゴリズム

kNNのアルゴリズムはシンプルで、予測データ(予測したいデータ)の周辺にあるk個のデータからどのクラスに属するのかを決定します。

kNNは特に数式などを使用せず、仮定を設けないアルゴリズムになっています。

このようなアルゴリズムをノンパラメトリックといいます。

ノンパラメトリックは、直感的に理解しやすいという特徴があります。

距離の求め方

周辺にあるデータを見つけるには、まずは各データとの距離を計算する必要があります。

距離を計算する方法は、ユークリッド距離が一般的です。

計算式は以下の通りです。

\[ L = \sqrt{a^{2} + b^{2}} \]

高校数学でも習う計算式なので、難しくないと思います。

kNNの特徴

メリット

  • 直感的に理解しやすい
    データ間の距離に基づいて予測を行うので、直感的に理解しやすい
  • 非線形にも対応
    ノンパラメトリックのため、非線形のような複雑な分布にも対応

デメリット

  • 計算コストが高い
    予測を行うためにデータ全体との距離を算出するので、データが大きいと計算量が膨大になる
  • ノイズに弱い
    ノイズや外れ値の影響を受けやすい
  • 次元の呪い
    次元数が増加することでデータの特性や解析が困難になることを次元の呪いといい、この現象がkNNでも発生してしまう

Python実践 kNNの実装

それでは、Pythonを使用しkNNを実装してみましょう!

kNNをPythonで実装するには、scikit-learnライブラリを使用します。

kNNは以下方法で、モデルを作成します。

from sklearn.neighbors import KNeighborsClassifier

モデル = KNeighborsClassifier()
モデル.fit(X_train, y_train)

サンプルコードでkNNを実装してみます。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split

# サンプルデータの生成
X, y = make_blobs(n_samples=200, centers=3, random_state=6)

# データの分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# kNNモデルの生成と学習
model = KNeighborsClassifier()
model.fit(X_train, y_train)

# 境界線を描画するためのメッシュグリッドの作成
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.01),
                     np.arange(y_min, y_max, 0.01))

# メッシュグリッドの各点でクラスを予測
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 境界線のプロット
plt.figure(figsize=(8, 6))
plt.contourf(xx, yy, Z, alpha=0.8, cmap='autumn')  # 境界線を塗り分け
plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k', cmap='autumn')  # データ点
plt.title("kNN")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

実行結果

まとめ

kNNは、「近くのものは仲間」という発送から生まれたアルゴリズムですので、機械学習アルゴリズムのなかでも直感的に理解しやすいアルゴリズムです。

理解しやすい一方で、計算コストが高いことや次元の呪いを考慮する必要があるので、注意してください。

kNNをPythonで実装する方法は以下の通りです。

from sklearn.neighbors import KNeighborsClassifier

モデル = KNeighborsClassifier()
モデル.fit(X_train, y_train)

簡単に実装できるので、ぜひ試してみてください!

ここまで読んでくださりありがとうございます。

よかったらシェアしてね!
  • URLをコピーしました!

この記事を書いた人

エンジニア。20代。組み込みエンジニアとして働き始めるも、働き方や業務内容に限界を感じ、 AI,Web3エンジニアを目指して勉強中。 エンジニアとして思うことや、学んだことを発信します。

目次