Bank Churn Binary Classification

Author

Olamide Adu

Introduction

bank-entrance

This is a beginner-friendly note to predict bank churning. Bank churning refers to the practice of repeated opening and closing of bank accounts typically to take advantage of financial incentives like sign-up bonuses, promotional offers and so on, or any other motives for opening and closing multiple accounts.

Aim of Project

The aim of this project is to predict the bank customers that will churn. Three models will be trained to predict customer churning. The models are:

  • logistic regression,

  • decision trees and,

  • random forest

Load Libraries

library(tidyverse)
library(ggthemes) 
library(tidymodels)
library(janitor)
library(GGally)
library(patchwork)
library(wordcloud2)
library(webshot2)
library(gt)
library(scales)
library(ggridges)

Data Definition

The bank customer churn dataset is a commonly used dataset for predicting customer churn in the banking industry. It contains information on bank customers who either left the bank or continue to be a customer. The dataset includes the following attributes:

  1. Customer ID: A unique identifier for each customer

  2. Surname: The customer’s surname or last name

  3. Credit Score: A numerical value representing the customer’s credit score

  4. Geography: The country where the customer resides (France, Spain or Germany)

  5. Gender: The customer’s gender (Male or Female)

  6. Age: The customer’s age.

  7. Tenure: The number of years the customer has been with the bank

  8. Balance: The customer’s account balance

  9. NumOfProducts: The number of bank products the customer uses (e.g., savings account, credit card)

  10. HasCrCard: Whether the customer has a credit card (1 = yes, 0 = no)

  11. IsActiveMember: Whether the customer is an active member (1 = yes, 0 = no)

  12. EstimatedSalary: The estimated salary of the customer

  13. Exited: Whether the customer has churned (1 = yes, 0 = no).

Import the data

train_data <- read_csv("data/train.csv") |> 
  clean_names() |> 
  mutate(exited = factor(exited))

dim(train_data)
[1] 165034     14

The train_data is having 165034 rows and 14 columns.

Exploratory Data Analysis

Target Variable

train_data |> 
  mutate(exited = factor(exited, levels = c(0, 1), labels = c("No", "Yes"))) |> 
  ggplot(aes(exited)) +
  geom_bar(fill = c("brown", "orange")) +
  theme_pander() +
  labs(
    x = "Churned",
    y = "Count",
    title = "Retained account is more than churned account",
  ) +
  scale_y_continuous(labels = label_comma())

train_data |> 
  mutate(exited = factor(exited, levels = c(0, 1), labels = c("No", "Yes"))) |> 
  ggplot(aes(exited, fill = geography)) +
  geom_bar(position = "dodge") +
  theme_pander() +
  labs(
    x = "Churned",
    y = "Count",
    fill = "Geography",
    title = "France banks have more members that other banks",
    subtitle = "Germany would have higher churn-rate than other countries"
  )

(a) churned and retained members

(b) churned and retained members according to countries

Figure 1: Frequency of churned and retained account

Figure 1 (a) reveals there are more retained accounts than churned accounts, while France is having more churned account than the other region, Germany seems to be close, and have less population than France, Figure 1 (b).

train_data |> 
  mutate(
    exited = factor(exited, levels = c(0, 1), labels = c("No", "Yes")),
    has_cr_card = factor(has_cr_card, levels = c(0, 1), labels = c("No", "Yes"))
    ) |> 
  ggplot(aes(exited, after_stat(prop), group = has_cr_card, fill = has_cr_card)) +
  geom_bar(position = "dodge") +
  scale_fill_tableau() +
  labs(
    x = "Churned", y = "Proportion", 
    fill = "Has credit card",
    title = "Owning a credit card doesn't influence churning"
  ) +
  scale_y_continuous(labels = label_percent()) +
  theme_clean()
  
train_data |> 
  mutate(
    exited = factor(exited, levels = c(0, 1), labels = c("No", "Yes")),
    is_active_member = factor(is_active_member, levels = c(0, 1), labels = c("No", "Yes"))
    ) |>
  ggplot(
    aes(exited, after_stat(prop), 
        group = is_active_member, 
        fill = is_active_member)
  ) +
  geom_bar(position = "dodge") +
  scale_fill_tableau() +
  labs(
    x = "Churned", y = "Proportion", 
    fill = "Active member",
    title = "Inactive members churn more than active members"
  ) +
  scale_y_continuous(labels = label_percent()) +
  theme_clean()

- Bar plot of customers bank account, a: total proportion of churned customers orange bar represents members that have churned, and black represents members that have not churned, b: Proportion of churned customers according to how active they are in  using the bank and its services.

(a) credit card

- Bar plot of customers bank account, a: total proportion of churned customers orange bar represents members that have churned, and black represents members that have not churned, b: Proportion of churned customers according to how active they are in  using the bank and its services.

(b) active member

Figure 2: Proportion of Churned Customer

Figure 2 (a) shows about 20% percent of the customers churning either if they have credit card or not, and Figure 2 (b) shows a clear members that are not active members churning more than active members.

train_data |> 
  ggplot(aes(estimated_salary, credit_score, color = factor(exited))) +
  geom_point() + 
  scale_color_tableau(labels = c("No", "Yes")) +
  facet_wrap(~geography) +
  labs(
    x = "Estimated Salary",
    y = "Credit Score",
    fill = "Churned",
    title = "Chruning is not affected by credit score and estimated salary",
    subtitle = "Germany has high churn compared to France and Spain"
  ) +
  scale_x_continuous(labels = label_currency(prefix = "€")) +
  theme_pander() +
  theme(axis.text.x = element_text(angle = 45))

Figure 3: Scatter plot of estimated salary and credit score of members across the three countries.

Figure 3 shows German members churn at all ranges of estimated salary and credit score.

Feature Variable

Categorical Variables

customer_name <- train_data |> 
  filter(exited == 1) |> 
  group_by(surname) |> 
  count() |> 
  ungroup() |> 
  rename(freq = n) |> 
  arrange(desc(freq))

wordcloud2(customer_name, size =1.6, color = "random-light", backgroundColor = "black")
Figure 4: Word cloud of customers according to the number of churned accounts they own

Figure 4 shows customers having churned account. The account under the name Hsia have churned the most.

train_data |> 
  ggplot(aes(gender, fill = gender)) +
  geom_bar(show.legend = F) +
  labs(x = "gender", y = "Count", title = "Male members are more than females") +
  scale_fill_calc() +
  theme_pander()

train_data |> 
  ggplot(aes(gender, fill = geography)) +
  geom_bar(position = "dodge") +
  labs(
    x = "Age",
    y = "Count",
    fill = "Geography",
    title = "Age distribution across countries"
  ) +
  scale_fill_calc() +
  theme_pander()

(a) General member population according to sex

(b) Member population according to country
Figure 5: Distribution of gender shows that there are more male members than female members

Figure 5 shows there are more male customer than female customers

train_data |> 
  mutate(is_active_member = factor(
    is_active_member,
    levels = c(0, 1),
    labels = c("No", "Yes"))
  ) |> 
  ggplot(aes(is_active_member, fill = is_active_member)) +
  geom_bar(show.legend = F) +
  labs(
    x = "Active member",
    y = "Count",
    title = "No difference in the count of active and inactive members"
  ) +
  scale_fill_tableau() +
  theme_pander()

train_data |> 
  mutate(is_active_member = factor(
    is_active_member,
    levels = c(0, 1),
    labels = c("No", "Yes"))
  ) |> 
  ggplot(aes(is_active_member, fill = gender)) +
  geom_bar(position = "dodge") +
  labs(
    title = "Active members according to sex",
    fill = "Gender",
    x = "Active member", y = "count"
  ) +
  scale_fill_colorblind() +
  theme_pander()

(a) Members status if active or not

(b) Members status according to their gender

Figure 6: Members activity status

Figure 6 shows no difference in the number of members activity, even when differentiated by their gender Figure 6 (b).

train_data |> 
  mutate(has_cr_card = factor(has_cr_card, levels = c(0, 1), labels = c("No", "Yes"))) |> 
  ggplot(aes(has_cr_card, fill = has_cr_card)) +
  geom_bar(show.legend = F) +
  labs(x = "Has Credit Card", y = "Count") +
  scale_fill_tableau() +
  theme_pander()

Figure 7: Members with Credit Cards

Numeric Variables

age_distribution <- train_data |> 
  ggplot(aes(age)) +
  geom_histogram(binwidth = 5, fill = "tomato3") +
  geom_density(stat = "bin", binwidth = 5, col = "black") +
  expand_limits(x = c(0, 100)) +
  labs(
    x = "Age",
    title = "General age distribution",
    y = "Count"
  ) +
  theme_pander() +
  theme(plot.title = element_text(size = 12),
        plot.subtitle = element_text(size = 8))

age_distribution_sex <- train_data |> 
  ggplot(aes(age, col = gender)) +
  geom_freqpoly(binwidth = 15) +
  labs(
    x = "Age",
    y = "",
    col = "Gender",
    title = "Members 30 to 40 are higher for both sex",
    subtitle = "Males are more than Females across all ages"
  ) +
  scale_color_calc() +
  theme_pander() +
  theme(plot.title = element_text(size = 12),
        plot.subtitle = element_text(size = 8))

age_distribution + age_distribution_sex

(a) General age distribution
Figure 8: Age distribution of customers

The highest number of account owners are between age 30 to 40 as show in ?@fig-age-distribution-1, with male members being more than female members across all ages ?@fig-age-distribution-2 .

train_data |> 
  ggplot(aes(credit_score)) +
  geom_histogram(binwidth = 50, alpha = 0.7, fill = "tomato3") +
  geom_density(stat = "bin", col = "black", binwidth = 50) +
  labs(
    x = "Credit Score",
    y = "Density", 
    title = "General credit score distribution"
  ) +
  theme_pander()

train_data |> 
  ggplot(aes(credit_score, geography, fill = geography)) +
  geom_density_ridges(alpha = 0.5, show.legend = F) +
  labs(x = "Credit Score", y = "Geography") +
  scale_fill_calc() +
  labs(
    title = "Regional credit score distribution",
    subtitle = "No difference in credit score distribution across regions"
  ) +
  theme_pander()

(a) a: Credit score distribution

(b) b: Credit score according to geography

Figure 9: Credit score across regions

Figure 9 (b) shows no differences in the credit score across regions and in general Figure 9 (a)

train_data |> 
  mutate(tenure = factor(tenure)) |> 
  ggplot(aes(fct_infreq(tenure), fill = tenure)) +
  geom_bar(show.legend = F) +
  labs(x = "Tenure", y = "count", title = "Count of members according to tenure") +
  scale_fill_calc() +
  theme_pander()

Figure 10: Members tenure

There’s a good distribution of old and new members Figure 10

train_data |> 
  ggplot(aes(estimated_salary)) +
  geom_histogram(binwidth = 10000, fill = "tomato3", alpha = 0.5) +
  geom_density(stat = "bin", binwidth = 10000, col = "black") +
  labs(
    x = "Estimated salary",
    y = " Count",
    title = "Multimodal estimated salary distribution"
  ) +
  theme_pander()

train_data |> 
  ggplot(aes(estimated_salary, gender, fill = gender)) +
  geom_density_ridges(alpha = 0.4) +
  labs(
    x = "Estimated salary",
    y = "Density",
    fill = "Gender",
    title = "Estimated salary distribution of the genders"
  ) +
  scale_x_continuous(labels = label_currency(prefix = "€")) +
  theme_pander() +
  theme(legend.position = "none")
Picking joint bandwidth of 4720
train_data |> 
  ggplot(aes(estimated_salary, geography, fill = geography)) +
  geom_density_ridges(alpha = 0.4) +
  labs(
    x = "Estimated salary",
    y = "Density",
    fill = "Geography",
    title = "Estimated salary distribution of according to regions"
  ) +
  scale_x_continuous(labels = label_currency(prefix = "€")) +
  theme_pander() +
  theme(legend.position = "none")
Picking joint bandwidth of 5250

(a) General estimated salary distribution

(b) Estimated salary distribution according to gender

(c) Estimated salary distribution according to region

Figure 11: Estimated salary distribution

Figure 11 showing multimodal distribution of estimated salary.

Correlation Plot

ggcorr(train_data)

Correlation plot

Model Workflow

Three models were used for predicting members that churned. They are:

  • logistic model

  • decision trees

  • random forest

The metrics for evaluation will only be the roc_auc

Data Pre-processing for Modelling

train_data_preproc <- train_data |> 
  select(!c(id, customer_id, surname)) |> 
  mutate_if(is_character, factor)

head(train_data_preproc)
# A tibble: 6 × 11
  credit_score geography gender   age tenure balance num_of_products has_cr_card
         <dbl> <fct>     <fct>  <dbl>  <dbl>   <dbl>           <dbl>       <dbl>
1          668 France    Male      33      3      0                2           1
2          627 France    Male      33      1      0                2           1
3          678 France    Male      40     10      0                2           1
4          581 France    Male      34      2 148883.               1           1
5          716 Spain     Male      33      5      0                2           1
6          588 Germany   Male      36      4 131779.               1           1
# ℹ 3 more variables: is_active_member <dbl>, estimated_salary <dbl>,
#   exited <fct>

Data Splitting

The data is split into three parts: - A training set - A testing set - A validation set

set.seed(234)

train_data_split <- initial_split(train_data_preproc, prop = c(.8), strata = exited)

training_data <- training(train_data_split)
validation_data <- testing(train_data_split)

Model Development

The engine to be used for the pre-selected algorithms, will be shown below.

log_reg_model  <- 
  logistic_reg() |> 
  set_engine("glm")

decision_tree_model <-
  decision_tree(
    tree_depth = tune(),
    cost_complexity = tune()
  ) |> 
  set_engine("rpart") |> 
  set_mode("classification")

rf_model <- 
  rand_forest(
    trees = 1000,
    mtry = tune(),
    min_n = tune()
  ) |> 
  set_engine("ranger") |> 
  set_mode("classification")

Feature Engineering

Some feature engineering will be carried out for the different models, based on their specifications

log_reg_rec <-
  recipe(
    exited ~ ., 
    data = training_data
  )  |> 
  step_dummy(geography) |>
  step_zv()

decision_tree_rec <- 
  recipe(exited ~ .,
         data = training_data) |> 
  step_nzv()

rf_rec <-
  recipe(
    exited ~ .,
  data = training_data
  ) |> 
  step_zv()

Recipe object and model specification will be added to each model workflow

log_reg_wf <-
  workflow() |> 
  add_model(log_reg_model) |> 
  add_recipe(log_reg_rec)
  

rand_forest_wf <-
  workflow() |> 
  add_recipe(rf_rec) |> 
  add_model(rf_model)

decision_tree_wf <-
  workflow() |> 
  add_recipe(decision_tree_rec) |> 
  add_model(decision_tree_model)

Logistic_regression

log_reg_fit <- 
  log_reg_wf |>  
  fit(training_data)

Next, we test the fit on the validation set.

log_reg_validation_pred <- predict(log_reg_fit, validation_data, type = "class") |> 
  bind_cols(validation_data) |> 
  relocate(.pred_class, .after = exited)
conf_mat(log_reg_validation_pred, truth = exited, estimate = .pred_class)
          Truth
Prediction     0     1
         0 24871  4335
         1  1152  2650

There are more false positive than false negative. The model predicts non-event better than event, but the accuracy is 83%

accuracy(log_reg_validation_pred, truth = exited, estimate = .pred_class)

?(caption)

# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.834
log_reg_validation_pred |> 
  mutate(.pred_class = as.numeric(.pred_class)) |> 
  roc_curve(exited, .pred_class) |> 
  autoplot() +
  labs(
    title = str_wrap(
      "Logistic regression shows low predictive capacity for churned class",
      width  = 70)
  ) + 
  theme_pander()

Figure 12: Area under the curve of Logistic regression

Area under the curve, Figure 12 shows the predictive capability for churning is below 50%, while accuracy is high ?@tbl-log-reg-accuracy, this can be the ability to predict retention, which is not our interest. We will check the other models to see which will give our desired result

training_data_folds <- vfold_cv(training_data, v = 5, strata = exited)

Decision Trees

For the decision tree, we will set a grid a not too complex, and this is due to computational resources. After that we will tune or model, collect their metrics and evaluate their performance. Parallel computing will be employed to reduce training time time.

Fit Regular Grid

set.seed(345)

tree_grid <- 
  grid_regular(
    cost_complexity(),
    tree_depth(),
    levels = 5
  )

tree_grid |> 
  count(tree_depth)
# A tibble: 5 × 2
  tree_depth     n
       <int> <int>
1          1     5
2          4     5
3          8     5
4         11     5
5         15     5

Tune Grid

set.seed(234)
doParallel::registerDoParallel(cores = 5)
dt_tune <-
  tune_grid(
    decision_tree_wf,
    grid = tree_grid,
    resamples = training_data_folds
  )
dt_tune |> 
  collect_metrics() |> 
  head(n = 10)
# A tibble: 10 × 8
   cost_complexity tree_depth .metric  .estimator  mean     n    std_err .config
             <dbl>      <int> <chr>    <chr>      <dbl> <int>      <dbl> <chr>  
 1    0.0000000001          1 accuracy binary     0.788     5 0.00000597 Prepro…
 2    0.0000000001          1 roc_auc  binary     0.5       5 0          Prepro…
 3    0.0000000178          1 accuracy binary     0.788     5 0.00000597 Prepro…
 4    0.0000000178          1 roc_auc  binary     0.5       5 0          Prepro…
 5    0.00000316            1 accuracy binary     0.788     5 0.00000597 Prepro…
 6    0.00000316            1 roc_auc  binary     0.5       5 0          Prepro…
 7    0.000562              1 accuracy binary     0.788     5 0.00000597 Prepro…
 8    0.000562              1 roc_auc  binary     0.5       5 0          Prepro…
 9    0.1                   1 accuracy binary     0.788     5 0.00000597 Prepro…
10    0.1                   1 roc_auc  binary     0.5       5 0          Prepro…

Visualizing the result will be better than the normal print

dt_tune |> 
  collect_metrics() |> 
  filter(.metric == "roc_auc") |> 
  mutate(tree_depth = factor(tree_depth)) |> 
  ggplot(aes(cost_complexity, mean, color = tree_depth)) +
  geom_line(linewidth = 1, alpha = 0.4) +
  labs( 
    x = "Cost Complexity",
    y = "AUC",
    col = "Tree depth",
    title = "Model improves as tree depth increases",
    subtitle = str_wrap("Tree depth 15 performed well, but 11 is the best", width = 50)
  ) +
  geom_point(size = 2) +
  scale_x_log10(labels = label_number()) +
  scale_color_colorblind() +
  theme_clean() +
  expand_limits(y = c(0.5, 0.9)) +
  theme(legend.position = "bottom")

Figure 13: ?(caption)

Our short tree with depth of 1 performed bad across all values of cost_complexity. The model improves as tree depth increases, but the best depth is 11 and not the maximum tree depth of 15 Figure 13. The top 5 candidates models are show below.

dt_tune |> 
  show_best("roc_auc")
# A tibble: 5 × 8
  cost_complexity tree_depth .metric .estimator  mean     n  std_err .config    
            <dbl>      <int> <chr>   <chr>      <dbl> <int>    <dbl> <chr>      
1    0.0000000001         11 roc_auc binary     0.877     5 0.000900 Preprocess…
2    0.0000000178         11 roc_auc binary     0.877     5 0.000900 Preprocess…
3    0.00000316           11 roc_auc binary     0.877     5 0.000900 Preprocess…
4    0.0000000001          8 roc_auc binary     0.868     5 0.00360  Preprocess…
5    0.0000000178          8 roc_auc binary     0.868     5 0.00360  Preprocess…

We will select the best of the hyperparameter values and use it to finalize our model

best_tree <- dt_tune |> 
  select_best("roc_auc")

best_tree
# A tibble: 1 × 3
  cost_complexity tree_depth .config              
            <dbl>      <int> <chr>                
1    0.0000000001         11 Preprocessor1_Model16

Before we proceed with finalizing, we will fit a random forest tree and see how it performs compare it to the decision tree model, then finalize on which is the best.

Random Forest

The right values to use, when training a single model is unknown, but we can train a group and see which turns out best.

Data Sharing for Random Forest

Our training data is large, more than 132926 rows and 11 columns, a sample of this data should be collected, and be used to train the random forest model. Parallelization will also be used here to speed up to processing time.

set.seed (222)
rand_for_resamples <- initial_split(training_data, prop = 0.1, strata = exited)

rand_for_resamples <- training(rand_for_resamples)

rand_for_resamples <- vfold_cv(rand_for_resamples, v = 10)

Tuning Random Forest

set.seed(456)
doParallel::registerDoParallel(cores = 5)

rf_tune <-
  tune_grid(
    rand_forest_wf,
    resamples = rand_for_resamples,
    grid = 20
  )
i Creating pre-processing data to finalize unknown parameter: mtry

Let’s see the result

rf_tune |> 
  collect_metrics() |> 
  filter(.metric == "roc_auc") |> 
  select(mtry, min_n, mean) |> 
  pivot_longer(
    mtry:min_n,
    values_to = "value",
    names_to = "parameter"
  ) |> 
  ggplot(aes(value, mean, col = parameter)) +
  geom_point() +
  labs(
    x = "Value",
    y = "AUC",
    title = "Low mtry and high min_n are likely to make good model"
  ) +
  facet_wrap(~parameter, scales = "free") +
  scale_color_colorblind() +
  theme_clean()

Tune Random Forest with Regular Grid

From the plot above, it seems low mtry (from 2 to 5) are good and high min_n (25 to 35) are good. We should a set range between this values using grid_regular for tuning one more time.

rf_grid <- 
  grid_regular(
    min_n(range = c(25, 35)),
    mtry(range = c(2, 6)),
    levels = 5
  )

rf_grid
# A tibble: 25 × 2
   min_n  mtry
   <int> <int>
 1    25     2
 2    27     2
 3    30     2
 4    32     2
 5    35     2
 6    25     3
 7    27     3
 8    30     3
 9    32     3
10    35     3
# ℹ 15 more rows

Now we retune with the regular grid.

set.seed(347)

rand_forest_tune_reg <-
  tune_grid(
    rand_forest_wf,
    resamples = rand_for_resamples,
    grid = rf_grid
  )
rand_forest_tune_reg |> 
  collect_metrics() |> 
  filter(.metric == "roc_auc") |> 
  mutate(mtry = factor(mtry)) |> 
  ggplot(aes(min_n, mean, col = mtry)) +
  geom_point() +
  geom_line() +
  labs(
    y = "AUC", 
    title = "Model performance reduces as mtry increases",
    subtitle = "Model performance generally increases as min_n increases"
  ) +
  theme_clean()

The best combination of parameters is when mtry is 2 and min_n is 30, but this can likely lead to underfitting (pretty low mtry of 2) as only two features are evaluated as candidates for each split.

rand_forest_tune_reg |> 
  show_best("roc_auc", n = 10)

?(caption)

# A tibble: 10 × 8
    mtry min_n .metric .estimator  mean     n std_err .config              
   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>                
 1     2    30 roc_auc binary     0.888    10 0.00383 Preprocessor1_Model03
 2     2    27 roc_auc binary     0.888    10 0.00391 Preprocessor1_Model02
 3     2    25 roc_auc binary     0.888    10 0.00391 Preprocessor1_Model01
 4     2    32 roc_auc binary     0.888    10 0.00390 Preprocessor1_Model04
 5     2    35 roc_auc binary     0.888    10 0.00387 Preprocessor1_Model05
 6     3    35 roc_auc binary     0.887    10 0.00397 Preprocessor1_Model10
 7     3    32 roc_auc binary     0.887    10 0.00396 Preprocessor1_Model09
 8     3    30 roc_auc binary     0.886    10 0.00405 Preprocessor1_Model08
 9     3    25 roc_auc binary     0.886    10 0.00402 Preprocessor1_Model06
10     3    27 roc_auc binary     0.886    10 0.00400 Preprocessor1_Model07
best_forest <- 
  rand_forest_tune_reg |> 
  select_best("roc_auc")

Model Choice decision

The best model seems to be the random forest model using the roc_auc to evaluate, but giving the tendency to for the random underfit, the decision tree will be selected.

final_tree <- finalize_model(
  decision_tree_model,
  best_tree
)

final_tree
Decision Tree Model Specification (classification)

Main Arguments:
  cost_complexity = 1e-10
  tree_depth = 11

Computational engine: rpart 

We can check the variable importance, from our model. We will first prep and juice our recipe for decision tree.

dec_prep <- prep(decision_tree_rec)

dec_tree <- juice(dec_prep)

Feature Importance

library(vip)

tree_fit <- final_tree |> 
  set_engine("rpart") |> 
  fit(exited ~ ., data = dec_tree)

vip(tree_fit, aes = list(fill = "mediumaquamarine")) +
  labs(
    x = "importance",
    y = "Features",
    title = "Feature importance"
  ) +
  theme_clean()

Figure 14: Variable of importance

Age, num_of_products, balance, and is_active_member are the most important features in the model Figure 14.

The model is finalized, we need to finalize the workflow, then make the last fit

Finalizing Models

tree_final_wf <- workflow() |> 
  add_recipe(decision_tree_rec) |> 
  add_model(final_tree)

tree_final_fit <- 
  last_fit(
    tree_final_wf,
    train_data_split
  )
tree_final_fit |> 
  collect_metrics("roc_auc")
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.856 Preprocessor1_Model1
2 roc_auc  binary         0.882 Preprocessor1_Model1

The metrics is good and we did not overfit during tuning.

tree_final_fit |> 
  collect_predictions() |> 
  mutate(prediction_status = case_when(
    exited == .pred_class ~ "Correct",
    .default = "Incorrect"
  )) |> 
  bind_cols(validation_data) |> 
  clean_names() |>
  filter(exited_19 == 1) |> 
  ggplot(aes(estimated_salary, credit_score, color = prediction_status)) +
  geom_point() + 
  scale_color_tableau() +
  facet_wrap(~geography) +
  labs(
    x = "Estimated Salary",
    y = "Credit Score",
    col = "Prediction",
    title = "Prediction of Exited members",
    subtitle = "Model have lots of incorrect predicitions in France"
  ) +
  scale_x_continuous(labels = label_currency(prefix = "€")) +
  theme_pander() +
  theme(axis.text.x = element_text(angle = 320))

Prediction on New Data

First, we need to extract our model

decision_tree_final_model <- tree_final_fit |> 
  extract_fit_parsnip()

We import the new data newly

test_data <- read_csv("data/test.csv") |> 
  clean_names()

head(test_data)
# A tibble: 6 × 13
      id customer_id surname  credit_score geography gender   age tenure balance
   <dbl>       <dbl> <chr>           <dbl> <chr>     <chr>  <dbl>  <dbl>   <dbl>
1 165034    15773898 Lucchese          586 France    Female    23      2      0 
2 165035    15782418 Nott              683 France    Female    46      2      0 
3 165036    15807120 K?                656 France    Female    34      7      0 
4 165037    15808905 O'Donne…          681 France    Male      36      8      0 
5 165038    15607314 Higgins           752 Germany   Male      38     10 121264.
6 165039    15672704 Pearson           593 France    Female    22      9      0 
# ℹ 4 more variables: num_of_products <dbl>, has_cr_card <dbl>,
#   is_active_member <dbl>, estimated_salary <dbl>
compare_df_cols(train_data, test_data) |> 
  gt()
Table 1:

Comparison between the variable types between the train and test data

column_name train_data test_data
age numeric numeric
balance numeric numeric
credit_score numeric numeric
customer_id numeric numeric
estimated_salary numeric numeric
exited factor NA
gender character character
geography character character
has_cr_card numeric numeric
id numeric numeric
is_active_member numeric numeric
num_of_products numeric numeric
surname character character
tenure numeric numeric

Table 1 shows only the target variable missing in the test data

Prediction

exited <- predict(decision_tree_final_model, new_data = test_data)
exited_customers <- exited |> 
  bind_cols(test_data) |>
  rename("exited" = .pred_class) |> 
  relocate(exited, .after = id) |> 
  select(id, exited)
write_csv(exited_customers, "olamide_submission.csv")
read_csv("olamide_submission.csv")
Rows: 110023 Columns: 2
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (2): id, exited

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
# A tibble: 110,023 × 2
       id exited
    <dbl>  <dbl>
 1 165034      0
 2 165035      1
 3 165036      0
 4 165037      0
 5 165038      0
 6 165039      0
 7 165040      0
 8 165041      0
 9 165042      1
10 165043      0
# ℹ 110,013 more rows