470 lines
16 KiB
Python
470 lines
16 KiB
Python
from datetime import datetime, timedelta
|
|
import discord
|
|
import sqlite3
|
|
import typing
|
|
|
|
from cogs import music_player
|
|
|
|
class Database:
|
|
<<<<<<< HEAD
|
|
def __init__(self, path: str = "[REDACTED]_bot.db"):
|
|
=======
|
|
def __init__(self, path: str):
|
|
>>>>>>> 669339f (Made player controls based on Discord actions.)
|
|
self.path = path
|
|
self._ensure_db()
|
|
|
|
def _ensure_db(self):
|
|
with sqlite3.connect(self.path) as conn:
|
|
|
|
# Table for keeping track of servers
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS server (
|
|
id INTEGER PRIMARY KEY,
|
|
discord_id INTEGER NOT NULL UNIQUE
|
|
)
|
|
""")
|
|
|
|
# Table for keeping track of channels
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS channel (
|
|
id INTEGER PRIMARY KEY,
|
|
discord_id INTEGER NOT NULL UNIQUE
|
|
)
|
|
""")
|
|
|
|
# Table for keeping track of users
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS user (
|
|
id INTEGER PRIMARY KEY,
|
|
discord_id INTEGER NOT NULL UNIQUE
|
|
)
|
|
""")
|
|
|
|
# Create the activity table
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS activity_change (
|
|
id INTEGER PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
before_activity_type TEXT,
|
|
before_activity_name TEXT,
|
|
before_activity_status TEXT NOT NULL,
|
|
after_activity_type TEXT,
|
|
after_activity_name TEXT,
|
|
after_activity_status TEXT NOT NULL,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
|
|
# Create the song request table
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS song_request (
|
|
id INTEGER PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
channel_id INTEGER NOT NULL,
|
|
search_term TEXT NOT NULL,
|
|
song_title TEXT,
|
|
song_artist TEXT,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
|
|
|
|
# # TEMP
|
|
# conn.execute("""
|
|
# ALTER TABLE song_play ADD COLUMN finished BOOL;
|
|
# """)
|
|
|
|
|
|
# Table for songs that actually get played
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS song_play (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user_id INTEGER,
|
|
channel_id INTEGER NOT NULL,
|
|
search_term TEXT NOT NULL,
|
|
song_title TEXT,
|
|
song_artist TEXT,
|
|
finished BOOL DEFAULT 0,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
)
|
|
""")
|
|
|
|
# # ############ TEMP ###############
|
|
# conn.execute("""DROP TABLE IF EXISTS song_play_old;""")
|
|
|
|
# conn.execute("""
|
|
# ALTER TABLE
|
|
# song_play
|
|
# RENAME TO
|
|
# song_play_old;
|
|
# """)
|
|
|
|
# conn.execute("""
|
|
# CREATE TABLE song_play (
|
|
# id INTEGER PRIMARY KEY,
|
|
# user_id INTEGER,
|
|
# channel_id INTEGER NOT NULL,
|
|
# search_term TEXT NOT NULL,
|
|
# song_title TEXT,
|
|
# song_artist TEXT,
|
|
# finished BOOLEAN DEFAULT 0,
|
|
# timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
|
|
# );
|
|
# """)
|
|
|
|
# conn.execute("""
|
|
# INSERT INTO song_play (
|
|
# id,
|
|
# user_id,
|
|
# channel_id,
|
|
# search_term,
|
|
# song_title,
|
|
# song_artist,
|
|
# timestamp
|
|
# ) SELECT
|
|
# id,
|
|
# user_id,
|
|
# channel_id,
|
|
# search_term,
|
|
# song_title,
|
|
# song_artist,
|
|
# timestamp
|
|
# FROM song_play_old;
|
|
# """)
|
|
|
|
# conn.execute("""
|
|
# DROP TABLE song_play_old;
|
|
# """)
|
|
# # ##################################
|
|
|
|
conn.commit()
|
|
|
|
def _insert_server(self, discord_id: int = None) -> int:
|
|
"""
|
|
Inserts Discord server ID into the 'server' table.
|
|
|
|
This method takes an ID for a server used in Discord, and inserts it
|
|
into the database. It ignores the case where the server ID is already
|
|
present. It then returns the row ID regardless.
|
|
|
|
Args:
|
|
discord_id (int): The ID used to identify the server in Discord.
|
|
|
|
Returns:
|
|
int: The ID of the server in the server table.
|
|
|
|
Examples:
|
|
>>> db = Database("path.db")
|
|
>>> db._insert_server(850610922256442889)
|
|
12
|
|
"""
|
|
with sqlite3.connect(self.path) as conn:
|
|
cursor = conn.cursor()
|
|
# Insert it; ignoring already exists error
|
|
cursor.execute("""
|
|
INSERT INTO server (discord_id)
|
|
VALUES (?)
|
|
ON CONFLICT(discord_id) DO NOTHING
|
|
RETURNING id;
|
|
""", (discord_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
row_id = row[0]
|
|
else:
|
|
# Get row ID if it already exists and wasn't inserted
|
|
cursor.execute("""
|
|
SELECT id FROM server WHERE discord_id = ?
|
|
""", (discord_id,))
|
|
row_id = cursor.fetchone()[0]
|
|
return row_id
|
|
|
|
def _insert_channel(self, discord_id: int = None) -> int:
|
|
"""
|
|
Inserts Discord channel ID into the 'channel' table.
|
|
|
|
This method takes an ID for a channel used in Discord, and inserts it
|
|
into the database. It ignores the case where the channel ID is already
|
|
present. It then returns the row ID regardless.
|
|
|
|
Args:
|
|
discord_id (int): The ID used to identify the channel in Discord.
|
|
|
|
Returns:
|
|
int: The ID of the channel in the channel table.
|
|
|
|
Examples:
|
|
>>> db = Database("path.db")
|
|
>>> db._insert_channel(8506109222564428891)
|
|
12
|
|
"""
|
|
with sqlite3.connect(self.path) as conn:
|
|
cursor = conn.cursor()
|
|
# Insert it; ignoring already exists error
|
|
cursor.execute("""
|
|
INSERT INTO channel (discord_id)
|
|
VALUES (?)
|
|
ON CONFLICT(discord_id) DO NOTHING
|
|
RETURNING id;
|
|
""", (discord_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
row_id = row[0]
|
|
else:
|
|
# Get row ID if it already exists and wasn't inserted
|
|
cursor.execute("""
|
|
SELECT id FROM channel WHERE discord_id = ?
|
|
""", (discord_id,))
|
|
row_id = cursor.fetchone()[0]
|
|
return row_id
|
|
|
|
def _insert_user(self, discord_id: int = None) -> int:
|
|
"""
|
|
Inserts Discord user ID into the 'user' table.
|
|
|
|
This method takes an ID for a user used in Discord, and inserts it
|
|
into the database. It ignores the case where the user ID is already
|
|
present. It then returns the row ID regardless.
|
|
|
|
Args:
|
|
discord_id (int): The ID used to identify the user in Discord.
|
|
|
|
Returns:
|
|
int: The ID of the user in the user table.
|
|
|
|
Examples:
|
|
>>> db = Database("path.db")
|
|
>>> db._insert_user(850610922256442889)
|
|
12
|
|
"""
|
|
with sqlite3.connect(self.path) as conn:
|
|
cursor = conn.cursor()
|
|
# Insert it; ignoring already exists error
|
|
cursor.execute("""
|
|
INSERT INTO user (discord_id)
|
|
VALUES (?)
|
|
ON CONFLICT(discord_id) DO NOTHING
|
|
RETURNING id;
|
|
""", (discord_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
row_id = row[0]
|
|
else:
|
|
# Get row ID if it already exists and wasn't inserted
|
|
cursor.execute("""
|
|
SELECT id FROM user WHERE discord_id = ?
|
|
""", (discord_id,))
|
|
row_id = cursor.fetchone()[0]
|
|
return row_id
|
|
|
|
def insert_activity_change(
|
|
self,
|
|
before: discord.Member,
|
|
after: discord.Member):
|
|
"""
|
|
Inserts an activity change into the database.
|
|
|
|
This method takes two discord.Memeber objects, and records the change
|
|
in activity into the 'activity_change' table.
|
|
|
|
Args:
|
|
before (discord.Member): The previous user status.
|
|
after (discord.Member): The current user status.
|
|
|
|
Raises:
|
|
ValueError: If the before and after activity do not refer to the
|
|
same user.
|
|
|
|
Examples:
|
|
>>> @commands.Cog.listener()
|
|
>>> async def on_presence_update(
|
|
... self,
|
|
... before: discord.Member,
|
|
... after: discord.Member):
|
|
... db = Database("path.db")
|
|
... db.insert_activity_change(before, after)
|
|
>>>
|
|
"""
|
|
# Ensure the users are the same
|
|
if before.id != after.id:
|
|
raise ValueError("User IDs do not match.")
|
|
user_id = self._insert_user(before.id)
|
|
# Get activities if they exist
|
|
before_type = before.activity.type.name if before.activity else None
|
|
before_name = before.activity.name if before.activity else None
|
|
after_type = after.activity.type.name if after.activity else None
|
|
after_name = after.activity.name if after.activity else None
|
|
# Insert the activity change
|
|
with sqlite3.connect(self.path) as conn:
|
|
conn.execute("""
|
|
INSERT INTO activity_change (
|
|
user_id,
|
|
before_activity_type,
|
|
before_activity_name,
|
|
before_activity_status,
|
|
after_activity_type,
|
|
after_activity_name,
|
|
after_activity_status
|
|
) VALUES (
|
|
?, ?, ?, ?, ?, ?, ?
|
|
)
|
|
""", (
|
|
user_id,
|
|
before_type,
|
|
before_name,
|
|
before.status.name,
|
|
after_type,
|
|
after_name,
|
|
after.status.name
|
|
))
|
|
|
|
def insert_song_request(
|
|
self,
|
|
message: discord.Message,
|
|
source: music_player.YTDLSource):
|
|
"""
|
|
Inserts a song request into the database.
|
|
|
|
This method takes a message and its derived music source and inserts
|
|
the relevant information into the 'song_request' table.
|
|
|
|
Args:
|
|
message (discord.Message): The Discord message requesting the song.
|
|
source (music_player.YTDLSource): The audio source.
|
|
"""
|
|
# Insert the information
|
|
with sqlite3.connect(self.path) as conn:
|
|
conn.execute("""
|
|
INSERT INTO song_request (
|
|
user_id,
|
|
channel_id,
|
|
search_term,
|
|
song_title,
|
|
song_artist
|
|
) VALUES (
|
|
?, ?, ?, ?, ?
|
|
)
|
|
""", (
|
|
self._insert_user(message.author.id),
|
|
self._insert_channel(message.channel.id),
|
|
source.search_term,
|
|
source.song_title,
|
|
source.artist
|
|
))
|
|
|
|
def insert_song_play(
|
|
self,
|
|
channel_id: int,
|
|
source: music_player.YTDLSource):
|
|
"""
|
|
Inserts a song play into the database.
|
|
|
|
This method takes a channel and the song being played and inserts the
|
|
relevant information into the 'song_play' table.
|
|
|
|
Args:
|
|
channel (int): The Discord channel the song is being played in.
|
|
source (music_player.YTDLSource): The audio source.
|
|
|
|
Returns:
|
|
int: The row ID of the entered song. Used to update 'played' value.
|
|
"""
|
|
user_id = self._insert_user(source.requester.id) if source.requester else None
|
|
channel_id = self._insert_user(channel_id)
|
|
# Insert the information
|
|
with sqlite3.connect(self.path) as conn:
|
|
cur = conn.cursor()
|
|
cur.execute("""
|
|
INSERT INTO song_play (
|
|
user_id,
|
|
channel_id,
|
|
search_term,
|
|
song_title,
|
|
song_artist
|
|
) VALUES (
|
|
?, ?, ?, ?, ?
|
|
)
|
|
""", (
|
|
user_id,
|
|
channel_id,
|
|
source.search_term,
|
|
source.song_title,
|
|
source.artist
|
|
))
|
|
return cur.lastrowid
|
|
|
|
def update_song_play(self, song_play_id: int, finished: bool):
|
|
"""
|
|
Updates a song_play entry on whether or not it was finished.
|
|
|
|
When a song plays, we want to know if it was finished or not. This
|
|
implies that either a user didn't want to hear it anymore, or that the
|
|
bot chose the wrong song from the search term.
|
|
|
|
Args:
|
|
song_play_id (int): The row ID within the database for the song
|
|
play.
|
|
finished (bool): Whether or not the song was completed.
|
|
"""
|
|
with sqlite3.connect(self.path) as conn:
|
|
conn.execute("""
|
|
UPDATE
|
|
song_play
|
|
SET
|
|
finished = ?
|
|
WHERE
|
|
id = ?
|
|
""", (finished, song_play_id))
|
|
|
|
def get_activity_stats(
|
|
self,
|
|
member: typing.Union[discord.Member, int],
|
|
start: datetime = datetime.now() - timedelta(days=30)
|
|
) -> dict[str, timedelta]:
|
|
"""
|
|
Gets stats on the activities of the given member.
|
|
|
|
This method searches the database for activity changes by the given
|
|
user and computes the amount of time spent in each activity.
|
|
|
|
Args:
|
|
member (discord.Member): The Discord member to get stats for.
|
|
start (datetime): The earliest activity change to get.
|
|
|
|
Returns:
|
|
dict[str, timedelta]: A dictionary of activity names and
|
|
seconds in each.
|
|
"""
|
|
# Get member Discord ID and convert to DB ID
|
|
member_id = member.id if isinstance(member, discord.Member) else member
|
|
member_id = self._insert_user(member_id)
|
|
# Pull all activities for this user
|
|
with sqlite3.connect(self.path) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
SELECT
|
|
before_activity_name,
|
|
after_activity_name,
|
|
timestamp
|
|
FROM
|
|
activity_change
|
|
WHERE
|
|
user_id = (?) AND
|
|
timestamp > (?)
|
|
""", (member_id, start))
|
|
activities = cursor.fetchall()
|
|
# Collect activities
|
|
activity_stats = {}
|
|
for first, second in zip(activities, activities[1:]):
|
|
if first[1] == second[0]:
|
|
activity_name = first[1]
|
|
activity_time = \
|
|
datetime.fromisoformat(second[2]) - \
|
|
datetime.fromisoformat(first[2])
|
|
if activity_name in activity_stats:
|
|
activity_stats[activity_name] += activity_time
|
|
else:
|
|
activity_stats[activity_name] = activity_time
|
|
if None in activity_stats:
|
|
del activity_stats[None]
|
|
return activity_stats |