Skip to content

Commit fce1c3d

Browse files
Added statesv2
1 parent ab2dca8 commit fce1c3d

File tree

4 files changed

+263
-74
lines changed

4 files changed

+263
-74
lines changed

telebot/__init__.py

Lines changed: 99 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import threading
88
import time
99
import traceback
10-
from typing import Any, Callable, List, Optional, Union
10+
from typing import Any, Callable, List, Optional, Union, Dict
1111

1212
# these imports are used to avoid circular import error
1313
import telebot.util
@@ -168,7 +168,8 @@ def __init__(
168168
disable_notification: Optional[bool]=None,
169169
protect_content: Optional[bool]=None,
170170
allow_sending_without_reply: Optional[bool]=None,
171-
colorful_logs: Optional[bool]=False
171+
colorful_logs: Optional[bool]=False,
172+
token_check: Optional[bool]=True
172173
):
173174

174175
# update-related
@@ -186,6 +187,11 @@ def __init__(
186187
self.webhook_listener = None
187188
self._user = None
188189

190+
# token check
191+
if token_check:
192+
self._user = self.get_me()
193+
self.bot_id = self._user.id
194+
189195
# logs-related
190196
if colorful_logs:
191197
try:
@@ -280,6 +286,8 @@ def __init__(
280286
self.threaded = threaded
281287
if self.threaded:
282288
self.worker_pool = util.ThreadPool(self, num_threads=num_threads)
289+
290+
283291

284292
@property
285293
def user(self) -> types.User:
@@ -6572,7 +6580,9 @@ def setup_middleware(self, middleware: BaseMiddleware):
65726580
self.middlewares.append(middleware)
65736581

65746582

6575-
def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None) -> None:
6583+
def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Optional[int]=None,
6584+
business_connection_id: Optional[str]=None, message_thread_id: Optional[int]=None,
6585+
bot_id: Optional[int]=None) -> None:
65766586
"""
65776587
Sets a new state of a user.
65786588
@@ -6591,14 +6601,29 @@ def set_state(self, user_id: int, state: Union[int, str, State], chat_id: Option
65916601
:param chat_id: Chat's identifier
65926602
:type chat_id: :obj:`int`
65936603
6604+
:param bot_id: Bot's identifier
6605+
:type bot_id: :obj:`int`
6606+
6607+
:param business_connection_id: Business identifier
6608+
:type business_connection_id: :obj:`str`
6609+
6610+
:param message_thread_id: Identifier of the message thread
6611+
:type message_thread_id: :obj:`int`
6612+
65946613
:return: None
65956614
"""
65966615
if chat_id is None:
65976616
chat_id = user_id
6598-
self.current_states.set_state(chat_id, user_id, state)
6617+
if bot_id is None:
6618+
bot_id = self.bot_id
6619+
self.current_states.set_state(
6620+
chat_id=chat_id, user_id=user_id, state=state, bot_id=bot_id,
6621+
business_connection_id=business_connection_id, message_thread_id=message_thread_id)
65996622

66006623

6601-
def reset_data(self, user_id: int, chat_id: Optional[int]=None):
6624+
def reset_data(self, user_id: int, chat_id: Optional[int]=None,
6625+
business_connection_id: Optional[str]=None,
6626+
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None:
66026627
"""
66036628
Reset data for a user in chat.
66046629
@@ -6608,14 +6633,27 @@ def reset_data(self, user_id: int, chat_id: Optional[int]=None):
66086633
:param chat_id: Chat's identifier
66096634
:type chat_id: :obj:`int`
66106635
6636+
:param bot_id: Bot's identifier
6637+
:type bot_id: :obj:`int`
6638+
6639+
:param business_connection_id: Business identifier
6640+
:type business_connection_id: :obj:`str`
6641+
6642+
:param message_thread_id: Identifier of the message thread
6643+
:type message_thread_id: :obj:`int`
6644+
66116645
:return: None
66126646
"""
66136647
if chat_id is None:
66146648
chat_id = user_id
6615-
self.current_states.reset_data(chat_id, user_id)
6649+
if bot_id is None:
6650+
bot_id = self.bot_id
6651+
self.current_states.reset_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
6652+
business_connection_id=business_connection_id, message_thread_id=message_thread_id)
66166653

66176654

6618-
def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None:
6655+
def delete_state(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
6656+
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> None:
66196657
"""
66206658
Delete the current state of a user.
66216659
@@ -6629,10 +6667,14 @@ def delete_state(self, user_id: int, chat_id: Optional[int]=None) -> None:
66296667
"""
66306668
if chat_id is None:
66316669
chat_id = user_id
6632-
self.current_states.delete_state(chat_id, user_id)
6670+
if bot_id is None:
6671+
bot_id = self.bot_id
6672+
self.current_states.delete_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
6673+
business_connection_id=business_connection_id, message_thread_id=message_thread_id)
66336674

66346675

6635-
def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Any]:
6676+
def retrieve_data(self, user_id: int, chat_id: Optional[int]=None, business_connection_id: Optional[str]=None,
6677+
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Optional[Dict[str, Any]]:
66366678
"""
66376679
Returns context manager with data for a user in chat.
66386680
@@ -6642,15 +6684,30 @@ def retrieve_data(self, user_id: int, chat_id: Optional[int]=None) -> Optional[A
66426684
:param chat_id: Chat's unique identifier, defaults to user_id
66436685
:type chat_id: int, optional
66446686
6687+
:param bot_id: Bot's identifier
6688+
:type bot_id: int, optional
6689+
6690+
:param business_connection_id: Business identifier
6691+
:type business_connection_id: str, optional
6692+
6693+
:param message_thread_id: Identifier of the message thread
6694+
:type message_thread_id: int, optional
6695+
66456696
:return: Context manager with data for a user in chat
66466697
:rtype: Optional[Any]
66476698
"""
66486699
if chat_id is None:
66496700
chat_id = user_id
6650-
return self.current_states.get_interactive_data(chat_id, user_id)
6701+
if bot_id is None:
6702+
bot_id = self.bot_id
6703+
return self.current_states.get_interactive_data(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
6704+
business_connection_id=business_connection_id,
6705+
message_thread_id=message_thread_id)
66516706

66526707

6653-
def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union[int, str, State]]:
6708+
def get_state(self, user_id: int, chat_id: Optional[int]=None,
6709+
business_connection_id: Optional[str]=None,
6710+
message_thread_id: Optional[int]=None, bot_id: Optional[int]=None) -> Union[int, str]:
66546711
"""
66556712
Gets current state of a user.
66566713
Not recommended to use this method. But it is ok for debugging.
@@ -6661,15 +6718,31 @@ def get_state(self, user_id: int, chat_id: Optional[int]=None) -> Optional[Union
66616718
:param chat_id: Chat's identifier
66626719
:type chat_id: :obj:`int`
66636720
6721+
:param bot_id: Bot's identifier
6722+
:type bot_id: :obj:`int`
6723+
6724+
:param business_connection_id: Business identifier
6725+
:type business_connection_id: :obj:`str`
6726+
6727+
:param message_thread_id: Identifier of the message thread
6728+
:type message_thread_id: :obj:`int`
6729+
66646730
:return: state of a user
66656731
:rtype: :obj:`int` or :obj:`str` or :class:`telebot.types.State`
66666732
"""
66676733
if chat_id is None:
66686734
chat_id = user_id
6669-
return self.current_states.get_state(chat_id, user_id)
6735+
if bot_id is None:
6736+
bot_id = self.bot_id
6737+
return self.current_states.get_state(chat_id=chat_id, user_id=user_id, bot_id=bot_id,
6738+
business_connection_id=business_connection_id, message_thread_id=message_thread_id)
66706739

66716740

6672-
def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
6741+
def add_data(self, user_id: int, chat_id: Optional[int]=None,
6742+
business_connection_id: Optional[str]=None,
6743+
message_thread_id: Optional[int]=None,
6744+
bot_id: Optional[int]=None,
6745+
**kwargs) -> None:
66736746
"""
66746747
Add data to states.
66756748
@@ -6679,13 +6752,25 @@ def add_data(self, user_id: int, chat_id: Optional[int]=None, **kwargs):
66796752
:param chat_id: Chat's identifier
66806753
:type chat_id: :obj:`int`
66816754
6755+
:param bot_id: Bot's identifier
6756+
:type bot_id: :obj:`int`
6757+
6758+
:param business_connection_id: Business identifier
6759+
:type business_connection_id: :obj:`str`
6760+
6761+
:param message_thread_id: Identifier of the message thread
6762+
:type message_thread_id: :obj:`int`
6763+
66826764
:param kwargs: Data to add
66836765
:return: None
66846766
"""
66856767
if chat_id is None:
66866768
chat_id = user_id
6769+
if bot_id is None:
6770+
bot_id = self.bot_id
66876771
for key, value in kwargs.items():
6688-
self.current_states.set_data(chat_id, user_id, key, value)
6772+
self.current_states.set_data(chat_id=chat_id, user_id=user_id, key=key, value=value, bot_id=bot_id,
6773+
business_connection_id=business_connection_id, message_thread_id=message_thread_id)
66896774

66906775

66916776
def register_next_step_handler_by_chat_id(

telebot/custom_filters.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99

10+
1011
class SimpleCustomFilter(ABC):
1112
"""
1213
Simple Custom Filter base class.
@@ -417,8 +418,6 @@ def check(self, message, text):
417418
user_id = message.from_user.id
418419
message = message.message
419420

420-
421-
422421

423422
if isinstance(text, list):
424423
new_text = []
@@ -430,15 +429,24 @@ def check(self, message, text):
430429
text = text.name
431430

432431
if message.chat.type in ['group', 'supergroup']:
433-
group_state = self.bot.current_states.get_state(chat_id, user_id)
432+
group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id,
433+
message_thread_id=message.message_thread_id)
434+
if group_state is None and not message.is_topic_message: # needed for general topic and group messages
435+
group_state = self.bot.current_states.get_state(chat_id=chat_id, user_id=user_id, business_connection_id=message.business_connection_id, bot_id=self.bot._user.id)
436+
434437
if group_state == text:
435438
return True
436439
elif type(text) is list and group_state in text:
437440
return True
438441

439442

440443
else:
441-
user_state = self.bot.current_states.get_state(chat_id, user_id)
444+
user_state = self.bot.current_states.get_state(
445+
chat_id=chat_id,
446+
user_id=user_id,
447+
business_connection_id=message.business_connection_id,
448+
bot_id=self.bot._user.id
449+
)
442450
if user_state == text:
443451
return True
444452
elif type(text) is list and user_state in text:

telebot/storage/base_storage.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,56 @@ def get_interactive_data(self, chat_id, user_id):
4747

4848
def save(self, chat_id, user_id, data):
4949
raise NotImplementedError
50+
51+
def convert_params_to_key(
52+
self,
53+
chat_id: int,
54+
user_id: int,
55+
prefix: str,
56+
separator: str,
57+
business_connection_id: str=None,
58+
message_thread_id: int=None,
59+
bot_id: int=None
60+
) -> str:
61+
"""
62+
Convert parameters to a key.
63+
"""
64+
params = [prefix]
65+
if bot_id:
66+
params.append(str(bot_id))
67+
if business_connection_id:
68+
params.append(business_connection_id)
69+
if message_thread_id:
70+
params.append(str(message_thread_id))
71+
params.append(str(chat_id))
72+
params.append(str(user_id))
5073

74+
return separator.join(params)
75+
76+
77+
78+
5179

5280

5381
class StateContext:
5482
"""
5583
Class for data.
5684
"""
57-
def __init__(self , obj, chat_id, user_id) -> None:
85+
def __init__(self , obj, chat_id, user_id, business_connection_id=None, message_thread_id=None, bot_id=None, ):
5886
self.obj = obj
59-
self.data = copy.deepcopy(obj.get_data(chat_id, user_id))
87+
res = obj.get_data(chat_id=chat_id, user_id=user_id, business_connection_id=business_connection_id,
88+
message_thread_id=message_thread_id, bot_id=bot_id)
89+
self.data = copy.deepcopy(res)
6090
self.chat_id = chat_id
6191
self.user_id = user_id
92+
self.bot_id = bot_id
93+
self.business_connection_id = business_connection_id
94+
self.message_thread_id = message_thread_id
95+
6296

6397

6498
def __enter__(self):
6599
return self.data
66100

67101
def __exit__(self, exc_type, exc_val, exc_tb):
68-
return self.obj.save(self.chat_id, self.user_id, self.data)
102+
return self.obj.save(self.chat_id, self.user_id, self.data, self.business_connection_id, self.message_thread_id, self.bot_id)

0 commit comments

Comments
 (0)