We probe feature-binding information in frozen vision backbones. The training pipeline builds a binding dataset, caches frozen model activations, trains one or more probes, and reports probe loss, accuracy, binding information, and binding ratio.
Create or activate a Python environment, then install the Python dependencies:
pip install -r requirements.txtPyTorch installation may need to be adjusted for your CUDA version. If pip install torch does not select the right build, install PyTorch from the official PyTorch instructions first, then install the rest of the requirements.
src/
main.py # entry point for all runs
trainer.py # training, binding evaluation
activation.py # obtain model activations
probes.py # probe architectures
cfgs/
config.yaml # main experiment config
dataset/ # dataset configs
dataset/
ColorShape.py
BalancedColorShape.py
CLEVR.py
OcclusionClevr.py
VGColor.py
VGTopAttr.py
script/
occlusionclevr_blender_generation_script.py
balancedcolorshape dataset is used in Section 3.1, 3.2 of the ppaer. All other datasets including colorshape, occlusionclevr, clevr, vgcolor, and vgtopattr are used in Section 3.3. The paper including the appendix contains more detailed setups. Each has their specific config file under src/cfgs/dataset. Datasets obtained from external sources/generators are detailed below:
occlusionclevr generates or loads CLEVR-style images using Blender and a local clone of facebookresearch/clevr-dataset-gen (see this repo for instructions on installing Blender):
git clone https://github.com/facebookresearch/clevr-dataset-gen.git
export CLEVR_GEN_DIR=/path/to/clevr-dataset-gen
export BLENDER_PATH=/path/to/blenderThe default location for generated images is:
data/generated/occlusionclevr
The clevr loader expects a derived 6-object CLEVR layout:
${CLEVR_DATA_DIR}/
images_6obj/
CLEVR_6obj_000000.png
scenes_6obj/
CLEVR_6obj_000000.json
Set:
export CLEVR_DATA_DIR=/path/to/derived/clevrThe vgcolor and vgtopattr loaders expect Visual Genome images plus project-specific mined metadata:
${VG_DATA_DIR}/
VG_100K/
VG_100K_2/
deprecated/
meta_filtered_color_coco/
attributes_filtered_color.json
vg_stats_filtered_color.json
meta_filtered_topattr_coco/
attributes_filtered_topattr.json
vg_stats_filtered_topattr.json
Set:
export VG_DATA_DIR=/path/to/visual_genome
export VG_IMG_ROOT=/path/to/visual_genomeBackbones, activations, and probes are configured in src/cfgs/config.yaml.
Supported activation modes:
| Mode | Meaning |
|---|---|
cls |
The CLS token. |
mean_spatial |
The mean-pooled representation of all spatial tokens. |
all |
All spatial token representations. |
cls_mean |
The concatenation of the CLS token and the mean-pooled spatial tokens. |
Supported probe types:
| Probe type | Brief explanation | Required setup (config.yaml) |
|---|---|---|
linear |
Linear readout. | use_conditioned_query |
dnn_concat / dnn |
MLP. | dnn_hidden_dim; dnn_num_layers; dnn_dropout; use_conditioned_query |
quadratic_concat |
Low-rank quadratic readout. | rank; use_conditioned_query |
quadratic_concat_reuse |
Quadratic probe with shared parameters for same feature type. Unavailable as feature probes. | rank; use_conditioned_query |
multilinear_concat_reuse |
Generalized reused probe for multiple feature types. Unavailable as feature probes. | rank; feature_dims; use_conditioned_query |
attention_quadratic |
Attention over tokens, followed by a quadratic readout. | rank; use_conditioned_query |
Attention probes require all spatial tokens, while all other probes can take different activation modes:
model:
output_mode: all # required for attentionAdjust dataset specific settings and experiment settings (dataset type, probe type) in their corresponding .yaml config files (see above).
Run from the repository root:
python src/main.pyOutputs are written under:
data/outputs/${dataset.name}/...
Frozen backbone activations are cached under:
data/cache/activations/
Happy binding!