diff --git a/mautrix_telegram/abstract_user.py b/mautrix_telegram/abstract_user.py index b9d99d15..2eb96260 100644 --- a/mautrix_telegram/abstract_user.py +++ b/mautrix_telegram/abstract_user.py @@ -40,6 +40,7 @@ PeerUser, PhoneCallRequested, TypeUpdate, + UpdateBotMessageReaction, UpdateChannel, UpdateChannelUserTyping, UpdateChatDefaultBannedRights, @@ -363,6 +364,8 @@ async def _update(self, update: TypeUpdate) -> None: await self.update_phone_call(update) elif isinstance(update, UpdateMessageReactions): await self.update_reactions(update) + elif isinstance(update, UpdateBotMessageReaction): + await self.update_bot_reactions(update) elif isinstance(update, (UpdateChatUserTyping, UpdateChannelUserTyping, UpdateUserTyping)): await self.update_typing(update) elif isinstance(update, UpdateUserStatus): @@ -636,6 +639,12 @@ async def update_reactions(self, update: UpdateMessageReactions) -> None: return await portal.handle_telegram_reactions(self, TelegramID(update.msg_id), update.reactions) + async def update_bot_reactions(self, update: UpdateBotMessageReaction) -> None: + portal = await po.Portal.get_by_entity(update.peer, tg_receiver=self.tgid) + if not portal or not portal.mxid or not portal.allow_bridging: + return + await portal.handle_telegram_bot_reactions(self, update) + async def update_phone_call(self, update: UpdatePhoneCall) -> None: self.log.debug("Phone call update %s", update) if not isinstance(update.phone_call, PhoneCallRequested): diff --git a/mautrix_telegram/portal.py b/mautrix_telegram/portal.py index 7801f5cb..c3ffa186 100644 --- a/mautrix_telegram/portal.py +++ b/mautrix_telegram/portal.py @@ -23,6 +23,7 @@ Callable, List, Literal, + NamedTuple, Union, cast, ) @@ -160,6 +161,7 @@ TypeUser, TypeUserFull, TypeUserProfilePhoto, + UpdateBotMessageReaction, UpdateChannelUserTyping, UpdateChatUserTyping, UpdateMessageReactions, @@ -270,6 +272,11 @@ class IgnoredMessageError(Exception): pass +class WrappedReaction(NamedTuple): + reaction: ReactionEmoji | ReactionCustomEmoji + date: datetime | None + + class Portal(DBPortal, BasePortal): bot: "Bot" config: Config @@ -3251,12 +3258,40 @@ async def handle_telegram_reactions( recent_reactions = resp.reactions async with self.reaction_lock(dbm.mxid): - await self._handle_telegram_reactions_locked( + await self._handle_telegram_user_reactions_locked( source, dbm, recent_reactions, total_count, timestamp=timestamp ) + async def handle_telegram_bot_reactions( + self, source: au.AbstractUser, update: UpdateBotMessageReaction + ) -> None: + tg_space = self.tgid if self.peer_type == "channel" else source.tgid + dbm = await DBMessage.get_one_by_tgid(TelegramID(update.msg_id), tg_space) + if dbm is None: + return + reactions: dict[TelegramID, list[WrappedReaction]] = {} + custom_emoji_ids: list[int] = [] + if isinstance(update.actor, PeerUser): + user_id = TelegramID(update.actor.user_id) + elif isinstance(update.actor, PeerChannel): + user_id = TelegramID(update.actor.channel_id) + else: + return + for reaction in update.new_reactions: + reactions.setdefault(user_id, []).append(WrappedReaction(reaction=reaction, date=None)) + async with self.reaction_lock(dbm.mxid): + await self._handle_telegram_parsed_reactions_locked( + source, + dbm, + reactions, + custom_emoji_ids, + is_full=True, + only_user_id=user_id, + timestamp=update.date, + ) + @staticmethod - def _reactions_filter(lst: list[MessagePeerReaction], existing: DBReaction) -> bool: + def _reactions_filter(lst: list[WrappedReaction], existing: DBReaction) -> bool: if not lst: return False for wrapped_reaction in lst: @@ -3279,7 +3314,7 @@ async def _get_reaction_limit(source: au.AbstractUser, sender: TelegramID) -> in return await source.get_max_reactions(is_premium) return 3 if is_premium else 1 - async def _handle_telegram_reactions_locked( + async def _handle_telegram_user_reactions_locked( self, source: au.AbstractUser, msg: DBMessage, @@ -3287,17 +3322,38 @@ async def _handle_telegram_reactions_locked( total_count: int, timestamp: datetime | None = None, ) -> None: - reactions: dict[TelegramID, list[MessagePeerReaction]] = {} + reactions: dict[TelegramID, list[WrappedReaction]] = {} custom_emoji_ids: list[int] = [] for reaction in reaction_list: if isinstance(reaction.peer_id, (PeerUser, PeerChannel)) and isinstance( reaction.reaction, (ReactionEmoji, ReactionCustomEmoji) ): sender_user_id = p.Puppet.get_id_from_peer(reaction.peer_id) - reactions.setdefault(sender_user_id, []).append(reaction) + reactions.setdefault(sender_user_id, []).append( + WrappedReaction(reaction.reaction, reaction.date) + ) if isinstance(reaction.reaction, ReactionCustomEmoji): custom_emoji_ids.append(reaction.reaction.document_id) is_full = len(reaction_list) == total_count + await self._handle_telegram_parsed_reactions_locked( + source, + msg, + reactions, + custom_emoji_ids, + is_full=is_full, + timestamp=timestamp, + ) + + async def _handle_telegram_parsed_reactions_locked( + self, + source: au.AbstractUser, + msg: DBMessage, + reactions: dict[TelegramID, list[WrappedReaction]], + custom_emoji_ids: list[int], + is_full: bool, + only_user_id: TelegramID | None = None, + timestamp: datetime | None = None, + ) -> None: custom_emojis = await util.transfer_custom_emojis_to_matrix(source, custom_emoji_ids) existing_reactions = await DBReaction.get_all_by_message(msg.mxid, msg.mx_room) @@ -3305,6 +3361,8 @@ async def _handle_telegram_reactions_locked( removed: list[DBReaction] = [] for existing_reaction in existing_reactions: sender_id = existing_reaction.tg_sender + if only_user_id is not None and sender_id != only_user_id: + continue new_reactions = reactions.get(sender_id) if self._reactions_filter(new_reactions, existing_reaction): if new_reactions is not None and len(new_reactions) == 0: