Implement interpretable neural models in PyTorch! | by Pietro Barbiero | May, 2023

To showcase the power of PyTorch Explain, let’s dive into our first tutorial!

A primer on concept bottleneck models

In this introductory session, we’ll dive into concept bottleneck models. These models, introduced in a paper [1] presented at the International Conference on Machine Learning in 2020, are designed to first learn and predict a set of concepts, such as “colour” or “shape,” and then utilize these concepts to solve a downstream classification task:

Concept Bottleneck Models learn tasks (Y) as a function of concepts (C). Image by the authors.

By following this approach, we can trace predictions back to concepts providing explanations like “The input object is an {apple} because it is {spherical} and {red}.”

Concept bottleneck models first learn a set of concepts, such as “colour” or “shape,” and then utilize these concepts to solve a downstream classification task.

Hands-on concept bottlenecks

To illustrate concept bottleneck models, we will revisit the well-known XOR problem, but with a twist. Our input will consist of two continuous features. To capture the essence of these features, we will employ a concept encoder that maps them into two meaningful concepts, denoted as “A” and “B”. The objective of our task is to predict the exclusive OR (XOR) of “A” and “B”. By working through this example, you’ll gain a better understanding of how concept bottlenecks can be applied in practice and witness their effectiveness in tackling a concrete problem.

We can start by importing the necessary libraries and loading this simple dataset:

import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.xor(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

Next, we instantiate a concept encoder to map the input features to the concept space and a task predictor to map concepts to task predictions:

concept_encoder = torch.nn.Sequential(
torch.nn.Linear(x.shape[1], 10),
torch.nn.Linear(10, 8),
torch.nn.Linear(8, c.shape[1]),
task_predictor = torch.nn.Sequential(
torch.nn.Linear(c.shape[1], 8),
torch.nn.Linear(8, 1),
model = torch.nn.Sequential(concept_encoder, task_predictor)

We then train the network by optimizing the cross-entropy loss on both concepts and tasks:

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form_c = torch.nn.BCELoss()
loss_form_y = torch.nn.BCEWithLogitsLoss()
for epoch in range(2001):

# generate concept and task predictions
c_pred = concept_encoder(x_train)
y_pred = task_predictor(c_pred)

# update loss
concept_loss = loss_form_c(c_pred, c_train)
task_loss = loss_form_y(y_pred, y_train)
loss = concept_loss + 0.2*task_loss


After training the model, we evaluate its performance on the test set:

c_pred = concept_encoder(x_test)
y_pred = task_predictor(c_pred)

concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
task_accuracy = accuracy_score(y_test, y_pred > 0)

Now, after just a few epochs, we can observe that both the concept and the task accuracy are quite good on the test set (~98% accuracy)!

Thanks to this architecture we can provide explanations for a model prediction by looking at the response of the task predictor in terms of the input concepts, as follows:

c_different = torch.FloatTensor([0, 1])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

c_equal = torch.FloatTensor([1, 1])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

which yields e.g., f([0,1])=1 and f([1,1])=0 , as expected. This allows us to understand a bit more about the behaviour of the model and check that it behaves as expected for any relevant set of concepts e.g., for mutually exclusive input concepts [0,1]or [1,0] it returns a prediction of y=1.

Concept bottleneck models provide intuitive explanations by tracing predictions back to concepts.

Drowning in the accuracy-explainability trade-off

One of the key advantages of concept bottleneck models is their ability to provide explanations for their predictions by revealing concept-prediction patterns allowing humans to assess whether the model’s reasoning aligns with their expectations.

However, the main issue with standard concept bottleneck models is that they struggle in solving complex problems! More generally, they suffer from a well-known issue well-known in explainable AI, referred to as the accuracy-explainability trade-off. Practically, we desire models that not only achieve high task performance but also offer high-quality explanations. Unfortunately, in many cases, as we strive for higher accuracy, the explanations provided by the models tend to deteriorate in quality and faithfulness, and vice versa.

Visually, this trade-off can be represented as follows:

Visual representation of the accuracy-explainability trade-off. The picture
shows the difference between interpretable and “black-box” (non-interpretable) models
in terms of two axes: task performance and explanation quality. Image by the authors.

where interpretable models excel at providing high-quality explanations but struggle with solving challenging tasks, while black-box models achieve high task accuracy at the expense of providing brittle and poor explanations.

To illustrate this trade-off in a concrete setting, let’s consider a concept bottleneck model applied to a slightly more demanding benchmark, the “trigonometry” dataset:

x, c, y = datasets.trigonometry(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

Upon training the same network architecture on this dataset, we observe significantly diminished task accuracy, reaching only around 80%.

Concept bottleneck models fail to strike a balance between task accuracy and explanation quality.

This begs the question: are we perpetually forced to choose between accuracy and the quality of explanations, or is there a way to strike a better balance?

Source link

Leave a Comment