Rozhodovací strom v R - Klasifikační strom & Kód v R s příkladem

Obsah:

Anonim

Co jsou rozhodovací stromy?

Rozhodovací stromy jsou univerzální algoritmus strojového učení, který může provádět klasifikační i regresní úlohy. Jsou to velmi výkonné algoritmy schopné přizpůsobit složité datové sady. Kromě toho jsou rozhodovací stromy základní složkou náhodných doménových struktur, které dnes patří k nejúčinnějším algoritmům strojového učení.

Školení a vizualizace rozhodovacích stromů

Chcete-li vytvořit svůj první rozhodovací strom v příkladu R, budeme v tomto tutoriálu rozhodovacího stromu postupovat následovně:

  • Krok 1: Importujte data
  • Krok 2: Vyčistěte datovou sadu
  • Krok 3: Vytvořte vlak / testovací sadu
  • Krok 4: Sestavte model
  • Krok 5: Proveďte předpověď
  • Krok 6: Měření výkonu
  • Krok 7: Nalaďte hyperparametry

Krok 1) Importujte data

Pokud vás zajímá osud Titanicu, můžete sledovat toto video na Youtube. Účelem této datové sady je předpovědět, u kterých lidí je větší pravděpodobnost, že po srážce s ledovcem přežijí. Datová sada obsahuje 13 proměnných a 1309 pozorování. Datová sada je uspořádána podle proměnné X.

set.seed(678)path <- 'https://raw.githubusercontent.com/guru99-edu/R-Programming/master/titanic_data.csv'titanic <-read.csv(path)head(titanic)

Výstup:

## X pclass survived name sex## 1 1 1 1 Allen, Miss. Elisabeth Walton female## 2 2 1 1 Allison, Master. Hudson Trevor male## 3 3 1 0 Allison, Miss. Helen Loraine female## 4 4 1 0 Allison, Mr. Hudson Joshua Creighton male## 5 5 1 0 Allison, Mrs. Hudson J C (Bessie Waldo Daniels) female## 6 6 1 1 Anderson, Mr. Harry male## age sibsp parch ticket fare cabin embarked## 1 29.0000 0 0 24160 211.3375 B5 S## 2 0.9167 1 2 113781 151.5500 C22 C26 S## 3 2.0000 1 2 113781 151.5500 C22 C26 S## 4 30.0000 1 2 113781 151.5500 C22 C26 S## 5 25.0000 1 2 113781 151.5500 C22 C26 S## 6 48.0000 0 0 19952 26.5500 E12 S## home.dest## 1 St Louis, MO## 2 Montreal, PQ / Chesterville, ON## 3 Montreal, PQ / Chesterville, ON## 4 Montreal, PQ / Chesterville, ON## 5 Montreal, PQ / Chesterville, ON## 6 New York, NY
tail(titanic)

Výstup:

## X pclass survived name sex age sibsp## 1304 1304 3 0 Yousseff, Mr. Gerious male NA 0## 1305 1305 3 0 Zabour, Miss. Hileni female 14.5 1## 1306 1306 3 0 Zabour, Miss. Thamine female NA 1## 1307 1307 3 0 Zakarian, Mr. Mapriededer male 26.5 0## 1308 1308 3 0 Zakarian, Mr. Ortin male 27.0 0## 1309 1309 3 0 Zimmerman, Mr. Leo male 29.0 0## parch ticket fare cabin embarked home.dest## 1304 0 2627 14.4583 C## 1305 0 2665 14.4542 C## 1306 0 2665 14.4542 C## 1307 0 2656 7.2250 C## 1308 0 2670 7.2250 C## 1309 0 315082 7.8750 S

Z výstupu hlavy a ocasu si můžete všimnout, že data nejsou zamíchána. To je velký problém! Když rozdělíte svá data mezi vlakovou soupravu a zkušební soupravu, vyberete pouze cestujícího ze třídy 1 a 2 (žádný cestující ze třídy 3 není v horních 80 procentech pozorování), což znamená, že algoritmus nikdy neuvidí vlastnosti cestujícího třídy 3. Tato chyba povede ke špatné předpovědi.

K překonání tohoto problému můžete použít funkci sample ().

shuffle_index <- sample(1:nrow(titanic))head(shuffle_index)

Rozhodovací strom R kód Vysvětlení

  • sample (1: nrow (titanic)): Generuje náhodný seznam indexů od 1 do 1309 (tj. maximální počet řádků).

Výstup:

## [1] 288 874 1078 633 887 992 

Tento index použijete k zamíchání titanové datové sady.

titanic <- titanic[shuffle_index, ]head(titanic)

Výstup:

## X pclass survived## 288 288 1 0## 874 874 3 0## 1078 1078 3 1## 633 633 3 0## 887 887 3 1## 992 992 3 1## name sex age## 288 Sutton, Mr. Frederick male 61## 874 Humblen, Mr. Adolf Mathias Nicolai Olsen male 42## 1078 O'Driscoll, Miss. Bridget female NA## 633 Andersson, Mrs. Anders Johan (Alfrida Konstantia Brogren) female 39## 887 Jermyn, Miss. Annie female NA## 992 Mamee, Mr. Hanna male NA## sibsp parch ticket fare cabin embarked home.dest## 288 0 0 36963 32.3208 D50 S Haddenfield, NJ## 874 0 0 348121 7.6500 F G63 S## 1078 0 0 14311 7.7500 Q## 633 1 5 347082 31.2750 S Sweden Winnipeg, MN## 887 0 0 14313 7.7500 Q## 992 0 0 2677 7.2292 C

Krok 2) Vyčistěte datovou sadu

Struktura dat ukazuje, že některé proměnné mají NA. Vyčištění dat je třeba provést následujícím způsobem

  • Přetáhněte proměnné home.dest, kabina, jméno, X a lístek
  • Vytvořte proměnné faktoru pro pclass a přežili
  • Zrušte NA
library(dplyr)# Drop variablesclean_titanic <- titanic % > %select(-c(home.dest, cabin, name, X, ticket)) % > %#Convert to factor levelmutate(pclass = factor(pclass, levels = c(1, 2, 3), labels = c('Upper', 'Middle', 'Lower')),survived = factor(survived, levels = c(0, 1), labels = c('No', 'Yes'))) % > %na.omit()glimpse(clean_titanic)

Vysvětlení kódu

  • select (-c (home.dest, cabin, name, X, ticket)): Zrušit nepotřebné proměnné
  • pclass = factor (pclass, levels = c (1,2,3), labels = c ('Upper', 'Middle', 'Lower')): Přidat štítek do proměnné pclass. 1 se změní na horní, 2 se změní na střední a 3 se sníží
  • faktor (přežilo, úrovně = c (0,1), štítky = c ('ne', 'ano')): Přidat štítek do proměnné přežilo. 1 se stává Ne a 2 se stává Ano
  • na.omit (): Odstraní pozorování NA

Výstup:

## Observations: 1,045## Variables: 8## $ pclass  Upper, Lower, Lower, Upper, Middle, Upper, Middle, U… ## $ survived  No, No, No, Yes, No, Yes, Yes, No, No, No, No, No, Y… ## $ sex  male, male, female, female, male, male, female, male… ## $ age  61.0, 42.0, 39.0, 49.0, 29.0, 37.0, 20.0, 54.0, 2.0,… ## $ sibsp  0, 0, 1, 0, 0, 1, 0, 0, 4, 0, 0, 1, 1, 0, 0, 0, 1, 1,… ## $ parch  0, 0, 5, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 2, 0, 4, 0,… ## $ fare  32.3208, 7.6500, 31.2750, 25.9292, 10.5000, 52.5542,… ## $ embarked  S, S, S, S, S, S, S, S, S, C, S, S, S, Q, C, S, S, C… 

Krok 3) Vytvořte vlak / testovací sadu

Před trénováním modelu musíte provést dva kroky:

  • Vytvoření vlaku a testovací sady: Trénujete model na vlakové soupravě a testujete předpověď na testovací sadě (tj. Neviditelná data)
  • Nainstalujte rpart.plot z konzoly

Běžnou praxí je rozdělit data 80/20, 80 procent dat slouží k trénování modelu a 20 procent k předpovědi. Musíte vytvořit dva samostatné datové rámce. Nechcete se dotknout testovací sady, dokud nedokončíte vytváření modelu. Můžete vytvořit název funkce create_train_test (), který má tři argumenty.

create_train_test(df, size = 0.8, train = TRUE)arguments:-df: Dataset used to train the model.-size: Size of the split. By default, 0.8. Numerical value-train: If set to `TRUE`, the function creates the train set, otherwise the test set. Default value sets to `TRUE`. Boolean value.You need to add a Boolean parameter because R does not allow to return two data frames simultaneously.
create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample < - 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}

Vysvětlení kódu

  • funkce (data, velikost = 0,8, vlak = PRAVDA): Přidejte argumenty do funkce
  • n_row = nrow (data): Spočítá počet řádků v datové sadě
  • total_row = size * n_row: Vraťte n-tý řádek a postavte vlakovou soupravu
  • train_sample <- 1: total_row: Vyberte první řádek do n-tého řádku
  • if (train == TRUE) {} else {}: Pokud je podmínka nastavena na hodnotu true, vrať vlakovou soupravu, jinak testovací sadu.

Můžete otestovat svou funkci a zkontrolovat rozměr.

data_train <- create_train_test(clean_titanic, 0.8, train = TRUE)data_test <- create_train_test(clean_titanic, 0.8, train = FALSE)dim(data_train)

Výstup:

## [1] 836 8
dim(data_test)

Výstup:

## [1] 209 8 

Vlaková datová sada má 1046 řádků, zatímco testovací datová sada má 262 řádků.

Pomocí funkce prop.table () v kombinaci s table () ověříte, zda je proces randomizace správný.

prop.table(table(data_train$survived))

Výstup:

#### No Yes## 0.5944976 0.4055024
prop.table(table(data_test$survived))

Výstup:

#### No Yes## 0.5789474 0.4210526

V obou souborech dat je počet přeživších stejný, přibližně 40 procent.

Nainstalujte si rpart.plot

rpart.plot není k dispozici z knihoven conda. Můžete jej nainstalovat z konzoly:

install.packages("rpart.plot") 

Krok 4) Sestavte model

Jste připraveni model postavit. Syntaxe funkce rozhodovacího stromu Rpart je:

rpart(formula, data=, method='')arguments:- formula: The function to predict- data: Specifies the data frame- method:- "class" for a classification tree- "anova" for a regression tree

Metodu třídy použijete, protože předpovídáte třídu.

library(rpart)library(rpart.plot)fit <- rpart(survived~., data = data_train, method = 'class')rpart.plot(fit, extra = 106

Vysvětlení kódu

  • rpart (): Funkce přizpůsobená modelu. Argumenty jsou:
    • přežilo ~ .: Vzorec rozhodovacích stromů
    • data = data_train: datová sada
    • method = 'class': Přizpůsobit binární model
  • rpart.plot (fit, extra = 106): Vynese strom. Extra funkce jsou nastaveny na 101 pro zobrazení pravděpodobnosti 2. třídy (užitečné pro binární odpovědi). Další informace o dalších možnostech najdete na dálniční známce.

Výstup:

Začínáte v kořenovém uzlu (hloubka 0 až 3, horní část grafu):

  1. Nahoře je to celková pravděpodobnost přežití. Ukazuje podíl cestujících, kteří nehodu přežili. Přežilo 41 procent cestujících.
  2. Tento uzel se ptá, zda je pohlaví cestujícího mužské. Pokud ano, pak jdete dolů do levého podřízeného uzlu root (hloubka 2). 63 procent jsou muži s pravděpodobností přežití 21 procent.
  3. Ve druhém uzlu se zeptáte, zda je cestujícímu muži více než 3,5 roku. Pokud ano, pak je šance na přežití 19 procent.
  4. Pokračujte tak, abyste pochopili, jaké vlastnosti mají vliv na pravděpodobnost přežití.

Všimněte si, že jednou z mnoha kvalit rozhodovacích stromů je, že vyžadují velmi malou přípravu dat. Zejména nevyžadují změnu měřítka nebo centrování funkcí.

Ve výchozím nastavení používá funkce rpart () k rozdělení noty míru Giniho nečistoty. Čím vyšší je Giniho koeficient, tím více různých instancí v uzlu.

Krok 5) Proveďte předpověď

Můžete předpovědět svoji testovací datovou sadu. Chcete-li vytvořit předpověď, můžete použít funkci predikce (). Základní syntaxe predikce pro rozhodovací strom R je:

predict(fitted_model, df, type = 'class')arguments:- fitted_model: This is the object stored after model estimation.- df: Data frame used to make the prediction- type: Type of prediction- 'class': for classification- 'prob': to compute the probability of each class- 'vector': Predict the mean response at the node level

Chcete z testovací sady předpovědět, u kterých cestujících je pravděpodobnější, že po srážce přežijí. To znamená, že mezi těmi 209 cestujícími budete vědět, který z nich přežije nebo ne.

predict_unseen <-predict(fit, data_test, type = 'class')

Vysvětlení kódu

  • predikce (fit, data_test, type = 'class'): Predikce třídy (0/1) testovací sady

Testování cestujícího, který to nezvládl, a těch, kteří to nezvládli.

table_mat <- table(data_test$survived, predict_unseen)table_mat

Vysvětlení kódu

  • tabulka (data_test $ přežil, předpovědět_unseen): Vytvořte tabulku, která spočítá, kolik cestujících je klasifikováno jako přeživších a zemřelo v porovnání se správnou klasifikací stromu rozhodnutí v R

Výstup:

## predict_unseen## No Yes## No 106 15## Yes 30 58

Model správně předpověděl 106 mrtvých cestujících, ale 15 přeživších klasifikoval jako mrtvé. Analogicky model nesprávně klasifikoval 30 cestujících jako přeživších, zatímco se ukázalo, že jsou mrtví.

Krok 6) Změřte výkon

Pomocí matice zmatku můžete vypočítat míru přesnosti pro úkol klasifikace :

Zmatek matrix je lepší volbou pro hodnocení výkonu klasifikace. Obecná myšlenka je spočítat, kolikrát jsou instance True klasifikovány jako False.

Každý řádek v matici zmatku představuje skutečný cíl, zatímco každý sloupec představuje předpokládaný cíl. První řádek této matice uvažuje o mrtvých pasažérech (třída False): 106 bylo správně klasifikováno jako mrtví ( True negativní ), zatímco zbývající byl nesprávně klasifikován jako přeživší ( False pozitivní ). Druhá řada považuje přeživší, pozitivní třída byla 58 ( skutečně pozitivní ), zatímco skutečná záporná byla 30.

Test přesnosti můžete vypočítat z matice záměny:

Je to podíl skutečných kladných a skutečných záporných na součtu matice. S R můžete kódovat takto:

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)

Vysvětlení kódu

  • sum (diag (table_mat)): Součet úhlopříčky
  • sum (table_mat): Součet matice.

Můžete vytisknout přesnost testovací sady:

print(paste('Accuracy for test', accuracy_Test))

Výstup:

## [1] "Accuracy for test 0.784688995215311" 

U testovací sady máte skóre 78 procent. Stejné cvičení můžete replikovat pomocí datové sady školení.

Krok 7) Nalaďte hyperparametry

Rozhodovací strom v R má různé parametry, které řídí aspekty přizpůsobení. V knihovně rozhodovacího stromu rpart můžete ovládat parametry pomocí funkce rpart.control (). V následujícím kódu uvedete parametry, které naladíte. Další parametry najdete na vinětě.

rpart.control(minsplit = 20, minbucket = round(minsplit/3), maxdepth = 30)Arguments:-minsplit: Set the minimum number of observations in the node before the algorithm perform a split-minbucket: Set the minimum number of observations in the final note i.e. the leaf-maxdepth: Set the maximum depth of any node of the final tree. The root node is treated a depth 0

Budeme postupovat následovně:

  • Sestavte funkci pro návrat přesnosti
  • Nalaďte maximální hloubku
  • Nalaďte minimální počet vzorků, které musí mít uzel, než se může rozdělit
  • Nalaďte minimální počet vzorků, které musí mít listový uzel

Pro zobrazení přesnosti můžete napsat funkci. Jednoduše zabalíte kód, který jste použili dříve:

  1. predikovat: predict_unseen <- predikovat (fit, data_test, type = 'class')
  2. Produkovat tabulku: table_mat <- tabulka (data_test $ přežil, předpovídat_unseen)
  3. Přesnost výpočtu: přesnost_Test <- součet (diag (tabulka_mat)) / součet (tabulka_mat)
accuracy_tune <- function(fit) {predict_unseen <- predict(fit, data_test, type = 'class')table_mat <- table(data_test$survived, predict_unseen)accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test}

Můžete zkusit vyladit parametry a zjistit, zda můžete vylepšit model oproti výchozí hodnotě. Připomínáme, že musíte dosáhnout přesnosti vyšší než 0,78

control <- rpart.control(minsplit = 4,minbucket = round(5 / 3),maxdepth = 3,cp = 0)tune_fit <- rpart(survived~., data = data_train, method = 'class', control = control)accuracy_tune(tune_fit)

Výstup:

## [1] 0.7990431 

S následujícím parametrem:

minsplit = 4minbucket= round(5/3)maxdepth = 3cp=0 

Získáte vyšší výkon než u předchozího modelu. Blahopřejeme!

souhrn

Můžeme shrnout funkce pro trénování algoritmu rozhodovacího stromu v R.

Knihovna

Objektivní

funkce

třída

parametry

podrobnosti

rpart

Strom klasifikace vlaků v R.

rpart ()

třída

vzorec, df, metoda

rpart

Trénujte regresní strom

rpart ()

anova

vzorec, df, metoda

rpart

Nakreslete stromy

rpart.plot ()

namontovaný model

základna

předpovědět

předpovědět()

třída

namontovaný model, typ

základna

předpovědět

předpovědět()

prob

namontovaný model, typ

základna

předpovědět

předpovědět()

vektor

namontovaný model, typ

rpart

Kontrolní parametry

rpart.control ()

minisplit

Nastavte minimální počet pozorování v uzlu, než algoritmus provede rozdělení

minbucket

V závěrečné poznámce, tj. V listu, nastavte minimální počet pozorování

maximální hloubka

Nastavte maximální hloubku libovolného uzlu konečného stromu. Kořenový uzel je zpracován v hloubce 0

rpart

Model vlaku s ovládacím parametrem

rpart ()

vzorec, df, metoda, ovládání

Poznámka: Trénujte model na tréninkových datech a otestujte výkon na neviditelné datové sadě, tj. Testovací sadě.