IT3: Idempotent Test-Time Training

1Computer Vision Laboratory, EPFL 2NVIDIA
3NeuraVision Lab, Bilkent University 4UC Berkeley
arXiv 2024
EPFL logo NVIDIA logo Bilkent logo UC Berkeley logo
Teaser Image

Idempotent Test-Time Training (IT3) approach. During training (left), the model \( f_{\theta} \) is trained to predict the label \( y \) with or without \( y \) given to it as input. At test time (right), when given a corrupted input, the model is sequentially applied. It then briefly trains with the objective of making \( f_{\theta}(\mathbf{x}, \cdot) \) to be idempotent using only the current test input.

Abstract

This paper introduces Idempotent Test-Time Training (IT3), a novel approach to addressing the challenge of distribution shift. While supervised-learning methods assume matching train and test distributions, this is rarely the case for machine learning systems deployed in the real world. Test-Time Training (TTT) approaches address this by adapting models during inference, but they are limited by a domain-specific auxiliary task. \itt{} is based on the universal property of idempotence. An idempotent operator is one that can be applied sequentially without changing the result beyond the initial application, that is \( f(f(\mathbf{x})) = f(\mathbf{x}) \). At training, the model receives an input \( \mathbf{x} \) along with another signal that can either be the ground truth label \( \mathbf{y} \) or a neutral "don't know" signal \( \mathbf{0} \). At test time, the additional signal can only be \( \mathbf{0} \). When sequentially applying the model, first predicting \( \mathbf{y}_0 = f(\mathbf{x}, \mathbf{0}) \) and then \( \mathbf{y}_1 = f(\mathbf{x}, \mathbf{y}_0) \), the distance between \( \mathbf{y}_0 \) and \( \mathbf{y}_1 \) measures certainty and indicates out-of-distribution input \( \mathbf{x} \) if high. We use this distance, that can be expressed as \( ||f(\mathbf{x}, f(\mathbf{x}, \mathbf{0})) - f(\mathbf{x}, \mathbf{0})|| \) as our TTT loss during inference. By carefully optimizing this objective, we effectively train \( f(\mathbf{x}, \cdot) \) to be idempotent, projecting the internal representation of the input onto the training distribution. We demonstrate the versatility of our approach across various tasks, including corrupted image classification, aerodynamic predictions, tabular data with missing information, age prediction from face, and large-scale aerial photo segmentation. Moreover, these tasks span different architectures such as MLPs, CNNs, and GNNs.

TL;DR

The paper introduces Idempotent Test-Time Training (IT3), a method that adapts models to distribution shifts and corruptions by enforcing prediction stability over repeated applications (idempotence property). Unlike other approaches, IT3 does not rely on domain-specific tasks, making it versatile across data types. It performs well on various OOD scenarios, showing promise for general-purpose test-time adaptation.

Idempotence-based Training

Our method employs the ZigZag inference paradigm, originally proven to be effective for out-of-distribution (OOD) detection and uncertainty estimation tasks. Integrating ZigZag into standard models is straightforward, requiring only minimal modifications to the first layer to accept an additional input. This simplicity allows the model to efficiently make two types of predictions—first without, and then with, its own previous outputs as inputs.

Interpolate start reference image.

Original Model

Interpolate start reference image.

Modified Model

Improved Performance

Idempotent Test-Time Training (IT³) enables the model to improve predictions on corrupted or unfamiliar data by optimizing itself during inference. In the example below, the model refines its output closer to the Ground Truth after applying IT³, compared to the Not Optimized version. Our approach results in more accurate and robust predictions in real-world scenarios where data distribution may shift unexpectedly.

Interpolate start reference image.

Input Image

Interpolate start reference image.

Not Optimized

Interpolate start reference image.

Optimized

Interpolate start reference image.

Ground Truth

Improved Generalization

Idempotent Test-Time Training (IT³) enhances the model’s ability to generalize to out-of-distribution (OOD) data. By applying optimization during inference, IT³ adjusts predictions for data that differs significantly from the training set, resulting in lower error rates across different OOD levels.

Interpolate start reference image.

Age results on OOD images.

Interpolate start reference image.

Airfoil results on OOD shapes.

Interpolate start reference image.

Car results on OOD shapes.

BibTeX

@article{durasov20243,
  title = {IT $\^{} 3$: Idempotent Test-Time Training},
  author = {Durasov, Nikita and Shocher, Assaf and Oner, Doruk and Chechik, Gal and Efros, Alexei A and Fua, Pascal}, 
  journal = {arXiv preprint arXiv:2410.04201},
  year = {2024}
}