Refactor database management and schema initialization
- Removed the old npc_memory.db file. - Updated time.txt with a new timestamp. - Refactored transaction recording in bank_functions.py to use parameterized queries. - Enhanced DatabaseManager in sql_commands.py to support singleton pattern and improved table creation logic. - Added methods for sanitizing SQL identifiers and parsing insert columns for upsert operations. - Improved error handling and connection management in execute_query, fetch_one, fetch_all, and fetch_as_dataframe methods. - Introduced a new bootstrap_database.py script for initializing the database schema. - Updated app.py to use the new initialize_database function for database management.
This commit is contained in:
+182
-46
@@ -3,11 +3,11 @@ from mysql.connector import pooling
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
import pandas as pd
|
||||
from copy import copy
|
||||
|
||||
|
||||
# Configure logging
|
||||
@@ -20,16 +20,33 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
_instances = {}
|
||||
|
||||
def __new__(cls, env="development"):
|
||||
instance_key = env or "default"
|
||||
if instance_key in cls._instances:
|
||||
return cls._instances[instance_key]
|
||||
instance = super().__new__(cls)
|
||||
cls._instances[instance_key] = instance
|
||||
return instance
|
||||
|
||||
def __init__(self, env="development"):
|
||||
# Load environment variables based on environment
|
||||
self.load_env(env)
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
env_file = f".env.{env}" if env else ".env"
|
||||
if not os.path.exists(env_file):
|
||||
env_file = ".env"
|
||||
self.load_env(env_file)
|
||||
|
||||
self.config = {
|
||||
"host": os.getenv("SQLHOST", "localhost"),
|
||||
"user": os.getenv("SQLUSER", "root"),
|
||||
"password": os.getenv("SQLPASS", ""),
|
||||
"database": os.getenv("SQLDB", "testdb"),
|
||||
"pool_reset_session": bool(os.getenv("POOL_RESET_SESSION", False)),
|
||||
"pool_reset_session": os.getenv("POOL_RESET_SESSION", "false").lower()
|
||||
in ("true", "1", "yes"),
|
||||
}
|
||||
|
||||
self.pool = pooling.MySQLConnectionPool(
|
||||
@@ -124,26 +141,122 @@ class DatabaseManager:
|
||||
inactivity INT NOT NULL
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"economy",
|
||||
"""
|
||||
ID BIGINT PRIMARY KEY,
|
||||
WALLET BIGINT NOT NULL DEFAULT 0,
|
||||
BANK BIGINT NOT NULL DEFAULT 0,
|
||||
DAILY DOUBLE DEFAULT 0
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"transactions",
|
||||
"""
|
||||
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||
USERID BIGINT NOT NULL,
|
||||
TYPE VARCHAR(50),
|
||||
AMOUNT DECIMAL(18,2),
|
||||
TIME DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"custom_commands",
|
||||
"""
|
||||
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||
GUILDID VARCHAR(32) NOT NULL,
|
||||
COMMANDNAME VARCHAR(100) NOT NULL,
|
||||
RESPONSE TEXT NOT NULL,
|
||||
MATCHTYPE VARCHAR(20) NOT NULL DEFAULT 'exact'
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"guilds",
|
||||
"""
|
||||
GUILD BIGINT PRIMARY KEY,
|
||||
WELCOME BIGINT DEFAULT NULL,
|
||||
RULES BIGINT DEFAULT NULL,
|
||||
GUIDE BIGINT DEFAULT NULL,
|
||||
INTRODUCTIONS BIGINT DEFAULT NULL,
|
||||
EVENTS BIGINT DEFAULT NULL,
|
||||
MEMBERCOUNT BIGINT DEFAULT NULL,
|
||||
LOGGING BIGINT DEFAULT NULL,
|
||||
TICKETING BIGINT DEFAULT NULL
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"rewards",
|
||||
"""
|
||||
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||
type VARCHAR(50) NOT NULL,
|
||||
amount INT NOT NULL DEFAULT 0,
|
||||
description TEXT DEFAULT NULL
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"logs",
|
||||
"""
|
||||
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||
guild_id BIGINT NOT NULL,
|
||||
user_id BIGINT NOT NULL,
|
||||
type VARCHAR(50) NOT NULL,
|
||||
message TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"gamble_limits",
|
||||
"""
|
||||
USERID BIGINT PRIMARY KEY,
|
||||
DAILY_LIMIT BIGINT DEFAULT NULL,
|
||||
EXCLUDED_UNTIL DATETIME DEFAULT NULL
|
||||
""",
|
||||
)
|
||||
self.create_table_if_not_exists(
|
||||
"users",
|
||||
"""
|
||||
ID BIGINT PRIMARY KEY,
|
||||
XP INT DEFAULT 0,
|
||||
LEVEL INT DEFAULT 0,
|
||||
birthday VARCHAR(10) DEFAULT NULL
|
||||
""",
|
||||
)
|
||||
|
||||
def load_env(self, env):
|
||||
env_file = f".env"
|
||||
def load_env(self, env_file):
|
||||
load_dotenv(env_file)
|
||||
logger.info(f"Loaded environment variables from {env_file}")
|
||||
|
||||
def get_connection(self):
|
||||
return self.pool.get_connection()
|
||||
|
||||
def _sanitize_identifier(self, identifier: str) -> str:
|
||||
if not re.match(r"^[A-Za-z0-9_]+$", identifier):
|
||||
raise ValueError(f"Invalid SQL identifier: {identifier}")
|
||||
return identifier
|
||||
|
||||
def _parse_insert_columns(self, query: str) -> list[str]:
|
||||
match = re.search(
|
||||
r"INSERT\s+INTO\s+\S+\s*\(([^)]+)\)\s*VALUES",
|
||||
query,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"Insert query must contain a column list for overwrite upsert support."
|
||||
)
|
||||
return [col.strip() for col in match.group(1).split(",") if col.strip()]
|
||||
|
||||
def execute_query(self, query, params=None, retries=3, delay=1):
|
||||
cursor = None
|
||||
connection = None
|
||||
cursor = None
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
connection = self.get_connection()
|
||||
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||
cursor.execute(query, params or ())
|
||||
logger.info(f"Executed query: {query} with params: {params}")
|
||||
connection.commit()
|
||||
return copy(cursor)
|
||||
logger.info(f"Executed query: {query} with params: {params}")
|
||||
return cursor.rowcount
|
||||
except mysql.connector.Error as err:
|
||||
logger.warning(f"Attempt {attempt + 1} failed: {err}")
|
||||
time.sleep(delay * (2**attempt))
|
||||
@@ -163,7 +276,7 @@ class DatabaseManager:
|
||||
Args:
|
||||
query (str): The SQL query to execute.
|
||||
params (tuple): The parameters to pass into the query.
|
||||
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to False.
|
||||
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to True.
|
||||
|
||||
Raises:
|
||||
ValueError: If no parameters are provided.
|
||||
@@ -171,37 +284,29 @@ class DatabaseManager:
|
||||
if not params:
|
||||
raise ValueError("Params must be provided for the insert operation.")
|
||||
|
||||
try:
|
||||
if overwrite:
|
||||
columns = [
|
||||
col.split("=")[0].strip()
|
||||
for col in query.split("VALUES")[0]
|
||||
.split("(")[1]
|
||||
.split(")")[0]
|
||||
.split(",")
|
||||
]
|
||||
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
|
||||
query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
|
||||
if overwrite:
|
||||
columns = self._parse_insert_columns(query)
|
||||
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
|
||||
query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
|
||||
|
||||
cursor = self.execute_query(query, params)
|
||||
if cursor:
|
||||
logger.info(f"Insert completed with query: {query}.")
|
||||
except mysql.connector.Error as err:
|
||||
logger.error(f"Insert failed with query: {query}. Error: {err}")
|
||||
rowcount = self.execute_query(query, params)
|
||||
if rowcount is None:
|
||||
logger.error(f"Insert failed with query: {query}.")
|
||||
else:
|
||||
logger.info(f"Insert completed with query: {query}.")
|
||||
|
||||
def bulk_insert(self, query, params=None):
|
||||
if not params:
|
||||
logger.warning("No data provided for bulk insert.")
|
||||
return
|
||||
|
||||
# Assuming params is a list of dictionaries
|
||||
if not isinstance(params, list) or not all(isinstance(d, dict) for d in params):
|
||||
raise ValueError("Params must be a list of dictionaries for bulk insert.")
|
||||
|
||||
keys = params[0].keys()
|
||||
keys = list(params[0].keys())
|
||||
placeholders = ", ".join(["%s"] * len(keys))
|
||||
query = f"{query} ({', '.join(keys)}) VALUES ({placeholders})"
|
||||
values = [tuple(data.values()) for data in params]
|
||||
values = [tuple(data[key] for key in keys) for data in params]
|
||||
|
||||
connection = None
|
||||
cursor = None
|
||||
@@ -216,7 +321,7 @@ class DatabaseManager:
|
||||
except mysql.connector.Error as err:
|
||||
logger.error(f"Bulk insert failed: {err}")
|
||||
if connection:
|
||||
connection.rollback() # Roll back on error
|
||||
connection.rollback()
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
@@ -225,34 +330,60 @@ class DatabaseManager:
|
||||
|
||||
def delete(self, table_name: str, condition: dict) -> None:
|
||||
"""Deletes a record from the specified table based on the condition provided."""
|
||||
table_name = self._sanitize_identifier(table_name)
|
||||
condition_column, condition_value = next(iter(condition.items()))
|
||||
condition_column = self._sanitize_identifier(condition_column)
|
||||
query = f"DELETE FROM {table_name} WHERE {condition_column} = %s"
|
||||
self.execute_query(query, (condition_value,))
|
||||
|
||||
def fetch_one(self, query, params=None):
|
||||
cursor = self.execute_query(query, params)
|
||||
return cursor.fetchone() if cursor else {}
|
||||
connection = None
|
||||
cursor = None
|
||||
try:
|
||||
connection = self.get_connection()
|
||||
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||
cursor.execute(query, params or ())
|
||||
return cursor.fetchone()
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if connection:
|
||||
connection.close()
|
||||
|
||||
def fetch_all(self, query, params=None):
|
||||
cursor = self.execute_query(query, params)
|
||||
return cursor.fetchall() if cursor else []
|
||||
connection = None
|
||||
cursor = None
|
||||
try:
|
||||
connection = self.get_connection()
|
||||
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||
cursor.execute(query, params or ())
|
||||
return cursor.fetchall()
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if connection:
|
||||
connection.close()
|
||||
|
||||
def fetch_as_dataframe(self, query, params=None):
|
||||
cursor = self.execute_query(query, params)
|
||||
if cursor:
|
||||
try:
|
||||
# Ensure cursor has a result to fetch
|
||||
if cursor.with_rows:
|
||||
results = cursor.fetchall()
|
||||
return pd.DataFrame(results) if results else pd.DataFrame()
|
||||
else:
|
||||
logger.warning("No result set to fetch from.")
|
||||
return pd.DataFrame()
|
||||
finally:
|
||||
connection = None
|
||||
cursor = None
|
||||
try:
|
||||
connection = self.get_connection()
|
||||
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||
cursor.execute(query, params or ())
|
||||
if cursor.with_rows:
|
||||
results = cursor.fetchall()
|
||||
return pd.DataFrame(results) if results else pd.DataFrame()
|
||||
logger.warning("No result set to fetch from.")
|
||||
return pd.DataFrame()
|
||||
finally:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
return pd.DataFrame()
|
||||
if connection:
|
||||
connection.close()
|
||||
|
||||
def create_table_if_not_exists(self, table_name, schema):
|
||||
table_name = self._sanitize_identifier(table_name)
|
||||
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})"
|
||||
self.execute_query(query)
|
||||
logger.info(f"Ensured table {table_name} exists with schema: {schema}")
|
||||
@@ -275,6 +406,11 @@ class DatabaseManager:
|
||||
return f"{base_query} WHERE {conditions}", list(filters.values())
|
||||
|
||||
|
||||
def initialize_database(env="development"):
|
||||
"""Initialize the database schema and return a shared DatabaseManager."""
|
||||
return DatabaseManager(env)
|
||||
|
||||
|
||||
# SQL scripts to create tables
|
||||
create_feedback_table = """
|
||||
CREATE TABLE IF NOT EXISTS feedback (
|
||||
|
||||
Reference in New Issue
Block a user