diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py index 9135b07c9f259a..f86111f330c844 100644 --- a/Lib/dataclasses.py +++ b/Lib/dataclasses.py @@ -6,6 +6,7 @@ import keyword import builtins import functools +import collections import _thread @@ -1024,7 +1025,7 @@ def is_dataclass(obj): return hasattr(cls, _FIELDS) -def asdict(obj, *, dict_factory=dict): +def asdict(obj, *, dict_factory=dict, transformers=None): """Return the fields of a dataclass instance as a new dictionary mapping field names to field values. @@ -1045,48 +1046,58 @@ class C: """ if not _is_dataclass_instance(obj): raise TypeError("asdict() should be called on dataclass instances") - return _asdict_inner(obj, dict_factory) + if transformers is None: + transformers = _DEFAULT_TRANFORMERS + else: + transformers = tuple(transformers) + _DEFAULT_TRANFORMERS + + return _asdict_inner(obj, dict_factory, transformers) + +def _handle_dataclass(obj, asdict_inner): + result = [] + for f in fields(obj): + value = asdict_inner(getattr(obj, f.name)) + result.append((f.name, value)) + return asdict_inner.dict_factory(result) + +def _handle_sequence(obj, asdict_inner): + return type(obj)(map(asdict_inner, obj)) + +def _handle_dict(obj, asdict_inner): + return type(obj)(((asdict_inner(k), asdict_inner(v)) + for k, v in obj.items())) + +def _handle_defaultdict(obj, asdict_inner): + return type(obj)(obj.default_factory, + ((asdict_inner(k), asdict_inner(v)) + for k, v in obj.items())) + +def _handle_namedtuple(obj, asdict_inner): + return type(obj)(*map(asdict_inner, obj)) + +_DEFAULT_TRANFORMERS = ( + (_is_dataclass_instance, _handle_dataclass), + (lambda obj: isinstance(obj, tuple) and hasattr(obj, "_fields"), + _handle_namedtuple), + ((tuple, list), _handle_sequence), + (collections.defaultdict, _handle_defaultdict), + (dict, _handle_dict), +) + +def _asdict_inner(obj, dict_factory, transformers): + def asdict_inner(obj): + for cond, transformer in transformers: + if isinstance(cond, (type, tuple)): + if isinstance(obj, cond): + return transformer(obj, asdict_inner) + elif cond(obj): + return transformer(obj, asdict_inner) + return copy.deepcopy(obj) -def _asdict_inner(obj, dict_factory): - if _is_dataclass_instance(obj): - result = [] - for f in fields(obj): - value = _asdict_inner(getattr(obj, f.name), dict_factory) - result.append((f.name, value)) - return dict_factory(result) - elif isinstance(obj, tuple) and hasattr(obj, '_fields'): - # obj is a namedtuple. Recurse into it, but the returned - # object is another namedtuple of the same type. This is - # similar to how other list- or tuple-derived classes are - # treated (see below), but we just need to create them - # differently because a namedtuple's __init__ needs to be - # called differently (see bpo-34363). + asdict_inner.dict_factory = dict_factory - # I'm not using namedtuple's _asdict() - # method, because: - # - it does not recurse in to the namedtuple fields and - # convert them to dicts (using dict_factory). - # - I don't actually want to return a dict here. The the main - # use case here is json.dumps, and it handles converting - # namedtuples to lists. Admittedly we're losing some - # information here when we produce a json list instead of a - # dict. Note that if we returned dicts here instead of - # namedtuples, we could no longer call asdict() on a data - # structure where a namedtuple was used as a dict key. - - return type(obj)(*[_asdict_inner(v, dict_factory) for v in obj]) - elif isinstance(obj, (list, tuple)): - # Assume we can create an object of this type by passing in a - # generator (which is not true for namedtuples, handled - # above). - return type(obj)(_asdict_inner(v, dict_factory) for v in obj) - elif isinstance(obj, dict): - return type(obj)((_asdict_inner(k, dict_factory), - _asdict_inner(v, dict_factory)) - for k, v in obj.items()) - else: - return copy.deepcopy(obj) + return asdict_inner(obj) def astuple(obj, *, tuple_factory=tuple): diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py index 037bf4c2214279..a59714ddf78c94 100644 --- a/Lib/test/test_dataclasses.py +++ b/Lib/test/test_dataclasses.py @@ -1652,6 +1652,18 @@ class C: t = astuple(c, tuple_factory=list) self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) + def test_helper_asdict_defaultdict(self): + @dataclass + class C: + d: dict + + from collections import defaultdict + c = C(defaultdict(int)) + self.assertEqual(asdict(c), { + 'd': defaultdict(int) + }) + + def test_dynamic_class_creation(self): cls_dict = {'__annotations__': {'x': int, 'y': int}, }