Recently there’s paper “An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale” on open-review. It uses pretrained transformers at scale for vision tasks. Transformers are highly successful for language tasks, but haven’t seen that much success for vision. In vision, transformers are either applied in conjunction with Convolutional Neural Networks(CNN) or to replace some components of CNN. Recently transformers has shown good results on object detection (End-to-End Object Detection with Transformers). This paper applies transformers to vision task without using CNN and shows that state-of-art results can be obtained without CNN.
The Cost of attention is quadratic. So for images, every pixel needs to attend to every other pixel which is costly. There are different methods used to overcome this like local attention (attend to subset of input) or global attention. This paper uses global attention.
The architecture follows very closely the transformers. This is done to use transformer architecture that has scaled well for NLP tasks and optimised implementation of the architecture can be used out of box from different libraries. The difference came from how images are fed as sequence of patches to transformers.
Patch Embedding
Transformer receives 1D embedding as input. To handle 2D image input., the image is divided into sequence of flattened 2D fix size image patches. So , image of size H*W*C is divided into sequence of patches of size N*(P2*C), where P*P is size of patch.
1. How to automatically deskew (straighten) a text image using OpenCV
2. Explanation of YOLO V4 a one stage detector
3. 5 Best Artificial Intelligence Online Courses for Beginners in 2020
4. A Non Mathematical guide to the mathematics behind Machine Learning
Before passing the patches to transformer , Paper suggest them to put them through linear projection to get patch embedding. The official jax implementation uses conv layer for the same.(can be done by simple linear layer but its costly). Below is snippet of code from my pytorch implementation for the same.
As with BERT’s [class] token, learnable class token is concatenated to patch embedding, which serves as class representation.
To retain positional information of patches, positional embedding are added to patch embedding. Paper have explored 2D-aware variant as well as standard 1D embedding for position , but haven’t seen much advantage of one over the other.
Hybrid Architecture.
Alternative can be to use intermediate feature maps of a ResNet instead of image patches as input to transformers. The 2D feature map from earlier layers of resnet are flattened and projected to transformer dimension and fed to transformer. class token and positional embedding are added as mentioned.
Vision transformer is pretrained on large datasets like Imagenet-1k, Imagenet-21k, JFT-300M. And based on task, it’s fine tuned on the task dataset. The table below shows the results of fine-tuning on vision transformer pretrained on JFT-300M.
You can find my repo for pytorch implementation here. I have used Imagenet-1k pretrained weights from https://github.com/rwightman/pytorch-image-models/ and updated checkpoint for my implementation. The checkpoint can be found here.
You can also find pytorch Kaggle Kernel for fine tuning vision transformer on tpu here.
Credit: BecomingHuman By: nachiket tanksale