base-discord-bot/database.py

555 lines
20 KiB
Python

from datetime import datetime, timedelta
import discord
import logging
import openai
import random
import sqlite3
import typing
from cogs import music_player
logger = logging.getLogger("database")
class Database:
def __init__(self, path: str):
self.path = path
self._ensure_db()
def _ensure_db(self):
with sqlite3.connect(self.path) as conn:
# Table for keeping track of servers
conn.execute("""
CREATE TABLE IF NOT EXISTS server (
id INTEGER PRIMARY KEY,
discord_id INTEGER NOT NULL UNIQUE
)
""")
# Table for keeping track of channels
conn.execute("""
CREATE TABLE IF NOT EXISTS channel (
id INTEGER PRIMARY KEY,
discord_id INTEGER NOT NULL UNIQUE
)
""")
# Table for keeping track of users
conn.execute("""
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY,
discord_id INTEGER NOT NULL UNIQUE
)
""")
# Create the activity table
conn.execute("""
CREATE TABLE IF NOT EXISTS activity_change (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
before_activity_type TEXT,
before_activity_name TEXT,
before_activity_status TEXT NOT NULL,
after_activity_type TEXT,
after_activity_name TEXT,
after_activity_status TEXT NOT NULL,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
# Create the song request table
conn.execute("""
CREATE TABLE IF NOT EXISTS song_request (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
channel_id INTEGER NOT NULL,
search_term TEXT NOT NULL,
song_title TEXT,
song_artist TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
# Table for songs that actually get played
conn.execute("""
CREATE TABLE IF NOT EXISTS song_play (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
channel_id INTEGER NOT NULL,
search_term TEXT NOT NULL,
song_title TEXT,
song_artist TEXT,
finished BOOL DEFAULT 0,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
)
""")
conn.commit()
def _insert_server(self, discord_id: int = None) -> int:
"""
Inserts Discord server ID into the 'server' table.
This method takes an ID for a server used in Discord, and inserts it
into the database. It ignores the case where the server ID is already
present. It then returns the row ID regardless.
Args:
discord_id (int): The ID used to identify the server in Discord.
Returns:
int: The ID of the server in the server table.
Examples:
>>> db = Database("path.db")
>>> db._insert_server(850610922256442889)
12
"""
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Insert it; ignoring already exists error
cursor.execute("""
INSERT INTO server (discord_id)
VALUES (?)
ON CONFLICT(discord_id) DO NOTHING
RETURNING id;
""", (discord_id,))
row = cursor.fetchone()
if row:
row_id = row[0]
else:
# Get row ID if it already exists and wasn't inserted
cursor.execute("""
SELECT id FROM server WHERE discord_id = ?
""", (discord_id,))
row_id = cursor.fetchone()[0]
return row_id
def _insert_channel(self, discord_id: int = None) -> int:
"""
Inserts Discord channel ID into the 'channel' table.
This method takes an ID for a channel used in Discord, and inserts it
into the database. It ignores the case where the channel ID is already
present. It then returns the row ID regardless.
Args:
discord_id (int): The ID used to identify the channel in Discord.
Returns:
int: The ID of the channel in the channel table.
Examples:
>>> db = Database("path.db")
>>> db._insert_channel(8506109222564428891)
12
"""
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Insert it; ignoring already exists error
cursor.execute("""
INSERT INTO channel (discord_id)
VALUES (?)
ON CONFLICT(discord_id) DO NOTHING
RETURNING id;
""", (discord_id,))
row = cursor.fetchone()
if row:
row_id = row[0]
else:
# Get row ID if it already exists and wasn't inserted
cursor.execute("""
SELECT id FROM channel WHERE discord_id = ?
""", (discord_id,))
row_id = cursor.fetchone()[0]
return row_id
def _insert_user(self, discord_id: int = None) -> int:
"""
Inserts Discord user ID into the 'user' table.
This method takes an ID for a user used in Discord, and inserts it
into the database. It ignores the case where the user ID is already
present. It then returns the row ID regardless.
Args:
discord_id (int): The ID used to identify the user in Discord.
Returns:
int: The ID of the user in the user table.
Examples:
>>> db = Database("path.db")
>>> db._insert_user(850610922256442889)
12
"""
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Insert it; ignoring already exists error
cursor.execute("""
INSERT INTO user (discord_id)
VALUES (?)
ON CONFLICT(discord_id) DO NOTHING
RETURNING id;
""", (discord_id,))
row = cursor.fetchone()
if row:
row_id = row[0]
else:
# Get row ID if it already exists and wasn't inserted
cursor.execute("""
SELECT id FROM user WHERE discord_id = ?
""", (discord_id,))
row_id = cursor.fetchone()[0]
return row_id
def insert_activity_change(
self,
before: discord.Member,
after: discord.Member):
"""
Inserts an activity change into the database.
This method takes two discord.Memeber objects, and records the change
in activity into the 'activity_change' table.
Args:
before (discord.Member): The previous user status.
after (discord.Member): The current user status.
Raises:
ValueError: If the before and after activity do not refer to the
same user.
Examples:
>>> @commands.Cog.listener()
>>> async def on_presence_update(
... self,
... before: discord.Member,
... after: discord.Member):
... db = Database("path.db")
... db.insert_activity_change(before, after)
>>>
"""
# Ensure the users are the same
if before.id != after.id:
raise ValueError("User IDs do not match.")
user_id = self._insert_user(before.id)
# Get activities if they exist
before_type = before.activity.type.name if before.activity else None
before_name = before.activity.name if before.activity else None
after_type = after.activity.type.name if after.activity else None
after_name = after.activity.name if after.activity else None
# Insert the activity change
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO activity_change (
user_id,
before_activity_type,
before_activity_name,
before_activity_status,
after_activity_type,
after_activity_name,
after_activity_status
) VALUES (
?, ?, ?, ?, ?, ?, ?
)
""", (
user_id,
before_type,
before_name,
before.status.name,
after_type,
after_name,
after.status.name
))
def insert_song_request(
self,
message: discord.Message,
source: music_player.YTDLSource):
"""
Inserts a song request into the database.
This method takes a message and its derived music source and inserts
the relevant information into the 'song_request' table.
Args:
message (discord.Message): The Discord message requesting the song.
source (music_player.YTDLSource): The audio source.
"""
# Insert the information
with sqlite3.connect(self.path) as conn:
conn.execute("""
INSERT INTO song_request (
user_id,
channel_id,
search_term,
song_title,
song_artist
) VALUES (
?, ?, ?, ?, ?
)
""", (
self._insert_user(message.author.id),
self._insert_channel(message.channel.id),
source.search_term,
source.song_title,
source.artist
))
def insert_song_play(
self,
channel_id: int,
source: music_player.YTDLSource):
"""
Inserts a song play into the database.
This method takes a channel and the song being played and inserts the
relevant information into the 'song_play' table.
Args:
channel (int): The Discord channel the song is being played in.
source (music_player.YTDLSource): The audio source.
Returns:
int: The row ID of the entered song. Used to update 'played' value.
"""
user_id = self._insert_user(source.requester.id) if source.requester else None
channel_id = self._insert_channel(channel_id)
# Insert the information
with sqlite3.connect(self.path) as conn:
cur = conn.cursor()
cur.execute("""
INSERT INTO song_play (
user_id,
channel_id,
search_term,
song_title,
song_artist
) VALUES (
?, ?, ?, ?, ?
)
""", (
user_id,
channel_id,
source.search_term,
source.song_title,
source.artist
))
return cur.lastrowid
def update_song_play(self, song_play_id: int, finished: bool):
"""
Updates a song_play entry on whether or not it was finished.
When a song plays, we want to know if it was finished or not. This
implies that either a user didn't want to hear it anymore, or that the
bot chose the wrong song from the search term.
Args:
song_play_id (int): The row ID within the database for the song
play.
finished (bool): Whether or not the song was completed.
"""
with sqlite3.connect(self.path) as conn:
conn.execute("""
UPDATE
song_play
SET
finished = ?
WHERE
id = ?
""", (finished, song_play_id))
def get_activity_stats(
self,
member: typing.Union[discord.Member, int],
start: datetime = datetime.now() - timedelta(days=30)
) -> dict[str, timedelta]:
"""
Gets stats on the activities of the given member.
This method searches the database for activity changes by the given
user and computes the amount of time spent in each activity.
Args:
member (discord.Member): The Discord member to get stats for.
start (datetime): The earliest activity change to get.
Returns:
dict[str, timedelta]: A dictionary of activity names and
seconds in each.
"""
# Get member Discord ID and convert to DB ID
member_id = member.id if isinstance(member, discord.Member) else member
member_id = self._insert_user(member_id)
# Pull all activities for this user
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT
before_activity_name,
after_activity_name,
timestamp
FROM
activity_change
WHERE
user_id = (?) AND
timestamp > (?)
""", (member_id, start))
activities = cursor.fetchall()
# Collect activities
activity_stats = {}
for first, second in zip(activities, activities[1:]):
if first[1] == second[0]:
activity_name = first[1]
activity_time = \
datetime.fromisoformat(second[2]) - \
datetime.fromisoformat(first[2])
if activity_name in activity_stats:
activity_stats[activity_name] += activity_time
else:
activity_stats[activity_name] = activity_time
if None in activity_stats:
del activity_stats[None]
return activity_stats
def get_next_song(self, users: list[int], channels: list[int], limit: int = 100, cutoff: datetime = None):
_cutoff = datetime.now() - timedelta(hours=1) if not cutoff else cutoff
print("users:", users)
print("channels:", channels)
# Convert user IDs to row IDs
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT
id
FROM
user
WHERE
discord_id IN (%s);""" % ",".join("?" for _ in users),
tuple(users))
user_ids = [row[0] for row in cursor.fetchall()]
cursor.execute("""
SELECT
id
FROM
channel
WHERE
discord_id IN (%s);""" % ",".join("?" for _ in channels),
tuple(channels))
channel_ids = [row[0] for row in cursor.fetchall()]
# Pull song plays from the given channels
logger.info("Getting past song plays")
cursor.execute("""
SELECT
song_title,
song_artist,
COUNT(*) AS count
FROM
song_play
WHERE
user_id IN (%s) AND
channel_id IN (%s) AND
finished = 1 AND
timestamp < ?
GROUP BY
song_title,
song_artist
ORDER BY
count DESC
LIMIT ?;
""" % (
",".join(str(id) for id in user_ids),
",".join(str(id) for id in channel_ids)
), (_cutoff, limit))
old_song_plays = cursor.fetchall()
# Compile results into cleaner list of dicts
candidates = [{"title": t, "artist": a, "plays": p} for t, a, p in old_song_plays]
print("candidates:", candidates)
# Get recent song plays
logger.info("Getting recent song plays")
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Get recent songs to avoid
cursor.execute("""
SELECT
song_title,
song_artist
FROM
song_play
WHERE
channel_id IN (%s) AND
timestamp >= ?
GROUP BY
song_title,
song_artist;
""" % (",".join(str(id) for id in channel_ids)), (_cutoff, ))
recent_song_plays = cursor.fetchall()
print("recent:", recent_song_plays)
# Remove all songs that were recently played
def keep(song_play: dict[str, str, int]):
return not (song_play["title"], song_play["artist"]) in recent_song_plays
candidates = list(filter(keep, candidates))
print("filtered candidates:", candidates)
if len(candidates) > 0:
candidate = random.choice(candidates)
return {"title": candidate["title"], "artist": candidate["artist"]}
# If we have no songs left to play, get a recommendation from ChatGPT
else:
# Get last five or so completed song plays
with sqlite3.connect(self.path) as conn:
cursor = conn.cursor()
# Get recent songs to avoid
cursor.execute("""
SELECT
song_title,
song_artist
FROM
song_play
WHERE
channel_id IN (%s) AND
finished = 1
GROUP BY
song_title,
song_artist
ORDER BY
timestamp DESC
LIMIT 5;
""" % (",".join(str(id) for id in channel_ids)))
last_five = cursor.fetchall()
print("last five song plays:", last_five)
setup_prompt = "I'm going to give you a list of songs and artists "\
"formatted as a Python list of dicts where the "\
"song title is the 'title' key and the artist is "\
"the 'artist' key. I want you to return a song "\
"title and artist that you would recommend based "\
"on the given songs. Don't be afraid to branch out "\
"and vary songs; the same artist should not be "\
"repeated more than twice. You should give me only a bare text "\
"string formatted as a Python dict where the "\
"'title' key is the song title, and the 'artist' "\
"key is the song's artist. Don't add anything other "\
"than this dict."
user_prompt = []
completion = openai.OpenAI().chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": setup_prompt},
{"role": "user", "content": str(last_five)}
]
)
return eval(completion.choices[0].message.content)