Drzewo decyzyjne w R - Drzewo klasyfikacji & Kod w R z przykładem

Spisie treści:

Anonim

Co to są drzewa decyzyjne?

Drzewa decyzyjne to wszechstronny algorytm uczenia maszynowego, który może wykonywać zarówno zadania klasyfikacji, jak i regresji. Są to bardzo potężne algorytmy, zdolne do dopasowywania złożonych zbiorów danych. Poza tym drzewa decyzyjne są podstawowymi składnikami lasów losowych, które należą do najsilniejszych dostępnych obecnie algorytmów uczenia maszynowego.

Szkolenie i wizualizacja drzew decyzyjnych

Aby zbudować pierwsze drzewo decyzyjne w przykładzie R, postępujemy w następujący sposób w tym samouczku Drzewo decyzyjne:

  • Krok 1: Zaimportuj dane
  • Krok 2: Wyczyść zbiór danych
  • Krok 3: Utwórz pociąg / zestaw testowy
  • Krok 4: Zbuduj model
  • Krok 5: Przewiduj
  • Krok 6: Zmierz wydajność
  • Krok 7: Dostrój hiperparametry

Krok 1) Zaimportuj dane

Jeśli jesteście ciekawi losów Titanica, możecie obejrzeć ten film na Youtube. Celem tego zbioru danych jest przewidywanie, którzy ludzie z większym prawdopodobieństwem przeżyją po zderzeniu z górą lodową. Zbiór danych zawiera 13 zmiennych i 1309 obserwacji. Zbiór danych jest uporządkowany według zmiennej X.

set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)

Wynik:

## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)

Wynik:

## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S

Z wyjścia głowy i ogona można zauważyć, że dane nie są przetasowane. To jest duży problem! Kiedy podzielisz swoje dane między skład pociągu i zestaw testowy, wybierzesz tylko pasażera z klasy 1 i 2 (żaden pasażer z klasy 3 nie znajduje się w górnych 80 procentach obserwacji), co oznacza, że ​​algorytm nigdy nie zobaczy cechy pasażera klasy 3. Ten błąd doprowadzi do złego przewidywania.

Aby rozwiązać ten problem, możesz użyć funkcji sample ().

shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)

Drzewo decyzyjne Kod R Objaśnienie

  • sample (1: nrow (titanic)): generuje losową listę indeksów od 1 do 1309 (tj. maksymalna liczba wierszy).

Wynik:

## [1] 288 874 1078 633 887 992 

Będziesz używać tego indeksu do tasowania zbioru danych Titanic.

titanic <- titanic[shuffle_index, ]head(titanic)

Wynik:

## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C

Krok 2) Wyczyść zbiór danych

Struktura danych pokazuje, że niektóre zmienne mają NA. Wyczyść dane w następujący sposób

  • Upuść zmienne home.dest, cabin, name, X i ticket
  • Utwórz zmienne czynnikowe dla pclass i przeżyłem
  • Porzuć NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)

Objaśnienie kodu

  • select (-c (home.dest, kabina, nazwa, X, bilet)): Usuń niepotrzebne zmienne
  • pclass = factor (pclass, levels = c (1,2,3), labels = c ('Upper', 'Middle', 'Lower')): Dodaj etykietę do zmiennej pclass. 1 staje się górny, 2 staje się miękki, a 3 staje się niższy
  • czynnik (przeżył, poziomy = c (0,1), etykiety = c („Nie”, „Tak”)): Dodaj etykietę do zmiennej, która przeżyła. 1 staje się nie, a 2 staje się tak
  • na.omit (): Usuń obserwacje NA

Wynik:

## Observations: 1,045## Variables: 8## $ pclass  Upper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived  No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex  male, male, female, female, male, male, female, male… ## $ age  61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp  0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch  0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare  32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked  S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C… 

Krok 3) Utwórz pociąg / zestaw testowy

Zanim wytrenujesz model, musisz wykonać dwa kroki:

  • Utwórz pociąg i zestaw testowy: trenujesz model w zestawie pociągu i testujesz prognozę na zestawie testowym (tj. Niewidoczne dane)
  • Zainstaluj rpart.plot z konsoli

Powszechną praktyką jest dzielenie danych na 80/20, 80% danych służy do trenowania modelu, a 20% do prognozowania. Musisz utworzyć dwie oddzielne ramki danych. Nie chcesz dotykać zestawu testowego, dopóki nie skończysz budować modelu. Możesz utworzyć nazwę funkcji create_train_test (), która przyjmuje trzy argumenty.

create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}

Objaśnienie kodu

  • function (data, size = 0.8, train = TRUE): Dodaj argumenty w funkcji
  • n_row = nrow (dane): Policz liczbę wierszy w zbiorze danych
  • total_row = size * n_row: Zwróć n-ty wiersz, aby skonstruować zestaw pociągów
  • train_sample <- 1: total_row: Wybierz wiersz od pierwszego do n-tego
  • if (train == TRUE) {} else {}: Jeśli warunek ma wartość true, zwraca zestaw pociągu, w przeciwnym razie zestaw testowy.

Możesz przetestować swoją funkcję i sprawdzić wymiar.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)

Wynik:

## [1] 836 8
dim(data_test)

Wynik:

## [1] 209 8 

Zestaw danych pociągu ma 1046 wierszy, podczas gdy zestaw danych testowych ma 262 wiersze.

Używasz funkcji prop.table () w połączeniu z table (), aby sprawdzić, czy proces randomizacji jest prawidłowy.

prop.table(table(data_train$survived))

Wynik:

#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Wynik:

#### No Yes## 0.5789474 0.4210526

W obu zestawach danych liczba osób, które przeżyły, jest taka sama, około 40 procent.

Zainstaluj rpart.plot

rpart.plot nie jest dostępny w bibliotekach Conda. Możesz go zainstalować z konsoli:

install.packages("rpart.plot") 

Krok 4) Zbuduj model

Jesteś gotowy do zbudowania modelu. Składnia funkcji drzewa decyzyjnego Rpart jest następująca:

rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree

Używasz metody class, ponieważ przewidujesz klasę.

library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106

Objaśnienie kodu

  • rpart (): Funkcja dopasowująca do modelu. Argumentami są:
    • przeżył ~ .: Formuła drzew decyzyjnych
    • data = data_train: zbiór danych
    • metoda = 'klasa': Dopasuj model binarny
  • rpart.plot (fit, extra = 106): Wykreśl drzewo. Dodatkowe funkcje są ustawione na 101, aby wyświetlić prawdopodobieństwo drugiej klasy (przydatne w przypadku odpowiedzi binarnych). Więcej informacji na temat innych opcji można znaleźć w winiecie.

Wynik:

Zaczynasz od węzła głównego (głębokość 0 ponad 3, góra wykresu):

  1. U góry jest to ogólne prawdopodobieństwo przeżycia. Pokazuje odsetek pasażerów, którzy przeżyli katastrofę. Przeżyło 41 procent pasażerów.
  2. Ten węzeł pyta, czy pasażer jest płci męskiej. Jeśli tak, to schodzisz do lewego węzła potomnego korzenia (głębokość 2). 63 procent to mężczyźni z prawdopodobieństwem przeżycia wynoszącym 21 procent.
  3. W drugim węźle pytasz, czy pasażer płci męskiej ma powyżej 3,5 roku życia. Jeśli tak, to szansa na przeżycie wynosi 19 procent.
  4. Idziesz dalej w ten sposób, aby zrozumieć, jakie cechy wpływają na prawdopodobieństwo przeżycia.

Zauważ, że jedną z wielu cech drzew decyzyjnych jest to, że wymagają one bardzo niewielkiego przygotowania danych. W szczególności nie wymagają skalowania ani centrowania funkcji.

Domyślnie funkcja rpart () używa miary nieczystości Giniego do podziału nuty. Im wyższy współczynnik Giniego, tym więcej różnych instancji w węźle.

Krok 5) Dokonaj prognozy

Możesz przewidzieć swój testowy zestaw danych. Aby dokonać prognozy, możesz użyć funkcji predykcji (). Podstawowa składnia predykcji dla drzewa decyzyjnego R to:

predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level

Chcesz przewidzieć, którzy pasażerowie z większym prawdopodobieństwem przeżyją po zderzeniu z zestawu testowego. Oznacza to, że spośród tych 209 pasażerów będziesz wiedział, który przeżyje, czy nie.

predict_unseen <-predict(fit, data_test, type = 'class')

Objaśnienie kodu

  • Predict (fit, data_test, type = 'class'): Przewiduj klasę (0/1) zestawu testowego

Testowanie pasażera, któremu się nie udało i tych, którym się to udało.

table_mat <- table(data_test$survived, predict_unseen)table_mat

Objaśnienie kodu

  • table (data_test $ Survived, Predict_unseen): Utwórz tabelę, aby policzyć, ilu pasażerów zostało sklasyfikowanych jako ocalałych i zmarłych w porównaniu z poprawną klasyfikacją drzewa decyzyjnego w R

Wynik:

## predict_unseen## No Yes## No 106 15## Yes 30 58

Model poprawnie przewidział 106 zabitych pasażerów, ale sklasyfikował 15 ocalałych jako martwych. Analogicznie, model błędnie sklasyfikował 30 pasażerów jako ocalałych, podczas gdy okazało się, że nie żyją.

Krok 6) Zmierz wydajność

Miarę dokładności dla zadania klasyfikacyjnego można obliczyć za pomocą macierzy pomyłki :

Matryca zamieszanie jest lepszym wyborem, aby ocenić skuteczność klasyfikacji. Ogólną ideą jest policzenie, ile razy instancje True są klasyfikowane jako fałszywe.

Każdy wiersz w macierzy nieporozumień reprezentuje rzeczywisty cel, podczas gdy każda kolumna przedstawia przewidywany cel. Pierwszy wiersz tej macierzy uwzględnia zmarłych pasażerów (klasa Fałsz): 106 zostało poprawnie zaklasyfikowanych jako zmarłych ( Prawdziwie ujemne ), podczas gdy pozostały został błędnie zaklasyfikowany jako ocalały ( Fałszywie dodatni ). Drugi rząd dotyczy osób, które przeżyły, klasa pozytywna wynosiła 58 ( prawdziwie pozytywnych ), podczas gdy prawdziwie negatywna klasa wynosiła 30.

Możesz obliczyć test dokładności z macierzy nieporozumień:

Jest to stosunek wartości prawdziwie dodatnich i prawdziwie ujemnych do sumy macierzy. Za pomocą R możesz kodować w następujący sposób:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Objaśnienie kodu

  • sum (diag (table_mat)): Suma przekątnej
  • sum (table_mat): Suma macierzy.

Możesz wydrukować dokładność zestawu testowego:

print(paste('Accuracy for test', accuracy_Test))

Wynik:

## [1] "Accuracy for test 0.784688995215311" 

Masz 78% punktów za zestaw testowy. Możesz powtórzyć to samo ćwiczenie z zestawem danych treningowych.

Krok 7) Dostrój hiperparametry

Drzewo decyzyjne w R ma różne parametry, które kontrolują aspekty dopasowania. W bibliotece drzewa decyzyjnego rpart można sterować parametrami za pomocą funkcji rpart.control (). W poniższym kodzie wprowadzasz parametry, które będziesz stroić. Inne parametry można znaleźć w winiecie.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Będziemy postępować następująco:

  • Skonstruuj funkcję, aby zwrócić dokładność
  • Dostrój maksymalną głębokość
  • Dostrój minimalną liczbę próbek, które musi mieć węzeł, zanim będzie można go podzielić
  • Dostrój minimalną liczbę próbek, które musi mieć węzeł liścia

Możesz napisać funkcję wyświetlającą dokładność. Po prostu zawijasz kod, którego użyłeś wcześniej:

  1. Predict: Predictive_unseen <- Predict (fit, data_test, type = 'class')
  2. Wyprodukuj tabelę: table_mat <- table (data_test $ przetrwała, predykcja_niepostrzeżenie)
  3. Oblicz dokładność: dokładność_Test <- suma (diag (table_mat)) / sum (table_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}

Możesz spróbować dostroić parametry i sprawdzić, czy możesz ulepszyć model w stosunku do wartości domyślnej. Przypominamy, że musisz uzyskać dokładność wyższą niż 0,78

control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)

Wynik:

## [1] 0.7990431 

Z następującym parametrem:

minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0 

Otrzymujesz wyższą wydajność niż w poprzednim modelu. Gratulacje!

Podsumowanie

Możemy podsumować funkcje trenujące algorytm drzewa decyzyjnego w R

Biblioteka

Cel

funkcjonować

klasa

parametry

Detale

rpart

Drzewo klasyfikacyjne pociągów w R

rpart ()

klasa

formuła, df, metoda

rpart

Drzewo regresji pociągu

rpart ()

anova

formuła, df, metoda

rpart

Działka drzew

rpart.plot ()

dopasowany model

baza

przepowiadać, wywróżyć

przepowiadać, wywróżyć()

klasa

model dopasowany, typ

baza

przepowiadać, wywróżyć

przepowiadać, wywróżyć()

prawd

model dopasowany, typ

baza

przepowiadać, wywróżyć

przepowiadać, wywróżyć()

wektor

model dopasowany, typ

rpart

Parametry dotyczące kontroli

rpart.control ()

minsplit

Ustaw minimalną liczbę obserwacji w węźle, zanim algorytm przeprowadzi podział

minbucket

Ustaw minimalną liczbę obserwacji w ostatniej nucie, czyli liścia

maksymalna głębokość

Ustaw maksymalną głębokość dowolnego węzła końcowego drzewa. Węzeł główny jest traktowany jako głębokość 0

rpart

Trenuj model z parametrem kontrolnym

rpart ()

formuła, df, metoda, kontrola

Uwaga: Wytrenuj model na danych uczących i przetestuj wydajność na niewidocznym zbiorze danych, tj. Zbiorze testowym.