Refactor database management and schema initialization

- Removed the old npc_memory.db file.
- Updated time.txt with a new timestamp.
- Refactored transaction recording in bank_functions.py to use parameterized queries.
- Enhanced DatabaseManager in sql_commands.py to support singleton pattern and improved table creation logic.
- Added methods for sanitizing SQL identifiers and parsing insert columns for upsert operations.
- Improved error handling and connection management in execute_query, fetch_one, fetch_all, and fetch_as_dataframe methods.
- Introduced a new bootstrap_database.py script for initializing the database schema.
- Updated app.py to use the new initialize_database function for database management.
This commit is contained in:
2026-05-31 11:16:44 +00:00
parent a6094c2d0c
commit be89cc3acd
11 changed files with 2175 additions and 53 deletions
+6
View File
@@ -0,0 +1,6 @@
from utils.sql_commands import initialize_database
if __name__ == "__main__":
initialize_database()
print("Database initialization complete.")
+2 -1
View File
@@ -4,6 +4,7 @@ from discord.ext import commands
import os
from dotenv import load_dotenv
from web.app import app
from utils.sql_commands import initialize_database
import threading
import itertools
@@ -50,9 +51,9 @@ class Client(commands.Bot):
def main():
load_dotenv()
initialize_database()
client = Client()
token = os.getenv("TOKEN")
print(token)
if token is not None:
threading.Thread(target=run_web, daemon=True).start()
client.run(token)
+2
View File
@@ -4,6 +4,7 @@ from discord.ext import commands
import os
from dotenv import load_dotenv
from web.app import app
from utils.sql_commands import initialize_database
import threading
@@ -58,6 +59,7 @@ class Client(commands.Bot):
def main():
load_dotenv()
initialize_database()
client = Client()
token = os.getenv("TOKEN")
if token is not None:
+1 -1
View File
@@ -63,7 +63,7 @@ class CustomCommandsCog(commands.Cog):
)
self.invalidate_cache(guild_id) # Invalidate cache after change
if deleted_rows is not None and len(deleted_rows) > 0:
if deleted_rows and deleted_rows > 0:
await ctx.send(f"Custom command `{command_name}` has been deleted!")
else:
await ctx.send(f"Custom command `{command_name}` not found.")
+1974
View File
File diff suppressed because it is too large Load Diff
BIN
View File
Binary file not shown.
+1 -1
View File
@@ -1 +1 @@
1759391917.148779
1775047922.8093011
+4 -1
View File
@@ -40,7 +40,10 @@ async def record_transaction(
transaction_type: The type of transaction performed.
amount: The amount of the transaction.
"""
db.insert("transactions", (user_id, transaction_type, amount))
db.execute_query(
"INSERT INTO transactions (USERID, TYPE, AMOUNT) VALUES (%s, %s, %s)",
(user_id, transaction_type, amount),
)
async def update_money(
+171 -35
View File
@@ -3,11 +3,11 @@ from mysql.connector import pooling
from dotenv import load_dotenv
import os
import random
import re
import time
from datetime import datetime, timedelta
import logging
import pandas as pd
from copy import copy
# Configure logging
@@ -20,16 +20,33 @@ logger = logging.getLogger(__name__)
class DatabaseManager:
_instances = {}
def __new__(cls, env="development"):
instance_key = env or "default"
if instance_key in cls._instances:
return cls._instances[instance_key]
instance = super().__new__(cls)
cls._instances[instance_key] = instance
return instance
def __init__(self, env="development"):
# Load environment variables based on environment
self.load_env(env)
if getattr(self, "_initialized", False):
return
self._initialized = True
env_file = f".env.{env}" if env else ".env"
if not os.path.exists(env_file):
env_file = ".env"
self.load_env(env_file)
self.config = {
"host": os.getenv("SQLHOST", "localhost"),
"user": os.getenv("SQLUSER", "root"),
"password": os.getenv("SQLPASS", ""),
"database": os.getenv("SQLDB", "testdb"),
"pool_reset_session": bool(os.getenv("POOL_RESET_SESSION", False)),
"pool_reset_session": os.getenv("POOL_RESET_SESSION", "false").lower()
in ("true", "1", "yes"),
}
self.pool = pooling.MySQLConnectionPool(
@@ -124,26 +141,122 @@ class DatabaseManager:
inactivity INT NOT NULL
""",
)
self.create_table_if_not_exists(
"economy",
"""
ID BIGINT PRIMARY KEY,
WALLET BIGINT NOT NULL DEFAULT 0,
BANK BIGINT NOT NULL DEFAULT 0,
DAILY DOUBLE DEFAULT 0
""",
)
self.create_table_if_not_exists(
"transactions",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
USERID BIGINT NOT NULL,
TYPE VARCHAR(50),
AMOUNT DECIMAL(18,2),
TIME DATETIME DEFAULT CURRENT_TIMESTAMP
""",
)
self.create_table_if_not_exists(
"custom_commands",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
GUILDID VARCHAR(32) NOT NULL,
COMMANDNAME VARCHAR(100) NOT NULL,
RESPONSE TEXT NOT NULL,
MATCHTYPE VARCHAR(20) NOT NULL DEFAULT 'exact'
""",
)
self.create_table_if_not_exists(
"guilds",
"""
GUILD BIGINT PRIMARY KEY,
WELCOME BIGINT DEFAULT NULL,
RULES BIGINT DEFAULT NULL,
GUIDE BIGINT DEFAULT NULL,
INTRODUCTIONS BIGINT DEFAULT NULL,
EVENTS BIGINT DEFAULT NULL,
MEMBERCOUNT BIGINT DEFAULT NULL,
LOGGING BIGINT DEFAULT NULL,
TICKETING BIGINT DEFAULT NULL
""",
)
self.create_table_if_not_exists(
"rewards",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
type VARCHAR(50) NOT NULL,
amount INT NOT NULL DEFAULT 0,
description TEXT DEFAULT NULL
""",
)
self.create_table_if_not_exists(
"logs",
"""
ID INT AUTO_INCREMENT PRIMARY KEY,
guild_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
type VARCHAR(50) NOT NULL,
message TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
""",
)
self.create_table_if_not_exists(
"gamble_limits",
"""
USERID BIGINT PRIMARY KEY,
DAILY_LIMIT BIGINT DEFAULT NULL,
EXCLUDED_UNTIL DATETIME DEFAULT NULL
""",
)
self.create_table_if_not_exists(
"users",
"""
ID BIGINT PRIMARY KEY,
XP INT DEFAULT 0,
LEVEL INT DEFAULT 0,
birthday VARCHAR(10) DEFAULT NULL
""",
)
def load_env(self, env):
env_file = f".env"
def load_env(self, env_file):
load_dotenv(env_file)
logger.info(f"Loaded environment variables from {env_file}")
def get_connection(self):
return self.pool.get_connection()
def _sanitize_identifier(self, identifier: str) -> str:
if not re.match(r"^[A-Za-z0-9_]+$", identifier):
raise ValueError(f"Invalid SQL identifier: {identifier}")
return identifier
def _parse_insert_columns(self, query: str) -> list[str]:
match = re.search(
r"INSERT\s+INTO\s+\S+\s*\(([^)]+)\)\s*VALUES",
query,
re.IGNORECASE,
)
if not match:
raise ValueError(
"Insert query must contain a column list for overwrite upsert support."
)
return [col.strip() for col in match.group(1).split(",") if col.strip()]
def execute_query(self, query, params=None, retries=3, delay=1):
cursor = None
connection = None
cursor = None
for attempt in range(retries):
try:
connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ())
logger.info(f"Executed query: {query} with params: {params}")
connection.commit()
return copy(cursor)
logger.info(f"Executed query: {query} with params: {params}")
return cursor.rowcount
except mysql.connector.Error as err:
logger.warning(f"Attempt {attempt + 1} failed: {err}")
time.sleep(delay * (2**attempt))
@@ -163,7 +276,7 @@ class DatabaseManager:
Args:
query (str): The SQL query to execute.
params (tuple): The parameters to pass into the query.
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to False.
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to True.
Raises:
ValueError: If no parameters are provided.
@@ -171,37 +284,29 @@ class DatabaseManager:
if not params:
raise ValueError("Params must be provided for the insert operation.")
try:
if overwrite:
columns = [
col.split("=")[0].strip()
for col in query.split("VALUES")[0]
.split("(")[1]
.split(")")[0]
.split(",")
]
columns = self._parse_insert_columns(query)
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
cursor = self.execute_query(query, params)
if cursor:
rowcount = self.execute_query(query, params)
if rowcount is None:
logger.error(f"Insert failed with query: {query}.")
else:
logger.info(f"Insert completed with query: {query}.")
except mysql.connector.Error as err:
logger.error(f"Insert failed with query: {query}. Error: {err}")
def bulk_insert(self, query, params=None):
if not params:
logger.warning("No data provided for bulk insert.")
return
# Assuming params is a list of dictionaries
if not isinstance(params, list) or not all(isinstance(d, dict) for d in params):
raise ValueError("Params must be a list of dictionaries for bulk insert.")
keys = params[0].keys()
keys = list(params[0].keys())
placeholders = ", ".join(["%s"] * len(keys))
query = f"{query} ({', '.join(keys)}) VALUES ({placeholders})"
values = [tuple(data.values()) for data in params]
values = [tuple(data[key] for key in keys) for data in params]
connection = None
cursor = None
@@ -216,7 +321,7 @@ class DatabaseManager:
except mysql.connector.Error as err:
logger.error(f"Bulk insert failed: {err}")
if connection:
connection.rollback() # Roll back on error
connection.rollback()
finally:
if cursor:
cursor.close()
@@ -225,34 +330,60 @@ class DatabaseManager:
def delete(self, table_name: str, condition: dict) -> None:
"""Deletes a record from the specified table based on the condition provided."""
table_name = self._sanitize_identifier(table_name)
condition_column, condition_value = next(iter(condition.items()))
condition_column = self._sanitize_identifier(condition_column)
query = f"DELETE FROM {table_name} WHERE {condition_column} = %s"
self.execute_query(query, (condition_value,))
def fetch_one(self, query, params=None):
cursor = self.execute_query(query, params)
return cursor.fetchone() if cursor else {}
connection = None
cursor = None
try:
connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ())
return cursor.fetchone()
finally:
if cursor:
cursor.close()
if connection:
connection.close()
def fetch_all(self, query, params=None):
cursor = self.execute_query(query, params)
return cursor.fetchall() if cursor else []
connection = None
cursor = None
try:
connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ())
return cursor.fetchall()
finally:
if cursor:
cursor.close()
if connection:
connection.close()
def fetch_as_dataframe(self, query, params=None):
cursor = self.execute_query(query, params)
if cursor:
connection = None
cursor = None
try:
# Ensure cursor has a result to fetch
connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ())
if cursor.with_rows:
results = cursor.fetchall()
return pd.DataFrame(results) if results else pd.DataFrame()
else:
logger.warning("No result set to fetch from.")
return pd.DataFrame()
finally:
if cursor:
cursor.close()
return pd.DataFrame()
if connection:
connection.close()
def create_table_if_not_exists(self, table_name, schema):
table_name = self._sanitize_identifier(table_name)
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})"
self.execute_query(query)
logger.info(f"Ensured table {table_name} exists with schema: {schema}")
@@ -275,6 +406,11 @@ class DatabaseManager:
return f"{base_query} WHERE {conditions}", list(filters.values())
def initialize_database(env="development"):
"""Initialize the database schema and return a shared DatabaseManager."""
return DatabaseManager(env)
# SQL scripts to create tables
create_feedback_table = """
CREATE TABLE IF NOT EXISTS feedback (
Binary file not shown.
+3 -3
View File
@@ -4,10 +4,10 @@ import os
if __package__ is None or __package__ == "":
# Running as __main__ (e.g. python web/app.py)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from utils.sql_commands import DatabaseManager
from utils.sql_commands import DatabaseManager, initialize_database
else:
# Imported as a module (e.g. from bot.py)
from utils.sql_commands import DatabaseManager
from utils.sql_commands import DatabaseManager, initialize_database
import os
import requests
@@ -21,7 +21,7 @@ load_dotenv()
app = Flask(__name__)
app.secret_key = os.getenv("SECRET_KEY")
db = DatabaseManager()
db = initialize_database()
# Ensure required environment variables are loaded
DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID")