N-Dimensional Rotary Positional Encodings (RoPE) in PyTorch
Rotary Positional Encodings (RoPE) are the modern method to encode positional information into Transformer inputs. While RoPE is most well-known for its use in 1D text and timeseries data, recent research has extended it to multidimensional data like 2D images and 3D volumes. This repository contains an implementation of N-dimensional RoPE that builds on and extends prior work with enhanced performance and new features.
Heo et al. of NAVER AI Lab proposed an extension of RoPE to 2D for Vision Transformers.
Called RoPE-Mixed, the formulation extends traditional 1D RoPE by extending the RoPE rotation matrix
where
where now
This repository is the result of research into further generalizing RoPE-Mixed to higher dimensions, with a particular focus on small object detection using DETRs on very large and sparse images.
Compared to the official implementation of RoPE-Mixed, our implementation offers:
- Improved performance (see benchmarks below)
- Generalization to ND spaces with N > 2
- Support for arbitrary, non-grid positions (for representing, e.g., arbitrary object positions)
- Gradient support for position tensors
- Implementation documentation
- Comprehensive unit tests and property-based tests using Hypothesis
- Encoder-decoder attention (a.k.a. cross-attention) support
- Experimental "grouped dimensions" construction for application to network modules beyond ViT backbones, such as detection transformers (DETRs)
- Custom gradient checkpointing logic for memory-constrained contexts (e.g., in DETR-like encoder-decoder setups where the number of key/value embeddings corresponding to pixel features may be much larger than the number of object queries)
The benchmarks below show performance of a single Multi-head Attention (MHA) layer with no Rotary Positional Encodings (i.e., vanilla MHA), an MHA layer with the reference RoPE-Mixed implementation applied to the embeddings immediately before the query-key product, and an MHA layer with our nd-RoPE implementation in the same place. The test data are random square images of size NxN, with batch size 4, embedding dimension 256, and 8 heads. The benchmarks were run in float32 precision on a single A100 GPU. The benchmark may be reproduced by running the benchmark notebook.
The image sizes along the x axes of the benchmark results denote the actual token counts of the processed embeddings without the typical downsampling/patchification used in most Vision Transformers.
On the forward pass, our implementation achieves significantly better memory scaling than the reference implementation, bringing the memory scaling from an apparent high-degree polynomial scaling increase to a constant multiplicative increase above the no-RoPE layer in training mode. In inference mode, additional optimizations such as in-place rotation of the embeddings allow us to bring the marginal memory cost to virtually no marginal increase. Importantly, our implementation allows token counts beyond a few thousand (e.g., 64x64) to be processed by a RoPE-enabled ViT layer or similar.
Forward-pass runtime is also significantly improved, with our implementation adding a small increase in walltime that becomes relatively insignificant past ~48x48.
Similar to the forward pass, our implementation brings marginal memory consumption from apparent high-degree polynomial scaling to a constant multiplicative increase above the No-RoPE case.
The trends in walltime are also similar, with the reference implementation adding a marginally large walltime increase and ours adding a small factor that washes out at moderate to high resolutions.
nd-rotary-encodings
has no requirements beyond base PyTorch.
To install, simply clone the repository and use pip:
git clone https://github.com/mawright/nd-rotary-encodings
cd nd-rotary-encodings
pip install -e . # editable installation
To run the test suite, you'll need to install the optional dependencies (pytest and Hypothesis):
pip install -e ".[tests]"
The high-level user-facing interface is the nn.Module
RoPEEncodingND
.
This layer takes the query embedding tensor, the query position tensor, and optionally the key and key position tensor, and returns RoPE-encoded versions of the query and key tensors.
A few usage examples:
- Basic 3D RoPE encoding of queries for self-attention:
import torch
from nd_rotary_encodings import RoPEEncodingND
# Architecture parameters
position_dim = 3
embed_dim = 128
n_heads = 4
# Query tensor parameters
batch_size = 4
seq_length = 16
embed_dim = 128
# Create a RoPE layer for 3D positions with embedding dimension of 128 and 4 heads
rope = RoPEEncodingND(position_dim, embed_dim, n_heads)
# Create query tensor and corresponding positions
query = torch.randn(batch_size, seq_length, embed_dim)
query_pos = torch.randn(batch_size, seq_length, position_dim) # float positions supported
rotated_query = rope(query, query_pos)
assert not torch.allclose(query, rotated_query)
- The same layer can be used for encoder-decoder attention with both a query and key tensor:
key_seq_length = 32
key = torch.randn(batch_size, key_seq_length, embed_dim)
key_pos = torch.randn(batch_size, key_seq_length, position_dim)
rotated_query_2, rotated_key = rope(query, query_pos, key, key_pos)
assert torch.equal(rotated_query, rotated_query_2)
assert not torch.allclose(key, rotated_key)
For more information on usage, see the documentation page for RoPEEncodingND
.
- pytorch-sparse-utils: Low-level utilities for dealing with large, sparse tensors.
- sparse-transformer-layers: Implementations of Transformer layers built on this repository's RoPE encoding layer tailored to sparse tensors, including variants like Multi-scale Deformable Attention.
- Integration of LieRE and/or other more-advanced schemes for RoPE rotation matrices
- Additional benchmarks
- Expanded usage examples for the more advanced and experimental features