Sparse autoencoders are a type of deep learning model used in explainability research that attempt to recapitulate the activations of intermediate layers of transformer models. Unlike other autoencoders (such as VAEs), these usually have a larger embedding dimension than the input and output representations. This is counteracted with a training constraint that enforces sparsity. One famous example of this being used is for the “Golden Gate” version of Claude 3.

Details

In the context of protein language models, the features extracted from intermediate layers do not outperform the embeddings from which they are derived on property prediction.

Figures

Ref https://adamkarvonen.github.io/

See also

Code

import torch
import torch.nn as nn
# D = d_model, F = dictionary_size
# e.g. if d_model = 12288 and dictionary_size = 49152
# then model_activations_D.shape = (12288,) and encoder_DF.weight.shape = (12288, 49152)
class SparseAutoEncoder(nn.Module):
 """
 A one-layer autoencoder.
 """
 
 def __init__(
 self,
 activation_dim: int,
 dict_size: int):
 super().__init__()
 self.activation_dim = activation_dim
 self.dict_size = dict_size
 self.encoder_DF = nn.Linear(activation_dim, dict_size, bias=True)
 self.decoder_FD = nn.Linear(dict_size, activation_dim, bias=True)
 
 def encode(
 self,
 model_activations_D: torch.Tensor) -> torch.Tensor:
 return nn.ReLU()(self.encoder_DF(model_activations_D))
 
 def decode(
 self,
 encoded_representation_F: torch.Tensor) -> torch.Tensor:
 return self.decoder_FD(encoded_representation_F)
 
 def forward_pass(
 self,
 model_activations_D: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
 encoded_representation_F = self.encode(model_activations_D)
 reconstructed_model_activations_D = self.decode(encoded_representation_F)
 return reconstructed_model_activations_D, encoded_representation_F
 
# B = batch size, D = d_model, F = dictionary_size
def calculate_loss(
 autoencoder: SparseAutoEncoder,
 model_activations_BD: torch.Tensor,
 l1_coeffient: float) -> torch.Tensor:
 reconstructed_model_activations_BD, encoded_representation_BF = autoencoder.forward_pass(model_activations_BD)
 reconstruction_error_BD = (reconstructed_model_activations_BD - model_activations_BD).pow(2)
 reconstruction_error_B = einops.reduce(reconstruction_error_BD, 'B D -> B', 'sum')
 l2_loss = reconstruction_error_B.mean()
 l1_loss = l1_coefficient * encoded_representation_BF.sum()
 loss = l2_loss + l1_loss
 return loss