import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
mush_df = pd.read_csv('mushrooms.csv')
mush_df2 = pd.get_dummies(mush_df)
X_mush = mush_df2.iloc[:,2:]
y_mush = mush_df2.iloc[:,1]
# use the variables X_train2, y_train2 for Question 5
X_train2, X_test2, y_train2, y_test2 = train_test_split(X_mush, y_mush, random_state=0)
使用decision tree classifier
我們利用decision tree classifier來train,並且找出此decision tree model最重要的五個features:from sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(random_state=0).fit(X_train2, y_train2) # plot tree from adspy_shared_utilities import plot_decision_tree plot_decision_tree(clf, list(X_train2), np.asarray(['poison','non-poison'])) import matplotlib.pyplot as plt
可以看到以下是整個decision tree visualization,一個畫面塞不下:
找出最重要的五個features:
zipped = list(zip(list(X_train2), list(clf.feature_importances_))) zipped.sort(key=lambda t: t[1], reverse=True) result = [ x[0] for x in zipped[0:5]]
分別是以下:
使用SVM classifier
我們接著用svm classifier (RBF kernel)來對gamma參數做validaiton curve,看哪在什麼參數區間內的gamma能有最好的accuracy:def validation_curve():
from sklearn.svm import SVC
from sklearn.model_selection import validation_curve
# Your code here
param_range = np.logspace(-4,1,6)
train_scores, test_scores = validation_curve(SVC(kernel='rbf', C=1, random_state=0), X_subset, y_subset,
param_name='gamma',
param_range=param_range, cv=3,
scoring='accuracy')
trains = np.mean(train_scores, axis=1)
tests = np.mean(test_scores, axis=1)
return trains,tests
檢視的gamma區間為0.0001 ~ 10,所以會有六個gamma值來檢視,每個gamma值會做3個fold validation,如果把每個gamma值的三次fold平均,結果如下:
test set:
training set:
可以看到training set gamma在第零個區間值 (i.e. 0.0001)有最糟的accuracy,所以是underfitting,而test set gamma在第五個區間值(e.g. 10) 有最糟的accuracy且training set gamma有最好的accuracy,所以是overfitting。比較好的選擇應該是gamma = 0.1。
沒有留言:
張貼留言