base-discord-bot/database.py

444 lines
16 KiB
Python

from datetime import datetime, timedelta
import discord
import sqlite3
import typing
from tqdm import tqdm
from cogs import music_player
class Database:
def __init__(self, path: str = "boywife_bot.db"):
self.path = path
self._ensure_db()
# # TEMP: THIS IS FOR MIGRATING THE PREVIOUS DATABASE SCHEMA
# with sqlite3.connect("ingest.db") as conn:
# cursor = conn.cursor()
# cursor.execute("""
# SELECT
# user_id,
# before_activity_type,
# before_activity_name,
# after_activity_type,
# after_activity_name,
# timestamp
# FROM
# user_activity
# """)
# activities = cursor.fetchall()
# for activity in tqdm(activities):
# # Convert activity type
# if activity[1] == "activity":
# new_before_type = "playing"
# elif activity[1] == "spotify":
# new_before_type = "listening"
# elif activity[1] == "game":
# new_before_type = "playing"
# else:
# new_before_type = None
# if activity[3] == "activity":
# new_after_type = "playing"
# elif activity[3] == "spotify":
# new_after_type = "listening"
# elif activity[3] == "game":
# new_after_type = "playing"
# else:
# new_after_type = None
# new_before_name = None if activity[2] == "other" else activity[2]
# new_after_name = None if activity[4] == "other" else activity[4]
# user_id = self._insert_user(activity[0])
# with sqlite3.connect(self.path) as conn:
# cursor = conn.cursor()
# cursor.execute("""
# INSERT INTO activity_change (
# user_id,
# before_activity_type,
# before_activity_name,
# before_activity_status,
# after_activity_type,
# after_activity_name,
# after_activity_status,
# timestamp
# ) VALUES (
# ?, ?, ?, ?, ?, ?, ?, ?
# )
# """, (
# user_id,
# new_before_type,
# new_before_name,
# "unknown",
# new_after_type,
# new_after_name,
# "unknown",
# datetime.datetime.fromisoformat(activity[5])
# ))
def _ensure_db(self):
with sqlite3.connect(self.path) as conn:
# Table for keeping track of servers
conn.execute("""
CREATE TABLE IF NOT EXISTS server (
id INTEGER PRIMARY KEY,
discord_id INTEGER NOT NULL UNIQUE
)
""")
# Table for keeping track of channels
conn.execute("""
CREATE TABLE IF NOT EXISTS channel (
id INTEGER PRIMARY KEY,
discord_id INTEGER NOT NULL UNIQUE
)
""")
# Table for keeping track of users
conn.execute("""
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY,
discord_id INTEGER NOT NULL UNIQUE
)
""")
# Create the activity table
conn.execute("""
CREATE TABLE IF NOT EXISTS activity_change (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
before_activity_type TEXT,
before_activity_name TEXT,
before_activity_status TEXT NOT NULL,
after_activity_type TEXT,
after_activity_name TEXT,
after_activity_status TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
# Create the song request table
conn.execute("""
CREATE TABLE IF NOT EXISTS song_request (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
search_term TEXT NOT NULL,
song_title TEXT,
song_artist TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
# Table for songs that actually get played
conn.execute("""
CREATE TABLE IF NOT EXISTS song_play (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
search_term TEXT NOT NULL,
song_title TEXT,
song_artist TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
conn.commit()
def _insert_server(self, discord_id: int = None) -> int:
"""Inserts Discord server ID into the 'server' table.
This method takes an ID for a server used in Discord, and inserts it
into the database. It ignores the case where the server ID is already
present. It then returns the row ID regardless.
Args:
discord_id (int): The ID used to identify the server in Discord.
Returns:
int: The ID of the server in the server table.
Examples:
>>> db = Database("path.db")
>>> db._insert_server(850610922256442889)
12
"""
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Insert it; ignoring already exists error
cursor.execute("""
INSERT INTO server (discord_id)
VALUES (?)
ON CONFLICT(discord_id) DO NOTHING
RETURNING id;
""", (discord_id,))
row = cursor.fetchone()
if row:
row_id = row[0]
else:
# Get row ID if it already exists and wasn't inserted
cursor.execute("""
SELECT id FROM server WHERE discord_id = ?
""", (discord_id,))
row_id = cursor.fetchone()[0]
return row_id
def _insert_channel(self, discord_id: int = None) -> int:
"""Inserts Discord channel ID into the 'channel' table.
This method takes an ID for a channel used in Discord, and inserts it
into the database. It ignores the case where the channel ID is already
present. It then returns the row ID regardless.
Args:
discord_id (int): The ID used to identify the channel in Discord.
Returns:
int: The ID of the channel in the channel table.
Examples:
>>> db = Database("path.db")
>>> db._insert_channel(8506109222564428891)
12
"""
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Insert it; ignoring already exists error
cursor.execute("""
INSERT INTO channel (discord_id)
VALUES (?)
ON CONFLICT(discord_id) DO NOTHING
RETURNING id;
""", (discord_id,))
row = cursor.fetchone()
if row:
row_id = row[0]
else:
# Get row ID if it already exists and wasn't inserted
cursor.execute("""
SELECT id FROM channel WHERE discord_id = ?
""", (discord_id,))
row_id = cursor.fetchone()[0]
return row_id
def _insert_user(self, discord_id: int = None) -> int:
"""Inserts Discord user ID into the 'user' table.
This method takes an ID for a user used in Discord, and inserts it
into the database. It ignores the case where the user ID is already
present. It then returns the row ID regardless.
Args:
discord_id (int): The ID used to identify the user in Discord.
Returns:
int: The ID of the user in the user table.
Examples:
>>> db = Database("path.db")
>>> db._insert_user(850610922256442889)
12
"""
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Insert it; ignoring already exists error
cursor.execute("""
INSERT INTO user (discord_id)
VALUES (?)
ON CONFLICT(discord_id) DO NOTHING
RETURNING id;
""", (discord_id,))
row = cursor.fetchone()
if row:
row_id = row[0]
else:
# Get row ID if it already exists and wasn't inserted
cursor.execute("""
SELECT id FROM user WHERE discord_id = ?
""", (discord_id,))
row_id = cursor.fetchone()[0]
return row_id
def insert_activity_change(
self,
before: discord.Member,
after: discord.Member):
"""Inserts an activity change into the database.
This method takes two discord.Memeber objects, and records the change
in activity into the 'activity_change' table.
Args:
before (discord.Member): The previous user status.
after (discord.Member): The current user status.
Raises:
ValueError: If the before and after activity do not refer to the
same user.
Examples:
>>> @commands.Cog.listener()
>>> async def on_presence_update(
... self,
... before: discord.Member,
... after: discord.Member):
... db = Database("path.db")
... db.insert_activity_change(before, after)
>>>
"""
# Ensure the users are the same
if before.id != after.id:
raise ValueError("User IDs do not match.")
user_id = self._insert_user(before.id)
# Get activities if they exist
before_type = before.activity.type.name if before.activity else None
before_name = before.activity.name if before.activity else None
after_type = after.activity.type.name if after.activity else None
after_name = after.activity.name if after.activity else None
# Insert the activity change
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO activity_change (
user_id,
before_activity_type,
before_activity_name,
before_activity_status,
after_activity_type,
after_activity_name,
after_activity_status
) VALUES (
?, ?, ?, ?, ?, ?, ?
)
""", (
user_id,
before_type,
before_name,
before.status.name,
after_type,
after_name,
after.status.name
))
def insert_song_request(
self,
message: discord.Message,
source: music_player.YTDLSource):
"""Inserts a song request into the database.
This method takes a message and its derived music source and inserts
the relevant information into the 'song_request' table.
Args:
message (discord.Message): The Discord message requesting the song.
source (music_player.YTDLSource): The audio source.
"""
# Insert the information
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO song_request (
user_id,
channel_id,
search_term,
song_title,
song_artist
) VALUES (
?, ?, ?, ?, ?
)
""", (
self._insert_user(message.author.id),
self._insert_channel(message.channel.id),
source.search_term,
source.song_title,
source.artist
))
def insert_song_play(
self,
channel_id: int,
source: music_player.YTDLSource):
"""Inserts a song play into the database.
This method takes a channel and the song being played and inserts the
relevant information into the 'song_play' table.
Args:
channel (int): The Discord channel the song is being played in.
source (music_player.YTDLSource): The audio source.
"""
user_id = self._insert_user(source.requester.id)
channel_id = self._insert_user(channel.id)
# Insert the information
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO song_play (
user_id,
channel_id,
search_term,
song_title,
song_artist
) VALUES (
?, ?, ?, ?, ?
)
""", (
user_id,
channel_id,
source.search_term,
source.song_title,
source.artist
))
def get_activity_stats(
self,
member: typing.Union[discord.Member, int],
start: datetime = datetime.now() - timedelta(days=30)
)-> dict[str, timedelta]:
"""Gets stats on the activities of the given member.
This method searches the database for activity changes by the given
user and computes the amount of time spent in each activity.
Args:
member (discord.Member): The Discord member to get stats for.
start (datetime): The earliest activity change to get.
Returns:
dict[str, timedelta]: A dictionary of activity names and
seconds in each.
"""
# Get member Discord ID and convert to DB ID
member_id = member.id if isinstance(member, discord.Member) else member
member_id = self._insert_user(member_id)
# Pull all activities for this user
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT
before_activity_name,
after_activity_name,
timestamp
FROM
activity_change
WHERE
user_id = (?) AND
timestamp > (?)
""", (member_id, start))
activities = cursor.fetchall()
# Collect activities
activity_stats = {}
for first, second in zip(activities, activities[1:]):
if first[1] == second[0]:
activity_name = first[1]
activity_time = \
datetime.fromisoformat(second[2]) - \
datetime.fromisoformat(first[2])
if activity_name in activity_stats:
activity_stats[activity_name] += activity_time
else:
activity_stats[activity_name] = activity_time
if None in activity_stats:
del activity_stats[None]
return activity_stats