#!pip install jupyterthemes
#Importing libararies
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Activation, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import Accuracy
from jupyterthemes import jtplot
#jtplot.style(theme = 'monokai', context = 'notebook', ticks = True, grid = False)
bank_df = pd.read_csv('UniversalBank.csv')
bank_df.head()
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 25 | 1 | 49 | 91107 | 4 | 1.6 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
1 | 2 | 45 | 19 | 34 | 90089 | 3 | 1.5 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
2 | 3 | 39 | 15 | 11 | 94720 | 1 | 1.0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 4 | 35 | 9 | 100 | 94112 | 1 | 2.7 | 2 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 5 | 35 | 8 | 45 | 91330 | 4 | 1.0 | 2 | 0 | 0 | 0 | 0 | 0 | 1 |
#Obtain dataframe info
bank_df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 5000 entries, 0 to 4999 Data columns (total 14 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 ID 5000 non-null int64 1 Age 5000 non-null int64 2 Experience 5000 non-null int64 3 Income 5000 non-null int64 4 ZIP Code 5000 non-null int64 5 Family 5000 non-null int64 6 CCAvg 5000 non-null float64 7 Education 5000 non-null int64 8 Mortgage 5000 non-null int64 9 Personal Loan 5000 non-null int64 10 Securities Account 5000 non-null int64 11 CD Account 5000 non-null int64 12 Online 5000 non-null int64 13 CreditCard 5000 non-null int64 dtypes: float64(1), int64(13) memory usage: 547.0 KB
# Obtain the statistical summary of the dataframe
bank_df.describe()
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.000000 | 5000.00000 | 5000.000000 | 5000.000000 |
mean | 2500.500000 | 45.338400 | 20.104600 | 73.774200 | 93152.503000 | 2.396400 | 1.937938 | 1.881000 | 56.498800 | 0.096000 | 0.104400 | 0.06040 | 0.596800 | 0.294000 |
std | 1443.520003 | 11.463166 | 11.467954 | 46.033729 | 2121.852197 | 1.147663 | 1.747659 | 0.839869 | 101.713802 | 0.294621 | 0.305809 | 0.23825 | 0.490589 | 0.455637 |
min | 1.000000 | 23.000000 | -3.000000 | 8.000000 | 9307.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 0.000000 | 0.000000 |
25% | 1250.750000 | 35.000000 | 10.000000 | 39.000000 | 91911.000000 | 1.000000 | 0.700000 | 1.000000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 0.000000 | 0.000000 |
50% | 2500.500000 | 45.000000 | 20.000000 | 64.000000 | 93437.000000 | 2.000000 | 1.500000 | 2.000000 | 0.000000 | 0.000000 | 0.000000 | 0.00000 | 1.000000 | 0.000000 |
75% | 3750.250000 | 55.000000 | 30.000000 | 98.000000 | 94608.000000 | 3.000000 | 2.500000 | 3.000000 | 101.000000 | 0.000000 | 0.000000 | 0.00000 | 1.000000 | 1.000000 |
max | 5000.000000 | 67.000000 | 43.000000 | 224.000000 | 96651.000000 | 4.000000 | 10.000000 | 3.000000 | 635.000000 | 1.000000 | 1.000000 | 1.00000 | 1.000000 | 1.000000 |
# For better visualization
bank_df.describe().T
count | mean | std | min | 25% | 50% | 75% | max | |
---|---|---|---|---|---|---|---|---|
ID | 5000.0 | 2500.500000 | 1443.520003 | 1.0 | 1250.75 | 2500.5 | 3750.25 | 5000.0 |
Age | 5000.0 | 45.338400 | 11.463166 | 23.0 | 35.00 | 45.0 | 55.00 | 67.0 |
Experience | 5000.0 | 20.104600 | 11.467954 | -3.0 | 10.00 | 20.0 | 30.00 | 43.0 |
Income | 5000.0 | 73.774200 | 46.033729 | 8.0 | 39.00 | 64.0 | 98.00 | 224.0 |
ZIP Code | 5000.0 | 93152.503000 | 2121.852197 | 9307.0 | 91911.00 | 93437.0 | 94608.00 | 96651.0 |
Family | 5000.0 | 2.396400 | 1.147663 | 1.0 | 1.00 | 2.0 | 3.00 | 4.0 |
CCAvg | 5000.0 | 1.937938 | 1.747659 | 0.0 | 0.70 | 1.5 | 2.50 | 10.0 |
Education | 5000.0 | 1.881000 | 0.839869 | 1.0 | 1.00 | 2.0 | 3.00 | 3.0 |
Mortgage | 5000.0 | 56.498800 | 101.713802 | 0.0 | 0.00 | 0.0 | 101.00 | 635.0 |
Personal Loan | 5000.0 | 0.096000 | 0.294621 | 0.0 | 0.00 | 0.0 | 0.00 | 1.0 |
Securities Account | 5000.0 | 0.104400 | 0.305809 | 0.0 | 0.00 | 0.0 | 0.00 | 1.0 |
CD Account | 5000.0 | 0.060400 | 0.238250 | 0.0 | 0.00 | 0.0 | 0.00 | 1.0 |
Online | 5000.0 | 0.596800 | 0.490589 | 0.0 | 0.00 | 1.0 | 1.00 | 1.0 |
CreditCard | 5000.0 | 0.294000 | 0.455637 | 0.0 | 0.00 | 0.0 | 1.00 | 1.0 |
# See how many null values exist in the dataframe
bank_df.isnull().sum()
ID 0 Age 0 Experience 0 Income 0 ZIP Code 0 Family 0 CCAvg 0 Education 0 Mortgage 0 Personal Loan 0 Securities Account 0 CD Account 0 Online 0 CreditCard 0 dtype: int64
# Visualize personal Loan column
# Percentage of customers who accepted personal loan ~ 9%
plt.figure(figsize = (10, 7))
sns.countplot(bank_df['Personal Loan']);
# Visualize Education feature
plt.figure(figsize = (10, 7))
sns.countplot(bank_df['Education'])
<matplotlib.axes._subplots.AxesSubplot at 0x203d6af3320>
# Visualize Age
# Uniform distribution between 30-60 years
plt.figure(figsize = (20, 10))
sns.countplot(bank_df['Age'])
<matplotlib.axes._subplots.AxesSubplot at 0x203d6af3fd0>
# Visualize credit card availability feature
# Recall that ~29% of customers have credit cards
plt.figure(figsize = (10, 7))
sns.countplot(bank_df['CreditCard']);
# Visualize income data
# Most customers have incomes that range between 45K and 60K per year
# Data is skewed with less customers earning above 100K
plt.figure(figsize = (20, 10))
sns.distplot(bank_df['Income'])
<matplotlib.axes._subplots.AxesSubplot at 0x203d6f54e48>
# Create two dataframes for the two classes
personalloans = bank_df[bank_df['Personal Loan'] == 1]
no_personalloans = bank_df[bank_df['Personal Loan'] == 0]
personalloans
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
9 | 10 | 34 | 9 | 180 | 93023 | 1 | 8.9 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
16 | 17 | 38 | 14 | 130 | 95010 | 4 | 4.7 | 3 | 134 | 1 | 0 | 0 | 0 | 0 |
18 | 19 | 46 | 21 | 193 | 91604 | 2 | 8.1 | 3 | 0 | 1 | 0 | 0 | 0 | 0 |
29 | 30 | 38 | 13 | 119 | 94104 | 1 | 3.3 | 2 | 0 | 1 | 0 | 1 | 1 | 1 |
38 | 39 | 42 | 18 | 141 | 94114 | 3 | 5.0 | 3 | 0 | 1 | 1 | 1 | 1 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
4883 | 4884 | 38 | 13 | 129 | 92646 | 3 | 4.1 | 3 | 0 | 1 | 0 | 1 | 1 | 1 |
4927 | 4928 | 43 | 19 | 121 | 94720 | 1 | 0.7 | 2 | 0 | 1 | 0 | 1 | 1 | 1 |
4941 | 4942 | 28 | 4 | 112 | 90049 | 2 | 1.6 | 2 | 0 | 1 | 0 | 0 | 1 | 0 |
4962 | 4963 | 46 | 20 | 122 | 90065 | 3 | 3.0 | 3 | 0 | 1 | 0 | 1 | 1 | 1 |
4980 | 4981 | 29 | 5 | 135 | 95762 | 3 | 5.3 | 1 | 0 | 1 | 0 | 1 | 1 | 1 |
480 rows × 14 columns
personalloans.describe()
# Mean income of customers who have personal loans is generally high ~ 144K and average CC of 3.9K
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 480.000000 | 480.000000 | 480.000000 | 480.000000 | 480.000000 | 480.000000 | 480.000000 | 480.000000 | 480.000000 | 480.0 | 480.000000 | 480.000000 | 480.00000 | 480.000000 |
mean | 2390.650000 | 45.066667 | 19.843750 | 144.745833 | 93153.202083 | 2.612500 | 3.905354 | 2.233333 | 100.845833 | 1.0 | 0.125000 | 0.291667 | 0.60625 | 0.297917 |
std | 1394.393674 | 11.590964 | 11.582443 | 31.584429 | 1759.223753 | 1.115393 | 2.097681 | 0.753373 | 160.847862 | 0.0 | 0.331064 | 0.455004 | 0.48909 | 0.457820 |
min | 10.000000 | 26.000000 | 0.000000 | 60.000000 | 90016.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 | 1.0 | 0.000000 | 0.000000 | 0.00000 | 0.000000 |
25% | 1166.500000 | 35.000000 | 9.000000 | 122.000000 | 91908.750000 | 2.000000 | 2.600000 | 2.000000 | 0.000000 | 1.0 | 0.000000 | 0.000000 | 0.00000 | 0.000000 |
50% | 2342.000000 | 45.000000 | 20.000000 | 142.500000 | 93407.000000 | 3.000000 | 3.800000 | 2.000000 | 0.000000 | 1.0 | 0.000000 | 0.000000 | 1.00000 | 0.000000 |
75% | 3566.000000 | 55.000000 | 30.000000 | 172.000000 | 94705.500000 | 4.000000 | 5.347500 | 3.000000 | 192.500000 | 1.0 | 0.000000 | 1.000000 | 1.00000 | 1.000000 |
max | 4981.000000 | 65.000000 | 41.000000 | 203.000000 | 96008.000000 | 4.000000 | 10.000000 | 3.000000 | 617.000000 | 1.0 | 1.000000 | 1.000000 | 1.00000 | 1.000000 |
no_personalloans
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 25 | 1 | 49 | 91107 | 4 | 1.6 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
1 | 2 | 45 | 19 | 34 | 90089 | 3 | 1.5 | 1 | 0 | 0 | 1 | 0 | 0 | 0 |
2 | 3 | 39 | 15 | 11 | 94720 | 1 | 1.0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 4 | 35 | 9 | 100 | 94112 | 1 | 2.7 | 2 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 5 | 35 | 8 | 45 | 91330 | 4 | 1.0 | 2 | 0 | 0 | 0 | 0 | 0 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
4995 | 4996 | 29 | 3 | 40 | 92697 | 1 | 1.9 | 3 | 0 | 0 | 0 | 0 | 1 | 0 |
4996 | 4997 | 30 | 4 | 15 | 92037 | 4 | 0.4 | 1 | 85 | 0 | 0 | 0 | 1 | 0 |
4997 | 4998 | 63 | 39 | 24 | 93023 | 2 | 0.3 | 3 | 0 | 0 | 0 | 0 | 0 | 0 |
4998 | 4999 | 65 | 40 | 49 | 90034 | 3 | 0.5 | 2 | 0 | 0 | 0 | 0 | 1 | 0 |
4999 | 5000 | 28 | 4 | 83 | 92612 | 3 | 0.8 | 1 | 0 | 0 | 0 | 0 | 1 | 1 |
4520 rows × 14 columns
no_personalloans.describe()
# Mean income of customers who have do not have personal loans is generally low ~ 66K and average CC of 1.7K
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Personal Loan | Securities Account | CD Account | Online | CreditCard | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 4520.000000 | 4520.000000 | 4520.000000 | 4520.000000 | 4520.000000 | 4520.000000 | 4520.000000 | 4520.000000 | 4520.000000 | 4520.0 | 4520.000000 | 4520.000000 | 4520.000000 | 4520.000000 |
mean | 2512.165487 | 45.367257 | 20.132301 | 66.237389 | 93152.428761 | 2.373451 | 1.729009 | 1.843584 | 51.789381 | 0.0 | 0.102212 | 0.035841 | 0.595796 | 0.293584 |
std | 1448.299331 | 11.450427 | 11.456672 | 40.578534 | 2156.949654 | 1.148771 | 1.567647 | 0.839975 | 92.038931 | 0.0 | 0.302961 | 0.185913 | 0.490792 | 0.455454 |
min | 1.000000 | 23.000000 | -3.000000 | 8.000000 | 9307.000000 | 1.000000 | 0.000000 | 1.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
25% | 1259.750000 | 35.000000 | 10.000000 | 35.000000 | 91911.000000 | 1.000000 | 0.600000 | 1.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
50% | 2518.500000 | 45.000000 | 20.000000 | 59.000000 | 93437.000000 | 2.000000 | 1.400000 | 2.000000 | 0.000000 | 0.0 | 0.000000 | 0.000000 | 1.000000 | 0.000000 |
75% | 3768.250000 | 55.000000 | 30.000000 | 84.000000 | 94608.000000 | 3.000000 | 2.300000 | 3.000000 | 98.000000 | 0.0 | 0.000000 | 0.000000 | 1.000000 | 1.000000 |
max | 5000.000000 | 67.000000 | 43.000000 | 224.000000 | 96651.000000 | 4.000000 | 8.800000 | 3.000000 | 635.000000 | 0.0 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
# Plot the distribution plot for both classes separately
# Customers who took personal loans tend to have higher income
plt.figure(figsize = (20, 10))
sns.distplot(personalloans['Income'], color = 'g')
sns.distplot(no_personalloans['Income'], color = 'r')
<matplotlib.axes._subplots.AxesSubplot at 0x203d7030550>
# Plot pairplot
plt.figure(figsize = (30, 30))
sns.pairplot(bank_df)
<seaborn.axisgrid.PairGrid at 0x203d7071ba8>
<Figure size 2160x2160 with 0 Axes>
# Correlation plot
# Stong Positive correlation between experience and age
# Strong positive correlation between CC average and income
plt.figure(figsize = (20, 20))
cm = bank_df.corr()
ax = plt.subplot()
sns.heatmap(cm, annot = True, ax = ax)
<matplotlib.axes._subplots.AxesSubplot at 0x203dd2c7c18>
# List all column names
bank_df.columns
Index(['ID', 'Age', 'Experience', 'Income', 'ZIP Code', 'Family', 'CCAvg', 'Education', 'Mortgage', 'Personal Loan', 'Securities Account', 'CD Account', 'Online', 'CreditCard'], dtype='object')
# Specify model input features (all data except for the target variable)
X = bank_df.drop(columns = ['Personal Loan'])
X
ID | Age | Experience | Income | ZIP Code | Family | CCAvg | Education | Mortgage | Securities Account | CD Account | Online | CreditCard | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 25 | 1 | 49 | 91107 | 4 | 1.6 | 1 | 0 | 1 | 0 | 0 | 0 |
1 | 2 | 45 | 19 | 34 | 90089 | 3 | 1.5 | 1 | 0 | 1 | 0 | 0 | 0 |
2 | 3 | 39 | 15 | 11 | 94720 | 1 | 1.0 | 1 | 0 | 0 | 0 | 0 | 0 |
3 | 4 | 35 | 9 | 100 | 94112 | 1 | 2.7 | 2 | 0 | 0 | 0 | 0 | 0 |
4 | 5 | 35 | 8 | 45 | 91330 | 4 | 1.0 | 2 | 0 | 0 | 0 | 0 | 1 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
4995 | 4996 | 29 | 3 | 40 | 92697 | 1 | 1.9 | 3 | 0 | 0 | 0 | 1 | 0 |
4996 | 4997 | 30 | 4 | 15 | 92037 | 4 | 0.4 | 1 | 85 | 0 | 0 | 1 | 0 |
4997 | 4998 | 63 | 39 | 24 | 93023 | 2 | 0.3 | 3 | 0 | 0 | 0 | 0 | 0 |
4998 | 4999 | 65 | 40 | 49 | 90034 | 3 | 0.5 | 2 | 0 | 0 | 0 | 1 | 0 |
4999 | 5000 | 28 | 4 | 83 | 92612 | 3 | 0.8 | 1 | 0 | 0 | 0 | 1 | 1 |
5000 rows × 13 columns
# Model output (target variable)
y = bank_df['Personal Loan']
y
0 0 1 0 2 0 3 0 4 0 .. 4995 0 4996 0 4997 0 4998 0 4999 0 Name: Personal Loan, Length: 5000, dtype: int64
from tensorflow.keras.utils import to_categorical
y = to_categorical(y)
y
array([[1., 0.], [1., 0.], [1., 0.], ..., [1., 0.], [1., 0.], [1., 0.]], dtype=float32)
# scale the data before training the model
from sklearn import metrics
from sklearn.preprocessing import StandardScaler, MinMaxScaler
scaler_x = StandardScaler()
X = scaler_x.fit_transform(X)
# spliting the data in to test and train sets
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.1)
# print the shapes
X_train.shape, X_test.shape, y_train.shape, y_test.shape
((4500, 13), (500, 13), (4500, 2), (500, 2))
# Create keras sequential model
ANN_model = keras.Sequential()
# Add dense layer
ANN_model.add(Dense(250, input_dim = 13, kernel_initializer = 'normal',activation = 'relu'))
# Add dropout layer to make sure ann isn't overfitting the training data
ANN_model.add(Dropout(0.3))
# Add dense layer
ANN_model.add(Dense(500, activation = 'relu'))
# Add dropout layer
ANN_model.add(Dropout(0.3))
# Add dense layer
ANN_model.add(Dense(500, activation = 'relu'))
# Add dropout layer
ANN_model.add(Dropout(0.4))
# Add dense layer
ANN_model.add(Dense(250, activation = 'linear'))
# Add dropout layer
ANN_model.add(Dropout(0.5))
# Add dense layer with softmax activation
ANN_model.add(Dense(2, activation = 'softmax'))
ANN_model.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_20 (Dense) (None, 250) 3500 _________________________________________________________________ dropout_16 (Dropout) (None, 250) 0 _________________________________________________________________ dense_21 (Dense) (None, 500) 125500 _________________________________________________________________ dropout_17 (Dropout) (None, 500) 0 _________________________________________________________________ dense_22 (Dense) (None, 500) 250500 _________________________________________________________________ dropout_18 (Dropout) (None, 500) 0 _________________________________________________________________ dense_23 (Dense) (None, 250) 125250 _________________________________________________________________ dropout_19 (Dropout) (None, 250) 0 _________________________________________________________________ dense_24 (Dense) (None, 2) 502 ================================================================= Total params: 505,252 Trainable params: 505,252 Non-trainable params: 0 _________________________________________________________________
# Compiling the model
ANN_model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])
#Splitting into validation and training set to avoid overfitting
history = ANN_model.fit(X_train, y_train, epochs = 20, validation_split = 0.2, verbose = 1)
Train on 3600 samples, validate on 900 samples Epoch 1/20 3600/3600 [==============================] - 4s 1ms/sample - loss: 0.1787 - acc: 0.9311 - val_loss: 0.0711 - val_acc: 0.9744 Epoch 2/20 3600/3600 [==============================] - 2s 469us/sample - loss: 0.0975 - acc: 0.9642 - val_loss: 0.0623 - val_acc: 0.9789 Epoch 3/20 3600/3600 [==============================] - 2s 449us/sample - loss: 0.0786 - acc: 0.9722 - val_loss: 0.0554 - val_acc: 0.9778 Epoch 4/20 3600/3600 [==============================] - 1s 416us/sample - loss: 0.0743 - acc: 0.9753 - val_loss: 0.0441 - val_acc: 0.9856 Epoch 5/20 3600/3600 [==============================] - 2s 431us/sample - loss: 0.0739 - acc: 0.9739 - val_loss: 0.0484 - val_acc: 0.9844 Epoch 6/20 3600/3600 [==============================] - 2s 439us/sample - loss: 0.0643 - acc: 0.9778 - val_loss: 0.0467 - val_acc: 0.9822 Epoch 7/20 3600/3600 [==============================] - 2s 425us/sample - loss: 0.0564 - acc: 0.9814 - val_loss: 0.0499 - val_acc: 0.9822 Epoch 8/20 3600/3600 [==============================] - 2s 570us/sample - loss: 0.0610 - acc: 0.9808 - val_loss: 0.0516 - val_acc: 0.9844 Epoch 9/20 3600/3600 [==============================] - 4s 1ms/sample - loss: 0.0548 - acc: 0.9794 - val_loss: 0.0481 - val_acc: 0.9844 Epoch 10/20 3600/3600 [==============================] - 4s 1ms/sample - loss: 0.0497 - acc: 0.9836 - val_loss: 0.0405 - val_acc: 0.9856 Epoch 11/20 3600/3600 [==============================] - 4s 1ms/sample - loss: 0.0457 - acc: 0.9831 - val_loss: 0.0381 - val_acc: 0.9867 Epoch 12/20 3600/3600 [==============================] - 4s 1ms/sample - loss: 0.0474 - acc: 0.9831 - val_loss: 0.0422 - val_acc: 0.9811 Epoch 13/20 3600/3600 [==============================] - 3s 781us/sample - loss: 0.0471 - acc: 0.9844 - val_loss: 0.0390 - val_acc: 0.9878 Epoch 14/20 3600/3600 [==============================] - 3s 938us/sample - loss: 0.0414 - acc: 0.9839 - val_loss: 0.0465 - val_acc: 0.9833 Epoch 15/20 3600/3600 [==============================] - 2s 644us/sample - loss: 0.0482 - acc: 0.9836 - val_loss: 0.0363 - val_acc: 0.9856 Epoch 16/20 3600/3600 [==============================] - 2s 431us/sample - loss: 0.0371 - acc: 0.9878 - val_loss: 0.0466 - val_acc: 0.9844 Epoch 17/20 3600/3600 [==============================] - 2s 419us/sample - loss: 0.0403 - acc: 0.9861 - val_loss: 0.0303 - val_acc: 0.9867 Epoch 18/20 3600/3600 [==============================] - 2s 436us/sample - loss: 0.0347 - acc: 0.9875 - val_loss: 0.0456 - val_acc: 0.9867 Epoch 19/20 3600/3600 [==============================] - 2s 495us/sample - loss: 0.0331 - acc: 0.9881 - val_loss: 0.0368 - val_acc: 0.9878 Epoch 20/20 3600/3600 [==============================] - 3s 796us/sample - loss: 0.0303 - acc: 0.9886 - val_loss: 0.0492 - val_acc: 0.9844
# Plot the model performance across epochs
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train_loss','val_loss'], loc = 'upper right')
plt.show()
# Make predictions
predictions = ANN_model.predict(X_test)
# Append the index of max value using argmax function
predict = []
for i in predictions:
predict.append(np.argmax(i))
# Get the acccuracy of the model
result = ANN_model.evaluate(X_test, y_test)
print("Accuracy : {}".format(result[1]))
500/500 [==============================] - 0s 373us/sample - loss: 0.1368 - acc: 0.9700 Accuracy : 0.9700000286102295
# Get the original values
y_original = []
for i in y_test:
y_original.append(np.argmax(i))
# Plot Confusion Matrix to plot original values against the model
confusion_matrix = metrics.confusion_matrix(y_original, predict)
sns.heatmap(confusion_matrix, annot = True)
<matplotlib.axes._subplots.AxesSubplot at 0x203dffd7cc0>
# Print out the classification report
from sklearn.metrics import classification_report
print(classification_report(y_original, predict))
precision recall f1-score support 0 0.98 0.99 0.98 459 1 0.84 0.78 0.81 41 accuracy 0.97 500 macro avg 0.91 0.88 0.90 500 weighted avg 0.97 0.97 0.97 500