Lasso, Ridge, Elastic Netのグラフ

Pythonで図形を書く練習として、Lasso、Ridge、Elastic Netの3種類の正則化項をグラフに書いてみました。

Elastic Net正則化項は、{\theta_1\|w\|_1+\theta_2\|w\|_2^2}として、{l_1}ノルムと{l_2}ノルムを{\theta_1}{\theta_2}の比率で組み合わせています。

中心から順番に、Lasso、Elastic Net({\theta_1 = 0.25})、Elastic Net({\theta_1 = 0.50})、Elastic Net({\theta_1 = 0.75})、Ridgeとなります。

f:id:decompose:20160325234659p:plain

import numpy as np
import matplotlib.pyplot as plt

def elasticnet(theta1,col):
	theta2 = 1 - theta1
	w2 = np.linspace(-1,1,num=100)
	inside = theta1 ** 2 - 4 * theta2 * (theta2 * abs(w2) ** 2 + theta1 * abs(w2) -1)
	w1 = (-theta1+np.sqrt(inside))/(2 * theta2)
	wn1 = -w1
	plt.plot(w2,w1,col)
	plt.plot(w2,wn1,col)

def lasso(col):
	w2 = np.linspace(-1,1,num=100)
	w1 = 1 - abs(w2)
	wn1 = -w1
	plt.plot(w2,w1,col)
	plt.plot(w2,wn1,col)

elasticnet(.0,"b")
elasticnet(.25,"g")
elasticnet(.5,"r")
elasticnet(.75,"c")
lasso("m")
plt.figure(figsize=(800,800))
plt.show()