feat: Add bank and wallet balance commands with improved transfer validation

- Added /bank and /wallet commands to check user balances
- Enhanced transfer validation to check for valid number input
- Included debug token logging in main function (to be removed in production)
- Added account creation logic for new users in balance commands
- Improved error handling for invalid transfer amounts
This commit is contained in:
2026-06-03 11:56:09 +00:00
parent 4b07ca86b9
commit b315069b1c
4 changed files with 168 additions and 84 deletions
+12 -8
View File
@@ -42,14 +42,18 @@ class Client(commands.Bot):
print("Loaded cogs") print("Loaded cogs")
def main(): def main():
load_dotenv() load_dotenv()
initialize_database() initialize_database()
client = Client() client = Client()
token = os.getenv("DISCORD_TOKEN") or os.getenv("TOKEN") token = os.getenv("DISCORD_TOKEN") or os.getenv("TOKEN")
if not token: if not token:
raise SystemExit("ERROR: Discord token not found. Set DISCORD_TOKEN or TOKEN in environment.") raise SystemExit("ERROR: Discord token not found. Set DISCORD_TOKEN or TOKEN in environment.")
client.run(token)
# Print the token for debugging purposes (remove this in production)
print(f"DEBUG: Using token starting with {token[:5]}.")
client.run(token)
if __name__ == "__main__": if __name__ == "__main__":
+95 -19
View File
@@ -15,25 +15,28 @@ from utils.bank_functions import (
from utils.sql_commands import DatabaseManager from utils.sql_commands import DatabaseManager
def validate_transfer( def validate_transfer(
payer_balance: int, author_id: int, receiver_id: int, amount: int payer_balance: int, author_id: int, receiver_id: int, amount: int
) -> tuple[bool, str]: ) -> tuple[bool, str]:
if amount <= 0: # Validate amount is a positive number
return False, "Please enter an amount greater than 0." if not isinstance(amount, (int, float)):
return False, "Amount must be a number"
if payer_balance is None: if amount <= 0:
return False, "Your account does not exist." return False, "Please enter an amount greater than 0."
if author_id == receiver_id: if payer_balance is None:
return False, "You cannot give yourself money." return False, "Your account does not exist."
if payer_balance < amount: if author_id == receiver_id:
return False, ( return False, "You cannot give yourself money."
f"You do not have {amount:,}<:flooney:1194943899765051473>. "
f"You have {payer_balance:,}<:flooney:1194943899765051473>." if payer_balance < amount:
) return False, (
f"You do not have {amount:,}<:flooney:1194943899765051473>. "
return True, "" f"You have {payer_balance:,}<:flooney:1194943899765051473>."
)
return True, ""
class Economy(commands.Cog): class Economy(commands.Cog):
@@ -400,6 +403,79 @@ class Economy(commands.Cog):
mention_author=False, mention_author=False,
) )
@commands.command(name="bank", brief="Check your bank balance", description="Check your bank balance.")
async def _bank(self, ctx: commands.Context, member: discord.Member | None = None):
"""Check your bank balance."""
target: discord.Member | discord.User = member or ctx.author
# Ensure the target is not a bot
if target.bot:
return
# Get the user's data
user_data = await bank_data(target)
wallet_balance = user_data.get("WALLET", 0)
bank_balance = user_data.get("BANK", 0)
# Create an account if one does not exist
if bank_balance is None or wallet_balance is None:
await create_account(target)
user_data = await bank_data(target)
wallet_balance = user_data.get("WALLET", 0)
bank_balance = user_data.get("BANK", 0)
# Reply with the user's bank balance
await ctx.reply(
f"{target.mention} has {bank_balance:,}<:flooney:1194943899765051473> in their bank."
)
@commands.command(name="wallet", brief="Check your wallet balance", description="Check your wallet balance.")
async def _wallet(self, ctx: commands.Context, member: discord.Member | None = None):
"""Check your wallet balance."""
target: discord.Member | discord.User = member or ctx.author
# Ensure the target is not a bot
if target.bot:
return
# Get the user's data
user_data = await bank_data(target)
wallet_balance = user_data.get("WALLET", 0)
bank_balance = user_data.get("BANK", 0)
# Create an account if one does not exist
if bank_balance is None or wallet_balance is None:
await create_account(target)
user_data = await bank_data(target)
wallet_balance = user_data.get("WALLET", 0)
bank_balance = user_data.get("BANK", 0)
# Reply with the user's wallet balance
await ctx.reply(
f"{target.mention} has {wallet_balance:,}<:flooney:1194943899765051473> in their wallet."
)
@commands.command(name="transfer", brief="Transfer money from bank to another user", description="Transfer money from your bank to another user's bank.")
async def _transfer(self, ctx, target: discord.Member, amount: int):
"""
Transfer money from your bank account to another user's bank account.
"""
payer_bank = int((await bank_data(ctx.author)).get("BANK", 0))
receiver_bank = int((await bank_data(target)).get("BANK", 0))
valid, reason = validate_transfer(
payer_bank, ctx.author.id, target.id, amount
)
if not valid:
return await ctx.reply(reason, mention_author=False)
await update_money(target, bank=amount)
await update_money(ctx.author, bank=-amount)
await ctx.reply(
f"Transferred {amount:,} flooneys to {target}.", mention_author=False
)
r""" r"""
.----------------. .----------------. .----------------. .----------------. .----------------. .----------------. .----------------. .----------------.
| .--------------. || .--------------. || .--------------. || .--------------. | | .--------------. || .--------------. || .--------------. || .--------------. |
+6 -3
View File
@@ -27,7 +27,11 @@ async def bank_data(user: discord.Member | discord.User) -> Dict[str, int]:
if balance is None: if balance is None:
await create_account(user) await create_account(user)
return await bank_data(user) return await bank_data(user)
return balance # Ensure we return a dictionary with WALLET and BANK keys
return {
"WALLET": balance.get("WALLET", 0),
"BANK": balance.get("BANK", 0)
}
async def record_transaction( async def record_transaction(
@@ -88,5 +92,4 @@ async def update_daily_timestamp(user: discord.User | discord.Member, timestamp:
Updates the DAILY_TIMESTAMP field for the user in the economy table. Updates the DAILY_TIMESTAMP field for the user in the economy table.
Stores the timestamp as a float (UNIX time). Stores the timestamp as a float (UNIX time).
""" """
db.execute_query("UPDATE economy SET DAILY = %s WHERE ID = %s",(timestamp.timestamp(), user.id), db.execute_query("UPDATE economy SET DAILY = %s WHERE ID = %s",(timestamp.timestamp(), user.id))
)
+55 -54
View File
@@ -263,54 +263,55 @@ class DatabaseManager:
) )
return [col.strip() for col in match.group(1).split(",") if col.strip()] return [col.strip() for col in match.group(1).split(",") if col.strip()]
def execute_query(self, query, params=None, retries=3, delay=1): def execute_query(self, query, params=None, retries=3, delay=1):
connection = None connection = None
cursor = None cursor = None
for attempt in range(retries): for attempt in range(retries):
try: try:
connection = self.get_connection() connection = self.get_connection()
cursor = connection.cursor(dictionary=True, buffered=True) cursor = connection.cursor(dictionary=True, buffered=True)
cursor.execute(query, params or ()) cursor.execute(query, params or ())
connection.commit() connection.commit()
logger.info(f"Executed query: {query} with params: {params}") logger.info(f"Executed query: {query} with params: {params}")
return cursor.rowcount # Return the actual cursor results instead of closing it
except mysql.connector.Error as err: return cursor.fetchall() if cursor.with_rows else cursor.rowcount
logger.warning(f"Attempt {attempt + 1} failed: {err}") except mysql.connector.Error as err:
time.sleep(delay * (2**attempt)) logger.warning(f"Attempt {attempt + 1} failed: {err}")
finally: time.sleep(delay * (2**attempt))
if cursor: finally:
cursor.close() if cursor:
if connection: cursor.close()
connection.close() if connection:
connection.close()
logger.error(f"All {retries} attempts failed for query: {query}")
return None
logger.error(f"All {retries} attempts failed for query: {query}") def insert(self, query: str, params: tuple, overwrite: bool = True) -> None:
return None """
Inserts data into the database using the given query and parameters.
def insert(self, query: str, params: tuple, overwrite: bool = True) -> None:
""" Args:
Inserts data into the database using the given query and parameters. query (str): The SQL query to execute.
params (tuple): The parameters to pass into the query.
Args: overwrite (bool, optional): Whether to perform an upsert operation. Defaults to True.
query (str): The SQL query to execute.
params (tuple): The parameters to pass into the query. Raises:
overwrite (bool, optional): Whether to perform an upsert operation. Defaults to True. ValueError: If no parameters are provided.
"""
Raises: if not params:
ValueError: If no parameters are provided. raise ValueError("Params must be provided for the insert operation.")
"""
if not params: if overwrite:
raise ValueError("Params must be provided for the insert operation.") columns = self._parse_insert_columns(query)
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns)
if overwrite: query = f"{query} ON DUPLICATE KEY UPDATE {update_set}"
columns = self._parse_insert_columns(query)
update_set = ", ".join(f"{col} = VALUES({col})" for col in columns) rowcount = self.execute_query(query, params)
query = f"{query} ON DUPLICATE KEY UPDATE {update_set}" if rowcount is None:
logger.error(f"Insert failed with query: {query}.")
rowcount = self.execute_query(query, params) else:
if rowcount is None: logger.info(f"Insert completed with query: {query}.")
logger.error(f"Insert failed with query: {query}.")
else:
logger.info(f"Insert completed with query: {query}.")
def bulk_insert(self, query, params=None): def bulk_insert(self, query, params=None):
if not params: if not params:
@@ -345,13 +346,13 @@ class DatabaseManager:
if connection: if connection:
connection.close() connection.close()
def delete(self, table_name: str, condition: dict) -> None: def delete(self, table_name: str, condition: dict) -> None:
"""Deletes a record from the specified table based on the condition provided.""" """Deletes a record from the specified table based on the condition provided."""
table_name = self._sanitize_identifier(table_name) table_name = self._sanitize_identifier(table_name)
condition_column, condition_value = next(iter(condition.items())) condition_column, condition_value = next(iter(condition.items()))
condition_column = self._sanitize_identifier(condition_column) condition_column = self._sanitize_identifier(condition_column)
query = f"DELETE FROM {table_name} WHERE {condition_column} = %s" query = f"DELETE FROM {table_name} WHERE {condition_column} = %s"
self.execute_query(query, (condition_value,)) self.execute_query(query, (condition_value,))
def fetch_one(self, query, params=None): def fetch_one(self, query, params=None):
connection = None connection = None