はじめに
今回はUnityでロジスティック回帰を使った2値分類を実装してみたいと思います。
前回や前々回で実装したものは出力値が数値でしたが、今回は分類(離散値)をします。
【Unity】Unityでもゼロから機械学習を作る【単回帰モデル】 - はなちるのマイノート
【Unity】Unityでもゼロから機械学習を作る【重回帰モデル】 - はなちるのマイノート
早速みていきましょう。
ロジスティック回帰モデルの概要
単回帰モデルは出力値が数値なので、一次関数を予想しました。
ロジスティック回帰モデルは2つのグループの境界線として直線を引きます。ちなみにこの境界線を決定境界と呼ぶそうです。
予想モデル
入力データを,正解値を
yt = 0 or 1
としたときに決定境界は直線なので前回同様に以下の式で表せます。
ここで勾配降下法を用いることを考えると、損失関数がパラメータで微分可能、つまりパラメータの変化に伴い連続的に変化する関数でなければなりません。
このことからシグモイド関数を用いて、予測値を確率値(0から1の範囲の値)に変換することで解決します。
ja.wikipedia.org
またシグモイド関数の0<x<1
の値になるという性質と,点(0,0.5)に関して点対称から以下のように考えられます。
- f(u)の値を確率とみなし、「該当する点がclass=1(yt=1)に属している確率」と考えられる
これより、f(u)
の値をu
の予測値yp
とします。
つまり予測値が0.5より大きいか、それ以外かで分類を行います。
まとめると以下の3つの式により予測関数を構築します。
損失関数
シグモイド関数の点(0,0.5)に関して点対称という性質を書きましたが、これよりyt=1
(正解がclass1)の確率がyp
なら、yt=0
(正解がclass0)の確率は1-yp
になります。
これを数式で表すとこのようにかけます。
この確率変数を用いて、尤度関数Lk
を求めます。
機械学習で必要な確率・統計の基礎知識 - はなちるのマイノート
最尤推定を行うのに、対数をとるのが一般的です。これは微分をしやすくするのとアンダーフロー(値が小さくなりすぎる)をしにくくするためですね。
さらにを求めるには以下により求められます。
この尤度関数において、変数yp
を変化させながら値が最大になったときが一番正解に近いということになります。
また今回実装するにあたって勾配降下法を用いるのですが、勾配降下法では値を最小にすることを目的とするので尤度関数に-1をかけます。
さらにデータ件数が多くなると値が果てしなく大きくなってしまうので、平均を用いてデータ件数の影響をなくします。
これらをまとめると損失関数は以下のように定義できます。
損失関数の微分
細かい証明はしませんが、損失関数を微分すると以下のようになります。
本当に不思議に感じますが、これほどまでに簡単になります。
これでとなる
yd
を定義します。
さらにのように変形することができることを用いて、以下のようにそれぞれの入力変数の偏微分を求めることができます。
(i=0,1,2)
勾配降下法
勾配降下法は以前紹介したものとほぼ同じように用いることができます。
【Unity】Unityでもゼロから機械学習を作る【単回帰モデル】 - はなちるのマイノート
【Unity】Unityでもゼロから機械学習を作る【重回帰モデル】 - はなちるのマイノート
これまでで必要な数式をまとめるとこんな感じ。また今までも使っていましたがm
,k
は「m番目のデータ系列,繰り返し計算k回目」という意味です。
繰り返しのアルゴリズム
(i=0,1,2)
データ
やっとのこと理論が一通り紹介し終わったので、それ通りに実装していきます。
ただデータがないことにはなにも始まらないので、定番のあやめの識別のデータセットを用います。
UCI Machine Learning Repository: Iris Data Set
これを加工して以下のようにしました。
実装してみる
コードはこんな感じ。
VectorN
とかは以前作成したものを流用していますので、知りたい方は以前の記事をみてみてください。
using System.Linq; using System.Text; using UnityEngine; namespace RegressionModel { public class RegressionModel : MonoBehaviour { // 繰り返し回数(この書き方だとインスペクター優先になります) [SerializeField] private int _iters = 10000; // 学習率 [SerializeField] private float _alpha = 0.0001f; // データ系列総数 private int _dataSize; void Start() { var datasets = DatasetReader.Read(); // 正解データ(0 or 1) var yt = datasets.Select(t => t.Yt) .ToArray(); // 入力データ(x0=1, x1=がく片の長さ, x2=がく片の幅)の設定 x0はダミー変数 VectorN[] x = datasets.Select(d => { var list = d.x.ToList(); list.Insert(0, 1); // ダミー変数を追加 return new VectorN(list); }) .ToArray(); _dataSize = yt.Length; // 重みベクトルの初期化 初期値は(1,1,1) VectorN w = VectorN.GetOne(x[0].Length); for (int i = 0; i < _iters; i++) { // 予測値の計算 var yp = Pred(x, w); // 誤差の計算 var yd = CalsulateError(yp, yt); // 勾配降下法の実装 w = GradientDescentMethod(x, w, yd); } // 学習後の回帰直線 var builder = new StringBuilder(); builder.Append($"u= {w[0]}"); for (int i = 1; i < w.Length; i++) { if (w[i] >= 0) builder.Append(" +"); builder.Append($" {w[i]} x{i}"); } Debug.Log(builder.ToString()); // 最終的な正解率(本当は別に検証データを作らなければなりませんが、めんどくさいので訓練データの正解率を調べる) int correct = 0; for(int i = 0; i < _dataSize; i++) { if (yt[i] == Classify(x[i], w)) correct++; } Debug.Log($"Correct answer rate: {(float)correct / _dataSize}"); } private float Sigmoid(float x) => 1.0f / (1 + Mathf.Exp(-x)); /// <summary> /// 予測値ypを計算する /// </summary> private float[] Pred(VectorN[] x, VectorN w) { float[] yp = new float[_dataSize]; for (int m = 0; m < _dataSize; m++) { yp[m] = Sigmoid(VectorN.Dot(x[m], w)); } return yp; } /// <summary> /// 誤差を計算する /// </summary> private float[] CalsulateError(float[] yp, float[] yt) { float[] yd = new float[_dataSize]; for (int m = 0; m < _dataSize; m++) { yd[m] = yp[m] - yt[m]; } return yd; } /// <summary> /// 勾配降下法の実装 /// </summary> private VectorN GradientDescentMethod(VectorN[] x, VectorN w, float[] yd) { for (int i = 0; i < w.Length; i++) { float sum = 0; for (int m = 0; m < _dataSize; m++) { sum += yd[m] * x[m][i]; } w[i] = w[i] - (_alpha / _dataSize * sum); } return w; } /// <summary> /// 予測結果から分類(0 or 1)を行う /// </summary> private int Classify(VectorN x, VectorN w) => VectorN.Dot(x, w) > 0.5 ? 1 : 0; } }
結果
本当はダメなのですが、検証用データを用意するのがめんどくさかったので学習用データで正解率を調べてみました。
学習率と繰り返し回数を色々と変化させながら実験してみると正解率が0.97
まであげることできました。
かなり良い結果がでてくれたので満足です。
さいごに
本当はグラフに出力したり学習途中の損失関数を求めたりしたかったのですが、結局やれませんでした。
ただせっかくUnityで実装しているので、もっと視覚的にわかりやすく表現して3Dで動く動画なんかを作ったりしようかなと密かに思っています。
Youtubeにも少し興味があるのでも挙げてみたりもしてみたいです。(いつも通りのやるやる詐欺になるかも)
ではまた。