Skip to content
Snippets Groups Projects
Commit ba5e3b48 authored by samuel.m's avatar samuel.m
Browse files

First backup COmmit

parent a07c33dd
No related branches found
No related tags found
No related merge requests found
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
/.idea
/dataVis
/data_storage
__pycache__
\ No newline at end of file
import sqlite3
import pandas as pd
import os
from rich.console import Console
from rich.table import Table
class SQL_tool():
def __init__(self, data_storage_location : str):
self.db_locations = []
self.init_db(data_storage_location)
def init_db(self, data_storage_location : str, verbose : bool=None):
"""
Recursive function that will go through a folder and make a db of any csv found even if nested
"""
if os.path.isdir(data_storage_location) and len(os.listdir(data_storage_location)) != 0: # Make sure its a directory & has contents
# Loop through the contents of each file found and make a sql db in the same location of any csv files
for file_name in os.listdir(data_storage_location):
file_path = os.path.join(data_storage_location, file_name)
# check if csv
if os.path.isfile(file_path) and file_name.endswith('.csv'):
db_name = os.path.splitext(file_name)[0] + ".db"
db_path = os.path.join(data_storage_location, db_name)
try:
if verbose: print(f"Processing {file_name}")
df = pd.read_csv(file_path, header=1) # Depending on the CSV select the correct row the headers are on
with sqlite3.connect(db_path) as conn:
try:
df.to_sql("data", conn, if_exists="replace", index=False)
if verbose: print(f"{db_name} Created at {db_path}")
self.db_locations.append(db_path)
except:
if verbose: print(f"Data base {db_name} at location {db_path} exists... Doing nothing....")
self.db_locations.append(db_path)
except:
if verbose: print(f"Failed to process file {file_name}")
elif os.path.isdir(file_path): # Recurse through the data folders, should work for any folder csv configuration
self.init_db(file_path)
def list_db(self, verbose: bool = False):
"""
Pretty prints the DB
"""
if verbose:
console = Console()
table = Table(title="Database List", show_lines=True)
table.add_column("Index", style="yellow", justify="left")
table.add_column("Database Name", style="cyan", justify="left")
table.add_column("Path", style="magenta", justify="left")
table.add_column("Tables", style="purple", justify="left")
table.add_column("Colunms", style="green", justify="left")
counter = 0
for db_path in self.db_locations:
db_name = os.path.basename(db_path)
db_tables = self.get_tables(db_path=db_path)
columns = self.peak_db(db_path=db_path, verbose=False)
table.add_row(str(counter), db_name, db_path, ",".join(db_tables), ",".join(columns))
counter += 1
console.print(table)
return self.db_locations
else:
return self.db_locations
def peak_db(self, db_path: str, verbose : bool=True):
"""
Displays the column headers and schema from a database.
"""
if verbose : print(f"Peeking into: {db_path}")
tables = self.get_tables(db_path=db_path)
headers = []
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
# Execute a query to retrieve table structure
for table in tables:
query = f"SELECT * FROM {table} LIMIT 1"
cur.execute(query) # LIMIT 1 to minimize query impact
headers = [description[0] for description in cur.description]
if verbose:
print("Column Headers:")
print(headers)
return headers
def get_column(self, db_path: str, column_num : int=None, column_name : str=None):
"""
Returns the list of items in a specified column by either column index or name
"""
if column_num is not None:
colunm_name = self.peak_db(db_path=db_path)[column_num]
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Dynamically build the SQL query to select the desired column
query = f"SELECT {colunm_name} FROM data"
try:
cursor.execute(query)
data = cursor.fetchall() # Fetch all rows
return [row[0] for row in data] # Extract the column values
except:
print(f"Error getting colunm {column_name}")
if column_name is not None:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Dynamically build the SQL query to select the desired column
query = f"SELECT {column_name} FROM data"
try:
cursor.execute(query)
data = cursor.fetchall() # Fetch all rows
return [row[0] for row in data] # Extract the column values
except:
print(f"Error getting colunm {column_name}")
def get_tables(self, db_path, verbose : bool=False):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Query for table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
table_lst = []
# Print the names of all tables
if tables:
for table in tables:
table_lst.append(table[0])
if verbose : print(f"DB_NAME : {os.path.basename(db_path)} -- Tables : {table[0]}")
else:
print("No tables found in the database.")
return table_lst
\ No newline at end of file
import matplotlib.pyplot as plt
import numpy as np
class Graph_tool():
def __init__(self):
pass
def gen_graph(self, x_data: list, x_name: str, y_data: list, y_name: str, log_scale: bool = False):
"""
Generates a line graph with appropriate scaling for large datasets.
Args:
x_data (list): Data for the x-axis.
x_name (str): Label for the x-axis.
y_data (list): Data for the y-axis.
y_name (str): Label for the y-axis.
log_scale (bool): Apply logarithmic scaling to axes if True. Defaults to False.
"""
# Handle large datasets by downsampling
max_points = 1000 # Maximum points to display for clarity
if len(x_data) > max_points or len(y_data) > max_points:
step = len(x_data) // max_points
x_data = x_data[::step]
y_data = y_data[::step]
# Create the plot
plt.figure(figsize=(10, 6)) # Set a larger figure size for readability
plt.plot(x_data, y_data, marker='o', markersize=4, label='Data Line', alpha=0.8) # Plot with markers and transparency
# Add labels and title
plt.xlabel(x_name)
plt.ylabel(y_name)
plt.title('Line Graph Example')
# Add grid and legend
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()
# Adjust axis limits dynamically with margin
x_margin = (max(x_data) - min(x_data)) * 0.05
y_margin = (max(y_data) - min(y_data)) * 0.05
plt.xlim(min(x_data) - x_margin, max(x_data) + x_margin)
plt.ylim(min(y_data) - y_margin, max(y_data) + y_margin)
# Apply logarithmic scaling if enabled
if log_scale:
plt.xscale('log')
plt.yscale('log')
# Show the plot
plt.show()
def graph_yearly_data(self , formatted_data, title="Yearly Data Trends"):
"""
Plots yearly data for multiple countries.
Args:
formatted_data (list): A list of lists where each sublist contains (year, stat) tuples for a country.
title (str): The title of the graph.
"""
plt.figure(figsize=(10, 6)) # Set the figure size
for country_index, country_data in enumerate(formatted_data):
# Separate years and stats for the current country
years, stats = zip(*country_data) # Unzip the (year, stat) tuples
# Plot the line for this country's data
plt.plot(years, stats, marker='o', label=f"Country {country_index + 1}")
# Add labels, title, and grid
plt.xlabel("Year")
plt.ylabel("Stat")
plt.title(title)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(title="Countries")
# Show the plot
plt.tight_layout()
plt.show()
main.py 0 → 100644
from data_code.sql_tool import SQL_tool
from graph_code.graph_tool import Graph_tool
from stats_code.stats_tool import Stat_tool
if __name__ == "__main__":
sql_t = SQL_tool(data_storage_location="data_storage") # Initialises all the data bases for sqlite
g_tool = Graph_tool()
s_tool = Stat_tool()
table_list = sql_t.list_db(verbose=True)
inp = input("x_table x_column y_table y_column(e.g 2 Year 4 Value)\n").split(" ")
x_tab, x_col, y_tab, y_col = (int(inp[0]), str(inp[1]), int(inp[2]), str(inp[3]))
# Get x
path = table_list[x_tab]
colunm_data_x = sql_t.get_column(db_path=path, column_name=x_col)
# Get y
path = table_list[y_tab]
colunm_data_y = sql_t.get_column(db_path=path, column_name=y_col)
formatted_yearly_data = s_tool.process_yearly_data(year_data=colunm_data_x, stats=colunm_data_y)
top_countries = s_tool.filter_top_countries(formatted_data=formatted_yearly_data, top_n=5)
g_tool.graph_yearly_data(formatted_data=top_countries, title="Yearly usage change")
#g_tool.gen_graph(x_data=colunm_data_x, x_name=x_col, y_data=colunm_data_y, y_name=y_col)
pandas
rich
matplotlib
\ No newline at end of file
class Stat_tool():
def __init__(self):
pass
def process_yearly_data(self, year_data, stats):
"""
Processes yearly data and groups stats for each country by year.
Assumes `year_data` repeats in a fixed pattern for each country.
Args:
year_data (list): List of years (repeated for each country).
stats (list): List of statistics corresponding to the years.
"""
# Determine the looping constant (number of years per country)
start_year = year_data[0]
looping_constant = 0
for year in year_data:
if year == start_year and looping_constant > 0:
break
looping_constant += 1
# Format the data: group stats by country and year
formatted_data = []
for i in range(0, len(stats), looping_constant):
country_data = []
for j in range(looping_constant):
if i + j < len(stats): # Ensure we don't exceed list bounds
country_data.append((year_data[j], stats[i + j]))
formatted_data.append(country_data)
return formatted_data
def filter_top_countries(self, formatted_data, top_n=5, metric_index=1):
"""
Filters top N countries based on the sum of a metric (e.g., stats).
Args:
formatted_data (list): List of lists with (year, stat) tuples.
top_n (int): Number of countries to include.
metric_index (int): Index of the stat in the tuple.
Returns:
list: Filtered data for top N countries.
"""
# Calculate total stats for each country
country_totals = [
(country_index, sum(data[metric_index] for data in country_data))
for country_index, country_data in enumerate(formatted_data)
]
# Sort by total stats and get the top N
top_countries = sorted(country_totals, key=lambda x: x[1], reverse=True)[:top_n]
# Return only the data for the top N countries
return [formatted_data[country[0]] for country in top_countries]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment