- import pandas as pd
- import numpy as np
- from sklearn.model_selection import train_test_split
- mush_df = pd.read_csv('mushrooms.csv')
- mush_df2 = pd.get_dummies(mush_df)
- X_mush = mush_df2.iloc[:,2:]
- y_mush = mush_df2.iloc[:,1]
- # use the variables X_train2, y_train2 for Question 5
- X_train2, X_test2, y_train2, y_test2 = train_test_split(X_mush, y_mush, random_state=0)
使用decision tree classifier
我們利用decision tree classifier來train,並且找出此decision tree model最重要的五個features:
- from sklearn.tree import DecisionTreeClassifier
- clf = DecisionTreeClassifier(random_state=0).fit(X_train2, y_train2)
- # plot tree
- from adspy_shared_utilities import plot_decision_tree
- plot_decision_tree(clf, list(X_train2), np.asarray(['poison','non-poison']))
- import matplotlib.pyplot as plt
可以看到以下是整個decision tree visualization,一個畫面塞不下:
找出最重要的五個features:
- zipped = list(zip(list(X_train2), list(clf.feature_importances_)))
- zipped.sort(key=lambda t: t[1], reverse=True)
- result = [ x[0] for x in zipped[0:5]]
分別是以下:
使用SVM classifier
我們接著用svm classifier (RBF kernel)來對gamma參數做validaiton curve,看哪在什麼參數區間內的gamma能有最好的accuracy:
- def validation_curve():
- from sklearn.svm import SVC
- from sklearn.model_selection import validation_curve
- # Your code here
- param_range = np.logspace(-4,1,6)
- train_scores, test_scores = validation_curve(SVC(kernel='rbf', C=1, random_state=0), X_subset, y_subset,
- param_name='gamma',
- param_range=param_range, cv=3,
- scoring='accuracy')
- trains = np.mean(train_scores, axis=1)
- tests = np.mean(test_scores, axis=1)
- return trains,tests
檢視的gamma區間為0.0001 ~ 10,所以會有六個gamma值來檢視,每個gamma值會做3個fold validation,如果把每個gamma值的三次fold平均,結果如下:
test set:
training set:
可以看到training set gamma在第零個區間值 (i.e. 0.0001)有最糟的accuracy,所以是underfitting,而test set gamma在第五個區間值(e.g. 10) 有最糟的accuracy且training set gamma有最好的accuracy,所以是overfitting。比較好的選擇應該是gamma = 0.1。
沒有留言:
張貼留言