Refactor database management and schema initialization

- Removed the old npc_memory.db file.
- Updated time.txt with a new timestamp.
- Refactored transaction recording in bank_functions.py to use parameterized queries.
- Enhanced DatabaseManager in sql_commands.py to support singleton pattern and improved table creation logic.
- Added methods for sanitizing SQL identifiers and parsing insert columns for upsert operations.
- Improved error handling and connection management in execute_query, fetch_one, fetch_all, and fetch_as_dataframe methods.
- Introduced a new bootstrap_database.py script for initializing the database schema.
- Updated app.py to use the new initialize_database function for database management.
This commit is contained in:
2026-05-31 11:16:44 +00:00
parent a6094c2d0c
commit be89cc3acd
11 changed files with 2175 additions and 53 deletions
+6
View File
@@ -0,0 +1,6 @@
from utils.sql_commands import initialize_database
if __name__ == "__main__":
initialize_database()
print("Database initialization complete.")
+2 -1
View File
@@ -4,6 +4,7 @@ from discord.ext import commands
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from web.app import app from web.app import app
from utils.sql_commands import initialize_database
import threading import threading
import itertools import itertools
@@ -50,9 +51,9 @@ class Client(commands.Bot):
def main(): def main():
load_dotenv() load_dotenv()
initialize_database()
client = Client() client = Client()
token = os.getenv("TOKEN") token = os.getenv("TOKEN")
print(token)
if token is not None: if token is not None:
threading.Thread(target=run_web, daemon=True).start() threading.Thread(target=run_web, daemon=True).start()
client.run(token) client.run(token)
+2
View File
@@ -4,6 +4,7 @@ from discord.ext import commands
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from web.app import app from web.app import app
from utils.sql_commands import initialize_database
import threading import threading
@@ -58,6 +59,7 @@ class Client(commands.Bot):
def main(): def main():
load_dotenv() load_dotenv()
initialize_database()
client = Client() client = Client()
token = os.getenv("TOKEN") token = os.getenv("TOKEN")
if token is not None: if token is not None:
+1 -1
View File
@@ -63,7 +63,7 @@ class CustomCommandsCog(commands.Cog):
) )
self.invalidate_cache(guild_id) # Invalidate cache after change self.invalidate_cache(guild_id) # Invalidate cache after change
if deleted_rows is not None and len(deleted_rows) > 0: if deleted_rows and deleted_rows > 0:
await ctx.send(f"Custom command `{command_name}` has been deleted!") await ctx.send(f"Custom command `{command_name}` has been deleted!")
else: else:
await ctx.send(f"Custom command `{command_name}` not found.") await ctx.send(f"Custom command `{command_name}` not found.")
+1974
View File
File diff suppressed because it is too large Load Diff
BIN
View File
Binary file not shown.
+1 -1
View File
@@ -1 +1 @@
1759391917.148779 1775047922.8093011
+4 -1
View File
@@ -40,7 +40,10 @@ async def record_transaction(
transaction_type: The type of transaction performed. transaction_type: The type of transaction performed.
amount: The amount of the transaction. amount: The amount of the transaction.
""" """
db.insert("transactions", (user_id, transaction_type, amount)) db.execute_query(
"INSERT INTO transactions (USERID, TYPE, AMOUNT) VALUES (%s, %s, %s)",
(user_id, transaction_type, amount),
)
async def update_money( async def update_money(
+182 -46
View File
@@ -3,11 +3,11 @@ from mysql.connector import pooling
from dotenv import load_dotenv from dotenv import load_dotenv
import os import os
import random import random
import re
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
import logging import logging
import pandas as pd import pandas as pd
from copy import copy
# Configure logging # Configure logging
@@ -20,16 +20,33 @@ logger = logging.getLogger(__name__)
class DatabaseManager: class DatabaseManager:
_instances = {}
def __new__(cls, env="development"):
instance_key = env or "default"
if instance_key in cls._instances:
return cls._instances[instance_key]
instance = super().__new__(cls)
cls._instances[instance_key] = instance
return instance
def __init__(self, env="development"): def __init__(self, env="development"):
# Load environment variables based on environment if getattr(self, "_initialized", False):
self.load_env(env) return
self._initialized = True
env_file = f".env.{env}" if env else ".env"
if not os.path.exists(env_file):
env_file = ".env"
self.load_env(env_file)
self.config = { self.config = {
"host": os.getenv("SQLHOST", "localhost"), "host": os.getenv("SQLHOST", "localhost"),
"user": os.getenv("SQLUSER", "root"), "user": os.getenv("SQLUSER", "root"),
"password": os.getenv("SQLPASS", ""), "password": os.getenv("SQLPASS", ""),
"database": os.getenv("SQLDB", "testdb"), "database": os.getenv("SQLDB", "testdb"),
"pool_reset_session": bool(os.getenv("POOL_RESET_SESSION", False)), "pool_reset_session": os.getenv("POOL_RESET_SESSION", "false").lower()
in ("true", "1", "yes"),
} }
self.pool = pooling.MySQLConnectionPool( self.pool = pooling.MySQLConnectionPool(
@@ -124,26 +141,122 @@ class DatabaseManager:
inactivity INT NOT NULL inactivity INT NOT NULL
""", """,
) )
self.create_table_if_not_exists(
"economy",
"""
ID BIGINT PRIMARY KEY,
WALLET BIGINT NOT NULL DEFAULT 0,
BANK BIGINT NOT NULL DEFAULT 0,
DAILY DOUBLE DEFAULT 0
""",
)
self.create_table_if_not_exists(
"transactions",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
USERID BIGINT NOT NULL,
TYPE VARCHAR(50),
AMOUNT DECIMAL(18,2),
TIME DATETIME DEFAULT CURRENT_TIMESTAMP
""",
)
self.create_table_if_not_exists(
"custom_commands",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
GUILDID VARCHAR(32) NOT NULL,
COMMANDNAME VARCHAR(100) NOT NULL,
RESPONSE TEXT NOT NULL,
MATCHTYPE VARCHAR(20) NOT NULL DEFAULT 'exact'
""",
)
self.create_table_if_not_exists(
"guilds",
"""
GUILD BIGINT PRIMARY KEY,
WELCOME BIGINT DEFAULT NULL,
RULES BIGINT DEFAULT NULL,
GUIDE BIGINT DEFAULT NULL,
INTRODUCTIONS BIGINT DEFAULT NULL,
EVENTS BIGINT DEFAULT NULL,
MEMBERCOUNT BIGINT DEFAULT NULL,
LOGGING BIGINT DEFAULT NULL,
TICKETING BIGINT DEFAULT NULL
""",
)
self.create_table_if_not_exists(
"rewards",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
type VARCHAR(50) NOT NULL,
amount INT NOT NULL DEFAULT 0,
description TEXT DEFAULT NULL
""",
)
self.create_table_if_not_exists(
"logs",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
guild_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
type VARCHAR(50) NOT NULL,
message TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
""",
)
self.create_table_if_not_exists(
"gamble_limits",
"""
USERID BIGINT PRIMARY KEY,
DAILY_LIMIT BIGINT DEFAULT NULL,
EXCLUDED_UNTIL DATETIME DEFAULT NULL
""",
)
self.create_table_if_not_exists(
"users",
"""
ID BIGINT PRIMARY KEY,
XP INT DEFAULT 0,
LEVEL INT DEFAULT 0,
birthday VARCHAR(10) DEFAULT NULL
""",
)
def load_env(self, env): def load_env(self, env_file):
env_file = f".env"
load_dotenv(env_file) load_dotenv(env_file)
logger.info(f"Loaded environment variables from {env_file}") logger.info(f"Loaded environment variables from {env_file}")
def get_connection(self): def get_connection(self):
return self.pool.get_connection() return self.pool.get_connection()
def _sanitize_identifier(self, identifier: str) -> str:
if not re.match(r"^[A-Za-z0-9_]+$", identifier):
raise ValueError(f"Invalid SQL identifier: {identifier}")
return identifier
def _parse_insert_columns(self, query: str) -> list[str]:
match = re.search(
r"INSERT\s+INTO\s+\S+\s*\(([^)]+)\)\s*VALUES",
query,
re.IGNORECASE,
)
if not match:
raise ValueError(
"Insert query must contain a column list for overwrite upsert support."
)
return [col.strip() for col in match.group(1).split(",") if col.strip()]
def execute_query(self, query, params=None, retries=3, delay=1): def execute_query(self, query, params=None, retries=3, delay=1):
cursor = None
connection = None connection = None
cursor = None
for attempt in range(retries): for attempt in range(retries):
try: try:
connection = self.get_connection() connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True) cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ()) cursor.execute(query, params or ())
logger.info(f"Executed query: {query} with params: {params}")
connection.commit() connection.commit()
return copy(cursor) logger.info(f"Executed query: {query} with params: {params}")
return cursor.rowcount
except mysql.connector.Error as err: except mysql.connector.Error as err:
logger.warning(f"Attempt {attempt + 1} failed: {err}") logger.warning(f"Attempt {attempt + 1} failed: {err}")
time.sleep(delay * (2**attempt)) time.sleep(delay * (2**attempt))
@@ -163,7 +276,7 @@ class DatabaseManager:
Args: Args:
query (str): The SQL query to execute. query (str): The SQL query to execute.
params (tuple): The parameters to pass into the query. params (tuple): The parameters to pass into the query.
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to False. overwrite (bool, optional): Whether to perform an upsert operation. Defaults to True.
Raises: Raises:
ValueError: If no parameters are provided. ValueError: If no parameters are provided.
@@ -171,37 +284,29 @@ class DatabaseManager:
if not params: if not params:
raise ValueError("Params must be provided for the insert operation.") raise ValueError("Params must be provided for the insert operation.")
try: if overwrite:
if overwrite: columns = self._parse_insert_columns(query)
columns = [ update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
col.split("=")[0].strip() query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
for col in query.split("VALUES")[0]
.split("(")[1]
.split(")")[0]
.split(",")
]
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
cursor = self.execute_query(query, params) rowcount = self.execute_query(query, params)
if cursor: if rowcount is None:
logger.info(f"Insert completed with query: {query}.") logger.error(f"Insert failed with query: {query}.")
except mysql.connector.Error as err: else:
logger.error(f"Insert failed with query: {query}. Error: {err}") logger.info(f"Insert completed with query: {query}.")
def bulk_insert(self, query, params=None): def bulk_insert(self, query, params=None):
if not params: if not params:
logger.warning("No data provided for bulk insert.") logger.warning("No data provided for bulk insert.")
return return
# Assuming params is a list of dictionaries
if not isinstance(params, list) or not all(isinstance(d, dict) for d in params): if not isinstance(params, list) or not all(isinstance(d, dict) for d in params):
raise ValueError("Params must be a list of dictionaries for bulk insert.") raise ValueError("Params must be a list of dictionaries for bulk insert.")
keys = params[0].keys() keys = list(params[0].keys())
placeholders = ", ".join(["%s"] * len(keys)) placeholders = ", ".join(["%s"] * len(keys))
query = f"{query} ({', '.join(keys)}) VALUES ({placeholders})" query = f"{query} ({', '.join(keys)}) VALUES ({placeholders})"
values = [tuple(data.values()) for data in params] values = [tuple(data[key] for key in keys) for data in params]
connection = None connection = None
cursor = None cursor = None
@@ -216,7 +321,7 @@ class DatabaseManager:
except mysql.connector.Error as err: except mysql.connector.Error as err:
logger.error(f"Bulk insert failed: {err}") logger.error(f"Bulk insert failed: {err}")
if connection: if connection:
connection.rollback() # Roll back on error connection.rollback()
finally: finally:
if cursor: if cursor:
cursor.close() cursor.close()
@@ -225,34 +330,60 @@ class DatabaseManager:
def delete(self, table_name: str, condition: dict) -> None: def delete(self, table_name: str, condition: dict) -> None:
"""Deletes a record from the specified table based on the condition provided.""" """Deletes a record from the specified table based on the condition provided."""
table_name = self._sanitize_identifier(table_name)
condition_column, condition_value = next(iter(condition.items())) condition_column, condition_value = next(iter(condition.items()))
condition_column = self._sanitize_identifier(condition_column)
query = f"DELETE FROM {table_name} WHERE {condition_column} = %s" query = f"DELETE FROM {table_name} WHERE {condition_column} = %s"
self.execute_query(query, (condition_value,)) self.execute_query(query, (condition_value,))
def fetch_one(self, query, params=None): def fetch_one(self, query, params=None):
cursor = self.execute_query(query, params) connection = None
return cursor.fetchone() if cursor else {} cursor = None
try:
connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ())
return cursor.fetchone()
finally:
if cursor:
cursor.close()
if connection:
connection.close()
def fetch_all(self, query, params=None): def fetch_all(self, query, params=None):
cursor = self.execute_query(query, params) connection = None
return cursor.fetchall() if cursor else [] cursor = None
try:
connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ())
return cursor.fetchall()
finally:
if cursor:
cursor.close()
if connection:
connection.close()
def fetch_as_dataframe(self, query, params=None): def fetch_as_dataframe(self, query, params=None):
cursor = self.execute_query(query, params) connection = None
if cursor: cursor = None
try: try:
# Ensure cursor has a result to fetch connection = self.get_connection()
if cursor.with_rows: cursor = connection.cursor(dictionary=True, buffered=True)
results = cursor.fetchall() cursor.execute(query, params or ())
return pd.DataFrame(results) if results else pd.DataFrame() if cursor.with_rows:
else: results = cursor.fetchall()
logger.warning("No result set to fetch from.") return pd.DataFrame(results) if results else pd.DataFrame()
return pd.DataFrame() logger.warning("No result set to fetch from.")
finally: return pd.DataFrame()
finally:
if cursor:
cursor.close() cursor.close()
return pd.DataFrame() if connection:
connection.close()
def create_table_if_not_exists(self, table_name, schema): def create_table_if_not_exists(self, table_name, schema):
table_name = self._sanitize_identifier(table_name)
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})" query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})"
self.execute_query(query) self.execute_query(query)
logger.info(f"Ensured table {table_name} exists with schema: {schema}") logger.info(f"Ensured table {table_name} exists with schema: {schema}")
@@ -275,6 +406,11 @@ class DatabaseManager:
return f"{base_query} WHERE {conditions}", list(filters.values()) return f"{base_query} WHERE {conditions}", list(filters.values())
def initialize_database(env="development"):
"""Initialize the database schema and return a shared DatabaseManager."""
return DatabaseManager(env)
# SQL scripts to create tables # SQL scripts to create tables
create_feedback_table = """ create_feedback_table = """
CREATE TABLE IF NOT EXISTS feedback ( CREATE TABLE IF NOT EXISTS feedback (
Binary file not shown.
+3 -3
View File
@@ -4,10 +4,10 @@ import os
if __package__ is None or __package__ == "": if __package__ is None or __package__ == "":
# Running as __main__ (e.g. python web/app.py) # Running as __main__ (e.g. python web/app.py)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils.sql_commands import DatabaseManager from utils.sql_commands import DatabaseManager, initialize_database
else: else:
# Imported as a module (e.g. from bot.py) # Imported as a module (e.g. from bot.py)
from utils.sql_commands import DatabaseManager from utils.sql_commands import DatabaseManager, initialize_database
import os import os
import requests import requests
@@ -21,7 +21,7 @@ load_dotenv()
app = Flask(__name__) app = Flask(__name__)
app.secret_key = os.getenv("SECRET_KEY") app.secret_key = os.getenv("SECRET_KEY")
db = DatabaseManager() db = initialize_database()
# Ensure required environment variables are loaded # Ensure required environment variables are loaded
DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID") DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID")