Skip to content

Commit 253d65a

Browse files
authored
Revert "Refactor TorchAOBaseTensor for better BC support" (#2854)
Revert "Refactor TorchAOBaseTensor for better BC support (#2793)" This reverts commit a9ffa50.
1 parent 07fbc89 commit 253d65a

File tree

4 files changed

+72
-215
lines changed

4 files changed

+72
-215
lines changed

test/test_utils.py

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -186,103 +186,60 @@ class MyTensor(TorchAOBaseTensor):
186186
tensor_data_names = ["qdata"]
187187
tensor_attribute_names = ["attr", "device"]
188188

189-
def __new__(cls, qdata, attr, device):
189+
def __new__(cls, qdata, attr, device=None):
190190
shape = qdata.shape
191191
if device is None:
192192
device = qdata.device
193193
kwargs = {"device": device}
194194
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
195195

196-
def __init__(self, qdata, attr, device):
196+
def __init__(self, qdata, attr, device=None):
197197
self.qdata = qdata
198198
self.attr = attr
199199

200200
l = torch.nn.Linear(2, 3)
201-
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr", None))
201+
l.weight = torch.nn.Parameter(MyTensor(l.weight, "attr"))
202202
lp_tensor = l.weight
203203

204204
another_tensor = torch.nn.Linear(2, 3).weight
205205
# attribute has to be the same
206-
lp_tensor_for_copy = MyTensor(another_tensor, "attr", None)
206+
lp_tensor_for_copy = MyTensor(another_tensor, "attr")
207207
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
208208

209209
@skip_if_no_cuda()
210210
def test_default_impls_with_optional_data(self):
211211
class MyTensorWithOptionalData(TorchAOBaseTensor):
212212
tensor_data_names = ["qdata"]
213-
tensor_attribute_names = ["attr", "device"]
214213
optional_tensor_data_names = ["zero_point"]
215-
216-
def __new__(cls, qdata, attr, device, zero_point=None):
217-
shape = qdata.shape
218-
if device is None:
219-
device = qdata.device
220-
kwargs = {"device": device}
221-
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
222-
223-
def __init__(self, qdata, attr, device, zero_point=None):
224-
self.qdata = qdata
225-
self.attr = attr
226-
self.zero_point = zero_point
227-
228-
# test both the optional Tensor is None
229-
# and not None
230-
l = torch.nn.Linear(2, 3)
231-
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, None)
232-
l = torch.nn.Linear(2, 3)
233-
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, "attr", None, None)
234-
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
235-
236-
l = torch.nn.Linear(2, 3)
237-
lp_tensor = MyTensorWithOptionalData(
238-
l.weight, "attr", None, torch.zeros_like(l.weight)
239-
)
240-
l = torch.nn.Linear(2, 3)
241-
lp_tensor_for_copy = MyTensorWithOptionalData(
242-
l.weight, "attr", None, torch.zeros_like(l.weight)
243-
)
244-
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
245-
246-
@skip_if_no_cuda()
247-
def test_default_impls_with_optional_attr(self):
248-
class MyTensorWithOptionalData(TorchAOBaseTensor):
249-
tensor_data_names = ["qdata"]
250214
tensor_attribute_names = ["attr", "device"]
251-
optional_tensor_data_names = ["zero_point"]
252-
optional_tensor_attribute_names = ["optional_attr"]
253215

254-
def __new__(cls, qdata, attr, device, zero_point=None, optional_attr=None):
216+
def __new__(cls, qdata, zero_point=None, attr=1.0, device=None):
255217
shape = qdata.shape
256218
if device is None:
257219
device = qdata.device
258220
kwargs = {"device": device}
259221
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
260222

261-
def __init__(
262-
self, qdata, attr, device, zero_point=None, optional_attr=None
263-
):
223+
def __init__(self, qdata, zero_point=None, attr=1.0, device=None):
264224
self.qdata = qdata
265-
self.attr = attr
266225
self.zero_point = zero_point
267-
self.optional_attr = optional_attr
226+
self.attr = attr
268227

269228
# test both the optional Tensor is None
270229
# and not None
271230
l = torch.nn.Linear(2, 3)
272-
lp_tensor = MyTensorWithOptionalData(l.weight, "attr", None, zero_point=None)
231+
lp_tensor = MyTensorWithOptionalData(l.weight, None, "attr")
273232
l = torch.nn.Linear(2, 3)
274-
lp_tensor_for_copy = MyTensorWithOptionalData(
275-
l.weight, "attr", None, zero_point=None
276-
)
233+
lp_tensor_for_copy = MyTensorWithOptionalData(l.weight, None, "attr")
277234
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
278235

279236
l = torch.nn.Linear(2, 3)
280237
lp_tensor = MyTensorWithOptionalData(
281-
l.weight, "attr", None, zero_point=None, optional_attr="value"
238+
l.weight, torch.zeros_like(l.weight), "attr"
282239
)
283240
l = torch.nn.Linear(2, 3)
284241
lp_tensor_for_copy = MyTensorWithOptionalData(
285-
l.weight, "attr", None, zero_point=None, optional_attr="value"
242+
l.weight, torch.zeros_like(l.weight), "attr"
286243
)
287244
self._test_default_impls_helper(lp_tensor, lp_tensor_for_copy)
288245

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ class Float8Tensor(TorchAOBaseTensor):
9494
"""
9595

9696
tensor_data_names = ["qdata", "scale"]
97-
tensor_attribute_names = []
98-
optional_tensor_attribute_names = [
97+
tensor_attribute_names = [
9998
"block_size",
10099
"mm_config",
101100
"hp_value_lb",
@@ -107,15 +106,15 @@ class Float8Tensor(TorchAOBaseTensor):
107106

108107
def __new__(
109108
cls,
110-
qdata: torch.Tensor,
111-
scale: torch.Tensor,
112-
block_size: Optional[List[int]] = None,
113-
mm_config: Optional[Float8MMConfig] = None,
114-
hp_value_lb: Optional[float] = None,
115-
hp_value_ub: Optional[float] = None,
116-
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
117-
kernel_preference: KernelPreference = KernelPreference.AUTO,
118-
dtype: Optional[torch.dtype] = None,
109+
qdata,
110+
scale,
111+
block_size,
112+
mm_config,
113+
hp_value_lb,
114+
hp_value_ub,
115+
act_quant_kwargs,
116+
kernel_preference,
117+
dtype,
119118
):
120119
shape = qdata.shape
121120
kwargs = {}

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,17 @@ class Int4PreshuffledTensor(TorchAOBaseTensor):
7575
"""
7676

7777
tensor_data_names = ["qdata", "group_scale"]
78-
tensor_attribute_names = ["block_size", "shape"]
7978
optional_tensor_data_names = ["group_zero", "row_scale"]
79+
tensor_attribute_names = ["block_size", "shape"]
8080

8181
def __new__(
8282
cls,
83-
qdata: torch.Tensor,
84-
group_scale: torch.Tensor,
85-
block_size: List[int],
86-
shape: List[int],
87-
group_zero: Optional[torch.Tensor] = None,
88-
row_scale: Optional[torch.Tensor] = None,
83+
qdata,
84+
group_scale,
85+
group_zero,
86+
row_scale,
87+
block_size,
88+
shape,
8989
):
9090
kwargs = {}
9191
kwargs["device"] = qdata.device
@@ -97,19 +97,19 @@ def __init__(
9797
self,
9898
qdata: torch.Tensor,
9999
group_scale: torch.Tensor,
100+
group_zero: Optional[torch.Tensor],
101+
row_scale: Optional[torch.Tensor],
100102
block_size: List[int],
101103
shape: List[int],
102-
group_zero: Optional[torch.Tensor] = None,
103-
row_scale: Optional[torch.Tensor] = None,
104104
):
105105
# one and only one of group_scale and group_zero should be None
106106
assert group_zero is None or row_scale is None
107107
assert not (group_zero is not None and row_scale is not None)
108108
self.qdata = qdata
109-
self.row_scale = row_scale
110-
self.block_size = block_size
111109
self.group_scale = group_scale
112110
self.group_zero = group_zero
111+
self.row_scale = row_scale
112+
self.block_size = block_size
113113

114114
def _quantization_type(self):
115115
return f"shape={self.shape}, block_size={self.block_size}, device={self.device}"
@@ -178,10 +178,10 @@ def from_hp(
178178
return Int4PreshuffledTensor(
179179
qdata=wq,
180180
group_scale=group_scale,
181-
block_size=block_size,
182-
shape=original_shape,
183181
group_zero=group_zero,
184182
row_scale=row_scale,
183+
block_size=block_size,
184+
shape=original_shape,
185185
)
186186

187187

0 commit comments

Comments
 (0)