Co to jest regresja logistyczna?
Regresja logistyczna służy do przewidywania klasy, tj. Prawdopodobieństwa. Regresja logistyczna pozwala dokładnie przewidzieć wynik binarny.
Wyobraź sobie, że chcesz przewidzieć, czy pożyczka zostanie odrzucona / przyjęta na podstawie wielu cech. Regresja logistyczna ma postać 0/1. y = 0, jeśli pożyczka zostanie odrzucona, y = 1, jeśli zostanie przyjęta.
Model regresji logistycznej różni się od modelu regresji liniowej na dwa sposoby.
- Przede wszystkim regresja logistyczna przyjmuje tylko dychotomiczne (binarne) dane wejściowe jako zmienną zależną (tj. Wektor 0 i 1).
- Po drugie, wynik jest mierzony przez następującą probabilistyczną funkcję łączącą, zwaną sigmoidą ze względu na jej kształt litery S:
Wartość wyjściowa funkcji zawsze mieści się w zakresie od 0 do 1. Sprawdź obraz poniżej
Funkcja sigmoida zwraca wartości od 0 do 1. Do zadania klasyfikacji potrzebujemy dyskretnego wyjścia o wartości 0 lub 1.
Aby przekształcić ciągły przepływ w dyskretną wartość, możemy ustawić decyzję ograniczoną na 0,5. Wszystkie wartości powyżej tego progu są klasyfikowane jako 1
W tym samouczku dowiesz się
- Co to jest regresja logistyczna?
- Jak utworzyć uogólniony model liniowy (GLM)
- Krok 1) Sprawdź zmienne ciągłe
- Krok 2) Sprawdź zmienne czynnikowe
- Krok 3) Inżynieria funkcji
- Krok 4) Statystyka podsumowująca
- Krok 5) Trenuj / zestaw testowy
- Krok 6) Zbuduj model
- Krok 7) Oceń wydajność modelu
Jak utworzyć uogólniony model liniowy (GLM)
Użyjmy zestawu danych dla dorosłych, aby zilustrować regresję logistyczną. „Dorosły” to świetny zbiór danych do zadania klasyfikacyjnego. Celem jest przewidzenie, czy roczny dochód jednostki w dolarach przekroczy 50 000. Zbiór danych zawiera 46033 obserwacje i dziesięć funkcji:
- wiek: wiek osoby. Numeryczne
- edukacja: poziom wykształcenia jednostki. Czynnik.
- stan cywilny: stan cywilny osoby. Czynnik, tj. Osoba nigdy nie będąca w związku małżeńskim, małżonek cywilny,…
- płeć: płeć osoby. Czynnik, czyli mężczyzna lub kobieta
- dochód: zmienna docelowa. Dochód powyżej lub poniżej 50 tys. Czynnik tj.> 50K, <= 50K
wśród innych
library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)
Wynik:
Observations: 48,842Variables: 10$ x1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass Private, Private, Local-gov, Private, ?, Private,… $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status Never-married, Married-civ-spouse, Married-civ-sp… $ race Black, White, White, Black, White, White, Black,… $ gender Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5…
Będziemy postępować następująco:
- Krok 1: Sprawdź zmienne ciągłe
- Krok 2: Sprawdź zmienne czynnikowe
- Krok 3: Inżynieria funkcji
- Krok 4: Statystyka podsumowująca
- Krok 5: Trenuj / zestaw testowy
- Krok 6: Zbuduj model
- Krok 7: Oceń wydajność modelu
- Krok 8: Popraw model
Twoim zadaniem jest przewidzenie, która osoba będzie miała przychody większe niż 50 tys.
W tym samouczku każdy krok zostanie szczegółowo omówiony w celu przeprowadzenia analizy na rzeczywistym zbiorze danych.
Krok 1) Sprawdź zmienne ciągłe
W pierwszym kroku możesz zobaczyć rozkład zmiennych ciągłych.
continuous <-select_if(data_adult, is.numeric)summary(continuous)
Objaśnienie kodu
- ciągły <- select_if (data_adult, is.numeric): Użyj funkcji select_if () z biblioteki dplyr, aby wybrać tylko kolumny liczbowe
- podsumowanie (ciągłe): drukowanie statystyki podsumowującej
Wynik:
## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00
Z powyższej tabeli widać, że dane mają zupełnie inne skale i godziny. Na tydzień mają duże wartości odstające (. Tj. Spójrz na ostatni kwartyl i maksymalną wartość).
Możesz sobie z tym poradzić wykonując dwa kroki:
- 1: Wykreśl rozkład godzin. Na tydzień
- 2: Standaryzuj zmienne ciągłe
- Wykreśl dystrybucję
Przyjrzyjmy się bliżej rozkładowi godzin na tydzień
# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")
Wynik:
Zmienna ma wiele wartości odstających i nie jest dobrze zdefiniowana dystrybucja. Możesz częściowo rozwiązać ten problem, usuwając górne 0,01 procent godzin tygodniowo.
Podstawowa składnia kwantyla:
quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.
Obliczamy górny centyl 2
top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent
Objaśnienie kodu
- kwantyl (data_adult $ hours.per.week, .99): Oblicz wartość 99 procent czasu pracy
Wynik:
## 99%## 80
98 procent populacji pracuje poniżej 80 godzin tygodniowo.
Możesz obniżyć obserwacje powyżej tego progu. Używasz filtru z biblioteki dplyr.
data_adult_drop <-data_adult %>%filter(hours.per.weekWynik:
## [1] 45537 10
- Standaryzuj zmienne ciągłe
Możesz ustandaryzować każdą kolumnę, aby poprawić wydajność, ponieważ dane nie mają tej samej skali. Możesz użyć funkcji mutate_if z biblioteki dplyr. Podstawowa składnia to:
mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the functionMożesz ustandaryzować kolumny numeryczne w następujący sposób:
data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)Objaśnienie kodu
- mutate_if (is.numeric, funs (scale)): Warunek to tylko kolumna numeryczna, a funkcja to skala
Wynik:
## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50KKrok 2) Sprawdź zmienne czynnikowe
Ten krok ma dwa cele:
- Sprawdź poziom w każdej kolumnie kategorialnej
- Zdefiniuj nowe poziomy
Podzielimy ten krok na trzy części:
- Wybierz kolumny kategorialne
- Przechowuj wykres słupkowy każdej kolumny na liście
- Wydrukuj wykresy
Możemy wybrać kolumny współczynników za pomocą poniższego kodu:
# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)Objaśnienie kodu
- data.frame (select_if (data_adult, is.factor)): Przechowujemy kolumny współczynnika we współczynniku w typie ramki danych. Biblioteka ggplot2 wymaga obiektu ramki danych.
Wynik:
## [1] 6Zbiór danych zawiera 6 zmiennych kategorialnych
Drugi krok jest bardziej wymagający. Chcesz wykreślić wykres słupkowy dla każdej kolumny współczynnika ramki danych. Wygodniej jest zautomatyzować proces, zwłaszcza w sytuacji, gdy kolumn jest dużo.
library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))Objaśnienie kodu
- lapply (): użyj funkcji lapply (), aby przekazać funkcję we wszystkich kolumnach zestawu danych. Dane wyjściowe przechowujesz na liście
- function (x): Funkcja zostanie przetworzona dla każdego x. Tutaj x to kolumny
- ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): Utwórz wykres słupkowy dla każdego elementu x. Uwaga, aby zwrócić x jako kolumnę, musisz uwzględnić ją wewnątrz get ()
Ostatni krok jest stosunkowo łatwy. Chcesz wydrukować 6 wykresów.
# Print the graphgraphWynik:
## [[1]]## ## [[2]]## ## [[3]]## ## [[4]]## ## [[5]]## ## [[6]]Uwaga: Użyj następnego przycisku, aby przejść do następnego wykresu
Krok 3) Inżynieria funkcji
Przekształcona edukacja
Z powyższego wykresu widać, że zmienna edukacja ma 16 poziomów. Jest to istotne, a niektóre poziomy mają stosunkowo niewielką liczbę obserwacji. Jeśli chcesz zwiększyć ilość informacji, które możesz uzyskać z tej zmiennej, możesz przekształcić ją na wyższy poziom. Mianowicie, tworzysz większe grupy o podobnym poziomie wykształcenia. Na przykład niski poziom edukacji zostanie przekształcony w osoby przedwcześnie kończące naukę. Wyższe poziomy edukacji zostaną zmienione na mistrzowskie.
Oto szczegóły:
Stary poziom
Nowy poziom
Przedszkole
spadkowicz
10
Spadkowicz
11th
Spadkowicz
12
Spadkowicz
1-4
Spadkowicz
5-6
Spadkowicz
7-8
Spadkowicz
9
Spadkowicz
HS-Grad
HighGrad
Uczelnia
Społeczność
Assoc-acdm
Społeczność
Assoc-voc
Społeczność
Licencjat
Licencjat
Mistrzowie
Mistrzowie
Szkoła prof
Mistrzowie
Doktorat
Dr
recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))Objaśnienie kodu
- Używamy czasownika mutate z biblioteki dplyr. Zmieniamy wartości edukacji stwierdzeniem ifelse
W poniższej tabeli tworzysz statystykę podsumowującą, aby zobaczyć, ile średnio lat edukacji (wartość z) potrzeba, aby uzyskać tytuł licencjata, magistra lub doktora.
recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)Wynik:
## # A tibble: 6 x 3## education average_educ_year count#### 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557 Przekształcenie stanu cywilnego
Możliwe jest również tworzenie niższych poziomów stanu cywilnego. W poniższym kodzie możesz zmienić poziom w następujący sposób:
Stary poziom
Nowy poziom
Nigdy się nie ożenił
Niezamężny
Żonaty-małżonek-nieobecny
Niezamężny
Żonaty-małżonek-AF
Żonaty
Żonaty-cywilny-małżonek
Rozdzielony
Rozdzielony
Rozwiedziony
Wdowy
Wdowa
# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))Możesz sprawdzić liczbę osób w każdej grupie.table(recast_data$marital.status)Wynik:
## ## Married Not_married Separated Widow## 21165 15359 7727 1286Krok 4) Statystyka podsumowująca
Czas sprawdzić statystyki dotyczące naszych zmiennych docelowych. Na poniższym wykresie policzysz odsetek osób zarabiających powyżej 50 tys. Osób, biorąc pod uwagę ich płeć.
# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()Wynik:
Następnie sprawdź, czy pochodzenie osoby wpływa na jej zarobki.
# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))Wynik:
Liczba godzin pracy według płci.
# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()Wynik:
Wykres pudełkowy potwierdza, że rozkład czasu pracy pasuje do różnych grup. Na wykresie pudełkowym obie płcie nie mają jednorodnych obserwacji.
Możesz sprawdzić gęstość tygodniowego czasu pracy według rodzaju wykształcenia. Dystrybucje mają wiele różnych typów. Można to prawdopodobnie wyjaśnić rodzajem umowy w USA.
# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()Objaśnienie kodu
- ggplot (recast_data, aes (x = hours.per.week)): Wykres gęstości wymaga tylko jednej zmiennej
- geom_density (aes (kolor = edukacja), alfa = 0,5): Obiekt geometryczny do kontrolowania gęstości
Wynik:
Aby potwierdzić swoje myśli, możesz wykonać jednokierunkowy test ANOVA:
anova <- aov(hours.per.week~education, recast_data)summary(anova)Wynik:
## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1Test ANOVA potwierdza różnicę w średniej między grupami.
Nieliniowość
Przed uruchomieniem modelu możesz sprawdzić, czy liczba przepracowanych godzin jest związana z wiekiem.
library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()Objaśnienie kodu
- ggplot (recast_data, aes (x = wiek, y = hours.per.week)): Ustaw estetykę wykresu
- geom_point (aes (kolor = dochód), rozmiar = 0,5): Skonstruuj wykres kropkowy
- stat_smooth (): Dodaj linię trendu z następującymi argumentami:
- metoda = „lm”: Wykreśl dopasowaną wartość w przypadku regresji liniowej
- formuła = y ~ poly (x, 2): Dopasuj regresję wielomianową
- se = TRUE: Dodaj błąd standardowy
- aes (kolor = dochód): Podziel model według dochodu
Wynik:
Krótko mówiąc, możesz przetestować warunki interakcji w modelu, aby wykryć efekt nieliniowości między tygodniowym czasem pracy a innymi funkcjami. Ważne jest, aby wykryć, w jakich warunkach różni się czas pracy.
Korelacja
Następnym sprawdzeniem jest wizualizacja korelacji między zmiennymi. Konwertujesz typ poziomu czynnika na numeryczny, aby można było wykreślić mapę cieplną zawierającą współczynnik korelacji obliczony metodą Spearmana.
library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")Objaśnienie kodu
- data.frame (lapply (recast_data, as.integer)): Konwertuj dane na numeryczne
- ggcorr () wykreśla mapę cieplną z następującymi argumentami:
- metoda: metoda obliczania korelacji
- nbreaks = 6: Liczba przerw
- hjust = 0,8: Kontroluje pozycję nazwy zmiennej na wykresie
- label = TRUE: Dodaj etykiety na środku okien
- label_size = 3: etykiety rozmiarów
- color = "grey50"): kolor etykiety
Wynik:
Krok 5) Trenuj / zestaw testowy
Każde nadzorowane zadanie uczenia maszynowego wymaga podzielenia danych między zestaw pociągów i zestaw testowy. Możesz użyć „funkcji”, którą utworzyłeś w innych nadzorowanych samouczkach do tworzenia zestawu treningowego / testowego.
set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)Wynik:
## [1] 36429 9dim(data_test)Wynik:
## [1] 9108 9Krok 6) Zbuduj model
Aby zobaczyć, jak działa algorytm, użyj pakietu glm (). Ogólny model liniowy jest zbiorem modeli. Podstawowa składnia to:
glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")Jesteś gotowy do oszacowania modelu logistycznego w celu podzielenia poziomu dochodów między zestaw cech.
formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)Objaśnienie kodu
- formuła <- dochód ~.: Stwórz dopasowany model
- logit <- glm (formuła, dane = pociąg_danych, rodzina = 'dwumian'): Dopasuj model logistyczny (rodzina = 'dwumian') do danych z pociągu_danych.
- podsumowanie (logit): Wydrukuj podsumowanie modelu
Wynik:
#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6Podsumowanie naszego modelu ujawnia ciekawe informacje. Wydajność regresji logistycznej jest oceniana za pomocą określonych kluczowych wskaźników.
- AIC (Akaike Information Criteria): jest to odpowiednik R2 w regresji logistycznej. Mierzy dopasowanie, gdy kara jest nakładana na liczbę parametrów. Mniejsze wartości AIC wskazują, że model jest bliższy prawdy.
- Odchylenie zerowe: pasuje do modelu tylko z punktem przecięcia z osią. Stopień swobody wynosi n-1. Możemy to zinterpretować jako wartość Chi-kwadrat (dopasowana wartość różni się od rzeczywistej wartości testowania hipotezy).
- Resztkowe odchylenie: Modeluj ze wszystkimi zmiennymi. Jest również interpretowane jako testowanie hipotezy Chi-kwadrat.
- Liczba iteracji punktacji Fishera: liczba iteracji przed zbieżnością.
Wynik funkcji glm () jest przechowywany na liście. Poniższy kod przedstawia wszystkie elementy dostępne w zmiennej logit, którą skonstruowaliśmy w celu oceny regresji logistycznej.
# Lista jest bardzo długa, wypisz tylko pierwsze trzy elementy
lapply(logit, class)[1:3]Wynik:
## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"Każdą wartość można wyodrębnić ze znakiem $, po którym następuje nazwa metryki. Na przykład zapisałeś model jako logit. Aby wyodrębnić kryteria AIC, użyj:
logit$aicWynik:
## [1] 27086.65Krok 7) Oceń wydajność modelu
Macierz zamieszania
Matryca zamieszanie jest lepszym wyborem, aby ocenić skuteczność klasyfikacji w porównaniu z różnych metryk widziałeś wcześniej. Ogólną ideą jest policzenie, ile razy instancje True są klasyfikowane jako fałszywe.
Aby obliczyć macierz nieporozumień, musisz najpierw mieć zestaw prognoz, aby można je było porównać z rzeczywistymi celami.
predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_matObjaśnienie kodu
- Predict (logit, data_test, type = 'response'): Oblicz prognozę na zbiorze testowym. Ustaw type = 'response', aby obliczyć prawdopodobieństwo odpowiedzi.
- table (data_test $ dochodu, prognoza> 0,5): Oblicz macierz pomyłki. przewidzieć> 0,5 oznacza, że zwraca 1, jeśli przewidywane prawdopodobieństwa są powyżej 0,5, w przeciwnym razie 0.
Wynik:
#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229Każdy wiersz w macierzy nieporozumień reprezentuje rzeczywisty cel, podczas gdy każda kolumna przedstawia przewidywany cel. W pierwszym wierszu tej macierzy uwzględniono dochody poniżej 50 tys. (Klasa Fałsz): 6241 zostało poprawnie zaklasyfikowanych jako osoby z dochodem poniżej 50 tys. ( Prawda ujemna ), a pozostała została błędnie zaklasyfikowana jako powyżej 50 tys. ( Fałszywie dodatnia ). Drugi rząd uwzględnia dochody powyżej 50 tys., Klasa dodatnia to 1229 ( prawdziwie dodatnie ), a prawdziwie ujemna 1074.
Możesz obliczyć dokładność modelu , sumując wartość prawdziwie dodatnią + prawdziwie ujemną przez całą obserwację
accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_TestObjaśnienie kodu
- sum (diag (table_mat)): Suma przekątnej
- sum (table_mat): Suma macierzy.
Wynik:
## [1] 0.8277339Model wydaje się cierpieć z powodu jednego problemu - przeszacowuje liczbę fałszywie negatywnych wyników. Nazywa się to paradoksem testu dokładności . Stwierdziliśmy, że dokładność to stosunek prawidłowych prognoz do całkowitej liczby przypadków. Możemy mieć stosunkowo dużą dokładność, ale model bezużyteczny. Dzieje się tak, gdy istnieje klasa dominująca. Jeśli spojrzysz wstecz na matrycę nieporozumień, zobaczysz, że większość przypadków jest sklasyfikowanych jako prawdziwie negatywna. Wyobraź sobie teraz, że model sklasyfikował wszystkie klasy jako negatywne (tj. Poniżej 50 tys.). Miałbyś dokładność 75 procent (6718/6718 + 2257). Twój model działa lepiej, ale ma trudności z odróżnieniem prawdziwego pozytywu od prawdziwego negatywu.
W takiej sytuacji lepiej jest mieć bardziej zwięzłą metrykę. Możemy spojrzeć na:
- Precyzja = TP / (TP + FP)
- Przypomnijmy = TP / (TP + FN)
Precyzja a przywołanie
Precyzja polega na dokładności pozytywnej prognozy. Przypomnienie to stosunek pozytywnych instancji, które są poprawnie wykrywane przez klasyfikator;
Możesz skonstruować dwie funkcje, aby obliczyć te dwie metryki
- Konstruuj precyzję
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}Objaśnienie kodu
- mat [1,1]: Zwraca pierwszą komórkę pierwszej kolumny ramki danych, czyli wartość prawdziwie dodatnią
- mata [1, 2]; Zwróć pierwszą komórkę drugiej kolumny ramki danych, czyli wynik fałszywie dodatni
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}Objaśnienie kodu
- mat [1,1]: Zwraca pierwszą komórkę pierwszej kolumny ramki danych, czyli wartość prawdziwie dodatnią
- mata [2,1]; Zwróć drugą komórkę pierwszej kolumny ramki danych, czyli wynik fałszywie ujemny
Możesz przetestować swoje funkcje
prec <- precision(table_mat)precrec <- recall(table_mat)recWynik:
## [1] 0.712877## [2] 0.5336518Kiedy model mówi, że jest to osoba powyżej 50 tys., Jest to poprawne tylko w 54 procentach przypadków i może zgłaszać roszczenia osób powyżej 50 tys. W 72 procentach.
Możesz stworzyć Jest średnia harmoniczna tych dwóch wskaźników, co oznacza, że daje większą wagę do niższych wartości.
f1 <- 2 * ((prec * rec) / (prec + rec))f1Wynik:
## [1] 0.6103799Kompromis między precyzją a przywołaniem
Niemożliwe jest uzyskanie zarówno wysokiej precyzji, jak i wysokiej przywołania.
Jeśli zwiększymy precyzję, właściwa osoba będzie lepiej przewidziana, ale wiele z nich przeoczylibyśmy (mniejsza pamięć). W niektórych sytuacjach wolimy większą precyzję niż przywołanie. Istnieje wklęsły związek między precyzją a pamięcią.
- Wyobraź sobie, że musisz przewidzieć, czy pacjent ma chorobę. Chcesz być tak precyzyjny, jak to tylko możliwe.
- Jeśli musisz wykryć potencjalnych oszukańczych ludzi na ulicy za pomocą rozpoznawania twarzy, lepiej byłoby złapać wiele osób oznaczonych jako oszukańcze, mimo że precyzja jest niska. Policja będzie mogła zwolnić osobę nieuczciwą.
Krzywa ROC
Charakterystyczny Receiver Operating krzywa jest kolejnym wspólnym narzędziem używanym przy klasyfikacji binarnej. Jest bardzo podobna do krzywej precyzji / przypomnienia, ale zamiast wykreślania precyzji w porównaniu z pamięcią, krzywa ROC przedstawia odsetek prawdziwie pozytywnych wyników (tj. Przypominanie) w porównaniu z odsetkiem wyników fałszywie dodatnich. Współczynnik wyników fałszywie pozytywnych to stosunek negatywnych przypadków, które są nieprawidłowo sklasyfikowane jako pozytywne. Równa się jeden minus prawdziwa stopa ujemna. Prawdziwa ujemna stopa jest również nazywana specyficznością . Stąd krzywa ROC przedstawia czułość (przypominanie) w porównaniu z 1-specyficznością
Aby wykreślić krzywą ROC, musimy zainstalować bibliotekę o nazwie RORC. Możemy znaleźć w bibliotece Conda. Możesz wpisać kod:
conda install -cr r-rocr --yes
Możemy wykreślić ROC za pomocą funkcji predykcji () i wydajności ().
library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))Objaśnienie kodu
- prediction (predykcja, data_test $ dochód): Biblioteka ROCR musi utworzyć obiekt predykcji, aby przekształcić dane wejściowe
- wydajność (ROCRpred, 'tpr', 'fpr'): Zwróć dwie kombinacje do utworzenia na wykresie. Tutaj konstruowane są tpr i fpr. Tot plot precyzja i przywołaj razem, użyj "prec", "rec".
Wynik:
Krok 8) Popraw model
Możesz spróbować dodać nieliniowość do modelu za pomocą interakcji między
- wiek i godziny. na tydzień
- płeć i godziny na tydzień.
Musisz użyć testu punktacji, aby porównać oba modele
formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2Wynik:
## [1] 0.6109181Wynik jest nieco wyższy niż poprzedni. Możesz kontynuować pracę nad danymi i spróbować pobić wynik.
Podsumowanie
Możemy podsumować funkcję trenowania regresji logistycznej w poniższej tabeli:
Pakiet
Cel
funkcjonować
argument
-
Utwórz zestaw danych pociągu / testu
create_train_set ()
dane, rozmiar, pociąg
glm
Wytrenuj uogólniony model liniowy
glm ()
wzór, dane, rodzina *
glm
Podsumuj model
Podsumowanie()
dopasowany model
baza
Przewidzieć
przepowiadać, wywróżyć()
dopasowany model, zbiór danych, typ = 'odpowiedź'
baza
Stwórz macierz zamieszania
stół()
y, przewidywanie ()
baza
Utwórz wynik dokładności
sum (diag (table ()) / sum (table ()
ROCR
Utwórz ROC: Krok 1 Utwórz prognozę
Prognoza()
przewidzieć (), y
ROCR
Utwórz ROC: Krok 2 Utwórz wydajność
wydajność()
prediction (), „tpr”, „fpr”
ROCR
Utwórz ROC: Krok 3 Wykreśl wykres
wątek()
wydajność()
Inne modele GLM to:
- dwumian: (link = "logit")
- gaussian: (link = "tożsamość")
- Gamma: (link = "odwrotność")
- inverse.gaussian: (link = "1 / mu 2")
- poisson: (link = "log")
- quasi: (link = "tożsamość", wariancja = "stała")
- quasibinomial: (link = "logit")
- quasipoisson: (link = "log")