Skip to content

Commit f077d7c

Browse files
authored
Merge branch 'main' into raft_model_arch
2 parents 9ae9e38 + 9b57de6 commit f077d7c

File tree

4 files changed

+44
-49
lines changed

4 files changed

+44
-49
lines changed

references/video_classification/train.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,13 @@
1212
from torch.utils.data.dataloader import default_collate
1313
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
1414

15-
try:
16-
from apex import amp
17-
except ImportError:
18-
amp = None
19-
20-
2115
try:
2216
from torchvision.prototype import models as PM
2317
except ImportError:
2418
PM = None
2519

2620

27-
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
21+
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
2822
model.train()
2923
metric_logger = utils.MetricLogger(delimiter=" ")
3024
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
@@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
3428
for video, target in metric_logger.log_every(data_loader, print_freq, header):
3529
start_time = time.time()
3630
video, target = video.to(device), target.to(device)
37-
output = model(video)
38-
loss = criterion(output, target)
31+
with torch.cuda.amp.autocast(enabled=scaler is not None):
32+
output = model(video)
33+
loss = criterion(output, target)
3934

4035
optimizer.zero_grad()
41-
if apex:
42-
with amp.scale_loss(loss, optimizer) as scaled_loss:
43-
scaled_loss.backward()
36+
37+
if scaler is not None:
38+
scaler.scale(loss).backward()
39+
scaler.step(optimizer)
40+
scaler.update()
4441
else:
4542
loss.backward()
46-
optimizer.step()
43+
optimizer.step()
4744

4845
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
4946
batch_size = video.shape[0]
@@ -101,11 +98,6 @@ def collate_fn(batch):
10198
def main(args):
10299
if args.weights and PM is None:
103100
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
104-
if args.apex and amp is None:
105-
raise RuntimeError(
106-
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
107-
"to enable mixed-precision training."
108-
)
109101

110102
if args.output_dir:
111103
utils.mkdir(args.output_dir)
@@ -224,9 +216,7 @@ def main(args):
224216

225217
lr = args.lr * args.world_size
226218
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
227-
228-
if args.apex:
229-
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
219+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
230220

231221
# convert scheduler to be per iteration, not per epoch, for warmup that lasts
232222
# between different epochs
@@ -267,6 +257,8 @@ def main(args):
267257
optimizer.load_state_dict(checkpoint["optimizer"])
268258
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
269259
args.start_epoch = checkpoint["epoch"] + 1
260+
if args.amp:
261+
scaler.load_state_dict(checkpoint["scaler"])
270262

271263
if args.test_only:
272264
evaluate(model, criterion, data_loader_test, device=device)
@@ -277,9 +269,7 @@ def main(args):
277269
for epoch in range(args.start_epoch, args.epochs):
278270
if args.distributed:
279271
train_sampler.set_epoch(epoch)
280-
train_one_epoch(
281-
model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex
282-
)
272+
train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
283273
evaluate(model, criterion, data_loader_test, device=device)
284274
if args.output_dir:
285275
checkpoint = {
@@ -289,6 +279,8 @@ def main(args):
289279
"epoch": epoch,
290280
"args": args,
291281
}
282+
if args.amp:
283+
checkpoint["scaler"] = scaler.state_dict()
292284
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
293285
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
294286

@@ -363,24 +355,16 @@ def parse_args():
363355
action="store_true",
364356
)
365357

366-
# Mixed precision training parameters
367-
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
368-
parser.add_argument(
369-
"--apex-opt-level",
370-
default="O1",
371-
type=str,
372-
help="For apex mixed precision training"
373-
"O0 for FP32 training, O1 for mixed precision training."
374-
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
375-
)
376-
377358
# distributed training parameters
378359
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
379360
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
380361

381362
# Prototype models only
382363
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
383364

365+
# Mixed precision training parameters
366+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
367+
384368
args = parser.parse_args()
385369

386370
return args

test/test_onnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def test_roi_align(self):
141141
model = ops.RoIAlign((5, 5), 1, -1)
142142
self.run_model(model, [(x, single_roi)])
143143

144+
@pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
144145
def test_roi_align_aligned(self):
145146
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
146147
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)

torchvision/io/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0) -> No
110110
raise RuntimeError(
111111
"Not compiled with video_reader support, "
112112
+ "to enable video_reader support, please install "
113-
+ "ffmpeg (version 4.2 is currently supported) and"
113+
+ "ffmpeg (version 4.2 is currently supported) and "
114114
+ "build torchvision from source."
115115
)
116116
self._c = torch.classes.torchvision.Video(path, stream, num_threads)

torchvision/prototype/datasets/utils/_internal.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os.path
99
import pathlib
1010
import pickle
11+
import platform
1112
from typing import BinaryIO
1213
from typing import (
1314
Sequence,
@@ -260,6 +261,11 @@ def _make_sharded_datapipe(root: str, dataset_size: int) -> IterDataPipe:
260261
return dp
261262

262263

264+
def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray:
265+
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
266+
return bytearray(file.read(-1 if count == -1 else count * item_size))
267+
268+
263269
def fromfile(
264270
file: BinaryIO,
265271
*,
@@ -293,20 +299,24 @@ def fromfile(
293299
item_size = (torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits // 8
294300
np_dtype = byte_order + char + str(item_size)
295301

296-
# PyTorch does not support tensors with underlying read-only memory. In case
297-
# - the file has a .fileno(),
298-
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
299-
# - the file is seekable
300-
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it to
301-
# a mutable location afterwards.
302302
buffer: Union[memoryview, bytearray]
303-
try:
304-
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
305-
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
306-
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
307-
except (PermissionError, io.UnsupportedOperation):
308-
# A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable
309-
buffer = bytearray(file.read(-1 if count == -1 else count * item_size))
303+
if platform.system() != "Windows":
304+
# PyTorch does not support tensors with underlying read-only memory. In case
305+
# - the file has a .fileno(),
306+
# - the file was opened for updating, i.e. 'r+b' or 'w+b',
307+
# - the file is seekable
308+
# we can avoid copying the data for performance. Otherwise we fall back to simply .read() the data and copy it
309+
# to a mutable location afterwards.
310+
try:
311+
buffer = memoryview(mmap.mmap(file.fileno(), 0))[file.tell() :]
312+
# Reading from the memoryview does not advance the file cursor, so we have to do it manually.
313+
file.seek(*(0, io.SEEK_END) if count == -1 else (count * item_size, io.SEEK_CUR))
314+
except (PermissionError, io.UnsupportedOperation):
315+
buffer = _read_mutable_buffer_fallback(file, count, item_size)
316+
else:
317+
# On Windows just trying to call mmap.mmap() on a file that does not support it, may corrupt the internal state
318+
# so no data can be read afterwards. Thus, we simply ignore the possible speed-up.
319+
buffer = _read_mutable_buffer_fallback(file, count, item_size)
310320

311321
# We cannot use torch.frombuffer() directly, since it only supports the native byte order of the system. Thus, we
312322
# read the data with np.frombuffer() with the correct byte order and convert it to the native one with the

0 commit comments

Comments
 (0)