Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks

ICML (2019), cited over 300 times

1. Introduction

Many ML tasks are defined on sets of instances.

→ These tasks do not depend on the order of elements

→ We need permutation invariant models for these tasks

2. Major contribution

Set Transformer

3. Background

Set-input problems?

Definition: A set of instances is given as an input and the corresponding target is a label for the entire set.

e.g.) 3D shape recognition, few shot image classification

Requirements for a model for set-input problems

c.f.) Ordinary MLP or RNNs violate these requirements.

Recent works

[Edwards & Storkey (2017)], [Zaheer et al. (2017)] proposed set pooling methods.


  1. Each element in a set is independently fed into a feed-forward neural network.
  2. The resulting embeddings are then aggregated using a pooling operation(mean, max, sum...)

→ this framework is proven to be a universal approximator for any set function.

→ However, it fails to learn complex mappings & interactions between elements in a set

e.g.) amortized clustering problem

Amortized clustering: Reusing the previous inference (clustering) results to accelerate the inference (clustering) of new dataset.

Pooling architecture for sets

Universal representation of permutation invariant functions

$$ \text{net}(\{x_1,...,x_n\}) = \rho (\text{pool}(\{\phi(x_1),...,\phi(x_n)\})) $$

$\phi: \text{encoder; } \rho: \text{decoder}$

The model remains permutation-invariant even if the "encoder" $\phi$ is a stack of permutation-equivariant(order-dependent) layers

e.g.) permutation-equivariant layer - order matters!

$$ f_i(x;\{x_1, ..., x_n\}) = \sigma_i(\lambda x+ \gamma \text{pool}(\{x_1,...,x_n\})) $$

$\lambda, \gamma : \text{learnable scalar variables; } \sigma(\cdot): \text{non-linear activation function}$