import pandas as pd import numpy as np from sklearn.model_selection import train_test_split mush_df = pd.read_csv('mushrooms.csv') mush_df2 = pd.get_dummies(mush_df) X_mush = mush_df2.iloc[:,2:] y_mush = mush_df2.iloc[:,1] # use the variables X_train2, y_train2 for Question 5 X_train2, X_test2, y_train2, y_test2 = train_test_split(X_mush, y_mush, random_state=0)
使用decision tree classifier
我們利用decision tree classifier來train,並且找出此decision tree model最重要的五個features:from sklearn.tree import DecisionTreeClassifier clf = DecisionTreeClassifier(random_state=0).fit(X_train2, y_train2) # plot tree from adspy_shared_utilities import plot_decision_tree plot_decision_tree(clf, list(X_train2), np.asarray(['poison','non-poison'])) import matplotlib.pyplot as plt
可以看到以下是整個decision tree visualization,一個畫面塞不下:
找出最重要的五個features:
zipped = list(zip(list(X_train2), list(clf.feature_importances_))) zipped.sort(key=lambda t: t[1], reverse=True) result = [ x[0] for x in zipped[0:5]]
分別是以下:
使用SVM classifier
我們接著用svm classifier (RBF kernel)來對gamma參數做validaiton curve,看哪在什麼參數區間內的gamma能有最好的accuracy:def validation_curve(): from sklearn.svm import SVC from sklearn.model_selection import validation_curve # Your code here param_range = np.logspace(-4,1,6) train_scores, test_scores = validation_curve(SVC(kernel='rbf', C=1, random_state=0), X_subset, y_subset, param_name='gamma', param_range=param_range, cv=3, scoring='accuracy') trains = np.mean(train_scores, axis=1) tests = np.mean(test_scores, axis=1) return trains,tests
檢視的gamma區間為0.0001 ~ 10,所以會有六個gamma值來檢視,每個gamma值會做3個fold validation,如果把每個gamma值的三次fold平均,結果如下:
test set:
training set:
可以看到training set gamma在第零個區間值 (i.e. 0.0001)有最糟的accuracy,所以是underfitting,而test set gamma在第五個區間值(e.g. 10) 有最糟的accuracy且training set gamma有最好的accuracy,所以是overfitting。比較好的選擇應該是gamma = 0.1。
沒有留言:
張貼留言