Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions configs/_base_/datasets/flickr30k_caption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# data settings

data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='CleanCaption', keys='gt_caption'),
dict(
type='PackInputs',
algorithm_keys=['gt_caption'],
meta_keys=['image_id'],
),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='PackInputs', meta_keys=['image_id']),
]

train_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='/mnt/petrelfs/share_data/liuyuan/data/mm/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='/mnt/petrelfs/share_data/liuyuan/data/mm/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)

val_evaluator = dict(
type='COCOCaption',
ann_file='data/coco/annotations/coco_karpathy_val_gt.json',
)

# # If you want standard test, please manually configure the test dataset
test_dataloader = dict(
batch_size=16,
num_workers=5,
dataset=dict(
type='Flickr30kCaption',
data_root='/mnt/petrelfs/share_data/liuyuan/data/mm/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='test',
pipeline=test_pipeline,
),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = val_evaluator
112 changes: 112 additions & 0 deletions configs/_base_/datasets/flickr30k_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# data settings
data_preprocessor = dict(
type='MultiModalDataPreprocessor',
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

rand_increasing_policies = [
dict(type='AutoContrast'),
dict(type='Equalize'),
dict(type='Rotate', magnitude_key='angle', magnitude_range=(0, 30)),
dict(
type='Brightness', magnitude_key='magnitude',
magnitude_range=(0, 0.0)),
dict(type='Sharpness', magnitude_key='magnitude', magnitude_range=(0, 0)),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='horizontal'),
dict(
type='Shear',
magnitude_key='magnitude',
magnitude_range=(0, 0.3),
direction='vertical'),
]

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
crop_ratio_range=(0.5, 1.0),
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies=rand_increasing_policies,
num_policies=2,
magnitude_level=5),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'is_matched'],
meta_keys=['image_id']),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(384, 384),
interpolation='bicubic',
backend='pillow'),
dict(type='CleanCaption', keys='text'),
dict(
type='PackInputs',
algorithm_keys=['text', 'gt_text_id', 'gt_image_id'],
meta_keys=['image_id']),
]

train_dataloader = dict(
batch_size=32,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=64,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='val',
pipeline=test_pipeline,
test_mode=True, # This is required for evaluation
),
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
persistent_workers=True,
)

val_evaluator = dict(type='RetrievalRecall', topk=(1, 5, 10))

# If you want standard test, please manually configure the test dataset
test_dataloader = dict(
batch_size=64,
num_workers=16,
dataset=dict(
type='Flickr30kRetrieval',
data_root='data/flickr30k',
ann_file='annotations/dataset_flickr30k.json',
data_prefix='images',
split='test',
pipeline=test_pipeline,
test_mode=True, # This is required for evaluation
),
sampler=dict(type='SequentialSampler', subsample_type='sequential'),
persistent_workers=True,
)
test_evaluator = val_evaluator
12 changes: 12 additions & 0 deletions configs/blip/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ python tools/test.py configs/blip/blip-base_8xb32_caption.py https://download.op
| :------------------------------- | :--------: | :------: | :------: | :--------------------------------------: | :----------------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_retrieval`\* | 447.49 | 64.82 | 86.28 | [config](./blip-base_8xb32_retrieval.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |

### Image-To-Text Retrieval on Flickr30k

| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_retrieval`\* | 447.49 | 95.10# | 99.60# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |

### Text-To-Image Retrieval on Flickr30k

| Model | Params (M) | Recall@1 | Recall@5 | Config | Download |
| :------------------------------- | :--------: | :------: | :------: | :------------------------------------------------: | :------------------------------------------------------------------------------------------: |
| `blip-base_3rdparty_retrieval`\* | 447.49 | 85.26# | 96.58# | [config](./blip-base_8xb32_retrieval_flickr30k.py) | [model](https://download.openmmlab.com/mmclassification/v1/blip/blip-base_3rdparty_coco-retrieval_20230419-a1804d2c.pth) |

### NLVR on NLVR2

| Model | Params (M) | Top-1 (%) | Config | Download |
Expand Down
59 changes: 59 additions & 0 deletions configs/blip/blip-base_8xb32_caption_flickr30k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
_base_ = [
'../_base_/datasets/flickr30k_caption.py',
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='BlipCaption',
vision_encoder=dict(
type='VisionTransformer',
arch='b',
img_size=384,
patch_size=16,
out_type='raw',
),
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
decoder_head=dict(
type='SeqGenerationHead',
decoder=dict(
type='XBertLMHeadDecoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
),
prompt='a picture of ',
max_txt_len=20,
)

# schedule settings
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))

param_scheduler = [
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=0,
end=10,
)
]

train_cfg = dict(max_epochs=10)
val_cfg = dict()
test_cfg = dict()
83 changes: 83 additions & 0 deletions configs/blip/blip-base_8xb32_retrieval_flickr30k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
_base_ = [
'../_base_/datasets/flickr30k_retrieval.py',
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='BlipRetrieval',
tokenizer=dict(type='BlipTokenizer', name_or_path='bert-base-uncased'),
vision_backbone=dict(
type='VisionTransformer',
arch='b',
img_size=384,
patch_size=16,
out_type='raw',
),
text_backbone=dict(
type='XBertEncoder',
med_config=dict(
architectures=['BertModel'],
attention_probs_dropout_prob=0.1,
hidden_act='gelu',
hidden_dropout_prob=0.1,
hidden_size=768,
initializer_range=0.02,
intermediate_size=3072,
layer_norm_eps=1e-12,
max_position_embeddings=512,
model_type='bert',
num_attention_heads=12,
num_hidden_layers=12,
pad_token_id=0,
add_type_embeddings=False,
vocab_size=30524,
encoder_width=768,
add_cross_attention=True),
),
vision_neck=dict(
type='Linear',
in_features=768,
out_features=256,
),
text_neck=dict(
type='Linear',
in_features=768,
out_features=256,
),
head=dict(
type='ITCHead',
embed_dim=256,
),
multimodal_head=dict(
type='ITMHead',
hidden_size=768,
with_pooler=False,
),
topk=256,
max_txt_len=35,
)

# optimizer
optimizer = dict(type='AdamW', lr=2e-5, weight_decay=0.04)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer)

# learning rate scheduler
param_scheduler = [dict(type='CosineAnnealingLR', by_epoch=True)]

# runtime settings
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=6)
val_cfg = dict(type='RetrievalValLoop')
test_cfg = dict(type='RetrievalTestLoop')

randomness = dict(seed=42)

default_hooks = dict(logger=dict(interval=1))

custom_hooks = [
dict(
type='WarmupParamHook',
param_name='alpha',
module_name='head',
warmup_epochs=2)
]
7 changes: 4 additions & 3 deletions mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from .coco_retrieval import COCORetrieval
from .coco_vqa import COCOVQA
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
from .flickr30k_caption import Flickr30kCaption
from .flickr30k_retrieval import Flickr30kRetrieval
from .gqa_dataset import GQA
from .nocaps import NoCaps
from .refcoco import RefCOCO
Expand All @@ -47,7 +49,6 @@

__all__.extend([
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'RefCOCO', 'VisualGenomeQA', 'ScienceQA',
'NoCaps'
'GQA', 'TextVQA'
'FlamingoEvalCOCOVQA', 'Flickr30kCaption', 'Flickr30kRetrieval',
'RefCOCO', 'VisualGenomeQA', 'ScienceQA', 'NoCaps', 'GQA', 'TextVQA'
])
Loading