Resumen del libro de Molnar
Estos métodos se usan para sacar explicaciones de un modelo \(M\) ya entrenado independientemente del algoritmo utilizado. Es decir, \(M\) es un modelo (random forest, KNN etc.) ya entrenado con un dataset \(D\) que funciona como una caja negra haciendo predicciones sobre datos, y queremos encontrar métodos genéricos para sacar explicaciones de cómo \(M\) llega a esas predicciones, o bien ser capaces de predecirlas.
Estos métodos pueden ser:
Por lo que he visto, existen 3 grandes paradigmas para poder explicar un modelo:
Entender el efecto marginal de una variable en la predicción de un modelo. Es decir, visualizar cómo cambia la predicción del modelo si cambiamos el valor de únicamente una variable.
Feature Importance
For Dummies
La idea es dibujar una función. El eje \(x\) será todos los posibles valores para un atributo \(S\), e \(y\) el valor promedio predicho por el modelo si a cada valor del dataset le cambiamos el atributo \(S\) por el valor \(x\).
Queremos calcular cómo cambia la predicción (\(\hat{f}_s\)) variando solo la variable \(S\).
\[\hat{f}_S(x_S) = E_{X_C}[\hat{f}(x_S, X_C)] = \int \hat{f}(x_S, x_C) dP(x_C)\]
En la práctica, se calcula ajustando el valor de la variable \(S\) a \(x_S\) y dejando el resto de las variables constantes a lo largo del dataset. Calculamos el valor promedio de la predicción para cada valor de \(x_S\) y pintamos la función.
\[\hat{f}_S(x_S) = \frac{1}{n} \sum_{i=1}^{n} \hat{f}(x_S, x_{C_i})\]
Se puede hacer un PDP con 2 atributos, lo cual es capaz de mostrar la relación entre atributos. Más dimensiones no serían visualizables.
Esto se lleva a cabo de la misma forma. Para cada par de posibles valores para los 2 atributos \(S_1\) y \(S_2\), se modifican los valores de cada uno de los datos en el dataset, y se promedia el resultado de la predicción. Esto se dibuja en un mapa de colores.
Esto permite ver la relación entre los atributos.
https://christophm.github.io/interpretable-ml-book/images/pdp-bike-1.jpeg
https://christophm.github.io/interpretable-ml-book/images/pdp-cervical-2d-1.jpeg
Alternativa rápida de PDP que reduce la necesidad de independencia entre variables y calcula un valor acumulado.
Feature Importance
En PDP cambiábamos el atributo \(S\) de cada dato del dataset por cada \(x_S\). En este caso solo lo haremos para datos cuya \(S\) original esté dentro de un rango de \(x_S\), reduciendo datos imposibles.
Partimos el rango de la variable \(S\) en \(N\) vecindades, y calculamos PDP acumulado para cada una de ellas.
\[\hat{f}_{j,ALE}(x_S) = \sum_{k=1}^{k_j(x)} \frac{1}{n_j(k)} \sum_{i:x_j^{(i)} \in N_j(k)} [ \hat{f}(z_{k,j},x_{-j}^{(i)}) - \hat{f}(z_{k-1,j},x_{-j}^{(i)}) ]\]
https://christophm.github.io/interpretable-ml-book/images/ale-bike-1.jpeg
https://christophm.github.io/interpretable-ml-book/images/ale-bike-2d-1.jpeg
La interpretación para 2 atributos puede ser complicada. Un dato alto de ALE no implica que el valor real sea más alto, si no que cuando los 2 atributos concurren en la misma zona el valor es más alto que si estuvieran separados. Pero no quiere decir que el valor sea más alto per se, porque los atributos por separado pueden hacer tender los datos hacia abajo.
Los atributos pueden intearctuar entre sí haciendo que el efecto de que ambos ocurran a la vez sea diferente al efecto de que ocurran por separado. Queremos estudiar la fuerza de la interacción entre pares de atributos, y de cada atributo con el resto.
Important
No vamos a estudiar las correlaciones o dependencias, si no las aportaciones de las interacciones a la predicción.
Feature Importance
For Dummies
Vamos a medir la fuerza de la interacción entre 2 atributos.
Si dos atributos no interactúan, su PDfunction es:
\[PD_{jk}(x_j, x_k) = PD_j(x_j) + PD_k(x_k)\]
Si un atributo no tiene ninguna interacción con ningún otro atributo, la función de predicción se puede escribir como:
\[\hat{f}(x) = PD_j(x_j) + PD_{-j}(x_{-j})\]
Siendo \(PD_{-j}(x_{-j})\) la función de predicción para el resto de atributos distintos a \(j\).
Medimos la fuerza de la interacción entre 2 atributos \(j\) y \(k\) como:
\[H^2_{jk} = \frac{\sum_{i=1}^n [ PD_{jk}(x_j^(i), x_k^(i)) - PD_j(x_j^(i)) - PD_k(x_k^(i)) ]^2}{\sum_{i=1}^n PD^2_{jk}(x_j^(i), x_k^(i))}\]
Medimos la fuerza de la interacción del atributo \(j\) con el resto de atributos como:
\[H^2_{j} = \frac{\sum_{i=1}^n [ \hat{f}(x^(i)) - PD_j(x_j^(i)) - PD_{-j}(x_{-j}^(i)) ]^2}{\sum_{i=1}^n \hat{f}^2(x^(i))}\]
https://christophm.github.io/interpretable-ml-book/images/interaction2-cervical-age-1.png
https://christophm.github.io/interpretable-ml-book/images/interaction-cervical-1.png
Si entendemos el modelo \(M\) como una función de un espacio vectorial a otro, podemos descomponerla en una suma de funciones. Cada una de estas funciones representará, o bien un valor constante, o bien un valor en función de cada uno de los atributos, o bien un valor en función de un subgrupo de atributos.
Surrogate Model
For Dummies
Intentamos descomponer la función de predicción en una suma de funciones con permutaciones de subgrupos de atributos.
Queremos encontrar cada una de las sub-funciones que componen la función de predicción.
Para un ejemplo con 2 atributos sería:
\[\hat{f}(x_1, x_2) = \hat{f}_0 + \hat{f}_1(x_1) + \hat{f}_2(x_2) + \hat{f}_{1,2}(x_1, x_2)\]
Definiendo la función de predicción como:
\[\hat{f} : \mathbb{R}^p \rightarrow \mathbb{R}\]
Definimos la descomposición en funciones como:
\[\hat{f}(x) = \sum_{S \subseteq \{1, \ldots, p\}} \hat{f}_S(x_S)\]
¿Cómo encontramos las funciones \(\hat{f}_S\)? Cada una se calcularía de la siguiente forma:
\[\hat{f}_S(x_S) = \int_{x_{-S}} \hat{f}(x) - \sum_{V \subset S} \hat{f}_V(x) dx_{-S}\]
Comparar el cambio de precisión (o la métrica de error elegida) de un modelo al permutar (aleatoriamente) los valores de un atributo.
Esto permite medir cuánto afecta un atributo al valor predicho por \(M\). Y, por lo tanto, cuánta importancia tiene dicho atributo internamente en el modelo.
Feature Importance
For Dummies
Cambiar de forma aleatoria el valor de un atributo para ver cómo afecta a la precisión de \(M\).
\[FI_{j,importance} = e_{perm}/e_{orig}\]
\[FI_{j,diff} = e_{perm} - e_{orig}\]
https://christophm.github.io/interpretable-ml-book/images/importance-bike-1.jpeg
Existe la duda de si se debe calcular sobre el dataset de entrenamiento o de test.
Usar el dataset de entrenamiento tiene el problema de que el error original en el dataset de entrenamiento no es fiable, por lo tanto puede no ser tamploco el \(FI\). Sin embargo, tiene la ventaja de que nos muestra más fielmente la importancia de los atributos para el modelo.
Tip
Se aconseja usar el dataset de test.
Entrenar un modelo ML inherentemente interpretable a través de los datos predichos por \(M\)$, para extrapolar la interpretación del nuevo modelo más simple al modelo \(M\).
Surrogate Model
For Dummies
Entrenar una regresión lineal o un árbol de decisión sobre el dataset con la predicción de \(M\).
Se entrena un árbol de decisión a través de las predicciones de un SVM.
https://christophm.github.io/interpretable-ml-book/images/surrogate-bike-1.jpeg
Buscaremos datos que sean representativos del conjunto del dataset. Estos son los prototipos.
Una crítica es un dato que no está bien representado por el conjunto de los prototipos.
Counterfactual Examples
For Dummies
La idea es encontrar un número pequeño de datos que permitan representar a la mayoría del dataset, conociendo también aquellos datos o regiones que no se verían representados.
\[MMD^2 = \frac{1}{m^2} \sum_{i,j=1}^m k(z_i, z_j) \\ - \frac{2}{mn} \sum_{i,j=1}^{m,n} k(z_i, x_j) + \frac{1}{n^2} \sum_{i,j=1}^n k(x_i, x_j)\]
El objetivo es minimizar \(MMD2\).
Esta función mide cómo de bien se ajusta un dato \(x\) a los prototipos \(z\).
\[witness(x) = |\frac{1}{n} \sum_{i=1}^n k(x, x_i) - \frac{1}{m} \sum_{j=1}^m k(x, z_j)|\]
El objetivo es usar como críticas aquellos datos con mayor witness.
Por ejemplo, se puede crear un modelo que siga la siguiente regla. El modelo predecirá el dato como la predicción del prototipo más cercano.
\[\hat{f}(x) = argmax_{i \in S} k(x, z_i)\]
SHAP se considera un método local. Sin embargo, obteniendo los valores SHAP para cada uno de los datos se pueden obtener plots globales muy útiles.
https://christophm.github.io/interpretable-ml-book/images/shap-importance.png
https://christophm.github.io/interpretable-ml-book/images/shap-importance-extended.png
https://christophm.github.io/interpretable-ml-book/images/shap-dependence.png
https://christophm.github.io/interpretable-ml-book/images/shap-dependence-interaction.png
PhD xAI