Co je logistická regrese?
Logistická regrese se používá k předpovědi třídy, tj. Pravděpodobnosti. Logistická regrese může přesně předpovědět binární výsledek.
Představte si, že chcete předpovědět, zda je půjčka zamítnuta / přijata na základě mnoha atributů. Logistická regrese má tvar 0/1. y = 0, pokud je půjčka odmítnuta, y = 1, pokud je přijata.
Logistický regresní model se od lineárního regresního modelu liší dvěma způsoby.
- Nejprve logistická regrese přijímá pouze dichotomický (binární) vstup jako závislou proměnnou (tj. Vektor 0 a 1).
- Zadruhé, výsledek je měřen následující pravděpodobnostní spojovací funkcí zvanou sigmoid kvůli jejímu tvaru S .:
Výstup funkce je vždy mezi 0 a 1. Zkontrolujte obrázek níže
Funkce sigmoid vrací hodnoty od 0 do 1. Pro úkol klasifikace potřebujeme diskrétní výstup 0 nebo 1.
Chcete-li převést spojitý tok na diskrétní hodnotu, můžeme nastavit rozhodnutí vázané na 0,5. Všechny hodnoty nad touto prahovou hodnotou jsou klasifikovány jako 1
V tomto výukovém programu se naučíte
- Co je logistická regrese?
- Jak vytvořit model zobecněné linky (GLM)
- Krok 1) Zkontrolujte spojité proměnné
- Krok 2) Zkontrolujte proměnné faktoru
- Krok 3) Inženýrství funkcí
- Krok 4) Souhrnná statistika
- Krok 5) Vlak / testovací sada
- Krok 6) Sestavte model
- Krok 7) Posuďte výkon modelu
Jak vytvořit model zobecněné linky (GLM)
Pojďme použít sadu dat pro dospělé k ilustraci logistické regrese. „Dospělý“ je skvělý soubor dat pro úkol klasifikace. Cílem je předpovědět, zda roční příjem v dolarech jednotlivce přesáhne 50 000. Datová sada obsahuje 46 033 pozorování a deset funkcí:
- věk: věk jedince. Číselné
- vzdělání: Vzdělávací úroveň jednotlivce. Faktor.
- marital.status: Rodinný stav jednotlivce. Faktor, tj. Nikdy ženatý, ženatý, manžel,…
- gender: pohlaví jednotlivce. Faktor, tj. Muž nebo žena
- příjem: cílová proměnná. Příjem nad nebo pod 50 tis. Faktor tj.> 50K, <= 50K
mimo jiné
library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)
Výstup:
Observations: 48,842Variables: 10$ x1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass Private, Private, Local-gov, Private, ?, Private,… $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status Never-married, Married-civ-spouse, Married-civ-sp… $ race Black, White, White, Black, White, White, Black,… $ gender Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5…
Budeme postupovat následovně:
- Krok 1: Zkontrolujte spojité proměnné
- Krok 2: Zkontrolujte proměnné faktoru
- Krok 3: Inženýrství funkcí
- Krok 4: Souhrnná statistika
- Krok 5: Trénink / testovací sada
- Krok 6: Sestavte model
- Krok 7: Posouzení výkonu modelu
- krok 8: Vylepšení modelu
Vaším úkolem je předpovědět, který jedinec bude mít tržby vyšší než 50 tis.
V tomto kurzu bude každý krok podrobně popsán, aby se provedla analýza na skutečné datové sadě.
Krok 1) Zkontrolujte spojité proměnné
V prvním kroku můžete vidět distribuci spojitých proměnných.
continuous <-select_if(data_adult, is.numeric)summary(continuous)
Vysvětlení kódu
- spojitý <- select_if (data_adult, is.numeric): Pomocí funkce select_if () z knihovny dplyr vyberte pouze číselné sloupce
- souhrn (průběžně): Vytiskne souhrnnou statistiku
Výstup:
## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Z výše uvedené tabulky můžete vidět, že data mají úplně jiná měřítka a hodiny. Per.weeks má velké odlehlé hodnoty (tj. Podívejte se na poslední kvartil a maximální hodnotu).
Můžete to vyřešit dvěma kroky:
- 1: Nakreslete distribuci hodin.za.týden
- 2: Standardizace spojitých proměnných
- Nakreslete distribuci
Podívejme se blíže na distribuci hours.per.week
# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")
Výstup:
Proměnná má spoustu odlehlých hodnot a není dobře definovaná distribuce. Tento problém můžete částečně vyřešit odstraněním horních 0,01 procenta hodin týdně.
Základní syntaxe kvantilu:
quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.
Vypočítáme horní 2 procentní percentil
top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent
Vysvětlení kódu
- quantile (data_adult $ hours.per.week, 0,99): Vypočítá hodnotu 99 procent pracovní doby
Výstup:
## 99%## 80
98 procent populace pracuje do 80 hodin týdně.
Pozorování můžete pustit nad tuto hranici. Používáte filtr z knihovny dplyr.
data_adult_drop <-data_adult %>%filter(hours.per.weekVýstup:
## [1] 45537 10
- Standardizujte spojité proměnné
Každý sloupec můžete standardizovat, abyste zlepšili výkon, protože vaše data nemají stejné měřítko. Můžete použít funkci mutate_if z knihovny dplyr. Základní syntaxe je:
mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the functionČíselné sloupce můžete standardizovat následujícím způsobem:
data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)Vysvětlení kódu
- mutate_if (is.numeric, funs (scale)): Podmínkou je pouze číselný sloupec a funkcí je scale
Výstup:
## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50KKrok 2) Zkontrolujte proměnné faktoru
Tento krok má dva cíle:
- Zkontrolujte úroveň v každém kategorickém sloupci
- Definujte nové úrovně
Tento krok rozdělíme na tři části:
- Vyberte kategorické sloupce
- Uložte sloupcový graf každého sloupce do seznamu
- Vytiskněte grafy
Můžeme vybrat sloupce faktoru pomocí níže uvedeného kódu:
# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)Vysvětlení kódu
- data.frame (select_if (data_adult, is.factor)): Ukládáme sloupce faktoru v faktoru v typu datového rámce. Knihovna ggplot2 vyžaduje objekt datového rámce.
Výstup:
## [1] 6Datová sada obsahuje 6 kategorických proměnných
Druhý krok je zručnější. Chcete vykreslit sloupcový graf pro každý sloupec ve faktoru datového rámce. Je pohodlnější proces automatizovat, zejména v situaci, kdy existuje spousta sloupců.
library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))Vysvětlení kódu
- lapply (): Pomocí funkce lapply () předáte funkci ve všech sloupcích datové sady. Výstup uložíte do seznamu
- funkce (x): Funkce bude zpracována pro každé x. Zde x jsou sloupce
- ggplot (factor, aes (get (x))) + geom_bar () + téma (axis.text.x = element_text (angle = 90)): Vytvořte sloupcový char chart pro každý x prvek. Chcete-li vrátit x jako sloupec, musíte jej zahrnout dovnitř get ()
Poslední krok je relativně snadný. Chcete vytisknout 6 grafů.
# Print the graphgraphVýstup:
## [[1]]## ## [[2]]## ## [[3]]## ## [[4]]## ## [[5]]## ## [[6]]Poznámka: Pomocí tlačítka Další přejděte na další graf
Krok 3) Inženýrství funkcí
Přepracované vzdělávání
Z výše uvedeného grafu můžete vidět, že proměnná vzdělání má 16 úrovní. To je podstatné a některé úrovně mají relativně nízký počet pozorování. Pokud chcete zlepšit množství informací, které můžete z této proměnné získat, můžete ji přepracovat na vyšší úroveň. Jmenovitě vytváříte větší skupiny s podobnou úrovní vzdělání. Například nízká úroveň vzdělání bude převedena do předčasného ukončení. Vyšší úrovně vzdělání se změní na magisterské.
Tady je detail:
Stará úroveň
Nová úroveň
Předškolní
výpadek
10.
Výpadek
11.
Výpadek
12
Výpadek
1.-4
Výpadek
5.-6
Výpadek
7.-8
Výpadek
9
Výpadek
HS-Grad
HighGrad
Některé vysoké školy
Společenství
Assoc-acdm
Společenství
Doc
Společenství
Bakaláři
Bakaláři
Mistři
Mistři
Profesní škola
Mistři
Doktorát
PhD
recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))Vysvětlení kódu
- Používáme sloveso mutovat z knihovny dplyr. Hodnoty vzdělávání měníme tvrzením ifelse
V níže uvedené tabulce vytvoříte souhrnnou statistiku, abyste v průměru zjistili, kolik let vzdělání (hodnota z) trvá k dosažení bakaláře, magisterského nebo doktorského studia.
recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)Výstup:
## # A tibble: 6 x 3## education average_educ_year count#### 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557 Přepracovat rodinný stav
Je také možné vytvořit nižší úrovně pro rodinný stav. V následujícím kódu změníte úroveň následujícím způsobem:
Stará úroveň
Nová úroveň
Se nikdy neoženil
Není vdaná
Ženatý manžel chybí
Není vdaná
Ženatý-AF-manžel
Ženatý
Manželský občan
Oddělené
Oddělené
Rozvedený
Vdovy
Vdova
# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))Můžete zkontrolovat počet jednotlivců v každé skupině.table(recast_data$marital.status)Výstup:
## ## Married Not_married Separated Widow## 21165 15359 7727 1286Krok 4) Souhrnná statistika
Je čas zkontrolovat některé statistiky o našich cílových proměnných. V níže uvedeném grafu spočítáte procento jednotlivců vydělávajících více než 50 tis. Vzhledem k jejich pohlaví.
# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()Výstup:
Dále zkontrolujte, zda původ jednotlivce ovlivňuje jejich výdělky.
# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))Výstup:
Počet hodin práce podle pohlaví.
# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()Výstup:
Krabicový graf potvrzuje, že rozložení pracovní doby odpovídá různým skupinám. Na grafu pole nemají obě pohlaví homogenní pozorování.
Hustotu týdenní pracovní doby můžete zkontrolovat podle typu vzdělání. Distribuce mají mnoho odlišných tipů. Pravděpodobně to lze vysvětlit typem smlouvy v USA.
# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()Vysvětlení kódu
- ggplot (recast_data, aes (x = hours.per.week)): Graf hustoty vyžaduje pouze jednu proměnnou
- geom_density (aes (color = education), alpha = 0,5): Geometrický objekt pro řízení hustoty
Výstup:
Chcete-li potvrdit své myšlenky, můžete provést jednosměrný test ANOVA:
anova <- aov(hours.per.week~education, recast_data)summary(anova)Výstup:
## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1Test ANOVA potvrzuje průměrný rozdíl mezi skupinami.
Nelineárnost
Před spuštěním modelu můžete zjistit, zda počet odpracovaných hodin souvisí s věkem.
library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()Vysvětlení kódu
- ggplot (recast_data, aes (x = věk, y = hodiny. za týden)): Nastavit estetiku grafu
- geom_point (aes (color = income), size = 0,5): Sestavte tečkovaný graf
- stat_smooth (): Přidejte trendovou čáru s následujícími argumenty:
- method = 'lm': Vynese vynesenou hodnotu, pokud je lineární regrese
- vzorec = y ~ poly (x, 2): Přizpůsobí polynomickou regresi
- se = TRUE: Přidejte standardní chybu
- aes (barva = příjem): Rozdělte model podle příjmu
Výstup:
Stručně řečeno, můžete otestovat podmínky interakce v modelu, abyste získali efekt nelinearity mezi týdenní pracovní dobou a dalšími funkcemi. Je důležité zjistit, za jakých podmínek se pracovní doba liší.
Korelace
Další kontrolou je vizualizace korelace mezi proměnnými. Převedete typ úrovně faktoru na numerický, abyste mohli vykreslit teplotní mapu obsahující koeficient korelace vypočítaný metodou Spearman.
library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")Vysvětlení kódu
- data.frame (lapply (recast_data, as.integer)): Převést data na číselná
- ggcorr () vykreslí teplotní mapu s následujícími argumenty:
- metoda: Metoda pro výpočet korelace
- nbreaks = 6: Počet přestávek
- hjust = 0,8: Kontrolní pozice názvu proměnné v grafu
- label = TRUE: Přidejte štítky do středu oken
- label_size = 3: Štítky velikosti
- color = "grey50"): Barva štítku
Výstup:
Krok 5) Vlak / testovací sada
Jakýkoli úkol strojového učení pod dohledem vyžaduje rozdělení dat mezi vlakovou soupravu a testovací sadu. Můžete použít "funkci", kterou jste vytvořili v jiných supervizních výukových programech, k vytvoření vlakové / testovací sady.
set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)Výstup:
## [1] 36429 9dim(data_test)Výstup:
## [1] 9108 9Krok 6) Sestavte model
Chcete-li zjistit, jak algoritmus funguje, použijte balíček glm (). Generalized Linear Model je sbírka modelů. Základní syntaxe je:
glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")Jste připraveni odhadnout logistický model a rozdělit úroveň příjmu mezi sadu funkcí.
formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)Vysvětlení kódu
- vzorec <- příjem ~.: Vytvořte model tak, aby se vešel
- logit <- glm (formula, data = data_train, family = 'binomial'): Upravte logistický model (family = 'binomial') s daty data_train.
- summary (logit): Vytiskne souhrn modelu
Výstup:
#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6Souhrn našeho modelu odhaluje zajímavé informace. Výkon logistické regrese se hodnotí pomocí konkrétních klíčových metrik.
- AIC (Akaike Information Criteria): Toto je ekvivalent R2 v logistické regrese. Měří přizpůsobení, když je na počet parametrů uplatněn trest. Menší hodnoty AIC naznačují, že model je blíže pravdě.
- Nulová odchylka: Hodí se k modelu pouze s průsečíkem. Stupeň volnosti je n-1. Můžeme to interpretovat jako hodnotu chí-kvadrát (přizpůsobená hodnota odlišná od testování hypotézy skutečné hodnoty).
- Zbytková odchylka: Model se všemi proměnnými. Je také interpretováno jako testování hypotézy chí-kvadrát.
- Počet iterací Fisher Scoring: Počet iterací před konvergováním.
Výstup funkce glm () je uložen v seznamu. Níže uvedený kód ukazuje všechny položky dostupné v proměnné logit, kterou jsme vytvořili k vyhodnocení logistické regrese.
# Seznam je velmi dlouhý, tiskne pouze první tři prvky
lapply(logit, class)[1:3]Výstup:
## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"Každou hodnotu lze extrahovat znakem $, za kterým následuje název metriky. Například jste model uložili jako logit. K extrakci kritérií AIC použijete:
logit$aicVýstup:
## [1] 27086.65Krok 7) Posuďte výkon modelu
Matice zmatku
Zmatek matrix je lepší volbou pro hodnocení výkonu klasifikace ve srovnání s různými metrikami jsi viděl předtím. Obecná myšlenka je spočítat, kolikrát jsou instance True klasifikovány jako False.
Chcete-li vypočítat matici zmatku, musíte nejprve mít sadu předpovědí, aby je bylo možné porovnat se skutečnými cíli.
predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_matVysvětlení kódu
- predikce (logit, data_test, type = 'response'): Vypočítá predikci na testovací sadě. Nastavit type = 'response' pro výpočet pravděpodobnosti odpovědi.
- tabulka (data_test $ příjem, předpovědět> 0,5): Vypočítejte matici záměny. predikce> 0,5 znamená, že vrátí 1, pokud jsou předpokládané pravděpodobnosti vyšší než 0,5, jinak 0.
Výstup:
#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229Každý řádek v matici zmatku představuje skutečný cíl, zatímco každý sloupec představuje předpokládaný cíl. První řádek této matice považuje příjem nižší než 50k (třída False): 6241 bylo správně klasifikováno jako jednotlivci s příjmem nižším než 50k ( skutečná záporná ), zatímco zbývající část byla nesprávně klasifikována jako vyšší než 50k ( falešně pozitivní ). Druhá řada zvažuje příjem nad 50 tis., Pozitivní třída byla 1229 ( skutečná kladná ), zatímco skutečná záporná byla 1074.
Přesnost modelu můžete vypočítat součtem skutečných kladných a záporných hodnot za celé pozorování
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_TestVysvětlení kódu
- sum (diag (table_mat)): Součet úhlopříčky
- sum (table_mat): Součet matice.
Výstup:
## [1] 0.8277339Zdá se, že model trpí jedním problémem, nadhodnocuje počet falešných negativů. Tomu se říká paradox testu přesnosti . Uvedli jsme, že přesnost je poměr správných předpovědí k celkovému počtu případů. Můžeme mít relativně vysokou přesnost, ale zbytečný model. Stává se to, když existuje dominantní třída. Pokud se podíváte zpět na matici zmatku, uvidíte, že většina případů je klasifikována jako skutečně negativní. Představte si nyní, že model klasifikoval všechny třídy jako negativní (tj. Nižší než 50k). Měli byste přesnost 75 procent (6718/6718 + 2257). Váš model funguje lépe, ale snaží se rozlišit skutečné pozitivní od skutečného negativního.
V takové situaci je lepší mít stručnější metriku. Můžeme se podívat na:
- Přesnost = TP / (TP + FP)
- Odvolání = TP / (TP + FN)
Precision vs Recall
Přesnost sleduje přesnost pozitivní predikce. Recall je poměr pozitivních instancí, které jsou správně detekovány klasifikátorem;
Pro výpočet těchto dvou metrik můžete vytvořit dvě funkce
- Vytvořte přesnost
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}Vysvětlení kódu
- mat [1,1]: Vrátí první buňku prvního sloupce datového rámce, tj. skutečnou kladnou hodnotu
- rohož [1,2]; Vrátí první buňku druhého sloupce datového rámce, tj. Falešně pozitivní
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}Vysvětlení kódu
- mat [1,1]: Vrátí první buňku prvního sloupce datového rámce, tj. skutečnou kladnou hodnotu
- rohož [2,1]; Vrátí druhou buňku prvního sloupce datového rámce, tj. Falešně negativní
Můžete otestovat své funkce
prec <- precision(table_mat)precrec <- recall(table_mat)recVýstup:
## [1] 0.712877## [2] 0.5336518Když model říká, že se jedná o jednotlivce nad 50 tis., Je to správné pouze v 54 procentech případů a v 72 procentech případů může nárokovat jednotlivce nad 50 tis.
harmonický průměr těchto dvou metrik, což znamená, že kladou větší důraz na nižší hodnoty.
f1 <- 2 * ((prec * rec) / (prec + rec))f1Výstup:
## [1] 0.6103799Přesnost vs Vyvolání kompromisu
Je nemožné mít vysokou přesnost i vysokou vybavenost.
Pokud zvýšíme přesnost, bude správný jedinec lépe předvídán, ale spousta z nich by nám chyběla (nižší odvolání). V některých situacích dáváme přednost vyšší přesnosti než vyvolání. Mezi přesností a odvoláním existuje konkávní vztah.
- Představte si, že musíte předvídat, zda má pacient nemoc. Chcete být co nejpřesnější.
- Pokud potřebujete detekovat potenciální podvodníky na ulici pomocí rozpoznávání obličeje, bylo by lepší chytit mnoho lidí označených jako podvodníci, i když přesnost je nízká. Policie bude schopna nepodvodnou osobu propustit.
Křivka ROC
Charakteristika Přijímač Provozní křivka je další častý nástroj používaný s binárním klasifikaci. Je to velmi podobné křivce přesnosti / odvolání, ale místo vykreslení přesnosti versus odvolání ukazuje křivka ROC skutečnou pozitivní míru (tj. Vyvolání) proti falešně pozitivní rychlosti. Míra falešně pozitivních výsledků je poměr negativních případů, které jsou nesprávně klasifikovány jako pozitivní. Rovná se jedné minus skutečná záporná sazba. Skutečná záporná sazba se také nazývá specificita . Křivka ROC proto vykresluje citlivost (odvolání) proti 1 specificitě
K vykreslení křivky ROC musíme nainstalovat knihovnu s názvem RORC. Najdeme v knihovně conda. Můžete zadat kód:
conda install -cr r-rocr - ano
Můžeme vykreslit ROC pomocí funkcí predikce () a výkonu ().
library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))Vysvětlení kódu
- predikce (predikce, data_test $ příjem): Knihovna ROCR musí vytvořit objekt predikce k transformaci vstupních dat
- performance (ROCRpred, 'tpr', 'fpr'): Vrátí dvě kombinace, které mají být vytvořeny v grafu. Zde jsou konstruovány tpr a fpr. Přesnost a celkové vyvolání pozemku, použijte „prec“, „rec“.
Výstup:
Krok 8) Vylepšete model
Můžete se pokusit přidat do modelu nelinearitu s interakcí mezi
- věk a hodiny. za týden
- pohlaví a hodiny. za týden.
K porovnání obou modelů musíte použít test skóre
formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2Výstup:
## [1] 0.6109181Skóre je o něco vyšší než předchozí. Na datech můžete dále pracovat a pokusit se překonat skóre.
souhrn
V následující tabulce můžeme shrnout funkci pro trénování logistické regrese:
Balík
Objektivní
funkce
argument
-
Vytvořte datovou sadu vlaku / testu
create_train_set ()
data, velikost, vlak
glm
Trénujte zobecněný lineární model
glm ()
vzorec, data, rodina *
glm
Shrňte model
souhrn()
namontovaný model
základna
Udělejte předpověď
předpovědět()
vybavený model, datová sada, type = 'response'
základna
Vytvořte matici záměny
stůl()
y, předpovídat ()
základna
Vytvořte skóre přesnosti
součet (diag (tabulka ()) / součet (tabulka ()
ROCR
Vytvořit ROC: Krok 1 Vytvořte předpověď
předpověď()
predikovat (), r
ROCR
Vytvořit ROC: Krok 2 Vytvořte výkon
výkon()
prediction (), 'tpr', 'fpr'
ROCR
Vytvořte ROC: Krok 3 Vyneste graf
spiknutí()
výkon()
Další typy modelů GLM jsou:
- binomický: (link = "logit")
- gaussian: (link = "identita")
- Gama: (link = "inverzní")
- inverse.gaussian: (link = "1 / mu 2")
- poisson: (link = "log")
- kvazi: (link = "identity", variance = "konstantní")
- quasibinomial: (link = "logit")
- quasipoisson: (link = "log")