GLM v R: Zobecněný lineární model s příkladem

Obsah:

Anonim

Co je logistická regrese?

Logistická regrese se používá k předpovědi třídy, tj. Pravděpodobnosti. Logistická regrese může přesně předpovědět binární výsledek.

Představte si, že chcete předpovědět, zda je půjčka zamítnuta / přijata na základě mnoha atributů. Logistická regrese má tvar 0/1. y = 0, pokud je půjčka odmítnuta, y = 1, pokud je přijata.

Logistický regresní model se od lineárního regresního modelu liší dvěma způsoby.

  • Nejprve logistická regrese přijímá pouze dichotomický (binární) vstup jako závislou proměnnou (tj. Vektor 0 a 1).
  • Zadruhé, výsledek je měřen následující pravděpodobnostní spojovací funkcí zvanou sigmoid kvůli jejímu tvaru S .:

Výstup funkce je vždy mezi 0 a 1. Zkontrolujte obrázek níže

Funkce sigmoid vrací hodnoty od 0 do 1. Pro úkol klasifikace potřebujeme diskrétní výstup 0 nebo 1.

Chcete-li převést spojitý tok na diskrétní hodnotu, můžeme nastavit rozhodnutí vázané na 0,5. Všechny hodnoty nad touto prahovou hodnotou jsou klasifikovány jako 1

V tomto výukovém programu se naučíte

  • Co je logistická regrese?
  • Jak vytvořit model zobecněné linky (GLM)
  • Krok 1) Zkontrolujte spojité proměnné
  • Krok 2) Zkontrolujte proměnné faktoru
  • Krok 3) Inženýrství funkcí
  • Krok 4) Souhrnná statistika
  • Krok 5) Vlak / testovací sada
  • Krok 6) Sestavte model
  • Krok 7) Posuďte výkon modelu

Jak vytvořit model zobecněné linky (GLM)

Pojďme použít sadu dat pro dospělé k ilustraci logistické regrese. „Dospělý“ je skvělý soubor dat pro úkol klasifikace. Cílem je předpovědět, zda roční příjem v dolarech jednotlivce přesáhne 50 000. Datová sada obsahuje 46 033 pozorování a deset funkcí:

  • věk: věk jedince. Číselné
  • vzdělání: Vzdělávací úroveň jednotlivce. Faktor.
  • marital.status: Rodinný stav jednotlivce. Faktor, tj. Nikdy ženatý, ženatý, manžel,…
  • gender: pohlaví jednotlivce. Faktor, tj. Muž nebo žena
  • příjem: cílová proměnná. Příjem nad nebo pod 50 tis. Faktor tj.> 50K, <= 50K

mimo jiné

library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)

Výstup:

Observations: 48,842Variables: 10$ x  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age  25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass  Private, Private, Local-gov, Private, ?, Private,… $ education  11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num  7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status  Never-married, Married-civ-spouse, Married-civ-sp… $ race  Black, White, White, Black, White, White, Black,… $ gender  Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week  40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income  <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5… 

Budeme postupovat následovně:

  • Krok 1: Zkontrolujte spojité proměnné
  • Krok 2: Zkontrolujte proměnné faktoru
  • Krok 3: Inženýrství funkcí
  • Krok 4: Souhrnná statistika
  • Krok 5: Trénink / testovací sada
  • Krok 6: Sestavte model
  • Krok 7: Posouzení výkonu modelu
  • krok 8: Vylepšení modelu

Vaším úkolem je předpovědět, který jedinec bude mít tržby vyšší než 50 tis.

V tomto kurzu bude každý krok podrobně popsán, aby se provedla analýza na skutečné datové sadě.

Krok 1) Zkontrolujte spojité proměnné

V prvním kroku můžete vidět distribuci spojitých proměnných.

continuous <-select_if(data_adult, is.numeric)summary(continuous)

Vysvětlení kódu

  • spojitý <- select_if (data_adult, is.numeric): Pomocí funkce select_if () z knihovny dplyr vyberte pouze číselné sloupce
  • souhrn (průběžně): Vytiskne souhrnnou statistiku

Výstup:

## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00

Z výše uvedené tabulky můžete vidět, že data mají úplně jiná měřítka a hodiny. Per.weeks má velké odlehlé hodnoty (tj. Podívejte se na poslední kvartil a maximální hodnotu).

Můžete to vyřešit dvěma kroky:

  • 1: Nakreslete distribuci hodin.za.týden
  • 2: Standardizace spojitých proměnných
  1. Nakreslete distribuci

Podívejme se blíže na distribuci hours.per.week

# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")

Výstup:

Proměnná má spoustu odlehlých hodnot a není dobře definovaná distribuce. Tento problém můžete částečně vyřešit odstraněním horních 0,01 procenta hodin týdně.

Základní syntaxe kvantilu:

quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.

Vypočítáme horní 2 procentní percentil

top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent

Vysvětlení kódu

  • quantile (data_adult $ hours.per.week, 0,99): Vypočítá hodnotu 99 procent pracovní doby

Výstup:

## 99%## 80 

98 procent populace pracuje do 80 hodin týdně.

Pozorování můžete pustit nad tuto hranici. Používáte filtr z knihovny dplyr.

data_adult_drop <-data_adult %>%filter(hours.per.week

Výstup:

## [1] 45537 10 
  1. Standardizujte spojité proměnné

Každý sloupec můžete standardizovat, abyste zlepšili výkon, protože vaše data nemají stejné měřítko. Můžete použít funkci mutate_if z knihovny dplyr. Základní syntaxe je:

mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the function

Číselné sloupce můžete standardizovat následujícím způsobem:

data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)

Vysvětlení kódu

  • mutate_if (is.numeric, funs (scale)): Podmínkou je pouze číselný sloupec a funkcí je scale

Výstup:

## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50K

Krok 2) Zkontrolujte proměnné faktoru

Tento krok má dva cíle:

  • Zkontrolujte úroveň v každém kategorickém sloupci
  • Definujte nové úrovně

Tento krok rozdělíme na tři části:

  • Vyberte kategorické sloupce
  • Uložte sloupcový graf každého sloupce do seznamu
  • Vytiskněte grafy

Můžeme vybrat sloupce faktoru pomocí níže uvedeného kódu:

# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)

Vysvětlení kódu

  • data.frame (select_if (data_adult, is.factor)): Ukládáme sloupce faktoru v faktoru v typu datového rámce. Knihovna ggplot2 vyžaduje objekt datového rámce.

Výstup:

## [1] 6 

Datová sada obsahuje 6 kategorických proměnných

Druhý krok je zručnější. Chcete vykreslit sloupcový graf pro každý sloupec ve faktoru datového rámce. Je pohodlnější proces automatizovat, zejména v situaci, kdy existuje spousta sloupců.

library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))

Vysvětlení kódu

  • lapply (): Pomocí funkce lapply () předáte funkci ve všech sloupcích datové sady. Výstup uložíte do seznamu
  • funkce (x): Funkce bude zpracována pro každé x. Zde x jsou sloupce
  • ggplot (factor, aes (get (x))) + geom_bar () + téma (axis.text.x = element_text (angle = 90)): Vytvořte sloupcový char chart pro každý x prvek. Chcete-li vrátit x jako sloupec, musíte jej zahrnout dovnitř get ()

Poslední krok je relativně snadný. Chcete vytisknout 6 grafů.

# Print the graphgraph

Výstup:

## [[1]]

## ## [[2]]

## ## [[3]]

## ## [[4]]

## ## [[5]]

## ## [[6]]

Poznámka: Pomocí tlačítka Další přejděte na další graf

Krok 3) Inženýrství funkcí

Přepracované vzdělávání

Z výše uvedeného grafu můžete vidět, že proměnná vzdělání má 16 úrovní. To je podstatné a některé úrovně mají relativně nízký počet pozorování. Pokud chcete zlepšit množství informací, které můžete z této proměnné získat, můžete ji přepracovat na vyšší úroveň. Jmenovitě vytváříte větší skupiny s podobnou úrovní vzdělání. Například nízká úroveň vzdělání bude převedena do předčasného ukončení. Vyšší úrovně vzdělání se změní na magisterské.

Tady je detail:

Stará úroveň

Nová úroveň

Předškolní

výpadek

10.

Výpadek

11.

Výpadek

12

Výpadek

1.-4

Výpadek

5.-6

Výpadek

7.-8

Výpadek

9

Výpadek

HS-Grad

HighGrad

Některé vysoké školy

Společenství

Assoc-acdm

Společenství

Doc

Společenství

Bakaláři

Bakaláři

Mistři

Mistři

Profesní škola

Mistři

Doktorát

PhD

recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Vysvětlení kódu

  • Používáme sloveso mutovat z knihovny dplyr. Hodnoty vzdělávání měníme tvrzením ifelse

V níže uvedené tabulce vytvoříte souhrnnou statistiku, abyste v průměru zjistili, kolik let vzdělání (hodnota z) trvá k dosažení bakaláře, magisterského nebo doktorského studia.

recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)

Výstup:

## # A tibble: 6 x 3## education average_educ_year count##   ## 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557

Přepracovat rodinný stav

Je také možné vytvořit nižší úrovně pro rodinný stav. V následujícím kódu změníte úroveň následujícím způsobem:

Stará úroveň

Nová úroveň

Se nikdy neoženil

Není vdaná

Ženatý manžel chybí

Není vdaná

Ženatý-AF-manžel

Ženatý

Manželský občan

Oddělené

Oddělené

Rozvedený

Vdovy

Vdova

# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Můžete zkontrolovat počet jednotlivců v každé skupině.
table(recast_data$marital.status)

Výstup:

## ## Married Not_married Separated Widow## 21165 15359 7727 1286 

Krok 4) Souhrnná statistika

Je čas zkontrolovat některé statistiky o našich cílových proměnných. V níže uvedeném grafu spočítáte procento jednotlivců vydělávajících více než 50 tis. Vzhledem k jejich pohlaví.

# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()

Výstup:

Dále zkontrolujte, zda původ jednotlivce ovlivňuje jejich výdělky.

# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))

Výstup:

Počet hodin práce podle pohlaví.

# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()

Výstup:

Krabicový graf potvrzuje, že rozložení pracovní doby odpovídá různým skupinám. Na grafu pole nemají obě pohlaví homogenní pozorování.

Hustotu týdenní pracovní doby můžete zkontrolovat podle typu vzdělání. Distribuce mají mnoho odlišných tipů. Pravděpodobně to lze vysvětlit typem smlouvy v USA.

# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()

Vysvětlení kódu

  • ggplot (recast_data, aes (x = hours.per.week)): Graf hustoty vyžaduje pouze jednu proměnnou
  • geom_density (aes (color = education), alpha = 0,5): Geometrický objekt pro řízení hustoty

Výstup:

Chcete-li potvrdit své myšlenky, můžete provést jednosměrný test ANOVA:

anova <- aov(hours.per.week~education, recast_data)summary(anova)

Výstup:

## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Test ANOVA potvrzuje průměrný rozdíl mezi skupinami.

Nelineárnost

Před spuštěním modelu můžete zjistit, zda počet odpracovaných hodin souvisí s věkem.

library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()

Vysvětlení kódu

  • ggplot (recast_data, aes (x = věk, y = hodiny. za týden)): Nastavit estetiku grafu
  • geom_point (aes (color = income), size = 0,5): Sestavte tečkovaný graf
  • stat_smooth (): Přidejte trendovou čáru s následujícími argumenty:
    • method = 'lm': Vynese vynesenou hodnotu, pokud je lineární regrese
    • vzorec = y ~ poly (x, 2): Přizpůsobí polynomickou regresi
    • se = TRUE: Přidejte standardní chybu
    • aes (barva = příjem): Rozdělte model podle příjmu

Výstup:

Stručně řečeno, můžete otestovat podmínky interakce v modelu, abyste získali efekt nelinearity mezi týdenní pracovní dobou a dalšími funkcemi. Je důležité zjistit, za jakých podmínek se pracovní doba liší.

Korelace

Další kontrolou je vizualizace korelace mezi proměnnými. Převedete typ úrovně faktoru na numerický, abyste mohli vykreslit teplotní mapu obsahující koeficient korelace vypočítaný metodou Spearman.

library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")

Vysvětlení kódu

  • data.frame (lapply (recast_data, as.integer)): Převést data na číselná
  • ggcorr () vykreslí teplotní mapu s následujícími argumenty:
    • metoda: Metoda pro výpočet korelace
    • nbreaks = 6: Počet přestávek
    • hjust = 0,8: Kontrolní pozice názvu proměnné v grafu
    • label = TRUE: Přidejte štítky do středu oken
    • label_size = 3: Štítky velikosti
    • color = "grey50"): Barva štítku

Výstup:

Krok 5) Vlak / testovací sada

Jakýkoli úkol strojového učení pod dohledem vyžaduje rozdělení dat mezi vlakovou soupravu a testovací sadu. Můžete použít "funkci", kterou jste vytvořili v jiných supervizních výukových programech, k vytvoření vlakové / testovací sady.

set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)

Výstup:

## [1] 36429 9
dim(data_test)

Výstup:

## [1] 9108 9 

Krok 6) Sestavte model

Chcete-li zjistit, jak algoritmus funguje, použijte balíček glm (). Generalized Linear Model je sbírka modelů. Základní syntaxe je:

glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")

Jste připraveni odhadnout logistický model a rozdělit úroveň příjmu mezi sadu funkcí.

formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)

Vysvětlení kódu

  • vzorec <- příjem ~.: Vytvořte model tak, aby se vešel
  • logit <- glm (formula, data = data_train, family = 'binomial'): Upravte logistický model (family = 'binomial') s daty data_train.
  • summary (logit): Vytiskne souhrn modelu

Výstup:

#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6

Souhrn našeho modelu odhaluje zajímavé informace. Výkon logistické regrese se hodnotí pomocí konkrétních klíčových metrik.

  • AIC (Akaike Information Criteria): Toto je ekvivalent R2 v logistické regrese. Měří přizpůsobení, když je na počet parametrů uplatněn trest. Menší hodnoty AIC naznačují, že model je blíže pravdě.
  • Nulová odchylka: Hodí se k modelu pouze s průsečíkem. Stupeň volnosti je n-1. Můžeme to interpretovat jako hodnotu chí-kvadrát (přizpůsobená hodnota odlišná od testování hypotézy skutečné hodnoty).
  • Zbytková odchylka: Model se všemi proměnnými. Je také interpretováno jako testování hypotézy chí-kvadrát.
  • Počet iterací Fisher Scoring: Počet iterací před konvergováním.

Výstup funkce glm () je uložen v seznamu. Níže uvedený kód ukazuje všechny položky dostupné v proměnné logit, kterou jsme vytvořili k vyhodnocení logistické regrese.

# Seznam je velmi dlouhý, tiskne pouze první tři prvky

lapply(logit, class)[1:3]

Výstup:

## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"

Každou hodnotu lze extrahovat znakem $, za kterým následuje název metriky. Například jste model uložili jako logit. K extrakci kritérií AIC použijete:

logit$aic

Výstup:

## [1] 27086.65

Krok 7) Posuďte výkon modelu

Matice zmatku

Zmatek matrix je lepší volbou pro hodnocení výkonu klasifikace ve srovnání s různými metrikami jsi viděl předtím. Obecná myšlenka je spočítat, kolikrát jsou instance True klasifikovány jako False.

Chcete-li vypočítat matici zmatku, musíte nejprve mít sadu předpovědí, aby je bylo možné porovnat se skutečnými cíli.

predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_mat

Vysvětlení kódu

  • predikce (logit, data_test, type = 'response'): Vypočítá predikci na testovací sadě. Nastavit type = 'response' pro výpočet pravděpodobnosti odpovědi.
  • tabulka (data_test $ příjem, předpovědět> 0,5): Vypočítejte matici záměny. predikce> 0,5 znamená, že vrátí 1, pokud jsou předpokládané pravděpodobnosti vyšší než 0,5, jinak 0.

Výstup:

#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229

Každý řádek v matici zmatku představuje skutečný cíl, zatímco každý sloupec představuje předpokládaný cíl. První řádek této matice považuje příjem nižší než 50k (třída False): 6241 bylo správně klasifikováno jako jednotlivci s příjmem nižším než 50k ( skutečná záporná ), zatímco zbývající část byla nesprávně klasifikována jako vyšší než 50k ( falešně pozitivní ). Druhá řada zvažuje příjem nad 50 tis., Pozitivní třída byla 1229 ( skutečná kladná ), zatímco skutečná záporná byla 1074.

Přesnost modelu můžete vypočítat součtem skutečných kladných a záporných hodnot za celé pozorování

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test

Vysvětlení kódu

  • sum (diag (table_mat)): Součet úhlopříčky
  • sum (table_mat): Součet matice.

Výstup:

## [1] 0.8277339 

Zdá se, že model trpí jedním problémem, nadhodnocuje počet falešných negativů. Tomu se říká paradox testu přesnosti . Uvedli jsme, že přesnost je poměr správných předpovědí k celkovému počtu případů. Můžeme mít relativně vysokou přesnost, ale zbytečný model. Stává se to, když existuje dominantní třída. Pokud se podíváte zpět na matici zmatku, uvidíte, že většina případů je klasifikována jako skutečně negativní. Představte si nyní, že model klasifikoval všechny třídy jako negativní (tj. Nižší než 50k). Měli byste přesnost 75 procent (6718/6718 + 2257). Váš model funguje lépe, ale snaží se rozlišit skutečné pozitivní od skutečného negativního.

V takové situaci je lepší mít stručnější metriku. Můžeme se podívat na:

  • Přesnost = TP / (TP + FP)
  • Odvolání = TP / (TP + FN)

Precision vs Recall

Přesnost sleduje přesnost pozitivní predikce. Recall je poměr pozitivních instancí, které jsou správně detekovány klasifikátorem;

Pro výpočet těchto dvou metrik můžete vytvořit dvě funkce

  1. Vytvořte přesnost
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}

Vysvětlení kódu

  • mat [1,1]: Vrátí první buňku prvního sloupce datového rámce, tj. skutečnou kladnou hodnotu
  • rohož [1,2]; Vrátí první buňku druhého sloupce datového rámce, tj. Falešně pozitivní
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}

Vysvětlení kódu

  • mat [1,1]: Vrátí první buňku prvního sloupce datového rámce, tj. skutečnou kladnou hodnotu
  • rohož [2,1]; Vrátí druhou buňku prvního sloupce datového rámce, tj. Falešně negativní

Můžete otestovat své funkce

prec <- precision(table_mat)precrec <- recall(table_mat)rec

Výstup:

## [1] 0.712877## [2] 0.5336518

Když model říká, že se jedná o jednotlivce nad 50 tis., Je to správné pouze v 54 procentech případů a v 72 procentech případů může nárokovat jednotlivce nad 50 tis.

harmonický průměr těchto dvou metrik, což znamená, že kladou větší důraz na nižší hodnoty.

f1 <- 2 * ((prec * rec) / (prec + rec))f1

Výstup:

## [1] 0.6103799 

Přesnost vs Vyvolání kompromisu

Je nemožné mít vysokou přesnost i vysokou vybavenost.

Pokud zvýšíme přesnost, bude správný jedinec lépe předvídán, ale spousta z nich by nám chyběla (nižší odvolání). V některých situacích dáváme přednost vyšší přesnosti než vyvolání. Mezi přesností a odvoláním existuje konkávní vztah.

  • Představte si, že musíte předvídat, zda má pacient nemoc. Chcete být co nejpřesnější.
  • Pokud potřebujete detekovat potenciální podvodníky na ulici pomocí rozpoznávání obličeje, bylo by lepší chytit mnoho lidí označených jako podvodníci, i když přesnost je nízká. Policie bude schopna nepodvodnou osobu propustit.

Křivka ROC

Charakteristika Přijímač Provozní křivka je další častý nástroj používaný s binárním klasifikaci. Je to velmi podobné křivce přesnosti / odvolání, ale místo vykreslení přesnosti versus odvolání ukazuje křivka ROC skutečnou pozitivní míru (tj. Vyvolání) proti falešně pozitivní rychlosti. Míra falešně pozitivních výsledků je poměr negativních případů, které jsou nesprávně klasifikovány jako pozitivní. Rovná se jedné minus skutečná záporná sazba. Skutečná záporná sazba se také nazývá specificita . Křivka ROC proto vykresluje citlivost (odvolání) proti 1 specificitě

K vykreslení křivky ROC musíme nainstalovat knihovnu s názvem RORC. Najdeme v knihovně conda. Můžete zadat kód:

conda install -cr r-rocr - ano

Můžeme vykreslit ROC pomocí funkcí predikce () a výkonu ().

library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Vysvětlení kódu

  • predikce (predikce, data_test $ příjem): Knihovna ROCR musí vytvořit objekt predikce k transformaci vstupních dat
  • performance (ROCRpred, 'tpr', 'fpr'): Vrátí dvě kombinace, které mají být vytvořeny v grafu. Zde jsou konstruovány tpr a fpr. Přesnost a celkové vyvolání pozemku, použijte „prec“, „rec“.

Výstup:

Krok 8) Vylepšete model

Můžete se pokusit přidat do modelu nelinearitu s interakcí mezi

  • věk a hodiny. za týden
  • pohlaví a hodiny. za týden.

K porovnání obou modelů musíte použít test skóre

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2

Výstup:

## [1] 0.6109181 

Skóre je o něco vyšší než předchozí. Na datech můžete dále pracovat a pokusit se překonat skóre.

souhrn

V následující tabulce můžeme shrnout funkci pro trénování logistické regrese:

Balík

Objektivní

funkce

argument

-

Vytvořte datovou sadu vlaku / testu

create_train_set ()

data, velikost, vlak

glm

Trénujte zobecněný lineární model

glm ()

vzorec, data, rodina *

glm

Shrňte model

souhrn()

namontovaný model

základna

Udělejte předpověď

předpovědět()

vybavený model, datová sada, type = 'response'

základna

Vytvořte matici záměny

stůl()

y, předpovídat ()

základna

Vytvořte skóre přesnosti

součet (diag (tabulka ()) / součet (tabulka ()

ROCR

Vytvořit ROC: Krok 1 Vytvořte předpověď

předpověď()

predikovat (), r

ROCR

Vytvořit ROC: Krok 2 Vytvořte výkon

výkon()

prediction (), 'tpr', 'fpr'

ROCR

Vytvořte ROC: Krok 3 Vyneste graf

spiknutí()

výkon()

Další typy modelů GLM jsou:

- binomický: (link = "logit")

- gaussian: (link = "identita")

- Gama: (link = "inverzní")

- inverse.gaussian: (link = "1 / mu 2")

- poisson: (link = "log")

- kvazi: (link = "identity", variance = "konstantní")

- quasibinomial: (link = "logit")

- quasipoisson: (link = "log")