code

2017年9月18日 星期一

Applied Machine Learning in Python 9 - Lab: Supervised Classification

假設有以下dataset (有毒菇類):

  1. import pandas as pd
  2. import numpy as np
  3. from sklearn.model_selection import train_test_split
  4.  
  5.  
  6. mush_df = pd.read_csv('mushrooms.csv')
  7. mush_df2 = pd.get_dummies(mush_df)
  8.  
  9. X_mush = mush_df2.iloc[:,2:]
  10. y_mush = mush_df2.iloc[:,1]
  11.  
  12. # use the variables X_train2, y_train2 for Question 5
  13. X_train2, X_test2, y_train2, y_test2 = train_test_split(X_mush, y_mush, random_state=0)
  14.  
  15.  

使用decision tree classifier

我們利用decision tree classifier來train,並且找出此decision tree model最重要的五個features:

  1. from sklearn.tree import DecisionTreeClassifier
  2. clf = DecisionTreeClassifier(random_state=0).fit(X_train2, y_train2)
  3.  
  4. # plot tree
  5. from adspy_shared_utilities import plot_decision_tree
  6. plot_decision_tree(clf, list(X_train2), np.asarray(['poison','non-poison']))
  7. import matplotlib.pyplot as plt

可以看到以下是整個decision tree visualization,一個畫面塞不下:


找出最重要的五個features:

  1. zipped = list(zip(list(X_train2), list(clf.feature_importances_)))
  2. zipped.sort(key=lambda t: t[1], reverse=True)
  3. result = [ x[0] for x in zipped[0:5]]

分別是以下:




使用SVM classifier

我們接著用svm classifier (RBF kernel)來對gamma參數做validaiton curve,看哪在什麼參數區間內的gamma能有最好的accuracy:

  1. def validation_curve():
  2. from sklearn.svm import SVC
  3. from sklearn.model_selection import validation_curve
  4.  
  5. # Your code here
  6. param_range = np.logspace(-4,1,6)
  7. train_scores, test_scores = validation_curve(SVC(kernel='rbf', C=1, random_state=0), X_subset, y_subset,
  8. param_name='gamma',
  9. param_range=param_range, cv=3,
  10. scoring='accuracy')
  11.  
  12. trains = np.mean(train_scores, axis=1)
  13. tests = np.mean(test_scores, axis=1)
  14.  
  15. return trains,tests


檢視的gamma區間為0.0001 ~ 10,所以會有六個gamma值來檢視,每個gamma值會做3個fold validation,如果把每個gamma值的三次fold平均,結果如下:

test set:

training set:


可以看到training set gamma在第零個區間值 (i.e. 0.0001)有最糟的accuracy,所以是underfitting,而test set gamma在第五個區間值(e.g. 10) 有最糟的accuracy且training set gamma有最好的accuracy,所以是overfitting。比較好的選擇應該是gamma = 0.1。



沒有留言:

張貼留言