Refactor database management and schema initialization
- Removed the old npc_memory.db file. - Updated time.txt with a new timestamp. - Refactored transaction recording in bank_functions.py to use parameterized queries. - Enhanced DatabaseManager in sql_commands.py to support singleton pattern and improved table creation logic. - Added methods for sanitizing SQL identifiers and parsing insert columns for upsert operations. - Improved error handling and connection management in execute_query, fetch_one, fetch_all, and fetch_as_dataframe methods. - Introduced a new bootstrap_database.py script for initializing the database schema. - Updated app.py to use the new initialize_database function for database management.
This commit is contained in:
@@ -0,0 +1,6 @@
|
|||||||
|
from utils.sql_commands import initialize_database
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
initialize_database()
|
||||||
|
print("Database initialization complete.")
|
||||||
@@ -4,6 +4,7 @@ from discord.ext import commands
|
|||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from web.app import app
|
from web.app import app
|
||||||
|
from utils.sql_commands import initialize_database
|
||||||
import threading
|
import threading
|
||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
@@ -50,9 +51,9 @@ class Client(commands.Bot):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
initialize_database()
|
||||||
client = Client()
|
client = Client()
|
||||||
token = os.getenv("TOKEN")
|
token = os.getenv("TOKEN")
|
||||||
print(token)
|
|
||||||
if token is not None:
|
if token is not None:
|
||||||
threading.Thread(target=run_web, daemon=True).start()
|
threading.Thread(target=run_web, daemon=True).start()
|
||||||
client.run(token)
|
client.run(token)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from discord.ext import commands
|
|||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from web.app import app
|
from web.app import app
|
||||||
|
from utils.sql_commands import initialize_database
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
|
||||||
@@ -58,6 +59,7 @@ class Client(commands.Bot):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
initialize_database()
|
||||||
client = Client()
|
client = Client()
|
||||||
token = os.getenv("TOKEN")
|
token = os.getenv("TOKEN")
|
||||||
if token is not None:
|
if token is not None:
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ class CustomCommandsCog(commands.Cog):
|
|||||||
)
|
)
|
||||||
self.invalidate_cache(guild_id) # Invalidate cache after change
|
self.invalidate_cache(guild_id) # Invalidate cache after change
|
||||||
|
|
||||||
if deleted_rows is not None and len(deleted_rows) > 0:
|
if deleted_rows and deleted_rows > 0:
|
||||||
await ctx.send(f"Custom command `{command_name}` has been deleted!")
|
await ctx.send(f"Custom command `{command_name}` has been deleted!")
|
||||||
else:
|
else:
|
||||||
await ctx.send(f"Custom command `{command_name}` not found.")
|
await ctx.send(f"Custom command `{command_name}` not found.")
|
||||||
|
|||||||
+1974
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -40,7 +40,10 @@ async def record_transaction(
|
|||||||
transaction_type: The type of transaction performed.
|
transaction_type: The type of transaction performed.
|
||||||
amount: The amount of the transaction.
|
amount: The amount of the transaction.
|
||||||
"""
|
"""
|
||||||
db.insert("transactions", (user_id, transaction_type, amount))
|
db.execute_query(
|
||||||
|
"INSERT INTO transactions (USERID, TYPE, AMOUNT) VALUES (%s, %s, %s)",
|
||||||
|
(user_id, transaction_type, amount),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_money(
|
async def update_money(
|
||||||
|
|||||||
+171
-35
@@ -3,11 +3,11 @@ from mysql.connector import pooling
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import logging
|
import logging
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from copy import copy
|
|
||||||
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
@@ -20,16 +20,33 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
|
_instances = {}
|
||||||
|
|
||||||
|
def __new__(cls, env="development"):
|
||||||
|
instance_key = env or "default"
|
||||||
|
if instance_key in cls._instances:
|
||||||
|
return cls._instances[instance_key]
|
||||||
|
instance = super().__new__(cls)
|
||||||
|
cls._instances[instance_key] = instance
|
||||||
|
return instance
|
||||||
|
|
||||||
def __init__(self, env="development"):
|
def __init__(self, env="development"):
|
||||||
# Load environment variables based on environment
|
if getattr(self, "_initialized", False):
|
||||||
self.load_env(env)
|
return
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
env_file = f".env.{env}" if env else ".env"
|
||||||
|
if not os.path.exists(env_file):
|
||||||
|
env_file = ".env"
|
||||||
|
self.load_env(env_file)
|
||||||
|
|
||||||
self.config = {
|
self.config = {
|
||||||
"host": os.getenv("SQLHOST", "localhost"),
|
"host": os.getenv("SQLHOST", "localhost"),
|
||||||
"user": os.getenv("SQLUSER", "root"),
|
"user": os.getenv("SQLUSER", "root"),
|
||||||
"password": os.getenv("SQLPASS", ""),
|
"password": os.getenv("SQLPASS", ""),
|
||||||
"database": os.getenv("SQLDB", "testdb"),
|
"database": os.getenv("SQLDB", "testdb"),
|
||||||
"pool_reset_session": bool(os.getenv("POOL_RESET_SESSION", False)),
|
"pool_reset_session": os.getenv("POOL_RESET_SESSION", "false").lower()
|
||||||
|
in ("true", "1", "yes"),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.pool = pooling.MySQLConnectionPool(
|
self.pool = pooling.MySQLConnectionPool(
|
||||||
@@ -124,26 +141,122 @@ class DatabaseManager:
|
|||||||
inactivity INT NOT NULL
|
inactivity INT NOT NULL
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"economy",
|
||||||
|
"""
|
||||||
|
ID BIGINT PRIMARY KEY,
|
||||||
|
WALLET BIGINT NOT NULL DEFAULT 0,
|
||||||
|
BANK BIGINT NOT NULL DEFAULT 0,
|
||||||
|
DAILY DOUBLE DEFAULT 0
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"transactions",
|
||||||
|
"""
|
||||||
|
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
USERID BIGINT NOT NULL,
|
||||||
|
TYPE VARCHAR(50),
|
||||||
|
AMOUNT DECIMAL(18,2),
|
||||||
|
TIME DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"custom_commands",
|
||||||
|
"""
|
||||||
|
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
GUILDID VARCHAR(32) NOT NULL,
|
||||||
|
COMMANDNAME VARCHAR(100) NOT NULL,
|
||||||
|
RESPONSE TEXT NOT NULL,
|
||||||
|
MATCHTYPE VARCHAR(20) NOT NULL DEFAULT 'exact'
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"guilds",
|
||||||
|
"""
|
||||||
|
GUILD BIGINT PRIMARY KEY,
|
||||||
|
WELCOME BIGINT DEFAULT NULL,
|
||||||
|
RULES BIGINT DEFAULT NULL,
|
||||||
|
GUIDE BIGINT DEFAULT NULL,
|
||||||
|
INTRODUCTIONS BIGINT DEFAULT NULL,
|
||||||
|
EVENTS BIGINT DEFAULT NULL,
|
||||||
|
MEMBERCOUNT BIGINT DEFAULT NULL,
|
||||||
|
LOGGING BIGINT DEFAULT NULL,
|
||||||
|
TICKETING BIGINT DEFAULT NULL
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"rewards",
|
||||||
|
"""
|
||||||
|
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
type VARCHAR(50) NOT NULL,
|
||||||
|
amount INT NOT NULL DEFAULT 0,
|
||||||
|
description TEXT DEFAULT NULL
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"logs",
|
||||||
|
"""
|
||||||
|
ID INT AUTO_INCREMENT PRIMARY KEY,
|
||||||
|
guild_id BIGINT NOT NULL,
|
||||||
|
user_id BIGINT NOT NULL,
|
||||||
|
type VARCHAR(50) NOT NULL,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"gamble_limits",
|
||||||
|
"""
|
||||||
|
USERID BIGINT PRIMARY KEY,
|
||||||
|
DAILY_LIMIT BIGINT DEFAULT NULL,
|
||||||
|
EXCLUDED_UNTIL DATETIME DEFAULT NULL
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
self.create_table_if_not_exists(
|
||||||
|
"users",
|
||||||
|
"""
|
||||||
|
ID BIGINT PRIMARY KEY,
|
||||||
|
XP INT DEFAULT 0,
|
||||||
|
LEVEL INT DEFAULT 0,
|
||||||
|
birthday VARCHAR(10) DEFAULT NULL
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
def load_env(self, env):
|
def load_env(self, env_file):
|
||||||
env_file = f".env"
|
|
||||||
load_dotenv(env_file)
|
load_dotenv(env_file)
|
||||||
logger.info(f"Loaded environment variables from {env_file}")
|
logger.info(f"Loaded environment variables from {env_file}")
|
||||||
|
|
||||||
def get_connection(self):
|
def get_connection(self):
|
||||||
return self.pool.get_connection()
|
return self.pool.get_connection()
|
||||||
|
|
||||||
|
def _sanitize_identifier(self, identifier: str) -> str:
|
||||||
|
if not re.match(r"^[A-Za-z0-9_]+$", identifier):
|
||||||
|
raise ValueError(f"Invalid SQL identifier: {identifier}")
|
||||||
|
return identifier
|
||||||
|
|
||||||
|
def _parse_insert_columns(self, query: str) -> list[str]:
|
||||||
|
match = re.search(
|
||||||
|
r"INSERT\s+INTO\s+\S+\s*\(([^)]+)\)\s*VALUES",
|
||||||
|
query,
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
if not match:
|
||||||
|
raise ValueError(
|
||||||
|
"Insert query must contain a column list for overwrite upsert support."
|
||||||
|
)
|
||||||
|
return [col.strip() for col in match.group(1).split(",") if col.strip()]
|
||||||
|
|
||||||
def execute_query(self, query, params=None, retries=3, delay=1):
|
def execute_query(self, query, params=None, retries=3, delay=1):
|
||||||
cursor = None
|
|
||||||
connection = None
|
connection = None
|
||||||
|
cursor = None
|
||||||
for attempt in range(retries):
|
for attempt in range(retries):
|
||||||
try:
|
try:
|
||||||
connection = self.get_connection()
|
connection = self.get_connection()
|
||||||
cursor = connection.cursor(dictionary=True, buffered=True)
|
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||||
cursor.execute(query, params or ())
|
cursor.execute(query, params or ())
|
||||||
logger.info(f"Executed query: {query} with params: {params}")
|
|
||||||
connection.commit()
|
connection.commit()
|
||||||
return copy(cursor)
|
logger.info(f"Executed query: {query} with params: {params}")
|
||||||
|
return cursor.rowcount
|
||||||
except mysql.connector.Error as err:
|
except mysql.connector.Error as err:
|
||||||
logger.warning(f"Attempt {attempt + 1} failed: {err}")
|
logger.warning(f"Attempt {attempt + 1} failed: {err}")
|
||||||
time.sleep(delay * (2**attempt))
|
time.sleep(delay * (2**attempt))
|
||||||
@@ -163,7 +276,7 @@ class DatabaseManager:
|
|||||||
Args:
|
Args:
|
||||||
query (str): The SQL query to execute.
|
query (str): The SQL query to execute.
|
||||||
params (tuple): The parameters to pass into the query.
|
params (tuple): The parameters to pass into the query.
|
||||||
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to False.
|
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to True.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If no parameters are provided.
|
ValueError: If no parameters are provided.
|
||||||
@@ -171,37 +284,29 @@ class DatabaseManager:
|
|||||||
if not params:
|
if not params:
|
||||||
raise ValueError("Params must be provided for the insert operation.")
|
raise ValueError("Params must be provided for the insert operation.")
|
||||||
|
|
||||||
try:
|
|
||||||
if overwrite:
|
if overwrite:
|
||||||
columns = [
|
columns = self._parse_insert_columns(query)
|
||||||
col.split("=")[0].strip()
|
|
||||||
for col in query.split("VALUES")[0]
|
|
||||||
.split("(")[1]
|
|
||||||
.split(")")[0]
|
|
||||||
.split(",")
|
|
||||||
]
|
|
||||||
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
|
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
|
||||||
query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
|
query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
|
||||||
|
|
||||||
cursor = self.execute_query(query, params)
|
rowcount = self.execute_query(query, params)
|
||||||
if cursor:
|
if rowcount is None:
|
||||||
|
logger.error(f"Insert failed with query: {query}.")
|
||||||
|
else:
|
||||||
logger.info(f"Insert completed with query: {query}.")
|
logger.info(f"Insert completed with query: {query}.")
|
||||||
except mysql.connector.Error as err:
|
|
||||||
logger.error(f"Insert failed with query: {query}. Error: {err}")
|
|
||||||
|
|
||||||
def bulk_insert(self, query, params=None):
|
def bulk_insert(self, query, params=None):
|
||||||
if not params:
|
if not params:
|
||||||
logger.warning("No data provided for bulk insert.")
|
logger.warning("No data provided for bulk insert.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Assuming params is a list of dictionaries
|
|
||||||
if not isinstance(params, list) or not all(isinstance(d, dict) for d in params):
|
if not isinstance(params, list) or not all(isinstance(d, dict) for d in params):
|
||||||
raise ValueError("Params must be a list of dictionaries for bulk insert.")
|
raise ValueError("Params must be a list of dictionaries for bulk insert.")
|
||||||
|
|
||||||
keys = params[0].keys()
|
keys = list(params[0].keys())
|
||||||
placeholders = ", ".join(["%s"] * len(keys))
|
placeholders = ", ".join(["%s"] * len(keys))
|
||||||
query = f"{query} ({', '.join(keys)}) VALUES ({placeholders})"
|
query = f"{query} ({', '.join(keys)}) VALUES ({placeholders})"
|
||||||
values = [tuple(data.values()) for data in params]
|
values = [tuple(data[key] for key in keys) for data in params]
|
||||||
|
|
||||||
connection = None
|
connection = None
|
||||||
cursor = None
|
cursor = None
|
||||||
@@ -216,7 +321,7 @@ class DatabaseManager:
|
|||||||
except mysql.connector.Error as err:
|
except mysql.connector.Error as err:
|
||||||
logger.error(f"Bulk insert failed: {err}")
|
logger.error(f"Bulk insert failed: {err}")
|
||||||
if connection:
|
if connection:
|
||||||
connection.rollback() # Roll back on error
|
connection.rollback()
|
||||||
finally:
|
finally:
|
||||||
if cursor:
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
@@ -225,34 +330,60 @@ class DatabaseManager:
|
|||||||
|
|
||||||
def delete(self, table_name: str, condition: dict) -> None:
|
def delete(self, table_name: str, condition: dict) -> None:
|
||||||
"""Deletes a record from the specified table based on the condition provided."""
|
"""Deletes a record from the specified table based on the condition provided."""
|
||||||
|
table_name = self._sanitize_identifier(table_name)
|
||||||
condition_column, condition_value = next(iter(condition.items()))
|
condition_column, condition_value = next(iter(condition.items()))
|
||||||
|
condition_column = self._sanitize_identifier(condition_column)
|
||||||
query = f"DELETE FROM {table_name} WHERE {condition_column} = %s"
|
query = f"DELETE FROM {table_name} WHERE {condition_column} = %s"
|
||||||
self.execute_query(query, (condition_value,))
|
self.execute_query(query, (condition_value,))
|
||||||
|
|
||||||
def fetch_one(self, query, params=None):
|
def fetch_one(self, query, params=None):
|
||||||
cursor = self.execute_query(query, params)
|
connection = None
|
||||||
return cursor.fetchone() if cursor else {}
|
cursor = None
|
||||||
|
try:
|
||||||
|
connection = self.get_connection()
|
||||||
|
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||||
|
cursor.execute(query, params or ())
|
||||||
|
return cursor.fetchone()
|
||||||
|
finally:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
if connection:
|
||||||
|
connection.close()
|
||||||
|
|
||||||
def fetch_all(self, query, params=None):
|
def fetch_all(self, query, params=None):
|
||||||
cursor = self.execute_query(query, params)
|
connection = None
|
||||||
return cursor.fetchall() if cursor else []
|
cursor = None
|
||||||
|
try:
|
||||||
|
connection = self.get_connection()
|
||||||
|
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||||
|
cursor.execute(query, params or ())
|
||||||
|
return cursor.fetchall()
|
||||||
|
finally:
|
||||||
|
if cursor:
|
||||||
|
cursor.close()
|
||||||
|
if connection:
|
||||||
|
connection.close()
|
||||||
|
|
||||||
def fetch_as_dataframe(self, query, params=None):
|
def fetch_as_dataframe(self, query, params=None):
|
||||||
cursor = self.execute_query(query, params)
|
connection = None
|
||||||
if cursor:
|
cursor = None
|
||||||
try:
|
try:
|
||||||
# Ensure cursor has a result to fetch
|
connection = self.get_connection()
|
||||||
|
cursor = connection.cursor(dictionary=True, buffered=True)
|
||||||
|
cursor.execute(query, params or ())
|
||||||
if cursor.with_rows:
|
if cursor.with_rows:
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return pd.DataFrame(results) if results else pd.DataFrame()
|
return pd.DataFrame(results) if results else pd.DataFrame()
|
||||||
else:
|
|
||||||
logger.warning("No result set to fetch from.")
|
logger.warning("No result set to fetch from.")
|
||||||
return pd.DataFrame()
|
return pd.DataFrame()
|
||||||
finally:
|
finally:
|
||||||
|
if cursor:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
return pd.DataFrame()
|
if connection:
|
||||||
|
connection.close()
|
||||||
|
|
||||||
def create_table_if_not_exists(self, table_name, schema):
|
def create_table_if_not_exists(self, table_name, schema):
|
||||||
|
table_name = self._sanitize_identifier(table_name)
|
||||||
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})"
|
query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})"
|
||||||
self.execute_query(query)
|
self.execute_query(query)
|
||||||
logger.info(f"Ensured table {table_name} exists with schema: {schema}")
|
logger.info(f"Ensured table {table_name} exists with schema: {schema}")
|
||||||
@@ -275,6 +406,11 @@ class DatabaseManager:
|
|||||||
return f"{base_query} WHERE {conditions}", list(filters.values())
|
return f"{base_query} WHERE {conditions}", list(filters.values())
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_database(env="development"):
|
||||||
|
"""Initialize the database schema and return a shared DatabaseManager."""
|
||||||
|
return DatabaseManager(env)
|
||||||
|
|
||||||
|
|
||||||
# SQL scripts to create tables
|
# SQL scripts to create tables
|
||||||
create_feedback_table = """
|
create_feedback_table = """
|
||||||
CREATE TABLE IF NOT EXISTS feedback (
|
CREATE TABLE IF NOT EXISTS feedback (
|
||||||
|
|||||||
Binary file not shown.
+3
-3
@@ -4,10 +4,10 @@ import os
|
|||||||
if __package__ is None or __package__ == "":
|
if __package__ is None or __package__ == "":
|
||||||
# Running as __main__ (e.g. python web/app.py)
|
# Running as __main__ (e.g. python web/app.py)
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from utils.sql_commands import DatabaseManager
|
from utils.sql_commands import DatabaseManager, initialize_database
|
||||||
else:
|
else:
|
||||||
# Imported as a module (e.g. from bot.py)
|
# Imported as a module (e.g. from bot.py)
|
||||||
from utils.sql_commands import DatabaseManager
|
from utils.sql_commands import DatabaseManager, initialize_database
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import requests
|
import requests
|
||||||
@@ -21,7 +21,7 @@ load_dotenv()
|
|||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.secret_key = os.getenv("SECRET_KEY")
|
app.secret_key = os.getenv("SECRET_KEY")
|
||||||
db = DatabaseManager()
|
db = initialize_database()
|
||||||
|
|
||||||
# Ensure required environment variables are loaded
|
# Ensure required environment variables are loaded
|
||||||
DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID")
|
DISCORD_CLIENT_ID = os.getenv("DISCORD_CLIENT_ID")
|
||||||
|
|||||||
Reference in New Issue
Block a user