Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks
ICML (2019), cited over 300 times
Many ML tasks are defined on sets of instances.
→ These tasks do not depend on the order of elements
→ We need permutation invariant models for these tasks
Set Transformer
Used self-attention to encode every element in a set.
→ Able to encode pairwise- or higher-order interactions.
Authors introduced an efficient attention scheme inspired by inducing point methods from sparse Gaussian process literature.
→ Reduced the $\mathcal{O}(n^2)$ computation to $\mathcal{O}(nm)$
Used self-attention to aggregate features
→ Beneficial when the problem requires multiple outputs that depend on each other
e.g.) meta-clustering
Definition: A set of instances is given as an input and the corresponding target is a label for the entire set.
e.g.) 3D shape recognition, few shot image classification
Requirements for a model for set-input problems
c.f.) Ordinary MLP or RNNs violate these requirements.
Recent works
[Edwards & Storkey (2017)], [Zaheer et al. (2017)] proposed set pooling methods.
Framework
→ this framework is proven to be a universal approximator for any set function.
→ However, it fails to learn complex mappings & interactions between elements in a set
e.g.) amortized clustering problem
Amortized clustering: Reusing the previous inference (clustering) results to accelerate the inference (clustering) of new dataset.
Universal representation of permutation invariant functions
$$ \text{net}(\{x_1,...,x_n\}) = \rho (\text{pool}(\{\phi(x_1),...,\phi(x_n)\})) $$
$\phi: \text{encoder; } \rho: \text{decoder}$
The model remains permutation-invariant even if the "encoder" $\phi$ is a stack of permutation-equivariant(order-dependent) layers
e.g.) permutation-equivariant layer - order matters!
$$ f_i(x;\{x_1, ..., x_n\}) = \sigma_i(\lambda x+ \gamma \text{pool}(\{x_1,...,x_n\})) $$
$\lambda, \gamma : \text{learnable scalar variables; } \sigma(\cdot): \text{non-linear activation function}$