Skip to content

Commit 728d629

Browse files
authored
[low-bit optim] Change 8-bit and FP8 optim block size from 2048 to 256 to match new bnb v0.44 (pytorch#927)
1 parent 26e790d commit 728d629

File tree

5 files changed

+12
-8
lines changed

5 files changed

+12
-8
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55
import torch
6+
from packaging.version import Version
67
from torch import nn
78
from torch.testing._internal.common_utils import (
89
TestCase,
@@ -105,8 +106,11 @@ def test_optim_8bit_correctness(self, optim_name):
105106
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
106107
model2 = copy.deepcopy(model1)
107108

109+
# https://github.com/bitsandbytes-foundation/bitsandbytes/releases/tag/v0.44.0
110+
block_size = 256 if Version(bnb.__version__) >= Version("0.44.0") else 2048
111+
108112
optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
109-
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters())
113+
optim2 = getattr(low_bit_optim, optim_name)(model2.parameters(), block_size=block_size)
110114

111115
for _ in range(2):
112116
x = torch.randn(4, 32, device=device)

torchao/prototype/low_bit_optim/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ model = ...
1919
optim = Adam8bit(model.parameters())
2020
```
2121

22-
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.
22+
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 256 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.
2323

2424
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.
2525

torchao/prototype/low_bit_optim/adam.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(
161161
weight_decay=0,
162162
amsgrad=False,
163163
*,
164-
block_size=2048,
164+
block_size=256,
165165
) -> None:
166166
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)
167167

@@ -199,7 +199,7 @@ def __init__(
199199
weight_decay=0,
200200
amsgrad=False,
201201
*,
202-
block_size=2048,
202+
block_size=256,
203203
) -> None:
204204
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=False)
205205

@@ -218,7 +218,7 @@ def __init__(
218218
weight_decay=1e-2,
219219
amsgrad=False,
220220
*,
221-
block_size=2048,
221+
block_size=256,
222222
) -> None:
223223
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)
224224

@@ -256,7 +256,7 @@ def __init__(
256256
weight_decay=1e-2,
257257
amsgrad=False,
258258
*,
259-
block_size=2048,
259+
block_size=256,
260260
) -> None:
261261
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size, is_adamw=True)
262262

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def dequantize(self, output_dtype=None):
5353
return dequant_with_qmap(self.codes, self.qmap, self.scale).to(dtype)
5454

5555
@classmethod
56-
def zeros(cls, shape, signed: bool = True, block_size: int = 2048, device=None):
56+
def zeros(cls, shape, signed: bool = True, block_size: int = 256, device=None):
5757
codes = torch.zeros(shape, dtype=torch.uint8, device=device)
5858
scale = torch.zeros(codes.numel() // block_size, device=device)
5959
qmap = torch.tensor(QMAP_SIGNED if signed else QMAP_UNSIGNED, device=device)

torchao/prototype/low_bit_optim/subclass_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def dequantize(self, output_dtype=None):
6060
return float_data.view(self.codes.shape).to(dtype)
6161

6262
@classmethod
63-
def zeros(cls, shape, block_size: int = 2048, device=None):
63+
def zeros(cls, shape, block_size: int = 256, device=None):
6464
codes = torch.zeros(shape, dtype=DTYPE, device=device)
6565
scale = torch.zeros(codes.numel() // block_size, device=device)
6666
return cls(codes, scale)

0 commit comments

Comments
 (0)