Predykcja poziomu dochodów z użyciem modeli klasyfikacyjnych
1 Cel projektu
Celem projektu jest zbudowanie modelu klasyfikacyjnego, który na podstawie cech przewidzi, czy dana osoba zarabia ponad 50 tysięcy dolarów rocznie.
Model ma pomóc w identyfikacji czynników wpływających na wysokie dochody oraz umożliwić skuteczną klasyfikację przyszłych obserwacji.
Wykorzystane zostaną metody uczenia maszynowego, takie jak regresja logistyczna, drzewo decyzyjne, las losowy, naiwny klasyfikator bayesowski, algorytm K najbliższych sąsaidów oraz XGBoost, a także techniki przetwarzania danych i strojenia hiperparametrów.
2 Zbiór danych
Pierwsze przykładowe 100 obserwacji:
2.1 Opis zbioru
Źródło: https://www.kaggle.com/datasets/uciml/adult-census-income/data
Zbiór danych Adult Census Income pochodzi z badania amerykańskiego spisu powszechnego i zawiera 32 561 obserwacji. Dane obejmują cechy m.in. takie jak wiek, wykształcenie, stan cywilny, zawód, liczba przepracowanych godzin tygodniowo oraz kraj pochodzenia. Celem analizy jest przewidywanie, czy dochód osoby przekracza próg 50 tysięcy dolarów rocznie (zmienna binarna).
Zbiór zawiera 15 zmiennych:
Opis zmiennych | ||||
---|---|---|---|---|
Zmienna | Opis | Typ | Zakres | Uwagi |
age | Wiek | numeric | 17-90 | - |
workclass | Typ zatrudnienia | factor | Federal-gov, Local-gov, Never-worked, Private, Self-emp-inc, Self-emp-not-inc, State-gov, Without-pay | '?' - 1836; 8 poziomów |
fnlwgt | Waga obserwacji nadana przez amerykańskie Biuro Spisu Ludności (U.S. Census Bureau) | numeric | 12285 – 1484705 | - |
education | Poziom wykształcenia | factor | 10th, 11th, 12th, 1st-4th, 5th-6th, 7th-8th, 9th, Assoc-acdm, Assoc-voc, Bachelors, Doctorate, HS-grad, Masters, Preschool, Prof-school, Some-college | 16 poziomów |
education.num | Liczbowa wersja wykształcenia | numeric | 1 - 16 | - |
marital.status | Stan cywilny | factor | Divorced, Married-AF-spouse, Married-civ-spouse, Married-spouse-absent, Never-married, Separated, Widowed | 7 poziomów |
occupation | Zawód | factor | Adm-clerical, Armed-Forces, Craft-repair, Exec-managerial, Farming-fishing, Handlers-cleaners, Machine-op-inspct, Other-service, Priv-house-serv, Prof-specialty, Protective-serv, Sales, Tech-support, Transport-moving | '?' - 1843; 14 poziomów |
relationship | Relacje | factor | Husband, Not-in-family, Other-relative, Own-child, Unmarried, Wife | 6 poziomów |
race | Rasa/pochodzenie etniczne | factor | Amer-Indian-Eskimo, Asian-Pac-Islander, Black, Other, White | 5 poziomów |
sex | Płeć | factor | Female, Male | 2 poziomy |
capital.gain | Dochód kapitałowy | numeric | 0-99999 | - |
capital.loss | Straty kapitałowe | numeric | 0-4356 | - |
hours.per.week | Liczba przepracowanych godzin tygodniowo | numeric | 1-99 | - |
native.country | Kraj pochodzenia | factor | Cambodia, Canada, China, Columbia, Cua, Dominican-Republic, Ecuador, El-Salvador, England, France, Germany, Greece, Guatemala, Haiti, Holand-Netherlands, Honduras, Hong, Hungary, India, Iran, Ireland, Italy, Jamaica, Japan, Laos, Mexico, Nicaragua,Outlying-US(Guam-USVI-etc), Peru, Philippines, Poland, PPortugal, Puerto-Rico, Scotland, South, Taiwan, Thailand, Trinadad&Tobago, United-States, Vietnam, Yugoslavia | '?'- 583; 41 poziomów |
income | Dochód; Zmienna docelowa (target) – czy dana osoba zarabia mniej czy więcej niż $50K rocznie | factor | <=50K – do 50 000 USD, >50K – powyżej 50 000 USD rocznie | 2 poziomy |
2.2 Braki danych
W zbiorze danych występuje 30 162 wiersze bez braków danych. Jest również:
7 wierszy, które mają braki danych tylko w zmiennej
occupation
, a pozostałe zmienne mają uzupełnione wartości.1809 wierszy gdzie występuje brak danych w zmiennych
workclass
ioccupation
, a pozostałe wartości są uzupełnione.556 wierszy zawierających brakującą wartość tylko dla zmiennej
native.country
.27 wierszy gdzie są braki danych jednocześnie w
native.country
,workclass
ioccupation
.
Około 7,4% całego zbioru danych zawiera wartości brakujące. Usuniemy jedynie 27 wierszy, które posiadają braki we wszystkich trzech kolumnach jednocześnie. Pozostałe brakujące wartości zostaną zastąpione wartością „Unknown”. Dzięki temu unikniemy usuwania dużej liczby obserwacji, w tym także tych nietypowych np. nietypowego zawodu.
Po usunięciu wspomnianych 27 rekordów, zbiór danych liczy teraz 32 534 obserwacje.
2.3 Analiza zbioru danych
Zanim przystąpimy do budowy modeli, przyjrzymy się bliżej danym. Sprawdzimy m.in. jak wygląda rozkład dochodów oraz jakie zależności występują między zmiennymi. Taka analiza pozwoli nam lepiej zrozumieć dane i przygotować je do dalszych etapów pracy.
2.3.1 Procentowy rozkład zmiennej income
Na podstawie wykresu możemy zauważyć, że zmienna docelowa income
jest niezbalansowana. Większość obserwacji – 75,9% – dotyczy dochodów równych lub poniżej $50K (<=50K), natomiast 24,1% stanowią dochody powyżej $50K (>50K). Taka dysproporcja klas może mieć istotny wpływ na skuteczność modeli klasyfikacyjnych i wymaga odpowiednich technik balansowania danych, takich jak np. oversampling.
2.3.2 Zależność dochodu od wieku
Młodsze osoby najczęściej zarabiają do 50 tysięcy dolarów rocznie. Wraz z wiekiem, szczególnie w przedziale 25–50 lat, widoczny jest wzrost zarobków, a coraz więcej osób przekracza próg 50 tysięcy dolarów.
2.3.3 Hours per week
Rozkład godzin pracy jest dość symetryczny, bez dużych odchyleń w jedną stronę.
zmienna | średnia | mediana | min | max | odch. standardowe | Q1 | Q3 |
---|---|---|---|---|---|---|---|
hours.per.week | 40.44 | 40 | 1 | 99 | 12.34 | 40 | 45 |
Mediana i średnia liczba godzin pracy wynosi około 40. Większość osób realizuje standardowy wymiar pracy (około 40 godz./tydz), ale istnieją też skrajne przypadki z mniejszą lub znacząco większą liczbą godzin pracy.
2.3.4 Płeć a dochód
<=50K | >50K | |
---|---|---|
Female | 9579 | 1177 |
Male | 15118 | 6660 |
Wśród kobiet zdecydowana większość (9579) zarabia nie więcej niż 50K (<=50K), a tylko 1177 osiąga wyższy dochód (>50). U mężczyzn również przeważają zarobki do 50K (15 118 osób), ale aż 6660 osiąga dochód powyżej tej granicy. Mężczyźni są liczniejsi w obu grupach, szczególnie w tej lepiej zarabiającej. Procent kobiet zarabiających powyżej 50 tys. w stosunku do wszytskich kobiet jest mniejszy niż procent mężczyzn zarabiających powyżej 50 tys. w stosunku do wszystkich mężczyzn. Dane sugerują, że kobiety rzadziej należą do grupy z wyższym dochodem.
2.3.5 Typ zatrudnienia
Typ zatrudnienia | Liczba osób |
---|---|
Federal-gov | 960 |
Local-gov | 2093 |
Never-worked | 7 |
Private | 22696 |
Self-emp-inc | 1116 |
Self-emp-not-inc | 2541 |
State-gov | 1298 |
Without-pay | 14 |
Unknown | 1809 |
W zbiorze dominują osoby zatrudnione w sektorze prywatnym (Private - 22 696 osób), co stanowi zdecydowaną większość. Najmniej liczną grupą są osoby nigdy niepracujące (Never-worked - 7 osób) oraz bez wynagrodzenia (Without-pay 14) . Nieznany tryb zatrudnienia (Unknown) ma 1809 osób.
2.3.6 Korelacja zmiennych liczbowych
Współczynniki korelacji między zmiennymi są niskie, co wskazuje na słabe powiązania liniowe między nimi. Najwyższa dodatnia korelacja występuje między education.num
a hours.per.week
(około 0.15), sugerując, że osoby z wyższym wykształceniem nieco częściej pracują więcej godzin. Jednak wszystkie korelacje są bardzo słabe lub niemal zerowe.
2.3.7 Rozkłady zmiennych numerycznych
Rozkład zmiennej age
jest zbliżony do normalnego, rozkład fnlwgt
jest prawostronnie asymetryczny. education.num
odzwierciedla różne poziomy wyksztacenia. Zmienne capital.gain
oraz capital.loss
wyróżniają się dużą asymetrią, większość wartości to 0 lub małe liczby.
2.3.8 capital.gain
zmienna | średnia | mediana | min | max | odch. standardowe | Q1 | Q3 |
---|---|---|---|---|---|---|---|
capital.gain | 1078.08 | 0 | 0 | 99999 | 7388.18 | 0 | 0 |
Mediana i oba kwartyle są równe 0, co oznacza, że więcej niż 75% osób nie ma żadnych zysków kapitałowych. Średnia (1077.65) jest znacznie wyższa niż mediana, co sugeruje obecność nielicznych, ale bardzo wysokich wartości. Wysokie odchylenie standardowe (7385.29) i maksymalna wartość 99999 (159 takich przypadków) wskazują na ekstremalne przypadki.
2.3.9 capital.loss
zmienna | średnia | mediana | min | max | odch. standardowe | Q1 | Q3 |
---|---|---|---|---|---|---|---|
capital.loss | 87.27 | 0 | 0 | 4356 | 402.93 | 0 | 0 |
Ponownie zarówno mediana, jak i oba kwartyle (Q1 i Q3) wynoszą 0, co oznacza, że co najmniej 75% obserwacji ma capital.loss
równe 0. Tutaj podobnie pojawiają się nieliczne wysokie wartości.
3 Metoda analizy
3.1 Regresja logistyczna (Logistic Regression)
Regresja logistyczna to algorytm klasyfikacji, służący do przewidywania prawdopodobieństwa przynależności obserwacji do jednej z dwóch klas (sukces i porażka). Model opiera się na funkcji logistycznej (sigmoidzie), która przekształca liniową kombinację cech (predyktorów \(X\)) w wartość z zakresu od 0 do 1, interpretowaną jako prawdopodobieństwo sukcesu.
Ogólna postać modelu:
\[ Y \sim B(1,p) \]
\[ p(X) = E(Y|X) = \frac{exp(\beta X)}{1 + exp(\beta X)} \]
gdzie \(B(1,p)\) jest rozkładem dwumianowym o prawdopodobieństwie sukcesu \(p\), a \(\beta X\) oznacza kombinację liniową parametrów modelu i wartości zmiennych niezależnych, przyjmując, że \(x_0 = 1\).
3.2 Drzewo decyzyjne (Decision Tree)
Drzewa decyzyjne stosowane są szczególnie wtedy, gdy funkcyjna postać związku pomiędzy predyktorami a zmienną wynikową jest nieznana lub ciężka do ustalenia. Każde drzewo decyzyjne składa się z korzenia (ang. root), węzłów (ang. nodes) i liści (ang. leaves). Korzeń to początkowy węzeł drzewa, z którego poprzez podziały (ang. splits) powstają kolejne węzły potomne. Końcowe węzły, które nie podlegają podziałom nazywamy liśćmi, a linie łączące węzły nazywamy gałęziami (ang. branches).
Drzewa decyzyjne można stosować zarówno do zadań klasyfikacyjnych, jak i regresyjnych. W przypadku klasyfikacji, każdy liść drzewa wskazuje, do której klasy należy dana obserwacja - czyli która klasa jest najbardziej prawdopodobna po przejściu przez kolejne podziały. W przypadku regresji, liście zawieracją zazwyczaj średnią wartość zmiennej wynikowej dla obserwacji, które tam trafiły.
Do zalet drzew decyzyjnych należy przede wszystkim łatwość interpretacji, niewielkie wymagania dotyczące przygotowania danych oraz możliwość pracy zarówno na zmiennych jakościowych, jak i ilościowych.
3.3 Las losowy (Random Forest)
Las losowy to metoda polegająca na losowym wyborze dla każdego drzewa wchodzącego w skład lasu \(m\) predyktorów spośród \(p\) dostępnych, a następnie budowaniu drzew z wykorzystaniem tylko tych predyktorów. Dzięki temu za każdym razem drzewo jest budowane w oparciu o nowy zestaw cech (najczęściej przyjmujemy \(m = \sqrt{p}\) ).
Dzięki uśrednianiu wyników poszczególnych drzew model ten uzyskuje wysoką stabilość i dokładność. Las losowy skutecznie zmniejsza ryzyko przeuczenia i jest jedną z najpopularniejszych metod stosowanych w zadaniach klasyfikacyjnych.
3.4 Naiwny klasyfikator Bayesa (Naive Bayes)
Jest to dość prosty klasyfikator probabilistyczny bazujący na twierdzeniu Bayesa. Naiwne klasyfikatory bayesowskie są oparte na założeniu o wzajemnej niezależności predyktorów (zmiennych niezależnych). Założenie to często nie jest spełnone i stąd nazwa przymiotnik “naiwny”.
Mimo swojej prostoty, często daje zaskakująco dobre wyniki w klasyfikacji, szczególnie przy dużych zbiorach danych i niezrównoważonych klasach.
3.5 XGBoost (Extreme Gradient Boosting)
XGBoost (Extreme Gradient Boosting) to wydajna i popularna metoda uczenia maszynowego oparta na metodzie boostingowej drzew decyzyjnych.
Boosting polega na tworzeniu wielu słabych modeli (np. prostych drzew decyzyjnych), gdzie każdy kolejny model uczy się na błędach poprzednich. Proces treningu jest iteracyjny - każde nowe drzewo stara sę poprawić przewidywania wcześniejszych modeli.
XGBoost może być używany zarówno w zadaniach klasyfikacyjnych, jak i regresyjnych, gdzie często osiąga dobre wyniki.
3.6 K-Nearest Neighbors (KNN)
Technika \(k\) najbliższych sąsiadów przewiduje wartość zmiennej wynikowej na podstawie \(k\) najbliższych obserwacji zbioru treningowego. Ten algorytm nie posiada jawnej postaci i zalicza się go do klasy technik nazywanych czarnymi skrzynkami (ang. black box). Może być wykorzystywany, zarówno do zadań klasyfikacyjnych, jak i regresyjnych. W obu przypadkach predykcja dla nowych wartości predyktorów przebiega podobnie.
Jeśli zadanie ma charakter klasyfikacyjny, to zmiennej wynikowej przypisuje się modę obserwacji (najczęstsza klasa) spośród k najbliższych sąsiadów. W przypadku zadań regresyjnych przypisuje się średnią lub medianę.
Parametrem, który ma znaczący wpływ na predykcję, jest liczba sąsiadów \(k\). Wybór zbyt małej liczby \(k\) może doprowadzić do przeuczenia modelu (overfitting), natomiast zbyt duża liczba sąsiadów powoduje obciążenie wyników.
4 Modelowanie klasyfikacyjne
4.1 Podział na zbiór treningowy i testowy
Proces modelowania rozpoczynamy od podziału danych na zbiór uczący i testowy. Zbiór treningowy stanowi 70% całkowitej liczby obserwacji, a testowy 30%.
Analiza rozkładu zmiennej docelowej (income
) ujawniła znaczną nierównowagę klas — około 75,9% obserwacji to dochód nie większy niż $50K (<=50
), a 24,1% to dochody przekraczające $50K (>50
). Taka dysproporcja może prowadzić do uprzywilejowania klasy dominującej i zaniżenia skuteczności klasyfikacji dla mniejszościowej klasy.
Aby zniwelować ten problem, zastosujemy technikę oversamplingu na zbiorze treningowym. Polega ona na sztucznym zwiększeniu liczby przykładów należących do klasy mniejszościowej poprzez ich losowe powielanie. Dzięki temu model uczy się na bardziej zbalansowanym zbiorze.
Dodatkowo, aby uzyskać rzetelną ocenę wydajności modelu, zastosowano walidację krzyżową (cross-validation). Dane uczące podzielone zostały na 5 części (tzw. foldy) przy użyciu 5-krotnej walidacji krzyżowej ze stratyfikacją względem zmiennej income
, co zapewnia zachowanie proporcji klas w każdym foldzie.
<- dt %>%
dt select(-fnlwgt) #nie opisuje właściwości osoby, tylko wynik ważenia próby - dlatego do modelowania usuwamy
set.seed(2025)
<- initial_split(dt, strata = income, prop = 0.7)
split <- training(split)
train <- testing(split) test
# 5-krotna walidacja krzyżowa
set.seed(2025)
<- vfold_cv(data = train, strata = income, v = 5)
cv_folds
# metryki
<- metric_set(yardstick::accuracy,
metryki ::f_meas,
yardstick::sensitivity,
yardstick::specificity,
yardstick::roc_auc,
yardstick::kap)
yardstick
## recipes
# Dla Regresji Logistycznej, Naive Bayes,XGboost, KNN
<- recipe(income ~ ., data = train) %>%
rec_dummy #step_best_normalize(capital.gain, capital.loss) %>%
#step_normalize(all_numeric_predictors(), -capital.gain, -capital.loss) %>%
step_dummy(all_nominal_predictors()) %>%
step_upsample(income, seed = 1234)
# Dla Decision tree, Random Forest
<- recipe(income ~ ., data = train) %>%
rec step_upsample(income, seed = 1234)
4.2 Definiujemy modele, które chcemy przetestować:
# 1. Logistic Regression (Regresja logistyczna)
<- logistic_reg(mode = "classification") %>%
lr set_engine("glm")
# 2. Decision Tree (Drzewo decyzyjne)
<- decision_tree(mode = "classification") %>%
dt set_engine("rpart")
# 3. Random Forest (Las losowy)
<- rand_forest(mode = "classification") %>%
rf set_engine("ranger", importance = "impurity")
# 4. Naive Bayes
<- naive_Bayes(mode = "classification") %>%
nb set_engine("klaR")
# 5. XGBoost
<- boost_tree(mode = "classification") %>%
xgb set_engine("xgboost")
# 6. K-Nearest Neighbors (KNN)
<- nearest_neighbor(mode = "classification") %>%
knn set_engine("kknn")
4.3 Tworzymy zestawy przepływów pracy:
<- workflow_set(
base preproc = list(rec = rec),
models = list(dt = dt, rf = rf))
<- workflow_set(
dummy preproc = list(rec_d = rec_dummy),
models = list(lr = lr,
nb = nb,
xgb = xgb,
knn = knn))
<- bind_rows(base, dummy)
models #models
4.4 Trenowanie modeli
#cl <- makePSOCKcluster(4)
#registerDoParallel(cl) #Paralelizacja - przyspiesza
#set.seed(44)
#results <- models %>%
# workflow_map(
# "fit_resamples",
# verbose = TRUE,
# resamples = cv_folds,
# metrics = metryki,
# control = control_resamples(save_pred = TRUE,
# save_workflow = TRUE)
# )
#save(results, file = "training_results_3base.rda")
load("training_results_3base.rda")
#stopCluster(cl)
4.5 Porównanie modeli
Porównamy modele ze względu na 6 metryk:
Accuracy - odsetek poprawnie zaklasyfikowanych obserwacji
F1-score - \(\frac{2PPV*TPR}{PPV+TPR}\)
Sensitivity (recall) - stosunek true positive do wszystkich przypadków positive
Specificity - stosunem pozycji TN (true negative) do wszytskich obserwacji negative
AUC - pole pod krzywą ROC
Kappa - podobna co accuracy, przydatna, gdy jedna lub więcej klas dominuje. \(\kappa = \frac{p_0-p_e}{1-p_e}\)
Na podstawie powyższych wyników, do dalszej analizy wybrałam algorytmy Random Forest (rec_rf) oraz XGBoost (rec_d_xgb), które wykazały się najlepszymi wynikami.
4.6 Tuning modeli
<- rand_forest(mtry = tune(), min_n = tune(), trees = 500) %>%
rf_spec set_engine("ranger") %>%
set_mode("classification")
#mtry - Liczba predyktorów losowo wybieranych przy każdym podziale drzewa.
#min_n - Minimalna liczba obserwacji w liściu drzewa.
#trees - Liczba drzew w lesie.
<- boost_tree(tree_depth = tune(), #Maksymalna głębokość drzewa
xgb_spec learn_rate = tune(), #Szybkość uczenia
loss_reduction = tune(), #Minimalna poprawa wymagana do dalszego podziału
min_n = tune(), #Minimalna liczba obserwacji w liściu
sample_size = tune(), #Proporcja obserwacji używana do trenowania każdego drzewa
trees = tune()) %>% #Liczba drzew
set_engine("xgboost") %>%
set_mode("classification")
4.6.1 Tuning Random Forest
Tworzymy zestaw parametrów do strojenia modelu lasu losowego (rf_spec
) — ustalamy zakresy dla dwóch ważnych hiperparametrów:
mtry
(liczba zmiennych losowo wybieranych przy podziale drzewa),min_n
(minimalna liczba obserwacji w węźle, poniżej której węzeł nie jest dzielony).
Kolejnym krokiem jest zdefiniowanie workflow, czyli połączenia modelu lasu losowego z przetwarzaniem danych (receptą rec
), co umożliwia łatwą integrację procesu trenowania.
Zakomentowany fragment kodu pokazuje, jak można przeprowadzić strojenie hiperparametrów za pomocą walidacji krzyżowej (tune_grid
), wykorzystując siatkę parametrów oraz metryki oceny modelu.
Na końcu wyniki tuningu zapisujemy do pliku RDS, aby móc je później wykorzystać bez konieczności ponownego trenowania.
<- extract_parameter_set_dials(rf_spec)
rf_param
<-
rf_param %>%
rf_spec extract_parameter_set_dials() %>%
update(
mtry = mtry(c(1, 10)),
min_n = min_n(c(2,60)))
%>% extract_parameter_dials("mtry")
rf_param %>% extract_parameter_dials("min_n")
rf_param
<- workflow() %>%
rf_wflow add_model(rf_spec) %>%
add_recipe(rec)
#set.seed(2025)
#rf_reg_tune <- rf_wflow %>%
# tune_grid(
# cv_folds,
# grid = rf_param %>% grid_regular(levels = 5),
# metrics = metryki
# )
#saveRDS(rf_reg_tune, "rf_reg_tune.rds")
4.6.1.1 Finalizowanie modelu
Na podstawie wyników strojenia wybieramy najlepsze parametry modelu lasu losowego, korzystając z metryki ROC AUC.
select_best(rf_reg_tune, metric = "roc_auc")
# A tibble: 1 × 3
mtry min_n .config
<int> <int> <chr>
1 4 60 Preprocessor1_Model22
Następnie tworzymy nowy workflow, który finalizuje workflow.
<- tibble(
rf_best_param mtry = 4,
min_n = 60)
<- rf_wflow %>%
final_rf_wflow finalize_workflow(rf_best_param)
final_rf_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
1 Recipe Step
• step_upsample()
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (classification)
Main Arguments:
mtry = 4
trees = 500
min_n = 60
Computational engine: ranger
Dopasowujemy ostateczny model do całego zbioru treningowego, przygotowując go do predykcji na zbiorze testowym.
<-
final_rf_fit %>%
final_rf_wflow fit(train)
4.6.1.2 Predykcja na zbiorze testowym
Metryki modelu Random Forest | |
---|---|
Miara | Wartość |
Dokładność (Accuracy) | 0.837 |
F1-score | 0.887 |
Czułość (Sensitivity) | 0.842 |
Specyficzność (Specificity) | 0.822 |
ROC AUC | 0.919 |
Kappa | 0.599 |
Model Random Forest osiągnął dokładność na poziomie 83,3%, co oznacza, że poprawnie klasyfikuje większość obserwacji (jendak jest to zbiór niezbalansowany). F1-score (0,884) i czułość (0.841) wskazują, że model dość dobrze radzi sobie z wykrywaniem pozytywnych przypadków, a specyficzność (0.806) pokazuje skuteczność w identyfikacji negatywnych. Wartość ROC AUC (0,911) potwierdza dobrą zdolność modelu do rozróżniania klas. Kappa na poziomie 0,585 sugeruje umiarkowaną zgodność ponad losowe przypisanie klas.
4.6.1.3 Confusion Matrix
Model poprawnie sklasyfikował 6232 przypadki jako zarabiające ≤50K oraz 1895 jako >50K. Jednak 457 osób zarabiających ponad 50K zostało błędnie zakwalifikowanych do grupy ≤50K, a 1178 osób zarabiających ≤50K trafiło do grupy >50K, co wskazuje na pewne błędy klasyfikacji.
4.6.2 Tuning XGBoost
Definiujemy zakresy hiperparametrów XGBoost i tworzymy workflow z modelem oraz przepisem na dane.
Następnie generujemy siatkę parametrów do strojenia metodą Latin Hypercube (hipersześciany łacińskie).
Wyniki strojenia ponownie zapisujemy do pliku, aby nie trzeba było powtarzać tego czasochłonnego procesu jeszcze raz.
<- extract_parameter_set_dials(xgb_spec)
xgb_param
<-
xgb_param %>%
xgb_spec extract_parameter_set_dials() %>%
update(
tree_depth = tree_depth(c(3L, 10L)),
learn_rate = learn_rate(range = c(-2, -0.5)),
loss_reduction = loss_reduction(range = c(-5, 1)),
min_n = min_n(c(5L, 60L)),
sample_size = sample_prop(c(0.5, 1)),
trees = trees(c(100L, 800L))
)%>% extract_parameter_dials("tree_depth")
xgb_param %>% extract_parameter_dials("learn_rate")
xgb_param %>% extract_parameter_dials("loss_reduction")
xgb_param %>% extract_parameter_dials("min_n")
xgb_param %>% extract_parameter_dials("sample_size")
xgb_param %>% extract_parameter_dials("trees")
xgb_param
#grid_regular(xgb_param, levels = 2)
<- workflow() %>%
xgb_wflow add_model(xgb_spec) %>%
add_recipe(rec_dummy)
<- grid_latin_hypercube(xgb_param, size = 10) #grid_max_entropy(xgb_param, size = 20)
xgb_grid
#cl <- makePSOCKcluster(4)
#registerDoParallel(cl)
#set.seed(2025)
#xgb_reg_tune <- xgb_wflow %>%
# tune_grid(
# cv_folds,
# grid = xgb_grid,
# metrics = metryki
# )
#saveRDS(xgb_reg_tune, "xgb_reg_tune.rds")
#stopCluster(cl)
4.6.2.1 Finalizowanie modelu
Na podstawie wyników strojenia wybieramy najlepsze parametry modelu.
select_best(xgb_reg_tune, metric = "roc_auc")
# A tibble: 1 × 7
trees min_n tree_depth learn_rate loss_reduction sample_size .config
<int> <int> <int> <dbl> <dbl> <dbl> <chr>
1 631 16 6 0.111 8.53 0.921 Preprocessor1_Mo…
Tworzymy nowy workflow finalizujący.
<- select_best(xgb_reg_tune, metric = "roc_auc")
best_params
<- xgb_wflow %>%
final_xgb_wflow finalize_workflow(best_params)
Dopasowujemy ostateczny model do całego zbioru treningowego.
<-
final_xgb_fit %>%
final_xgb_wflow fit(train)
4.6.2.2 Predykcja na zbiorze testowym
Metryki modelu XGBoost | |
---|---|
Miara | Wartość |
Dokładność (Accuracy) | 0.835 |
F1-score | 0.884 |
Czułość (Sensitivity) | 0.824 |
Specyficzność (Specificity) | 0.871 |
ROC AUC | 0.931 |
Kappa | 0.607 |
Model XGBoost osiągnął dokładność na poziomie 83%. F1-score (0.88) świadczy o dość dobrej równowadze między precyzją a czułością modelu. Czułość na poziomie 0.819 pokazuje, że model dobrze wykrywa przypadki pozytywne, natomiast specyficzność 0.867 oznacza skuteczność rozpoznawania przypadków negatywnych. ROC AUC (0.925) sugeruje, że model jest efektywny w rozróżnianiu klas.
4.6.2.3 Confusion matrix
Macierz pomyłek pokazuje, że model poprawnie sklasyfikował większość obserwacji, zwłaszcza tych zarabiających ponad 50K (2040 poprawnych przewidywań). Jednak popełnił też błędy, myląc 1344 osoby zarabiające <=50K jako >50K.
5 Podsumowanie
Przeprowadzona została analiza zbioru mająca na celu przewidzenie, czy dana osoba zarabia powyżej czy poniżej 50 tysięcy dolarów rocznie.
Z uwagi na niezbalansowany zbiór danych (więcej obserwacji klasy <=50K), zastosowano oversampling, aby poprawić jakość predykcji i zrównoważyć klasy.
Zbudowane zostały, a następnie porównane modele klasyfikacyjne: regresja logistyczna, model drzewa decyzyjnego, las losowy, naiwny klasyfikator bayesowski, XGBoost oraz algorytm k-najbliższych sąsiadów. Spośród nich najwyższe wyniki osiągnęły modele Random Forest i XGBoost, które zostały następnie stuningowane przy użyciu walidacji krzyżowej i przeszukiwania siatki parametrów.
Ostatecznie, oba modele osiągnęły dość dobre wyniki. Jednak mimo to nie idealnie radzą sobie z klasyfikacją i mają tendencję do błędnej klasyfikacji osób zarabiających <=50 jako osób które zarazabiają > 50K.
6 Bibliografia
https://pl.wikipedia.org/wiki/Wikipedia:Strona_g%C5%82%C3%B3wna
Wiedza teoretyczna i praktyczna zdobyta podczas wykładów i laboratoriów z przedmiotu Metody walidacji modeli statystycznych.