Skip to content
Snippets Groups Projects
Commit ba5e3b48 authored by samuel.m's avatar samuel.m
Browse files

First backup COmmit

parent a07c33dd
Branches
No related tags found
No related merge requests found
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
/.idea
/dataVis
/data_storage
__pycache__
\ No newline at end of file
import sqlite3
import pandas as pd
import os
from rich.console import Console
from rich.table import Table
class SQL_tool():
def __init__(self, data_storage_location : str):
self.db_locations = []
self.init_db(data_storage_location)
def init_db(self, data_storage_location : str, verbose : bool=None):
"""
Recursive function that will go through a folder and make a db of any csv found even if nested
"""
if os.path.isdir(data_storage_location) and len(os.listdir(data_storage_location)) != 0: # Make sure its a directory & has contents
# Loop through the contents of each file found and make a sql db in the same location of any csv files
for file_name in os.listdir(data_storage_location):
file_path = os.path.join(data_storage_location, file_name)
# check if csv
if os.path.isfile(file_path) and file_name.endswith('.csv'):
db_name = os.path.splitext(file_name)[0] + ".db"
db_path = os.path.join(data_storage_location, db_name)
try:
if verbose: print(f"Processing {file_name}")
df = pd.read_csv(file_path, header=1) # Depending on the CSV select the correct row the headers are on
with sqlite3.connect(db_path) as conn:
try:
df.to_sql("data", conn, if_exists="replace", index=False)
if verbose: print(f"{db_name} Created at {db_path}")
self.db_locations.append(db_path)
except:
if verbose: print(f"Data base {db_name} at location {db_path} exists... Doing nothing....")
self.db_locations.append(db_path)
except:
if verbose: print(f"Failed to process file {file_name}")
elif os.path.isdir(file_path): # Recurse through the data folders, should work for any folder csv configuration
self.init_db(file_path)
def list_db(self, verbose: bool = False):
"""
Pretty prints the DB
"""
if verbose:
console = Console()
table = Table(title="Database List", show_lines=True)
table.add_column("Index", style="yellow", justify="left")
table.add_column("Database Name", style="cyan", justify="left")
table.add_column("Path", style="magenta", justify="left")
table.add_column("Tables", style="purple", justify="left")
table.add_column("Colunms", style="green", justify="left")
counter = 0
for db_path in self.db_locations:
db_name = os.path.basename(db_path)
db_tables = self.get_tables(db_path=db_path)
columns = self.peak_db(db_path=db_path, verbose=False)
table.add_row(str(counter), db_name, db_path, ",".join(db_tables), ",".join(columns))
counter += 1
console.print(table)
return self.db_locations
else:
return self.db_locations
def peak_db(self, db_path: str, verbose : bool=True):
"""
Displays the column headers and schema from a database.
"""
if verbose : print(f"Peeking into: {db_path}")
tables = self.get_tables(db_path=db_path)
headers = []
with sqlite3.connect(db_path) as conn:
cur = conn.cursor()
# Execute a query to retrieve table structure
for table in tables:
query = f"SELECT * FROM {table} LIMIT 1"
cur.execute(query) # LIMIT 1 to minimize query impact
headers = [description[0] for description in cur.description]
if verbose:
print("Column Headers:")
print(headers)
return headers
def get_column(self, db_path: str, column_num : int=None, column_name : str=None):
"""
Returns the list of items in a specified column by either column index or name
"""
if column_num is not None:
colunm_name = self.peak_db(db_path=db_path)[column_num]
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Dynamically build the SQL query to select the desired column
query = f"SELECT {colunm_name} FROM data"
try:
cursor.execute(query)
data = cursor.fetchall() # Fetch all rows
return [row[0] for row in data] # Extract the column values
except:
print(f"Error getting colunm {column_name}")
if column_name is not None:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Dynamically build the SQL query to select the desired column
query = f"SELECT {column_name} FROM data"
try:
cursor.execute(query)
data = cursor.fetchall() # Fetch all rows
return [row[0] for row in data] # Extract the column values
except:
print(f"Error getting colunm {column_name}")
def get_tables(self, db_path, verbose : bool=False):
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Query for table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
table_lst = []
# Print the names of all tables
if tables:
for table in tables:
table_lst.append(table[0])
if verbose : print(f"DB_NAME : {os.path.basename(db_path)} -- Tables : {table[0]}")
else:
print("No tables found in the database.")
return table_lst
\ No newline at end of file
import matplotlib.pyplot as plt
import numpy as np
class Graph_tool():
def __init__(self):
pass
def gen_graph(self, x_data: list, x_name: str, y_data: list, y_name: str, log_scale: bool = False):
"""
Generates a line graph with appropriate scaling for large datasets.
Args:
x_data (list): Data for the x-axis.
x_name (str): Label for the x-axis.
y_data (list): Data for the y-axis.
y_name (str): Label for the y-axis.
log_scale (bool): Apply logarithmic scaling to axes if True. Defaults to False.
"""
# Handle large datasets by downsampling
max_points = 1000 # Maximum points to display for clarity
if len(x_data) > max_points or len(y_data) > max_points:
step = len(x_data) // max_points
x_data = x_data[::step]
y_data = y_data[::step]
# Create the plot
plt.figure(figsize=(10, 6)) # Set a larger figure size for readability
plt.plot(x_data, y_data, marker='o', markersize=4, label='Data Line', alpha=0.8) # Plot with markers and transparency
# Add labels and title
plt.xlabel(x_name)
plt.ylabel(y_name)
plt.title('Line Graph Example')
# Add grid and legend
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend()
# Adjust axis limits dynamically with margin
x_margin = (max(x_data) - min(x_data)) * 0.05
y_margin = (max(y_data) - min(y_data)) * 0.05
plt.xlim(min(x_data) - x_margin, max(x_data) + x_margin)
plt.ylim(min(y_data) - y_margin, max(y_data) + y_margin)
# Apply logarithmic scaling if enabled
if log_scale:
plt.xscale('log')
plt.yscale('log')
# Show the plot
plt.show()
def graph_yearly_data(self , formatted_data, title="Yearly Data Trends"):
"""
Plots yearly data for multiple countries.
Args:
formatted_data (list): A list of lists where each sublist contains (year, stat) tuples for a country.
title (str): The title of the graph.
"""
plt.figure(figsize=(10, 6)) # Set the figure size
for country_index, country_data in enumerate(formatted_data):
# Separate years and stats for the current country
years, stats = zip(*country_data) # Unzip the (year, stat) tuples
# Plot the line for this country's data
plt.plot(years, stats, marker='o', label=f"Country {country_index + 1}")
# Add labels, title, and grid
plt.xlabel("Year")
plt.ylabel("Stat")
plt.title(title)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(title="Countries")
# Show the plot
plt.tight_layout()
plt.show()
main.py 0 → 100644
from data_code.sql_tool import SQL_tool
from graph_code.graph_tool import Graph_tool
from stats_code.stats_tool import Stat_tool
if __name__ == "__main__":
sql_t = SQL_tool(data_storage_location="data_storage") # Initialises all the data bases for sqlite
g_tool = Graph_tool()
s_tool = Stat_tool()
table_list = sql_t.list_db(verbose=True)
inp = input("x_table x_column y_table y_column(e.g 2 Year 4 Value)\n").split(" ")
x_tab, x_col, y_tab, y_col = (int(inp[0]), str(inp[1]), int(inp[2]), str(inp[3]))
# Get x
path = table_list[x_tab]
colunm_data_x = sql_t.get_column(db_path=path, column_name=x_col)
# Get y
path = table_list[y_tab]
colunm_data_y = sql_t.get_column(db_path=path, column_name=y_col)
formatted_yearly_data = s_tool.process_yearly_data(year_data=colunm_data_x, stats=colunm_data_y)
top_countries = s_tool.filter_top_countries(formatted_data=formatted_yearly_data, top_n=5)
g_tool.graph_yearly_data(formatted_data=top_countries, title="Yearly usage change")
#g_tool.gen_graph(x_data=colunm_data_x, x_name=x_col, y_data=colunm_data_y, y_name=y_col)
pandas
rich
matplotlib
\ No newline at end of file
class Stat_tool():
def __init__(self):
pass
def process_yearly_data(self, year_data, stats):
"""
Processes yearly data and groups stats for each country by year.
Assumes `year_data` repeats in a fixed pattern for each country.
Args:
year_data (list): List of years (repeated for each country).
stats (list): List of statistics corresponding to the years.
"""
# Determine the looping constant (number of years per country)
start_year = year_data[0]
looping_constant = 0
for year in year_data:
if year == start_year and looping_constant > 0:
break
looping_constant += 1
# Format the data: group stats by country and year
formatted_data = []
for i in range(0, len(stats), looping_constant):
country_data = []
for j in range(looping_constant):
if i + j < len(stats): # Ensure we don't exceed list bounds
country_data.append((year_data[j], stats[i + j]))
formatted_data.append(country_data)
return formatted_data
def filter_top_countries(self, formatted_data, top_n=5, metric_index=1):
"""
Filters top N countries based on the sum of a metric (e.g., stats).
Args:
formatted_data (list): List of lists with (year, stat) tuples.
top_n (int): Number of countries to include.
metric_index (int): Index of the stat in the tuple.
Returns:
list: Filtered data for top N countries.
"""
# Calculate total stats for each country
country_totals = [
(country_index, sum(data[metric_index] for data in country_data))
for country_index, country_data in enumerate(formatted_data)
]
# Sort by total stats and get the top N
top_countries = sorted(country_totals, key=lambda x: x[1], reverse=True)[:top_n]
# Return only the data for the top N countries
return [formatted_data[country[0]] for country in top_countries]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment