@@ -186,103 +186,60 @@ class MyTensor(TorchAOBaseTensor):
186
186
tensor_data_names = ["qdata" ]
187
187
tensor_attribute_names = ["attr" , "device" ]
188
188
189
- def __new__ (cls , qdata , attr , device ):
189
+ def __new__ (cls , qdata , attr , device = None ):
190
190
shape = qdata .shape
191
191
if device is None :
192
192
device = qdata .device
193
193
kwargs = {"device" : device }
194
194
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
195
195
196
- def __init__ (self , qdata , attr , device ):
196
+ def __init__ (self , qdata , attr , device = None ):
197
197
self .qdata = qdata
198
198
self .attr = attr
199
199
200
200
l = torch .nn .Linear (2 , 3 )
201
- l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" , None ))
201
+ l .weight = torch .nn .Parameter (MyTensor (l .weight , "attr" ))
202
202
lp_tensor = l .weight
203
203
204
204
another_tensor = torch .nn .Linear (2 , 3 ).weight
205
205
# attribute has to be the same
206
- lp_tensor_for_copy = MyTensor (another_tensor , "attr" , None )
206
+ lp_tensor_for_copy = MyTensor (another_tensor , "attr" )
207
207
self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
208
208
209
209
@skip_if_no_cuda ()
210
210
def test_default_impls_with_optional_data (self ):
211
211
class MyTensorWithOptionalData (TorchAOBaseTensor ):
212
212
tensor_data_names = ["qdata" ]
213
- tensor_attribute_names = ["attr" , "device" ]
214
213
optional_tensor_data_names = ["zero_point" ]
215
-
216
- def __new__ (cls , qdata , attr , device , zero_point = None ):
217
- shape = qdata .shape
218
- if device is None :
219
- device = qdata .device
220
- kwargs = {"device" : device }
221
- return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
222
-
223
- def __init__ (self , qdata , attr , device , zero_point = None ):
224
- self .qdata = qdata
225
- self .attr = attr
226
- self .zero_point = zero_point
227
-
228
- # test both the optional Tensor is None
229
- # and not None
230
- l = torch .nn .Linear (2 , 3 )
231
- lp_tensor = MyTensorWithOptionalData (l .weight , "attr" , None , None )
232
- l = torch .nn .Linear (2 , 3 )
233
- lp_tensor_for_copy = MyTensorWithOptionalData (l .weight , "attr" , None , None )
234
- self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
235
-
236
- l = torch .nn .Linear (2 , 3 )
237
- lp_tensor = MyTensorWithOptionalData (
238
- l .weight , "attr" , None , torch .zeros_like (l .weight )
239
- )
240
- l = torch .nn .Linear (2 , 3 )
241
- lp_tensor_for_copy = MyTensorWithOptionalData (
242
- l .weight , "attr" , None , torch .zeros_like (l .weight )
243
- )
244
- self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
245
-
246
- @skip_if_no_cuda ()
247
- def test_default_impls_with_optional_attr (self ):
248
- class MyTensorWithOptionalData (TorchAOBaseTensor ):
249
- tensor_data_names = ["qdata" ]
250
214
tensor_attribute_names = ["attr" , "device" ]
251
- optional_tensor_data_names = ["zero_point" ]
252
- optional_tensor_attribute_names = ["optional_attr" ]
253
215
254
- def __new__ (cls , qdata , attr , device , zero_point = None , optional_attr = None ):
216
+ def __new__ (cls , qdata , zero_point = None , attr = 1.0 , device = None ):
255
217
shape = qdata .shape
256
218
if device is None :
257
219
device = qdata .device
258
220
kwargs = {"device" : device }
259
221
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
260
222
261
- def __init__ (
262
- self , qdata , attr , device , zero_point = None , optional_attr = None
263
- ):
223
+ def __init__ (self , qdata , zero_point = None , attr = 1.0 , device = None ):
264
224
self .qdata = qdata
265
- self .attr = attr
266
225
self .zero_point = zero_point
267
- self .optional_attr = optional_attr
226
+ self .attr = attr
268
227
269
228
# test both the optional Tensor is None
270
229
# and not None
271
230
l = torch .nn .Linear (2 , 3 )
272
- lp_tensor = MyTensorWithOptionalData (l .weight , "attr" , None , zero_point = None )
231
+ lp_tensor = MyTensorWithOptionalData (l .weight , None , "attr" )
273
232
l = torch .nn .Linear (2 , 3 )
274
- lp_tensor_for_copy = MyTensorWithOptionalData (
275
- l .weight , "attr" , None , zero_point = None
276
- )
233
+ lp_tensor_for_copy = MyTensorWithOptionalData (l .weight , None , "attr" )
277
234
self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
278
235
279
236
l = torch .nn .Linear (2 , 3 )
280
237
lp_tensor = MyTensorWithOptionalData (
281
- l .weight , "attr" , None , zero_point = None , optional_attr = "value "
238
+ l .weight , torch . zeros_like ( l . weight ), "attr "
282
239
)
283
240
l = torch .nn .Linear (2 , 3 )
284
241
lp_tensor_for_copy = MyTensorWithOptionalData (
285
- l .weight , "attr" , None , zero_point = None , optional_attr = "value "
242
+ l .weight , torch . zeros_like ( l . weight ), "attr "
286
243
)
287
244
self ._test_default_impls_helper (lp_tensor , lp_tensor_for_copy )
288
245
0 commit comments