|
2 | 2 | from collections import defaultdict
|
3 | 3 | from typing import DefaultDict, Optional, Tuple, cast
|
4 | 4 |
|
5 |
| -from telegram.ext import BasePersistence |
6 |
| -from telegram.ext.utils.types import BD, CD, UD, CDCData, ConversationDict |
7 |
| - |
8 |
| -from .models import BotData, CallbackData, ChatData, ConversationData, UserData |
| 5 | +from models import BotData, CallbackData, ChatData, ConversationData, UserData |
| 6 | +from telegram.ext import BasePersistence, PersistenceInput |
| 7 | +from telegram.ext._utils.types import BD, CD, UD, CDCData, ConversationDict |
9 | 8 |
|
10 | 9 |
|
11 | 10 | class DjangoPersistence(BasePersistence[UD, CD, BD]):
|
12 | 11 | def __init__(
|
13 | 12 | self,
|
14 | 13 | namespace: str = "",
|
15 |
| - store_user_data: bool = True, |
16 |
| - store_chat_data: bool = True, |
17 |
| - store_bot_data: bool = True, |
18 |
| - store_callback_data: bool = False, |
| 14 | + store_data: Optional[PersistenceInput] = None, |
| 15 | + update_interval: float = 60, |
19 | 16 | ):
|
20 |
| - super().__init__( |
21 |
| - store_user_data=store_user_data, |
22 |
| - store_chat_data=store_chat_data, |
23 |
| - store_bot_data=store_bot_data, |
24 |
| - store_callback_data=store_callback_data, |
25 |
| - ) |
| 17 | + super().__init__(store_data=store_data, update_interval=update_interval) |
26 | 18 | self._namespace = namespace
|
27 | 19 |
|
28 |
| - def get_bot_data(self) -> BD: |
| 20 | + async def get_bot_data(self) -> BD: |
29 | 21 | try:
|
30 |
| - return BotData.objects.get(namespace=self._namespace).data |
| 22 | + return (await BotData.objects.aget(namespace=self._namespace)).data |
31 | 23 | except BotData.DoesNotExist:
|
32 | 24 | return {}
|
33 | 25 |
|
34 |
| - def update_bot_data(self, data: BD) -> None: |
35 |
| - BotData.objects.update_or_create(namespace=self._namespace, defaults={"data": data}) |
| 26 | + async def update_bot_data(self, data: BD) -> None: |
| 27 | + await BotData.objects.aupdate_or_create(namespace=self._namespace, defaults={"data": data}) |
36 | 28 |
|
37 |
| - def refresh_bot_data(self, bot_data: BD) -> None: |
| 29 | + async def refresh_bot_data(self, bot_data: BD) -> None: |
38 | 30 | if isinstance(bot_data, dict):
|
39 | 31 | orig_keys = set(bot_data.keys())
|
40 |
| - bot_data.update(self.get_bot_data()) |
| 32 | + bot_data.update(await self.get_bot_data()) |
41 | 33 | for key in orig_keys - set(bot_data.keys()):
|
42 | 34 | bot_data.pop(key)
|
43 | 35 |
|
44 |
| - def get_chat_data(self) -> DefaultDict[int, CD]: |
45 |
| - return defaultdict( |
46 |
| - dict, {data.chat_id: data.data for data in ChatData.objects.filter(namespace=self._namespace)} |
47 |
| - ) |
| 36 | + async def get_chat_data(self) -> DefaultDict[int, CD]: |
| 37 | + chat_data = {} |
| 38 | + async for data in ChatData.objects.filter(namespace=self._namespace): |
| 39 | + chat_data[data.chat_id] = data.data |
| 40 | + return defaultdict(dict, chat_data) |
48 | 41 |
|
49 |
| - def update_chat_data(self, chat_id: int, data: CD) -> None: |
50 |
| - ChatData.objects.update_or_create(namespace=self._namespace, chat_id=chat_id, defaults={"data": data}) |
| 42 | + async def update_chat_data(self, chat_id: int, data: CD) -> None: |
| 43 | + await ChatData.objects.aupdate_or_create(namespace=self._namespace, chat_id=chat_id, defaults={"data": data}) |
51 | 44 |
|
52 |
| - def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: |
| 45 | + async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None: |
53 | 46 | try:
|
54 | 47 | if isinstance(chat_data, dict):
|
55 | 48 | orig_keys = set(chat_data.keys())
|
56 |
| - chat_data.update(ChatData.objects.get(namespace=self._namespace, chat_id=chat_id).data) |
| 49 | + chat_data.update((await ChatData.objects.aget(namespace=self._namespace, chat_id=chat_id)).data) |
57 | 50 | for key in orig_keys - set(chat_data.keys()):
|
58 | 51 | chat_data.pop(key)
|
59 | 52 | except ChatData.DoesNotExist:
|
60 | 53 | pass
|
61 | 54 |
|
62 |
| - def get_user_data(self) -> DefaultDict[int, UD]: |
63 |
| - return defaultdict( |
64 |
| - dict, {data.user_id: data.data for data in UserData.objects.filter(namespace=self._namespace)} |
65 |
| - ) |
| 55 | + async def get_user_data(self) -> DefaultDict[int, UD]: |
| 56 | + user_data = {} |
| 57 | + async for data in UserData.objects.filter(namespace=self._namespace): |
| 58 | + user_data[data.user_id] = data.data |
| 59 | + return defaultdict(dict, user_data) |
66 | 60 |
|
67 |
| - def update_user_data(self, user_id: int, data: UD) -> None: |
68 |
| - UserData.objects.update_or_create(namespace=self._namespace, user_id=user_id, defaults={"data": data}) |
| 61 | + async def update_user_data(self, user_id: int, data: UD) -> None: |
| 62 | + await UserData.objects.aupdate_or_create(namespace=self._namespace, user_id=user_id, defaults={"data": data}) |
69 | 63 |
|
70 |
| - def refresh_user_data(self, user_id: int, user_data: UD) -> None: |
| 64 | + async def refresh_user_data(self, user_id: int, user_data: UD) -> None: |
71 | 65 | try:
|
72 | 66 | if isinstance(user_data, dict):
|
73 | 67 | orig_keys = set(user_data.keys())
|
74 |
| - user_data.update(UserData.objects.get(namespace=self._namespace, user_id=user_id).data) |
| 68 | + user_data.update((await UserData.objects.aget(namespace=self._namespace, user_id=user_id)).data) |
75 | 69 | for key in orig_keys - set(user_data.keys()):
|
76 | 70 | user_data.pop(key)
|
77 | 71 | except UserData.DoesNotExist:
|
78 | 72 | pass
|
79 | 73 |
|
80 |
| - def get_callback_data(self) -> Optional[CDCData]: |
| 74 | + async def get_callback_data(self) -> Optional[CDCData]: |
81 | 75 | try:
|
82 |
| - cdcdata_json = CallbackData.objects.get(namespace=self._namespace).data |
| 76 | + cdcdata_json = (await CallbackData.objects.aget(namespace=self._namespace)).data |
83 | 77 | # Before asking me wtf is this, just check DictPersistence
|
84 | 78 | return cast(CDCData, ([(one, float(two), three) for one, two, three in cdcdata_json[0]], cdcdata_json[1]))
|
85 | 79 | except CallbackData.DoesNotExist:
|
86 | 80 | return None
|
87 | 81 |
|
88 |
| - def update_callback_data(self, data: CDCData) -> None: |
89 |
| - CallbackData.objects.update_or_create(namespace=self._namespace, defaults={"data": data}) |
| 82 | + async def update_callback_data(self, data: CDCData) -> None: |
| 83 | + await CallbackData.objects.aupdate_or_create(namespace=self._namespace, defaults=data) |
90 | 84 |
|
91 |
| - def get_conversations(self, name: str) -> ConversationDict: |
92 |
| - return { |
93 |
| - tuple(json.loads(data.key)): data.state |
94 |
| - for data in ConversationData.objects.filter(namespace=self._namespace, name=name) |
95 |
| - } |
| 85 | + async def get_conversations(self, name: str) -> ConversationDict: |
| 86 | + result = {} |
| 87 | + async for data in ConversationData.objects.filter(namespace=self._namespace, name=name): |
| 88 | + result[tuple(json.loads(data.key))] = data.state |
96 | 89 |
|
97 |
| - def update_conversation(self, name: str, key: Tuple[int, ...], new_state: Optional[object]) -> None: |
98 |
| - ConversationData.objects.update_or_create( |
| 90 | + async def update_conversation(self, name: str, key: Tuple[int, ...], new_state: Optional[object]) -> None: |
| 91 | + await ConversationData.objects.aupdate_or_create( |
99 | 92 | namespace=self._namespace,
|
100 | 93 | name=name,
|
101 | 94 | key=json.dumps(key, sort_keys=True),
|
102 | 95 | defaults={"state": new_state},
|
103 | 96 | )
|
104 | 97 |
|
105 |
| - def flush(self) -> None: |
| 98 | + async def flush(self) -> None: |
106 | 99 | pass
|
| 100 | + |
| 101 | + async def drop_chat_data(self, chat_id: int) -> None: |
| 102 | + await ChatData.objects.filter(namespace=self._namespace, chat_id=chat_id).adelete() |
| 103 | + |
| 104 | + async def drop_user_data(self, user_id: int) -> None: |
| 105 | + await UserData.objects.filter(namespace=self._namespace, user_id=user_id).adelete() |
0 commit comments