mirror of
https://github.com/deadcxap/YandexMusicDiscordBot.git
synced 2026-01-10 00:21:45 +03:00
221 lines
7.9 KiB
Python
221 lines
7.9 KiB
Python
from random import randint
|
|
from typing import Any, Literal
|
|
from yandex_music import Track
|
|
from pymongo import UpdateOne, ReturnDocument
|
|
from pymongo.errors import DuplicateKeyError
|
|
|
|
from MusicBot.database import BaseGuildsDatabase, guilds
|
|
|
|
class VoiceGuildsDatabase(BaseGuildsDatabase):
|
|
|
|
async def get_tracks_list(self, gid: int, list_type: Literal['next', 'previous']) -> list[dict[str, Any]]:
|
|
if list_type not in ('next', 'previous'):
|
|
raise ValueError("list_type must be either 'next' or 'previous'")
|
|
projection = {f"{list_type}_tracks": 1}
|
|
guild = await self.get_guild(gid, projection=projection)
|
|
return guild.get(f"{list_type}_tracks", [])
|
|
|
|
async def get_track(self, gid: int, list_type: Literal['next', 'previous', 'current']) -> dict[str, Any] | None:
|
|
if list_type not in ('next', 'previous', 'current'):
|
|
raise ValueError("list_type must be either 'next' or 'previous'")
|
|
|
|
field = f'{list_type}_tracks'
|
|
guild = await self.get_guild(gid, projection={'current_track': 1, field: 1})
|
|
|
|
if list_type == 'current':
|
|
return guild['current_track']
|
|
|
|
result = await guilds.find_one_and_update(
|
|
{'_id': gid},
|
|
{'$pop': {field: -1}},
|
|
projection={field: 1},
|
|
return_document=ReturnDocument.BEFORE
|
|
)
|
|
|
|
res = result.get(field, []) if result else None
|
|
|
|
if field == 'previous_tracks' and res:
|
|
await guilds.find_one_and_update(
|
|
{'_id': gid},
|
|
{'$push': {'next_tracks': {'$each': [guild['current_track']], '$position': 0}}},
|
|
projection={'next_tracks': 1}
|
|
)
|
|
|
|
return res[0] if res else None
|
|
|
|
async def modify_track(
|
|
self,
|
|
gid: int,
|
|
track: Track | dict[str, Any] | list[dict[str, Any]] | list[Track],
|
|
list_type: Literal['next', 'previous'],
|
|
operation: Literal['insert', 'append', 'extend', 'pop_start', 'pop_end']
|
|
) -> dict[str, Any] | None:
|
|
field = f"{list_type}_tracks"
|
|
track_data = self._normalize_track_data(track)
|
|
|
|
operations = {
|
|
'insert': {'$push': {field: {'$each': track_data, '$position': 0}}},
|
|
'append': {'$push': {field: {'$each': track_data}}},
|
|
'extend': {'$push': {field: {'$each': track_data}}}, # Same as append for consistency with python
|
|
'pop_start': {'$pop': {field: -1}},
|
|
'pop_end': {'$pop': {field: 1}}
|
|
}
|
|
|
|
update = operations[operation]
|
|
try:
|
|
await guilds.update_one(
|
|
{'_id': gid},
|
|
update,
|
|
array_filters=None
|
|
)
|
|
return await self._get_popped_track(gid, field, operation)
|
|
except DuplicateKeyError:
|
|
await self._handle_duplicate_error(gid, field)
|
|
return await self.modify_track(gid, track, list_type, operation)
|
|
|
|
def _normalize_track_data(self, track: Track | dict | list) -> list[dict]:
|
|
if not isinstance(track, list):
|
|
track = [track]
|
|
|
|
return [
|
|
t.to_dict() if isinstance(t, Track) else t
|
|
for t in track
|
|
]
|
|
|
|
async def pop_random_track(self, gid: int, field: Literal['next', 'previous']) -> dict[str, Any] | None:
|
|
tracks = await self.get_tracks_list(gid, field)
|
|
track = tracks.pop(randint(0, len(tracks) - 1)) if tracks else None
|
|
await self.update(gid, {f"{field}_tracks": tracks})
|
|
return track
|
|
|
|
async def get_current_menu(self, gid: int) -> int | None:
|
|
guild = await self.get_guild(gid, projection={'current_menu': 1})
|
|
return guild['current_menu']
|
|
|
|
async def _get_popped_track(self, gid: int, field: str, operation: str) -> dict[str, Any] | None:
|
|
if operation not in ('pop_start', 'pop_end', 'pop_random'):
|
|
return None
|
|
|
|
guild = await self.get_guild(gid, projection={field: 1})
|
|
tracks = guild.get(field, [])
|
|
|
|
if not tracks:
|
|
return None
|
|
|
|
if operation == 'pop_start':
|
|
return tracks[0]
|
|
elif operation == 'pop_end':
|
|
return tracks[-1]
|
|
elif operation == 'pop_random':
|
|
return tracks[randint(0, len(tracks) - 1)]
|
|
|
|
return None
|
|
|
|
async def _handle_duplicate_error(self, gid: int, field: str) -> None:
|
|
"""Handle duplicate key errors by cleaning up the array."""
|
|
guild = await self.get_guild(gid, projection={field: 1})
|
|
tracks = guild.get(field, [])
|
|
|
|
if not tracks:
|
|
return
|
|
|
|
# Remove duplicates while preserving order
|
|
unique_tracks = []
|
|
seen = set()
|
|
for track in tracks:
|
|
track_id = track.get('id')
|
|
if track_id not in seen:
|
|
seen.add(track_id)
|
|
unique_tracks.append(track)
|
|
|
|
await guilds.update_one(
|
|
{'_id': gid},
|
|
{'$set': {field: unique_tracks}}
|
|
)
|
|
|
|
async def set_current_track(self, gid: int, track: Track | dict[str, Any]) -> None:
|
|
"""Set the current track and update the previous tracks list."""
|
|
if isinstance(track, Track):
|
|
track = track.to_dict()
|
|
|
|
await guilds.update_one(
|
|
{'_id': gid},
|
|
{'$set': {'current_track': track}}
|
|
)
|
|
|
|
async def clear_tracks(self, gid: int, list_type: Literal['next', 'previous']) -> None:
|
|
"""Clear the specified tracks list."""
|
|
field = f"{list_type}_tracks"
|
|
await guilds.update_one(
|
|
{'_id': gid},
|
|
{'$set': {field: []}}
|
|
)
|
|
|
|
async def shuffle_tracks(self, gid: int, list_type: Literal['next', 'previous']) -> None:
|
|
"""Shuffle the specified tracks list."""
|
|
field = f"{list_type}_tracks"
|
|
guild = await self.get_guild(gid, projection={field: 1})
|
|
tracks = guild.get(field, [])
|
|
|
|
if not tracks:
|
|
return
|
|
|
|
shuffled_tracks = tracks.copy()
|
|
for i in range(len(shuffled_tracks) - 1, 0, -1):
|
|
j = randint(0, i)
|
|
shuffled_tracks[i], shuffled_tracks[j] = shuffled_tracks[j], shuffled_tracks[i]
|
|
|
|
await guilds.update_one(
|
|
{'_id': gid},
|
|
{'$set': {field: shuffled_tracks}}
|
|
)
|
|
|
|
async def move_track(
|
|
self,
|
|
gid: int,
|
|
from_list: Literal['next', 'previous'],
|
|
to_list: Literal['next', 'previous'],
|
|
track_index: int
|
|
) -> bool:
|
|
"""Move a track from one list to another."""
|
|
from_field = f"{from_list}_tracks"
|
|
to_field = f"{to_list}_tracks"
|
|
|
|
if from_field not in ('next_tracks', 'previous_tracks') or to_field not in ('next_tracks', 'previous_tracks'):
|
|
raise ValueError(f"Invalid list type: '{from_field}'")
|
|
|
|
guild = await guilds.find_one(
|
|
{'_id': gid},
|
|
projection={from_field: 1, to_field: 1},
|
|
)
|
|
|
|
if not guild or not guild.get(from_field) or track_index >= len(guild[from_field]):
|
|
return False
|
|
|
|
track = guild[from_field].pop(track_index)
|
|
updates = [
|
|
UpdateOne(
|
|
{'_id': gid},
|
|
{'$set': {from_field: guild[from_field]}},
|
|
),
|
|
UpdateOne(
|
|
{'_id': gid},
|
|
{'$push': {to_field: {'$each': [track], '$position': 0}}},
|
|
)
|
|
]
|
|
|
|
await guilds.bulk_write(updates)
|
|
return True
|
|
|
|
async def get_track_count(self, gid: int, list_type: Literal['next', 'previous']) -> int:
|
|
"""Get the count of tracks in the specified list."""
|
|
field = f"{list_type}_tracks"
|
|
guild = await self.get_guild(gid, projection={field: 1})
|
|
return len(guild.get(field, []))
|
|
|
|
async def set_current_menu(self, gid: int, menu_id: int | None) -> None:
|
|
"""Set the current menu message ID."""
|
|
await guilds.update_one(
|
|
{'_id': gid},
|
|
{'$set': {'current_menu': menu_id}}
|
|
) |