{ "cells": [ { "cell_type": "code", "execution_count": 88, "id": "4213b546-ae18-4c2b-8500-d964c4fe6f3a", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "df_diabetes = pd.read_csv('~/diabetes_binary.csv')" ] }, { "cell_type": "code", "execution_count": 89, "id": "87929a3c-de6c-45f8-b458-28c2528bcef0", "metadata": {}, "outputs": [], "source": [ "X = df_diabetes.drop('Diabetes_binary', axis = 1)\n", "y = df_diabetes['Diabetes_binary']\n", "\n", "from sklearn.model_selection import train_test_split\n", "\n", "train_X, test_X, train_y, test_y = train_test_split(X, y, test_size = 0.20)" ] }, { "cell_type": "code", "execution_count": 90, "id": "7d185cbf-cb2f-4a41-ba41-ec29599b985c", "metadata": {}, "outputs": [], "source": [ "\n", "from sklearn.linear_model import LogisticRegression\n", "\n", "# instantiate the model (using the default parameters)\n", "logreg = LogisticRegression(random_state=16, max_iter = 500)\n", "\n", "# fit the model with data\n", "logreg.fit(train_X, train_y)\n", "\n", "y_pred = logreg.predict_proba(test_X)" ] }, { "cell_type": "code", "execution_count": 91, "id": "c8d24bb1-8579-45e8-a3ef-ee3ca5e91eb4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0.32650398, 0.67349602],\n", " [0.348918 , 0.651082 ],\n", " [0.49500037, 0.50499963],\n", " ...,\n", " [0.22088228, 0.77911772],\n", " [0.21823599, 0.78176401],\n", " [0.58515343, 0.41484657]])" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_pred" ] }, { "cell_type": "code", "execution_count": 92, "id": "ac3ccfa0-c011-4916-9bea-1cdaec466cd3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8273611694540879" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import roc_auc_score\n", "\n", "roc_auc_score(test_y, y_pred[:,1])" ] }, { "cell_type": "code", "execution_count": 93, "id": "073cec42-fbc7-4fde-94d8-6d97f4db4b27", "metadata": {}, "outputs": [], "source": [ "def gini(x):\n", " total = 0\n", " for i, xi in enumerate(x[:-1], 1):\n", " total += np.sum(np.abs(xi - x[i:]))\n", " return total / (len(x)**2 * np.mean(x))" ] }, { "cell_type": "code", "execution_count": 94, "id": "4d7f8331-5162-491b-bf21-af611adfe80c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.322536007721004" ] }, "execution_count": 94, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gini(y_pred[:,1]) # Lorenz Zonoid of full model" ] }, { "cell_type": "code", "execution_count": 95, "id": "7dcdfb93-30e4-4a31-bebd-b771f81558b6", "metadata": {}, "outputs": [], "source": [ "lorenz_zonoids = {}\n", "for i in train_X.columns:\n", " logreg.fit(train_X[[i]], train_y)\n", " lorenz_zonoids[i] = (round(gini(logreg.predict_proba(test_X[[i]])[:,1]),4))" ] }, { "cell_type": "code", "execution_count": 96, "id": "4556e8c1-5ad1-4924-872e-2a77079e0f76", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'HighBP': 0.1898,\n", " 'HighChol': 0.1438,\n", " 'CholCheck': 0.0187,\n", " 'BMI': 0.1653,\n", " 'Smoker': 0.0433,\n", " 'Stroke': 0.0319,\n", " 'HeartDiseaseorAttack': 0.076,\n", " 'PhysActivity': 0.0737,\n", " 'Fruits': 0.0268,\n", " 'Veggies': 0.0342,\n", " 'HvyAlcoholConsump': 0.0184,\n", " 'AnyHealthcare': 0.0054,\n", " 'NoDocbcCost': 0.0112,\n", " 'GenHlth': 0.2269,\n", " 'MentHlth': 0.0333,\n", " 'PhysHlth': 0.0977,\n", " 'DiffWalk': 0.1186,\n", " 'Sex': 0.0223,\n", " 'Age': 0.157,\n", " 'Education': 0.0931,\n", " 'Income': 0.1252}" ] }, "execution_count": 96, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lorenz_zonoids" ] }, { "cell_type": "code", "execution_count": 97, "id": "bec41479-8067-40c3-8513-5e6fc0e02fe8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'GenHlth': 0.2269,\n", " 'HighBP': 0.1898,\n", " 'BMI': 0.1653,\n", " 'Age': 0.157,\n", " 'HighChol': 0.1438,\n", " 'Income': 0.1252,\n", " 'DiffWalk': 0.1186,\n", " 'PhysHlth': 0.0977,\n", " 'Education': 0.0931,\n", " 'HeartDiseaseorAttack': 0.076,\n", " 'PhysActivity': 0.0737,\n", " 'Smoker': 0.0433,\n", " 'Veggies': 0.0342,\n", " 'MentHlth': 0.0333,\n", " 'Stroke': 0.0319,\n", " 'Fruits': 0.0268,\n", " 'Sex': 0.0223,\n", " 'CholCheck': 0.0187,\n", " 'HvyAlcoholConsump': 0.0184,\n", " 'NoDocbcCost': 0.0112,\n", " 'AnyHealthcare': 0.0054}" ] }, "execution_count": 97, "metadata": {}, "output_type": "execute_result" } ], "source": [ "{k: v for k, v in sorted(lorenz_zonoids.items(), key=lambda item: item[1], reverse = True)}" ] }, { "cell_type": "markdown", "id": "72d0b782-648b-44f4-a338-fb5c4179138b", "metadata": {}, "source": [ "You can see the ordered set is almost the same we got with shapley values; this shows the methodology is robust" ] }, { "cell_type": "code", "execution_count": 98, "id": "682616bc-340d-4280-a7a0-3a7fa478e5b3", "metadata": {}, "outputs": [], "source": [ "# Model Selection" ] }, { "cell_type": "code", "execution_count": 99, "id": "d371b16c-4a62-4a63-bf42-47bcc8a9050a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.2774" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model with first two higher marginal lorenz contribution\n", "\n", "logreg.fit(train_X[['GenHlth','HighBP']], train_y)\n", "round(gini(logreg.predict_proba(test_X[['GenHlth','HighBP']])[:,1]),4)" ] }, { "cell_type": "code", "execution_count": 100, "id": "0b8e1bf2-9d68-4f22-8481-d778aee5e0c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7777009173684001" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "roc_auc_score(test_y, logreg.predict_proba(test_X[['GenHlth','HighBP']])[:,1])" ] }, { "cell_type": "code", "execution_count": 101, "id": "fe05a5bd-94d7-4afc-b062-aebf1caabd0c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.2938" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model with first three higher marginal lorenz contribution\n", "\n", "logreg.fit(train_X[['GenHlth','HighBP','BMI']], train_y)\n", "round(gini(logreg.predict_proba(test_X[['GenHlth','HighBP','BMI']])[:,1]),4)" ] }, { "cell_type": "code", "execution_count": 102, "id": "23335f4d-4e4d-489b-9c25-723273ea04fe", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7977581520787433" ] }, "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], "source": [ "roc_auc_score(test_y, logreg.predict_proba(test_X[['GenHlth','HighBP','BMI']])[:,1])" ] }, { "cell_type": "code", "execution_count": 103, "id": "4ccb3754-6920-4881-94f3-70950ab62eca", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.3099" ] }, "execution_count": 103, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model with first four higher marginal lorenz contribution\n", "\n", "logreg.fit(train_X[['GenHlth','HighBP','BMI', 'Age']], train_y)\n", "round(gini(logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age']])[:,1]),4)" ] }, { "cell_type": "code", "execution_count": 104, "id": "34c6c9d3-872d-4295-8490-76c34ca31d74", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8146446931061582" ] }, "execution_count": 104, "metadata": {}, "output_type": "execute_result" } ], "source": [ "roc_auc_score(test_y, logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age']])[:,1])" ] }, { "cell_type": "code", "execution_count": 105, "id": "beda9e17-b4b4-4dec-b2e9-d5b801e93a1a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.3155" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model with first five higher marginal lorenz contribution\n", "\n", "logreg.fit(train_X[['GenHlth','HighBP','BMI', 'Age','HighChol']], train_y)\n", "round(gini(logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age','HighChol']])[:,1]),4)" ] }, { "cell_type": "code", "execution_count": 106, "id": "5a17b793-e586-422e-bc0e-d6c4511b5332", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8206829060159185" ] }, "execution_count": 106, "metadata": {}, "output_type": "execute_result" } ], "source": [ "roc_auc_score(test_y, logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age','HighChol']])[:,1])" ] }, { "cell_type": "code", "execution_count": 107, "id": "17d59295-62be-45d7-8dba-ace43f99b051", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.3164" ] }, "execution_count": 107, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model with first six higher marginal lorenz contribution\n", "\n", "logreg.fit(train_X[['GenHlth','HighBP','BMI', 'Age','HighChol','Income']], train_y)\n", "round(gini(logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age','HighChol', 'Income']])[:,1]),4)\n" ] }, { "cell_type": "code", "execution_count": 108, "id": "00f304f5-4b80-4438-9a1a-352a4a8b5da6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8218839763824517" ] }, "execution_count": 108, "metadata": {}, "output_type": "execute_result" } ], "source": [ "roc_auc_score(test_y, logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age','HighChol', 'Income']])[:,1])" ] }, { "cell_type": "code", "execution_count": 109, "id": "f578d5fc-004b-4e7f-913c-50ed2f779b44", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.3165" ] }, "execution_count": 109, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Model with first seven higher marginal lorenz contribution\n", "\n", "logreg.fit(train_X[['GenHlth','HighBP','BMI', 'Age','HighChol','Income','DiffWalk']], train_y)\n", "round(gini(logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age','HighChol', 'Income','DiffWalk']])[:,1]),4)\n" ] }, { "cell_type": "code", "execution_count": 110, "id": "8e404a17-12d6-49d1-81eb-eee27a32992f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8220086963778568" ] }, "execution_count": 110, "metadata": {}, "output_type": "execute_result" } ], "source": [ "roc_auc_score(test_y, logreg.predict_proba(test_X[['GenHlth','HighBP','BMI','Age','HighChol', 'Income','DiffWalk']])[:,1])" ] }, { "cell_type": "code", "execution_count": 111, "id": "3886d2f8-5cff-4532-8447-36b9c6753b45", "metadata": {}, "outputs": [], "source": [ "# pip install shapley_lz" ] }, { "cell_type": "code", "execution_count": null, "id": "58255824-54e6-4044-b583-bde3bc5b1e23", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/alex/anaconda3/lib/python3.9/site-packages/shapley_lz/explainer/shapley_lz.py:37: UserWarning: A background dataset larger than 50, may cause prohibitive long runtime. Consider using a sample of maximally 100 observations.\n", " warnings.warn('A background dataset larger than 50, may cause prohibitive long runtime. Consider using a sample of maximally 100 observations.')\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "Enter 's' to sample 50 observations and enter 'c', to continue with current background dataset s\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 50%|██████████████████████▌ | 2/4 [01:30<01:30, 45.06s/it]" ] } ], "source": [ "import numpy as np\n", "from sklearn.ensemble import RandomForestClassifier as rf_class\n", "from sklearn.datasets import make_classification as gen_data\n", "from shapley_lz.explainer.shapley_lz import ShapleyLorenzShare as slShare\n", "\n", "# Simple example w/o train-test splitting thus same covariance matrix used and only first 100 observations explained\n", "# Generate data\n", "N = 1000 # number of observations\n", "p = 4 # number of features\n", "X, y = gen_data(n_samples = N, n_features = 4, n_informative = 4, n_redundant = 0)\n", "\n", "# Train model\n", "model = rf_class()\n", "model.fit(X,y)\n", "\n", "# Compute Shapley Lorenz Zonoid shares\n", "slz = slShare(model.predict_proba, X, y)\n", "slz_values = slz.shapleyLorenz_val(X, y, class_prob = True, pred_out = 'predict_proba')\n", "\n", "# Plot\n", "# (Bar chart automatically plots in increasing order of SLZ value)\n", "slz.slz_plots(slz_values[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "3f41fe8d-04fd-40b5-9104-c13de08e5270", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "527300e1-0c5b-41a7-852b-fb9e62b579ce", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ece68421-93a7-45ea-8ef5-27d00f543d2a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.15" } }, "nbformat": 4, "nbformat_minor": 5 }