juliaで最小二乗法を勾配降下法で解く
前回のデータを使って勾配降下法で最小二乗法を解いてみた。
juliaもわからないしjupyter notebookもわからないしPlotsもわからない中、もがきながらなんとか形になりました。
Jは誤差関数として二乗誤差を使う。
\[
J=\frac{1}{2}\sum_{i=1}^n(y_i - (ax_i+b))^2
\]
ここで、を次のように置く。
\[
w = \left(
\begin{array}{cc}
a \\
b \\
\end{array}
\right),
\nabla J = \left(
\begin{array}{cc}
\frac{\partial J}{\partial a} \\
\frac{\partial J}{\partial b} \\
\end{array}
\right)
\]
このとき、勾配降下法では、次のようにパラメータを更新していく。
\[
w^{(t+1)} = w^{(t)} - \alpha \nabla J
\]
は乱数で初期値を与えておいて、学習率を十分小さくすれば、傾きに沿って極値へと向かっていくはず。
前回の結果からも分かる通り、は次のように計算される。
\[
\nabla J =
\left(
\begin{array}{cc}
\sum_{i=1}^nx_i^2 & \sum_{i=i}^nx_i \\
\sum_{i=1}^nx_i & n \\
\end{array}
\right)
w-
\left(
\begin{array}{cc}
\sum_{i=1}^nx_iy_i \\
\sum_{i=1}^ny_i \\
\end{array}
\right)
\]
ここで、
\[
P = \left(
\begin{array}{cc}
\sum_{i=1}^nx_i^2 & \sum_{i=i}^nx_i \\
\sum_{i=1}^nx_i & n \\
\end{array}
\right),
q = \left(
\begin{array}{cc}
\sum_{i=1}^nx_iy_i \\
\sum_{i=1}^ny_i \\
\end{array}
\right)
\]
と置くと、
\[
\nabla J = Pw - q
\]
と書ける。
また、は、次のように変形できる。
\[
J = \frac{1}{2}(a\sum_{i=i}^nx_i+b - \sum_{i=1}^ny_i)^2
\]
シグマの値は定数だから、先に計算しておけばを与えればすぐに計算できる。
今回は勾配降下の様子をグラフ化してみた。2変数だからなんとかグラフにできる。
の値を乱数で与えて勾配降下というルーチンを20回やって図にしたのがこれ
縦線になっているのがの等高線のハズだから、谷底みたいに細長い部分ができてしまっていて、そこから降下しなくなっているんだろう。
データの正規化をすればいいのかな?
この図を見るとなんかある直線上に収束してるように見えるけど等高線と傾きが違うような気がする。
なにか計算ミスしてるんだろうか?
作ったプログラムは以下
#描画するための設定 using Plots gr() #誤差関数 function J(w, X, Y) a = w[1] b = w[2] 0.5*(a*X+b-Y)^2 end #ファイル読み込み io = open("data.txt", "r") n = countlines(io) #サンプル数 data = zeros(Float64, n, 2) # n*2行列 seekstart(io) for i = 1:n line = split(readline(io)) data[i, 1] = parse(Float64, line[1]) data[i, 2] = parse(Float64, line[2]) end x = data[:,1] y = data[:,2] X = sum(x) # Σxi Y = sum(y) # Σyi X2 = x'x # Σ(xi^2) XY = x'y # Σxiyi P = [X2 X; X n] q = [XY Y]' α = 0.001#学習率 plot() for i = 1:20 #20回勾配降下を繰り返す W = zeros(2,10) #各ステップごとの結果を格納 2*1ベクトル?を10列 w = 40*[rand()-0.5 rand()-0.5]' #初期化 for j = 1:10 #10ステップ勾配降下する W[:,j] = w w = w - α * (P*w-q) end plot!(W[1,:], W[2,:], marker=:circle) #勾配降下が終わったらプロット end #誤差関数をプロット S = -20:1:20 T = -20:1:20 U = [J([s,t],X,Y) for s in S, t in T]' plot!(S,T,U)
juliaの行列というか配列の取扱いに苦戦したのでメモ。
[1,2]とすると配列 これは列ベクトルみたいな扱いになる
[1 2]とすると1*2行列
[1; 2]とすると2*1行列になる。