code

2017年9月18日 星期一

Applied Machine Learning in Python 9 - Lab: Supervised Classification

假設有以下dataset (有毒菇類):

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split


mush_df = pd.read_csv('mushrooms.csv')
mush_df2 = pd.get_dummies(mush_df)

X_mush = mush_df2.iloc[:,2:]
y_mush = mush_df2.iloc[:,1]

# use the variables X_train2, y_train2 for Question 5
X_train2, X_test2, y_train2, y_test2 = train_test_split(X_mush, y_mush, random_state=0)



使用decision tree classifier

我們利用decision tree classifier來train,並且找出此decision tree model最重要的五個features:

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0).fit(X_train2, y_train2)

# plot tree
from adspy_shared_utilities import plot_decision_tree
plot_decision_tree(clf, list(X_train2), np.asarray(['poison','non-poison']))
import matplotlib.pyplot as plt

可以看到以下是整個decision tree visualization,一個畫面塞不下:


找出最重要的五個features:

zipped = list(zip(list(X_train2), list(clf.feature_importances_)))
zipped.sort(key=lambda t: t[1], reverse=True)
result = [ x[0] for x in zipped[0:5]]

分別是以下:




使用SVM classifier

我們接著用svm classifier (RBF kernel)來對gamma參數做validaiton curve,看哪在什麼參數區間內的gamma能有最好的accuracy:

def validation_curve():
    from sklearn.svm import SVC
    from sklearn.model_selection import validation_curve

    # Your code here
    param_range = np.logspace(-4,1,6)
    train_scores, test_scores = validation_curve(SVC(kernel='rbf', C=1, random_state=0), X_subset, y_subset,
                                                 param_name='gamma',
                                                 param_range=param_range, cv=3,
                                                  scoring='accuracy')

    trains = np.mean(train_scores, axis=1)
    tests = np.mean(test_scores, axis=1)

    return trains,tests


檢視的gamma區間為0.0001 ~ 10,所以會有六個gamma值來檢視,每個gamma值會做3個fold validation,如果把每個gamma值的三次fold平均,結果如下:

test set:

training set:


可以看到training set gamma在第零個區間值 (i.e. 0.0001)有最糟的accuracy,所以是underfitting,而test set gamma在第五個區間值(e.g. 10) 有最糟的accuracy且training set gamma有最好的accuracy,所以是overfitting。比較好的選擇應該是gamma = 0.1。



沒有留言:

張貼留言