Skip to content

Commit 155adfb

Browse files
authored
Merge pull request #68 from rhutkovich/feature/async-django-impl
Made methods async, same as the most recent BasePersistence class
2 parents 1d06bba + e7d88bd commit 155adfb

File tree

1 file changed

+47
-48
lines changed

1 file changed

+47
-48
lines changed

python_telegram_bot_django_persistence/persistence.py

Lines changed: 47 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,105 +2,104 @@
22
from collections import defaultdict
33
from typing import DefaultDict, Optional, Tuple, cast
44

5-
from telegram.ext import BasePersistence
6-
from telegram.ext.utils.types import BD, CD, UD, CDCData, ConversationDict
7-
8-
from .models import BotData, CallbackData, ChatData, ConversationData, UserData
5+
from models import BotData, CallbackData, ChatData, ConversationData, UserData
6+
from telegram.ext import BasePersistence, PersistenceInput
7+
from telegram.ext._utils.types import BD, CD, UD, CDCData, ConversationDict
98

109

1110
class DjangoPersistence(BasePersistence[UD, CD, BD]):
1211
def __init__(
1312
self,
1413
namespace: str = "",
15-
store_user_data: bool = True,
16-
store_chat_data: bool = True,
17-
store_bot_data: bool = True,
18-
store_callback_data: bool = False,
14+
store_data: Optional[PersistenceInput] = None,
15+
update_interval: float = 60,
1916
):
20-
super().__init__(
21-
store_user_data=store_user_data,
22-
store_chat_data=store_chat_data,
23-
store_bot_data=store_bot_data,
24-
store_callback_data=store_callback_data,
25-
)
17+
super().__init__(store_data=store_data, update_interval=update_interval)
2618
self._namespace = namespace
2719

28-
def get_bot_data(self) -> BD:
20+
async def get_bot_data(self) -> BD:
2921
try:
30-
return BotData.objects.get(namespace=self._namespace).data
22+
return (await BotData.objects.aget(namespace=self._namespace)).data
3123
except BotData.DoesNotExist:
3224
return {}
3325

34-
def update_bot_data(self, data: BD) -> None:
35-
BotData.objects.update_or_create(namespace=self._namespace, defaults={"data": data})
26+
async def update_bot_data(self, data: BD) -> None:
27+
await BotData.objects.aupdate_or_create(namespace=self._namespace, defaults={"data": data})
3628

37-
def refresh_bot_data(self, bot_data: BD) -> None:
29+
async def refresh_bot_data(self, bot_data: BD) -> None:
3830
if isinstance(bot_data, dict):
3931
orig_keys = set(bot_data.keys())
40-
bot_data.update(self.get_bot_data())
32+
bot_data.update(await self.get_bot_data())
4133
for key in orig_keys - set(bot_data.keys()):
4234
bot_data.pop(key)
4335

44-
def get_chat_data(self) -> DefaultDict[int, CD]:
45-
return defaultdict(
46-
dict, {data.chat_id: data.data for data in ChatData.objects.filter(namespace=self._namespace)}
47-
)
36+
async def get_chat_data(self) -> DefaultDict[int, CD]:
37+
chat_data = {}
38+
async for data in ChatData.objects.filter(namespace=self._namespace):
39+
chat_data[data.chat_id] = data.data
40+
return defaultdict(dict, chat_data)
4841

49-
def update_chat_data(self, chat_id: int, data: CD) -> None:
50-
ChatData.objects.update_or_create(namespace=self._namespace, chat_id=chat_id, defaults={"data": data})
42+
async def update_chat_data(self, chat_id: int, data: CD) -> None:
43+
await ChatData.objects.aupdate_or_create(namespace=self._namespace, chat_id=chat_id, defaults={"data": data})
5144

52-
def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None:
45+
async def refresh_chat_data(self, chat_id: int, chat_data: CD) -> None:
5346
try:
5447
if isinstance(chat_data, dict):
5548
orig_keys = set(chat_data.keys())
56-
chat_data.update(ChatData.objects.get(namespace=self._namespace, chat_id=chat_id).data)
49+
chat_data.update((await ChatData.objects.aget(namespace=self._namespace, chat_id=chat_id)).data)
5750
for key in orig_keys - set(chat_data.keys()):
5851
chat_data.pop(key)
5952
except ChatData.DoesNotExist:
6053
pass
6154

62-
def get_user_data(self) -> DefaultDict[int, UD]:
63-
return defaultdict(
64-
dict, {data.user_id: data.data for data in UserData.objects.filter(namespace=self._namespace)}
65-
)
55+
async def get_user_data(self) -> DefaultDict[int, UD]:
56+
user_data = {}
57+
async for data in UserData.objects.filter(namespace=self._namespace):
58+
user_data[data.user_id] = data.data
59+
return defaultdict(dict, user_data)
6660

67-
def update_user_data(self, user_id: int, data: UD) -> None:
68-
UserData.objects.update_or_create(namespace=self._namespace, user_id=user_id, defaults={"data": data})
61+
async def update_user_data(self, user_id: int, data: UD) -> None:
62+
await UserData.objects.aupdate_or_create(namespace=self._namespace, user_id=user_id, defaults={"data": data})
6963

70-
def refresh_user_data(self, user_id: int, user_data: UD) -> None:
64+
async def refresh_user_data(self, user_id: int, user_data: UD) -> None:
7165
try:
7266
if isinstance(user_data, dict):
7367
orig_keys = set(user_data.keys())
74-
user_data.update(UserData.objects.get(namespace=self._namespace, user_id=user_id).data)
68+
user_data.update((await UserData.objects.aget(namespace=self._namespace, user_id=user_id)).data)
7569
for key in orig_keys - set(user_data.keys()):
7670
user_data.pop(key)
7771
except UserData.DoesNotExist:
7872
pass
7973

80-
def get_callback_data(self) -> Optional[CDCData]:
74+
async def get_callback_data(self) -> Optional[CDCData]:
8175
try:
82-
cdcdata_json = CallbackData.objects.get(namespace=self._namespace).data
76+
cdcdata_json = (await CallbackData.objects.aget(namespace=self._namespace)).data
8377
# Before asking me wtf is this, just check DictPersistence
8478
return cast(CDCData, ([(one, float(two), three) for one, two, three in cdcdata_json[0]], cdcdata_json[1]))
8579
except CallbackData.DoesNotExist:
8680
return None
8781

88-
def update_callback_data(self, data: CDCData) -> None:
89-
CallbackData.objects.update_or_create(namespace=self._namespace, defaults={"data": data})
82+
async def update_callback_data(self, data: CDCData) -> None:
83+
await CallbackData.objects.aupdate_or_create(namespace=self._namespace, defaults=data)
9084

91-
def get_conversations(self, name: str) -> ConversationDict:
92-
return {
93-
tuple(json.loads(data.key)): data.state
94-
for data in ConversationData.objects.filter(namespace=self._namespace, name=name)
95-
}
85+
async def get_conversations(self, name: str) -> ConversationDict:
86+
result = {}
87+
async for data in ConversationData.objects.filter(namespace=self._namespace, name=name):
88+
result[tuple(json.loads(data.key))] = data.state
9689

97-
def update_conversation(self, name: str, key: Tuple[int, ...], new_state: Optional[object]) -> None:
98-
ConversationData.objects.update_or_create(
90+
async def update_conversation(self, name: str, key: Tuple[int, ...], new_state: Optional[object]) -> None:
91+
await ConversationData.objects.aupdate_or_create(
9992
namespace=self._namespace,
10093
name=name,
10194
key=json.dumps(key, sort_keys=True),
10295
defaults={"state": new_state},
10396
)
10497

105-
def flush(self) -> None:
98+
async def flush(self) -> None:
10699
pass
100+
101+
async def drop_chat_data(self, chat_id: int) -> None:
102+
await ChatData.objects.filter(namespace=self._namespace, chat_id=chat_id).adelete()
103+
104+
async def drop_user_data(self, user_id: int) -> None:
105+
await UserData.objects.filter(namespace=self._namespace, user_id=user_id).adelete()

0 commit comments

Comments
 (0)