{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4a8cf18e-c15f-44a0-bd39-36aac17aa524",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0487dad4-eb54-46e1-91b3-daa4f4d2064e",
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_csv('~/diabetes_binary.csv')\n",
"pd.set_option('display.max_columns', 50) # this is to show all the columns"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cfcb7969-4e9a-4ccf-ad0f-5fc8fc50901e",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Diabetes_binary \n",
" HighBP \n",
" HighChol \n",
" CholCheck \n",
" BMI \n",
" Smoker \n",
" Stroke \n",
" HeartDiseaseorAttack \n",
" PhysActivity \n",
" Fruits \n",
" Veggies \n",
" HvyAlcoholConsump \n",
" AnyHealthcare \n",
" NoDocbcCost \n",
" GenHlth \n",
" MentHlth \n",
" PhysHlth \n",
" DiffWalk \n",
" Sex \n",
" Age \n",
" Education \n",
" Income \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 26.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 3.0 \n",
" 5.0 \n",
" 30.0 \n",
" 0.0 \n",
" 1.0 \n",
" 4.0 \n",
" 6.0 \n",
" 8.0 \n",
" \n",
" \n",
" 1 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 26.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 3.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 12.0 \n",
" 6.0 \n",
" 8.0 \n",
" \n",
" \n",
" 2 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 26.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 10.0 \n",
" 0.0 \n",
" 1.0 \n",
" 13.0 \n",
" 6.0 \n",
" 8.0 \n",
" \n",
" \n",
" 3 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 28.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 3.0 \n",
" 0.0 \n",
" 3.0 \n",
" 0.0 \n",
" 1.0 \n",
" 11.0 \n",
" 6.0 \n",
" 8.0 \n",
" \n",
" \n",
" 4 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 29.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 2.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 8.0 \n",
" 5.0 \n",
" 8.0 \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 70687 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 37.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 4.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 6.0 \n",
" 4.0 \n",
" 1.0 \n",
" \n",
" \n",
" 70688 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 29.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 2.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 10.0 \n",
" 3.0 \n",
" 6.0 \n",
" \n",
" \n",
" 70689 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 25.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 5.0 \n",
" 15.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 13.0 \n",
" 6.0 \n",
" 4.0 \n",
" \n",
" \n",
" 70690 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 18.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 4.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 11.0 \n",
" 2.0 \n",
" 4.0 \n",
" \n",
" \n",
" 70691 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 25.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 2.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 9.0 \n",
" 6.0 \n",
" 2.0 \n",
" \n",
" \n",
"
\n",
"
70692 rows × 22 columns
\n",
"
"
],
"text/plain": [
" Diabetes_binary HighBP HighChol CholCheck BMI Smoker Stroke \\\n",
"0 0.0 1.0 0.0 1.0 26.0 0.0 0.0 \n",
"1 0.0 1.0 1.0 1.0 26.0 1.0 1.0 \n",
"2 0.0 0.0 0.0 1.0 26.0 0.0 0.0 \n",
"3 0.0 1.0 1.0 1.0 28.0 1.0 0.0 \n",
"4 0.0 0.0 0.0 1.0 29.0 1.0 0.0 \n",
"... ... ... ... ... ... ... ... \n",
"70687 1.0 0.0 1.0 1.0 37.0 0.0 0.0 \n",
"70688 1.0 0.0 1.0 1.0 29.0 1.0 0.0 \n",
"70689 1.0 1.0 1.0 1.0 25.0 0.0 0.0 \n",
"70690 1.0 1.0 1.0 1.0 18.0 0.0 0.0 \n",
"70691 1.0 1.0 1.0 1.0 25.0 0.0 0.0 \n",
"\n",
" HeartDiseaseorAttack PhysActivity Fruits Veggies HvyAlcoholConsump \\\n",
"0 0.0 1.0 0.0 1.0 0.0 \n",
"1 0.0 0.0 1.0 0.0 0.0 \n",
"2 0.0 1.0 1.0 1.0 0.0 \n",
"3 0.0 1.0 1.0 1.0 0.0 \n",
"4 0.0 1.0 1.0 1.0 0.0 \n",
"... ... ... ... ... ... \n",
"70687 0.0 0.0 0.0 1.0 0.0 \n",
"70688 1.0 0.0 1.0 1.0 0.0 \n",
"70689 1.0 0.0 1.0 0.0 0.0 \n",
"70690 0.0 0.0 0.0 0.0 0.0 \n",
"70691 1.0 1.0 1.0 0.0 0.0 \n",
"\n",
" AnyHealthcare NoDocbcCost GenHlth MentHlth PhysHlth DiffWalk Sex \\\n",
"0 1.0 0.0 3.0 5.0 30.0 0.0 1.0 \n",
"1 1.0 0.0 3.0 0.0 0.0 0.0 1.0 \n",
"2 1.0 0.0 1.0 0.0 10.0 0.0 1.0 \n",
"3 1.0 0.0 3.0 0.0 3.0 0.0 1.0 \n",
"4 1.0 0.0 2.0 0.0 0.0 0.0 0.0 \n",
"... ... ... ... ... ... ... ... \n",
"70687 1.0 0.0 4.0 0.0 0.0 0.0 0.0 \n",
"70688 1.0 0.0 2.0 0.0 0.0 1.0 1.0 \n",
"70689 1.0 0.0 5.0 15.0 0.0 1.0 0.0 \n",
"70690 1.0 0.0 4.0 0.0 0.0 1.0 0.0 \n",
"70691 1.0 0.0 2.0 0.0 0.0 0.0 0.0 \n",
"\n",
" Age Education Income \n",
"0 4.0 6.0 8.0 \n",
"1 12.0 6.0 8.0 \n",
"2 13.0 6.0 8.0 \n",
"3 11.0 6.0 8.0 \n",
"4 8.0 5.0 8.0 \n",
"... ... ... ... \n",
"70687 6.0 4.0 1.0 \n",
"70688 10.0 3.0 6.0 \n",
"70689 13.0 6.0 4.0 \n",
"70690 11.0 2.0 4.0 \n",
"70691 9.0 6.0 2.0 \n",
"\n",
"[70692 rows x 22 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c3b7f023-db0e-42c5-a9bb-94d1872bd912",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70692, 22)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "069930e8-d2d9-4c9a-9d9e-66cb2e44c9f7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"df.plot.scatter(x='BMI', y='Diabetes_binary', title='Scatterplot of BMI and Diabetes') # Binary response"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "efb7daab-6b7e-45e1-8189-9ecd5ccd422c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.regplot(x='BMI', y='Diabetes_binary', data=df, logistic=True, ci=None) # Logistic curve fit to BMI"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b4d6beb3-5864-4dda-ae3d-54027e47d09f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.regplot(x='Age', y='Diabetes_binary', data=df, logistic=True, ci=None) # Logistic curve fit to Age"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "287e7f96-d94b-4b45-b932-c3e8e41310f3",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Diabetes_binary \n",
" HighBP \n",
" HighChol \n",
" CholCheck \n",
" BMI \n",
" Smoker \n",
" Stroke \n",
" HeartDiseaseorAttack \n",
" PhysActivity \n",
" Fruits \n",
" Veggies \n",
" HvyAlcoholConsump \n",
" AnyHealthcare \n",
" NoDocbcCost \n",
" GenHlth \n",
" MentHlth \n",
" PhysHlth \n",
" DiffWalk \n",
" Sex \n",
" Age \n",
" Education \n",
" Income \n",
" \n",
" \n",
" \n",
" \n",
" count \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" 70692.000000 \n",
" \n",
" \n",
" mean \n",
" 0.500000 \n",
" 0.563458 \n",
" 0.525703 \n",
" 0.975259 \n",
" 29.856985 \n",
" 0.475273 \n",
" 0.062171 \n",
" 0.147810 \n",
" 0.703036 \n",
" 0.611795 \n",
" 0.788774 \n",
" 0.042721 \n",
" 0.954960 \n",
" 0.093914 \n",
" 2.837082 \n",
" 3.752037 \n",
" 5.810417 \n",
" 0.252730 \n",
" 0.456997 \n",
" 8.584055 \n",
" 4.920953 \n",
" 5.698311 \n",
" \n",
" \n",
" std \n",
" 0.500004 \n",
" 0.495960 \n",
" 0.499342 \n",
" 0.155336 \n",
" 7.113954 \n",
" 0.499392 \n",
" 0.241468 \n",
" 0.354914 \n",
" 0.456924 \n",
" 0.487345 \n",
" 0.408181 \n",
" 0.202228 \n",
" 0.207394 \n",
" 0.291712 \n",
" 1.113565 \n",
" 8.155627 \n",
" 10.062261 \n",
" 0.434581 \n",
" 0.498151 \n",
" 2.852153 \n",
" 1.029081 \n",
" 2.175196 \n",
" \n",
" \n",
" min \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 12.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" \n",
" \n",
" 25% \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 25.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 2.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 7.000000 \n",
" 4.000000 \n",
" 4.000000 \n",
" \n",
" \n",
" 50% \n",
" 0.500000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 29.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 3.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 9.000000 \n",
" 5.000000 \n",
" 6.000000 \n",
" \n",
" \n",
" 75% \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 33.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 1.000000 \n",
" 0.000000 \n",
" 4.000000 \n",
" 2.000000 \n",
" 6.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 11.000000 \n",
" 6.000000 \n",
" 8.000000 \n",
" \n",
" \n",
" max \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 98.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 5.000000 \n",
" 30.000000 \n",
" 30.000000 \n",
" 1.000000 \n",
" 1.000000 \n",
" 13.000000 \n",
" 6.000000 \n",
" 8.000000 \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Diabetes_binary HighBP HighChol CholCheck \\\n",
"count 70692.000000 70692.000000 70692.000000 70692.000000 \n",
"mean 0.500000 0.563458 0.525703 0.975259 \n",
"std 0.500004 0.495960 0.499342 0.155336 \n",
"min 0.000000 0.000000 0.000000 0.000000 \n",
"25% 0.000000 0.000000 0.000000 1.000000 \n",
"50% 0.500000 1.000000 1.000000 1.000000 \n",
"75% 1.000000 1.000000 1.000000 1.000000 \n",
"max 1.000000 1.000000 1.000000 1.000000 \n",
"\n",
" BMI Smoker Stroke HeartDiseaseorAttack \\\n",
"count 70692.000000 70692.000000 70692.000000 70692.000000 \n",
"mean 29.856985 0.475273 0.062171 0.147810 \n",
"std 7.113954 0.499392 0.241468 0.354914 \n",
"min 12.000000 0.000000 0.000000 0.000000 \n",
"25% 25.000000 0.000000 0.000000 0.000000 \n",
"50% 29.000000 0.000000 0.000000 0.000000 \n",
"75% 33.000000 1.000000 0.000000 0.000000 \n",
"max 98.000000 1.000000 1.000000 1.000000 \n",
"\n",
" PhysActivity Fruits Veggies HvyAlcoholConsump \\\n",
"count 70692.000000 70692.000000 70692.000000 70692.000000 \n",
"mean 0.703036 0.611795 0.788774 0.042721 \n",
"std 0.456924 0.487345 0.408181 0.202228 \n",
"min 0.000000 0.000000 0.000000 0.000000 \n",
"25% 0.000000 0.000000 1.000000 0.000000 \n",
"50% 1.000000 1.000000 1.000000 0.000000 \n",
"75% 1.000000 1.000000 1.000000 0.000000 \n",
"max 1.000000 1.000000 1.000000 1.000000 \n",
"\n",
" AnyHealthcare NoDocbcCost GenHlth MentHlth PhysHlth \\\n",
"count 70692.000000 70692.000000 70692.000000 70692.000000 70692.000000 \n",
"mean 0.954960 0.093914 2.837082 3.752037 5.810417 \n",
"std 0.207394 0.291712 1.113565 8.155627 10.062261 \n",
"min 0.000000 0.000000 1.000000 0.000000 0.000000 \n",
"25% 1.000000 0.000000 2.000000 0.000000 0.000000 \n",
"50% 1.000000 0.000000 3.000000 0.000000 0.000000 \n",
"75% 1.000000 0.000000 4.000000 2.000000 6.000000 \n",
"max 1.000000 1.000000 5.000000 30.000000 30.000000 \n",
"\n",
" DiffWalk Sex Age Education Income \n",
"count 70692.000000 70692.000000 70692.000000 70692.000000 70692.000000 \n",
"mean 0.252730 0.456997 8.584055 4.920953 5.698311 \n",
"std 0.434581 0.498151 2.852153 1.029081 2.175196 \n",
"min 0.000000 0.000000 1.000000 1.000000 1.000000 \n",
"25% 0.000000 0.000000 7.000000 4.000000 4.000000 \n",
"50% 0.000000 0.000000 9.000000 5.000000 6.000000 \n",
"75% 1.000000 1.000000 11.000000 6.000000 8.000000 \n",
"max 1.000000 1.000000 13.000000 6.000000 8.000000 "
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.describe()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "68f1790f-b48f-4cbf-b063-be7917d9ab4f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Diabetes_binary 0\n",
"HighBP 0\n",
"HighChol 0\n",
"CholCheck 0\n",
"BMI 0\n",
"Smoker 0\n",
"Stroke 0\n",
"HeartDiseaseorAttack 0\n",
"PhysActivity 0\n",
"Fruits 0\n",
"Veggies 0\n",
"HvyAlcoholConsump 0\n",
"AnyHealthcare 0\n",
"NoDocbcCost 0\n",
"GenHlth 0\n",
"MentHlth 0\n",
"PhysHlth 0\n",
"DiffWalk 0\n",
"Sex 0\n",
"Age 0\n",
"Education 0\n",
"Income 0\n",
"dtype: int64"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we check the number of null values in the dataset\n",
"\n",
"df.isna().sum() # no null values here"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "bc6a2628-103a-478a-8c6a-0cdcb6583975",
"metadata": {},
"outputs": [],
"source": [
"X = df.drop('Diabetes_binary', axis = 1)\n",
"y = df['Diabetes_binary']"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9858bc19-cdea-48aa-9841-da58298f14a1",
"metadata": {},
"outputs": [],
"source": [
"# We will use the package statsmodel because it takes a more statistical approach, with detailed outputs\n",
"# The same models can be produced with packages like scikit-learn\n",
"\n",
"import statsmodels.api as sm"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6c60c285-77b9-4f44-844f-f51e1836a43f",
"metadata": {},
"outputs": [],
"source": [
"X = sm.add_constant(X) # this step is unique to the statsmodels package, other packages like the notorious scikit-learn do this by themselves"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "33ea2794-0086-4225-9ec2-e6e2268d45b8",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"train_X, test_X, train_y, test_y = train_test_split(X, y, test_size = 0.20) # Notice here I used a lesser portion of the dataset"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "65abdf81-0f2d-4661-acba-ed1c2808689a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimization terminated successfully.\n",
" Current function value: 0.511743\n",
" Iterations 6\n"
]
}
],
"source": [
"model = sm.Logit(train_y, train_X).fit()"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "8da35cbe-cdeb-4280-b2a1-2bbedf87a6ca",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"Logit Regression Results \n",
"\n",
" Dep. Variable: Diabetes_binary No. Observations: 56553 \n",
" \n",
"\n",
" Model: Logit Df Residuals: 56531 \n",
" \n",
"\n",
" Method: MLE Df Model: 21 \n",
" \n",
"\n",
" Date: Mon, 03 Apr 2023 Pseudo R-squ.: 0.2617 \n",
" \n",
"\n",
" Time: 12:16:01 Log-Likelihood: -28941. \n",
" \n",
"\n",
" converged: True LL-Null: -39199. \n",
" \n",
"\n",
" Covariance Type: nonrobust LLR p-value: 0.000 \n",
" \n",
"
\n",
"\n",
"\n",
" coef std err z P>|z| [0.025 0.975] \n",
" \n",
"\n",
" const -6.9389 0.139 -49.830 0.000 -7.212 -6.666 \n",
" \n",
"\n",
" HighBP 0.7223 0.022 32.766 0.000 0.679 0.766 \n",
" \n",
"\n",
" HighChol 0.5941 0.021 28.170 0.000 0.553 0.635 \n",
" \n",
"\n",
" CholCheck 1.3655 0.090 15.112 0.000 1.188 1.543 \n",
" \n",
"\n",
" BMI 0.0771 0.002 43.691 0.000 0.074 0.081 \n",
" \n",
"\n",
" Smoker 0.0056 0.021 0.267 0.790 -0.036 0.047 \n",
" \n",
"\n",
" Stroke 0.1593 0.046 3.478 0.001 0.070 0.249 \n",
" \n",
"\n",
" HeartDiseaseorAttack 0.2735 0.032 8.584 0.000 0.211 0.336 \n",
" \n",
"\n",
" PhysActivity -0.0447 0.024 -1.877 0.061 -0.091 0.002 \n",
" \n",
"\n",
" Fruits -0.0434 0.022 -1.981 0.048 -0.086 -0.000 \n",
" \n",
"\n",
" Veggies -0.0570 0.026 -2.183 0.029 -0.108 -0.006 \n",
" \n",
"\n",
" HvyAlcoholConsump -0.7460 0.054 -13.688 0.000 -0.853 -0.639 \n",
" \n",
"\n",
" AnyHealthcare 0.1016 0.053 1.928 0.054 -0.002 0.205 \n",
" \n",
"\n",
" NoDocbcCost 0.0445 0.038 1.166 0.243 -0.030 0.119 \n",
" \n",
"\n",
" GenHlth 0.5794 0.013 45.165 0.000 0.554 0.605 \n",
" \n",
"\n",
" MentHlth -0.0038 0.001 -2.611 0.009 -0.007 -0.001 \n",
" \n",
"\n",
" PhysHlth -0.0082 0.001 -6.165 0.000 -0.011 -0.006 \n",
" \n",
"\n",
" DiffWalk 0.0943 0.029 3.267 0.001 0.038 0.151 \n",
" \n",
"\n",
" Sex 0.2740 0.021 12.781 0.000 0.232 0.316 \n",
" \n",
"\n",
" Age 0.1538 0.004 35.146 0.000 0.145 0.162 \n",
" \n",
"\n",
" Education -0.0396 0.011 -3.467 0.001 -0.062 -0.017 \n",
" \n",
"\n",
" Income -0.0587 0.006 -10.107 0.000 -0.070 -0.047 \n",
" \n",
"
"
],
"text/plain": [
"\n",
"\"\"\"\n",
" Logit Regression Results \n",
"==============================================================================\n",
"Dep. Variable: Diabetes_binary No. Observations: 56553\n",
"Model: Logit Df Residuals: 56531\n",
"Method: MLE Df Model: 21\n",
"Date: Mon, 03 Apr 2023 Pseudo R-squ.: 0.2617\n",
"Time: 12:16:01 Log-Likelihood: -28941.\n",
"converged: True LL-Null: -39199.\n",
"Covariance Type: nonrobust LLR p-value: 0.000\n",
"========================================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"----------------------------------------------------------------------------------------\n",
"const -6.9389 0.139 -49.830 0.000 -7.212 -6.666\n",
"HighBP 0.7223 0.022 32.766 0.000 0.679 0.766\n",
"HighChol 0.5941 0.021 28.170 0.000 0.553 0.635\n",
"CholCheck 1.3655 0.090 15.112 0.000 1.188 1.543\n",
"BMI 0.0771 0.002 43.691 0.000 0.074 0.081\n",
"Smoker 0.0056 0.021 0.267 0.790 -0.036 0.047\n",
"Stroke 0.1593 0.046 3.478 0.001 0.070 0.249\n",
"HeartDiseaseorAttack 0.2735 0.032 8.584 0.000 0.211 0.336\n",
"PhysActivity -0.0447 0.024 -1.877 0.061 -0.091 0.002\n",
"Fruits -0.0434 0.022 -1.981 0.048 -0.086 -0.000\n",
"Veggies -0.0570 0.026 -2.183 0.029 -0.108 -0.006\n",
"HvyAlcoholConsump -0.7460 0.054 -13.688 0.000 -0.853 -0.639\n",
"AnyHealthcare 0.1016 0.053 1.928 0.054 -0.002 0.205\n",
"NoDocbcCost 0.0445 0.038 1.166 0.243 -0.030 0.119\n",
"GenHlth 0.5794 0.013 45.165 0.000 0.554 0.605\n",
"MentHlth -0.0038 0.001 -2.611 0.009 -0.007 -0.001\n",
"PhysHlth -0.0082 0.001 -6.165 0.000 -0.011 -0.006\n",
"DiffWalk 0.0943 0.029 3.267 0.001 0.038 0.151\n",
"Sex 0.2740 0.021 12.781 0.000 0.232 0.316\n",
"Age 0.1538 0.004 35.146 0.000 0.145 0.162\n",
"Education -0.0396 0.011 -3.467 0.001 -0.062 -0.017\n",
"Income -0.0587 0.006 -10.107 0.000 -0.070 -0.047\n",
"========================================================================================\n",
"\"\"\""
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"id": "36ab0a85-6156-4f45-a356-f3affb7bd72d",
"metadata": {},
"source": [
"The output structure is similar to the one of linear regression, with a few meaningful differences:\n",
"\n",
"- We have a pseudo-R-squared, which can be thought of as the substitute to the R-squared value for a linear regression model. It is calculated as the ratio of the maximized log-likelihood function of the null model to the one of the full model. This value can range from 0 to 1, with higher values indicating a better model fit. In this example, the pseudo R-squared value is 0.2646, which isn't much. This tells us that the model isn't a very good fit.\n",
"- The coefficients interpretation, that we'll see below in practice, is different than in linear regression, They are now expressed as logarithm of the odds-ratio, so we need to take the exponential to get the odds ratio. The coefficient will then represent the increase (or decrease) in the odds of having diabetes for a one-unit increase of the independent variable the coefficient refers to, while holding all the other variables constant. If we are dealing with a dummy variable, of course, this would mean the increase (decrease) of the odds if the category is present (e.g if you are a smoker)\n",
"- The rest you can interpret as you did with linear regression. There is the standard error, the z value, the P>|z| which represent the result of the test (we need to look at this value to understand if the coefficient is significatively different from 0) and the 95% confidence interval for the coefficient estimate."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "722099b8-ba9b-4079-a532-9236f32e231d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.313112141132595"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Let's take a few examples\n",
"\n",
"# First, let's see dummy variables\n",
"\n",
"np.exp(0.2724) # the odds of diabetes are 31.3% higher if you are a male, according to the model"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "2531d035-03ae-4498-8060-e6670af4e714",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9185122844014574"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.exp(-0.0850) # the odds of diabetes are 8.15% lower if you are a person who eats vegetables regularly"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "cbd6a4c2-be60-462c-8b2e-57d7950e98c0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.781222631241982"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.exp(0.5773) # the odds of diabetes are 78.1% higher if you have high cholesterol level"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "a99ec496-0e11-4a20-9714-239ac3dc669c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.078315390787216"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Now we see some continuous / ordinal variables\n",
"\n",
"np.exp(0.0754) # on average, for a 1-unit increase of BMI, the odds of diabetes increase of 7.8%"
]
},
{
"cell_type": "markdown",
"id": "655e0b4a-0613-4745-aa6e-c0335eea141f",
"metadata": {},
"source": [
"If we want to see how much the odds increase going e.g. from BMI 24 to BMI 27, we need to raise the log-odds to a power corresponding to the number of steps of explanatory variable we are taking. In this case, 3"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "f4aa04d7-f072-415d-9d1a-2979f75410c3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.2538264054844275"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# How the odds of diabetes increase going from BMI 24 to BMI 27\n",
"\n",
"np.exp(0.0754)**3 # 25.38% increase in the odds, while holding all other variables constant"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5cdd83d6-3721-4423-bbc1-de7fbdce0941",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9916351814230984"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# The same can be done with other variables with more than 1 level (or continuous variables)\n",
"\n",
"np.exp(-0.0084) # Physical health, which has 30 levels"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "24e14adb-10ce-4f42-aa16-6a5a3a466716",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.7763126037756957"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.exp(-0.0422)**6 # Education, which has 6 levels"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "300bd0e6-38c2-4464-861e-448e928c1735",
"metadata": {},
"outputs": [],
"source": [
"from statsmodels.stats.outliers_influence import variance_inflation_factor\n",
"vif = pd.DataFrame()\n",
"vif['VIF'] = [variance_inflation_factor(X.values,i) for i in range(X.shape[1])]\n",
"vif['features'] = X.columns"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "fe3442e6-c525-4b4b-b4e2-3d5daee85002",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" VIF \n",
" features \n",
" \n",
" \n",
" \n",
" \n",
" 0 \n",
" 126.858651 \n",
" const \n",
" \n",
" \n",
" 1 \n",
" 1.357474 \n",
" HighBP \n",
" \n",
" \n",
" 2 \n",
" 1.179017 \n",
" HighChol \n",
" \n",
" \n",
" 3 \n",
" 1.032894 \n",
" CholCheck \n",
" \n",
" \n",
" 4 \n",
" 1.180737 \n",
" BMI \n",
" \n",
" \n",
" 5 \n",
" 1.082180 \n",
" Smoker \n",
" \n",
" \n",
" 6 \n",
" 1.094027 \n",
" Stroke \n",
" \n",
" \n",
" 7 \n",
" 1.193935 \n",
" HeartDiseaseorAttack \n",
" \n",
" \n",
" 8 \n",
" 1.167430 \n",
" PhysActivity \n",
" \n",
" \n",
" 9 \n",
" 1.101333 \n",
" Fruits \n",
" \n",
" \n",
" 10 \n",
" 1.103377 \n",
" Veggies \n",
" \n",
" \n",
" 11 \n",
" 1.022839 \n",
" HvyAlcoholConsump \n",
" \n",
" \n",
" 12 \n",
" 1.095602 \n",
" AnyHealthcare \n",
" \n",
" \n",
" 13 \n",
" 1.143393 \n",
" NoDocbcCost \n",
" \n",
" \n",
" 14 \n",
" 1.890908 \n",
" GenHlth \n",
" \n",
" \n",
" 15 \n",
" 1.267156 \n",
" MentHlth \n",
" \n",
" \n",
" 16 \n",
" 1.700399 \n",
" PhysHlth \n",
" \n",
" \n",
" 17 \n",
" 1.591225 \n",
" DiffWalk \n",
" \n",
" \n",
" 18 \n",
" 1.088941 \n",
" Sex \n",
" \n",
" \n",
" 19 \n",
" 1.337148 \n",
" Age \n",
" \n",
" \n",
" 20 \n",
" 1.331728 \n",
" Education \n",
" \n",
" \n",
" 21 \n",
" 1.544598 \n",
" Income \n",
" \n",
" \n",
"
\n",
"
"
],
"text/plain": [
" VIF features\n",
"0 126.858651 const\n",
"1 1.357474 HighBP\n",
"2 1.179017 HighChol\n",
"3 1.032894 CholCheck\n",
"4 1.180737 BMI\n",
"5 1.082180 Smoker\n",
"6 1.094027 Stroke\n",
"7 1.193935 HeartDiseaseorAttack\n",
"8 1.167430 PhysActivity\n",
"9 1.101333 Fruits\n",
"10 1.103377 Veggies\n",
"11 1.022839 HvyAlcoholConsump\n",
"12 1.095602 AnyHealthcare\n",
"13 1.143393 NoDocbcCost\n",
"14 1.890908 GenHlth\n",
"15 1.267156 MentHlth\n",
"16 1.700399 PhysHlth\n",
"17 1.591225 DiffWalk\n",
"18 1.088941 Sex\n",
"19 1.337148 Age\n",
"20 1.331728 Education\n",
"21 1.544598 Income"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vif # No multicollinearity here"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "f5d00c03-6ecd-48a6-8dc4-dffda9919f78",
"metadata": {},
"outputs": [],
"source": [
"# We obtain the predictions on the test set\n",
"\n",
"predicted_probs = model.predict(test_X) # with the fitted model, the method .predict() outputs predictions for test data."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "2bee9500-1b5f-4d60-b2b4-3af73cf322df",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" 0 \n",
" Diabetes_binary \n",
" \n",
" \n",
" \n",
" \n",
" 18157 \n",
" 0.290604 \n",
" 0.0 \n",
" \n",
" \n",
" 36808 \n",
" 0.858143 \n",
" 1.0 \n",
" \n",
" \n",
" 8627 \n",
" 0.041202 \n",
" 0.0 \n",
" \n",
" \n",
" 37717 \n",
" 0.888710 \n",
" 1.0 \n",
" \n",
" \n",
" 32252 \n",
" 0.478378 \n",
" 0.0 \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 64399 \n",
" 0.821172 \n",
" 1.0 \n",
" \n",
" \n",
" 64005 \n",
" 0.175046 \n",
" 1.0 \n",
" \n",
" \n",
" 47158 \n",
" 0.856522 \n",
" 1.0 \n",
" \n",
" \n",
" 15484 \n",
" 0.655594 \n",
" 0.0 \n",
" \n",
" \n",
" 53889 \n",
" 0.477339 \n",
" 1.0 \n",
" \n",
" \n",
"
\n",
"
14139 rows × 2 columns
\n",
"
"
],
"text/plain": [
" 0 Diabetes_binary\n",
"18157 0.290604 0.0\n",
"36808 0.858143 1.0\n",
"8627 0.041202 0.0\n",
"37717 0.888710 1.0\n",
"32252 0.478378 0.0\n",
"... ... ...\n",
"64399 0.821172 1.0\n",
"64005 0.175046 1.0\n",
"47158 0.856522 1.0\n",
"15484 0.655594 0.0\n",
"53889 0.477339 1.0\n",
"\n",
"[14139 rows x 2 columns]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"check_predictions = pd.concat([predicted_probs,test_y],axis=1)\n",
"check_predictions # The output is in probability of being class 1 "
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "d255a319-4801-4363-867f-dc2fe484ea66",
"metadata": {},
"outputs": [],
"source": [
"threshold = 0.8\n",
"prediction = model.predict(test_X) # We need to convert to discrete class by applying a threshold\n",
"predicted_choice = (prediction > threshold).astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "1169582c-d3bb-4224-a060-44d0143e3f19",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" 0 \n",
" Diabetes_binary \n",
" \n",
" \n",
" \n",
" \n",
" 18157 \n",
" 0 \n",
" 0.0 \n",
" \n",
" \n",
" 36808 \n",
" 1 \n",
" 1.0 \n",
" \n",
" \n",
" 8627 \n",
" 0 \n",
" 0.0 \n",
" \n",
" \n",
" 37717 \n",
" 1 \n",
" 1.0 \n",
" \n",
" \n",
" 32252 \n",
" 0 \n",
" 0.0 \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 64399 \n",
" 1 \n",
" 1.0 \n",
" \n",
" \n",
" 64005 \n",
" 0 \n",
" 1.0 \n",
" \n",
" \n",
" 47158 \n",
" 1 \n",
" 1.0 \n",
" \n",
" \n",
" 15484 \n",
" 0 \n",
" 0.0 \n",
" \n",
" \n",
" 53889 \n",
" 0 \n",
" 1.0 \n",
" \n",
" \n",
"
\n",
"
14139 rows × 2 columns
\n",
"
"
],
"text/plain": [
" 0 Diabetes_binary\n",
"18157 0 0.0\n",
"36808 1 1.0\n",
"8627 0 0.0\n",
"37717 1 1.0\n",
"32252 0 0.0\n",
"... .. ...\n",
"64399 1 1.0\n",
"64005 0 1.0\n",
"47158 1 1.0\n",
"15484 0 0.0\n",
"53889 0 1.0\n",
"\n",
"[14139 rows x 2 columns]"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"check_discrete_predictions = pd.concat([predicted_choice,test_y],axis=1)\n",
"check_discrete_predictions"
]
},
{
"cell_type": "markdown",
"id": "e40cc78c-de03-4bc6-abb9-5d3fec422806",
"metadata": {},
"source": [
"Now we see some of the classification scores, which employ discrete classes (actual and predicted)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "62ffe1d3-b257-4418-b4be-499a70d91dc4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.6266355470683924"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import accuracy_score # Proportion of correct prediction (TP+TN) on total predition\n",
"accuracy_score(test_y, predicted_choice, normalize=True, sample_weight=None)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "8da66a23-8bdd-463a-bbe9-23c8a58dd0bb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8439024390243902"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import precision_score\n",
"\n",
"# Proportion of correctly classified positives on the total of instances classified as positives\n",
"\n",
"precision_score(test_y, predicted_choice)\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "e456b2a0-181a-4e35-9ce8-8cd1788cbcad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.31622609673790775"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import recall_score\n",
"\n",
"# Proportion of correctly classified positives on the total of actual positive instances\n",
"\n",
"recall_score(test_y, predicted_choice)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "17455902-5fe5-41f5-b6bd-564afdf1aa55",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9407997723068166"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Proportion of correctly classified negatives on the total of actual negative instances\n",
"\n",
"\n",
"recall_score(test_y, predicted_choice, pos_label = 0)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "5e58bea5-73dd-4262-9b85-02d64524e21f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# The ROC curve shows the trade-off between true positive rate and false positive rate. \n",
"# The less false positives you have to admit to have more true positives, the better your model.\n",
"\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.metrics import roc_curve\n",
"\n",
"fpr, tpr, _ = roc_curve(test_y, predicted_probs)\n",
"\n",
"#create ROC curve\n",
"plt.axline((0,0),(1,1), linestyle='dotted', color = 'red')\n",
"plt.plot(fpr,tpr)\n",
"plt.ylabel('True Positive Rate')\n",
"plt.xlabel('False Positive Rate')\n",
"plt.title('ROC curve')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "f458924d-73a6-421b-9acb-6968a4064882",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8246096668274371"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import roc_auc_score\n",
"# This computes the Area Under the ROC curve (AUROC, or AUC): it goes from 0.5 to 1, with 0.5 being random chance\n",
"roc_auc_score(test_y, predicted_probs)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "70ae4724-4986-47b5-b7c5-4e53cf59013b",
"metadata": {},
"outputs": [],
"source": [
"# Now, as an experiment, we try the same pipeline but with unbalanced dataset. We remove 90% of diabetes cases at random.\n",
"\n",
"no_diabetes = df.loc[df['Diabetes_binary'] == 0] # Here only class-0 observations are selected\n",
"diabetes = df.loc[df['Diabetes_binary'] == 1] # Here only class-1 observations are selected"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "d78802d3-525c-449f-b013-5872404ba9bc",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" \n",
" Diabetes_binary \n",
" HighBP \n",
" HighChol \n",
" CholCheck \n",
" BMI \n",
" Smoker \n",
" Stroke \n",
" HeartDiseaseorAttack \n",
" PhysActivity \n",
" Fruits \n",
" Veggies \n",
" HvyAlcoholConsump \n",
" AnyHealthcare \n",
" NoDocbcCost \n",
" GenHlth \n",
" MentHlth \n",
" PhysHlth \n",
" DiffWalk \n",
" Sex \n",
" Age \n",
" Education \n",
" Income \n",
" \n",
" \n",
" \n",
" \n",
" 52325 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 33.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 4.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 9.0 \n",
" 4.0 \n",
" 3.0 \n",
" \n",
" \n",
" 48393 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 23.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 4.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 13.0 \n",
" 4.0 \n",
" 1.0 \n",
" \n",
" \n",
" 37397 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 34.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 5.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 10.0 \n",
" 5.0 \n",
" 7.0 \n",
" \n",
" \n",
" 43839 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 25.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 3.0 \n",
" 10.0 \n",
" 2.0 \n",
" 0.0 \n",
" 0.0 \n",
" 12.0 \n",
" 4.0 \n",
" 3.0 \n",
" \n",
" \n",
" 50465 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 43.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 3.0 \n",
" 21.0 \n",
" 30.0 \n",
" 1.0 \n",
" 0.0 \n",
" 9.0 \n",
" 4.0 \n",
" 2.0 \n",
" \n",
" \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" ... \n",
" \n",
" \n",
" 64590 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 39.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 4.0 \n",
" 3.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 4.0 \n",
" 5.0 \n",
" 3.0 \n",
" \n",
" \n",
" 40081 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 29.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 3.0 \n",
" 15.0 \n",
" 15.0 \n",
" 1.0 \n",
" 0.0 \n",
" 6.0 \n",
" 5.0 \n",
" 5.0 \n",
" \n",
" \n",
" 49312 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 29.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 4.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 11.0 \n",
" 4.0 \n",
" 4.0 \n",
" \n",
" \n",
" 48587 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 26.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 2.0 \n",
" 0.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 9.0 \n",
" 6.0 \n",
" 5.0 \n",
" \n",
" \n",
" 65184 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 25.0 \n",
" 1.0 \n",
" 0.0 \n",
" 0.0 \n",
" 1.0 \n",
" 0.0 \n",
" 1.0 \n",
" 1.0 \n",
" 1.0 \n",
" 0.0 \n",
" 2.0 \n",
" 0.0 \n",
" 5.0 \n",
" 1.0 \n",
" 0.0 \n",
" 11.0 \n",
" 4.0 \n",
" 6.0 \n",
" \n",
" \n",
"
\n",
"
3500 rows × 22 columns
\n",
"
"
],
"text/plain": [
" Diabetes_binary HighBP HighChol CholCheck BMI Smoker Stroke \\\n",
"52325 1.0 1.0 0.0 1.0 33.0 0.0 0.0 \n",
"48393 1.0 1.0 1.0 1.0 23.0 0.0 0.0 \n",
"37397 1.0 1.0 1.0 1.0 34.0 1.0 0.0 \n",
"43839 1.0 1.0 0.0 1.0 25.0 0.0 0.0 \n",
"50465 1.0 1.0 1.0 1.0 43.0 0.0 0.0 \n",
"... ... ... ... ... ... ... ... \n",
"64590 1.0 1.0 1.0 1.0 39.0 0.0 0.0 \n",
"40081 1.0 0.0 1.0 1.0 29.0 1.0 0.0 \n",
"49312 1.0 1.0 1.0 1.0 29.0 1.0 0.0 \n",
"48587 1.0 0.0 1.0 1.0 26.0 1.0 0.0 \n",
"65184 1.0 0.0 1.0 1.0 25.0 1.0 0.0 \n",
"\n",
" HeartDiseaseorAttack PhysActivity Fruits Veggies HvyAlcoholConsump \\\n",
"52325 0.0 0.0 0.0 0.0 0.0 \n",
"48393 0.0 0.0 1.0 1.0 1.0 \n",
"37397 1.0 0.0 0.0 1.0 0.0 \n",
"43839 0.0 1.0 0.0 1.0 0.0 \n",
"50465 0.0 1.0 0.0 0.0 0.0 \n",
"... ... ... ... ... ... \n",
"64590 0.0 1.0 1.0 1.0 0.0 \n",
"40081 0.0 1.0 0.0 1.0 0.0 \n",
"49312 0.0 1.0 1.0 1.0 0.0 \n",
"48587 0.0 1.0 1.0 1.0 0.0 \n",
"65184 0.0 1.0 0.0 1.0 1.0 \n",
"\n",
" AnyHealthcare NoDocbcCost GenHlth MentHlth PhysHlth DiffWalk Sex \\\n",
"52325 1.0 0.0 4.0 0.0 0.0 1.0 1.0 \n",
"48393 1.0 1.0 4.0 0.0 0.0 0.0 0.0 \n",
"37397 1.0 0.0 5.0 0.0 0.0 1.0 1.0 \n",
"43839 1.0 0.0 3.0 10.0 2.0 0.0 0.0 \n",
"50465 1.0 1.0 3.0 21.0 30.0 1.0 0.0 \n",
"... ... ... ... ... ... ... ... \n",
"64590 1.0 1.0 4.0 3.0 1.0 1.0 0.0 \n",
"40081 1.0 1.0 3.0 15.0 15.0 1.0 0.0 \n",
"49312 1.0 0.0 4.0 0.0 0.0 0.0 0.0 \n",
"48587 1.0 0.0 2.0 0.0 0.0 0.0 1.0 \n",
"65184 1.0 0.0 2.0 0.0 5.0 1.0 0.0 \n",
"\n",
" Age Education Income \n",
"52325 9.0 4.0 3.0 \n",
"48393 13.0 4.0 1.0 \n",
"37397 10.0 5.0 7.0 \n",
"43839 12.0 4.0 3.0 \n",
"50465 9.0 4.0 2.0 \n",
"... ... ... ... \n",
"64590 4.0 5.0 3.0 \n",
"40081 6.0 5.0 5.0 \n",
"49312 11.0 4.0 4.0 \n",
"48587 9.0 6.0 5.0 \n",
"65184 11.0 4.0 6.0 \n",
"\n",
"[3500 rows x 22 columns]"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"diabetes.sample(3500) # We sample just 3500 observation with diagnosis diabetes, around 10% of the total"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "d32c40ff-4831-4ef7-8125-72992f7ea4d7",
"metadata": {},
"outputs": [],
"source": [
"diabetes_reduced = diabetes.sample(3500)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "d6a83dce-1998-4e9a-9f3e-a8c16c1d6735",
"metadata": {},
"outputs": [],
"source": [
"unbalanced_df = pd.concat([no_diabetes,diabetes_reduced], axis = 0) # we merge the sample with original, class 0 data\n",
"unbalanced_df = unbalanced_df.sample(frac = 1) # we shuffle rows to ensure randomness"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "3a8c142a-fbcd-4fa7-a72e-c5cfd262003e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.09902110564137384"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"unbalanced_df.loc[unbalanced_df['Diabetes_binary'] == 1].shape[0]/len(unbalanced_df.loc[unbalanced_df['Diabetes_binary'] == 0])\n",
"\n",
"# The proportion of class 1 now is 10%"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "39e5d9bc-80da-40e4-aa68-47087b3df1be",
"metadata": {},
"outputs": [],
"source": [
"# Now we split the new dataset as we did before\n",
"\n",
"X = unbalanced_df.drop('Diabetes_binary', axis = 1)\n",
"y = unbalanced_df['Diabetes_binary']\n",
"\n",
"X = sm.add_constant(X)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "ac7ae7a1-c925-4c42-885b-c0d2e39abde3",
"metadata": {},
"outputs": [],
"source": [
"train_X, test_X, train_y, test_y = train_test_split(X, y, test_size = 0.20 ) # We preserve the proportion of response classes in the splits"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "ab99ba3d-8394-4326-8714-1b3312b6399b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7770, 22)"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_X.shape"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "8912d99d-5c72-4494-beb1-ee9f98136215",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimization terminated successfully.\n",
" Current function value: 0.244102\n",
" Iterations 8\n"
]
}
],
"source": [
"model = sm.Logit(train_y, train_X).fit()"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "87fdf335-ee80-4303-a6fb-74196a4ee962",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"Logit Regression Results \n",
"\n",
" Dep. Variable: Diabetes_binary No. Observations: 31076 \n",
" \n",
"\n",
" Model: Logit Df Residuals: 31054 \n",
" \n",
"\n",
" Method: MLE Df Model: 21 \n",
" \n",
"\n",
" Date: Mon, 03 Apr 2023 Pseudo R-squ.: 0.1910 \n",
" \n",
"\n",
" Time: 14:35:04 Log-Likelihood: -7585.7 \n",
" \n",
"\n",
" converged: True LL-Null: -9376.6 \n",
" \n",
"\n",
" Covariance Type: nonrobust LLR p-value: 0.000 \n",
" \n",
"
\n",
"\n",
"\n",
" coef std err z P>|z| [0.025 0.975] \n",
" \n",
"\n",
" const -8.3860 0.318 -26.346 0.000 -9.010 -7.762 \n",
" \n",
"\n",
" HighBP 0.7517 0.051 14.862 0.000 0.653 0.851 \n",
" \n",
"\n",
" HighChol 0.5900 0.046 12.730 0.000 0.499 0.681 \n",
" \n",
"\n",
" CholCheck 1.2239 0.230 5.316 0.000 0.773 1.675 \n",
" \n",
"\n",
" BMI 0.0576 0.003 19.975 0.000 0.052 0.063 \n",
" \n",
"\n",
" Smoker -0.0380 0.045 -0.851 0.395 -0.126 0.050 \n",
" \n",
"\n",
" Stroke 0.0740 0.084 0.884 0.377 -0.090 0.238 \n",
" \n",
"\n",
" HeartDiseaseorAttack 0.2613 0.059 4.455 0.000 0.146 0.376 \n",
" \n",
"\n",
" PhysActivity -0.0559 0.049 -1.152 0.249 -0.151 0.039 \n",
" \n",
"\n",
" Fruits -0.0224 0.046 -0.486 0.627 -0.113 0.068 \n",
" \n",
"\n",
" Veggies 0.0156 0.054 0.288 0.773 -0.090 0.121 \n",
" \n",
"\n",
" HvyAlcoholConsump -0.7854 0.133 -5.892 0.000 -1.047 -0.524 \n",
" \n",
"\n",
" AnyHealthcare 0.1871 0.119 1.571 0.116 -0.046 0.421 \n",
" \n",
"\n",
" NoDocbcCost -0.0838 0.079 -1.064 0.287 -0.238 0.071 \n",
" \n",
"\n",
" GenHlth 0.5191 0.028 18.779 0.000 0.465 0.573 \n",
" \n",
"\n",
" MentHlth -0.0049 0.003 -1.679 0.093 -0.011 0.001 \n",
" \n",
"\n",
" PhysHlth -0.0049 0.003 -1.871 0.061 -0.010 0.000 \n",
" \n",
"\n",
" DiffWalk 0.1807 0.057 3.183 0.001 0.069 0.292 \n",
" \n",
"\n",
" Sex 0.2738 0.045 6.039 0.000 0.185 0.363 \n",
" \n",
"\n",
" Age 0.1218 0.010 12.790 0.000 0.103 0.140 \n",
" \n",
"\n",
" Education -0.0351 0.024 -1.488 0.137 -0.081 0.011 \n",
" \n",
"\n",
" Income -0.0332 0.012 -2.730 0.006 -0.057 -0.009 \n",
" \n",
"
"
],
"text/plain": [
"\n",
"\"\"\"\n",
" Logit Regression Results \n",
"==============================================================================\n",
"Dep. Variable: Diabetes_binary No. Observations: 31076\n",
"Model: Logit Df Residuals: 31054\n",
"Method: MLE Df Model: 21\n",
"Date: Mon, 03 Apr 2023 Pseudo R-squ.: 0.1910\n",
"Time: 14:35:04 Log-Likelihood: -7585.7\n",
"converged: True LL-Null: -9376.6\n",
"Covariance Type: nonrobust LLR p-value: 0.000\n",
"========================================================================================\n",
" coef std err z P>|z| [0.025 0.975]\n",
"----------------------------------------------------------------------------------------\n",
"const -8.3860 0.318 -26.346 0.000 -9.010 -7.762\n",
"HighBP 0.7517 0.051 14.862 0.000 0.653 0.851\n",
"HighChol 0.5900 0.046 12.730 0.000 0.499 0.681\n",
"CholCheck 1.2239 0.230 5.316 0.000 0.773 1.675\n",
"BMI 0.0576 0.003 19.975 0.000 0.052 0.063\n",
"Smoker -0.0380 0.045 -0.851 0.395 -0.126 0.050\n",
"Stroke 0.0740 0.084 0.884 0.377 -0.090 0.238\n",
"HeartDiseaseorAttack 0.2613 0.059 4.455 0.000 0.146 0.376\n",
"PhysActivity -0.0559 0.049 -1.152 0.249 -0.151 0.039\n",
"Fruits -0.0224 0.046 -0.486 0.627 -0.113 0.068\n",
"Veggies 0.0156 0.054 0.288 0.773 -0.090 0.121\n",
"HvyAlcoholConsump -0.7854 0.133 -5.892 0.000 -1.047 -0.524\n",
"AnyHealthcare 0.1871 0.119 1.571 0.116 -0.046 0.421\n",
"NoDocbcCost -0.0838 0.079 -1.064 0.287 -0.238 0.071\n",
"GenHlth 0.5191 0.028 18.779 0.000 0.465 0.573\n",
"MentHlth -0.0049 0.003 -1.679 0.093 -0.011 0.001\n",
"PhysHlth -0.0049 0.003 -1.871 0.061 -0.010 0.000\n",
"DiffWalk 0.1807 0.057 3.183 0.001 0.069 0.292\n",
"Sex 0.2738 0.045 6.039 0.000 0.185 0.363\n",
"Age 0.1218 0.010 12.790 0.000 0.103 0.140\n",
"Education -0.0351 0.024 -1.488 0.137 -0.081 0.011\n",
"Income -0.0332 0.012 -2.730 0.006 -0.057 -0.009\n",
"========================================================================================\n",
"\"\"\""
]
},
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.summary() # Same as before, note however how less coefficients are now significant, due to the fact that we have less observations"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "134357d2-7f70-49a4-910c-c9f164c235d6",
"metadata": {},
"outputs": [],
"source": [
"predicted_probs = model.predict(test_X)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "ef93bfa8-19a3-43f4-b182-03de73daa17b",
"metadata": {},
"outputs": [],
"source": [
"threshold = 0.5\n",
"prediction = model.predict(test_X)\n",
"predicted_choice = (prediction > threshold).astype(int)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "467a7d93-ab7e-4095-9b5a-355e74866a05",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9073359073359073"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(test_y, predicted_choice) # Accuracy appears to be higher; in reality, though, this is due to class imbalancement"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "230cd7c7-4783-4fe2-ba00-d08ae044bb4d",
"metadata": {},
"outputs": [],
"source": [
"dumb_predictions_zero = np.zeros(test_y.shape) # I create a vector of just zeros"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "a5d122ef-df9d-4aca-87ca-28169fd5c8ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9081081081081082"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(test_y, dumb_predictions_zero)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "725cc099-d376-4cf9-8c81-90e37467ab55",
"metadata": {},
"outputs": [],
"source": [
"dumb_predictions_one = np.ones(test_y.shape) # I create a vector of just ones"
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "6a7fa3f5-434c-476a-8f63-1926d3df6fa1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0918918918918919"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"accuracy_score(test_y, dumb_predictions_one)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "8cbbc1ca-2c4f-44ce-9222-fbe2ef7ed5b1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0918918918918919"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.sum(test_y == 1)/len(test_y)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "4928ee93-d640-4173-82ab-28cf802d899a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.46511627906976744 0.056022408963585436 0.9934807256235828\n"
]
}
],
"source": [
"# These are precision, sensitivity and specificity with the unbalanced dataset (it mainly predicts 0)\n",
"\n",
"print(precision_score(test_y, predicted_choice),\n",
"recall_score(test_y, predicted_choice),\n",
"recall_score(test_y, predicted_choice, pos_label = 0))"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "31f4b8f7-9390-4c27-b184-371c20bdea33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0 0.056022408963585436 0.9934807256235828\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/alex/anaconda3/lib/python3.9/site-packages/sklearn/metrics/_classification.py:1327: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.\n",
" _warn_prf(average, modifier, msg_start, len(result))\n"
]
}
],
"source": [
"# These are precision, sensitivity and specificity with 'dumb' predictions zero\n",
"\n",
"print(precision_score(test_y, dumb_predictions_zero),\n",
"recall_score(test_y, predicted_choice),\n",
"recall_score(test_y, predicted_choice, pos_label = 0))"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "768df0fb-1a4f-4a95-88eb-cdd975c5fc78",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0918918918918919 0.056022408963585436 0.9934807256235828\n"
]
}
],
"source": [
"# These are precision, sensitivity and specificity with 'dumb' predictions one\n",
"\n",
"print(precision_score(test_y, dumb_predictions_one),\n",
"recall_score(test_y, predicted_choice),\n",
"recall_score(test_y, predicted_choice, pos_label = 0))"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "73103303-3d41-4932-a6bb-cb0cfa1c734f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.8250714968527095 0.5\n"
]
}
],
"source": [
"print(roc_auc_score(test_y, predicted_probs),roc_auc_score(test_y, dumb_predictions))"
]
},
{
"cell_type": "code",
"execution_count": 93,
"id": "6a511e73-fb48-49ff-9bc6-2f7845b175f6",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"fpr, tpr, _ = roc_curve(test_y, predicted_probs)\n",
"\n",
"#create ROC curve\n",
"\n",
"plt.axline((0,0),(1,1), linestyle='dotted', color = 'red')\n",
"plt.plot(fpr,tpr)\n",
"plt.ylabel('True Positive Rate')\n",
"plt.xlabel('False Positive Rate')\n",
"plt.title('Unbalanced dataset ROC curve')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78e801f1-42ad-4015-81c8-67d67a2e4d53",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}