import os
import pandas as pd
import yaml
import tkinter as tk
from tkinter import ttk
from pandasai import Agent
import pypandoc

# Sample DataFrame
sales_by_country = pd.DataFrame({
    "country": ["United States", "United Kingdom", "France", "Germany", "Italy", "Spain", "Canada", "Australia", "Japan", "China", "Malta"],
    "sales": [5000, 3200, 2900, 4100, 2300, 2100, 2500, 2600, 4500, 7000, 2900]    
})

# Load the API key from config.yaml
with open("config.yaml", "r") as file:
    config = yaml.safe_load(file)

# Set up PandasAI API key
os.environ["PANDASAI_API_KEY"] = config["PANDASAI_API_KEY"]

# Initialize the Agent with the DataFrame
agent = Agent(dfs=sales_by_country)

# Create the main application window
root = tk.Tk()
root.title("PandasAI Question Selector")
root.geometry("600x650")

# Title Label (Sales by Country)
title_label = tk.Label(root, text="Sales by Country", font=("Arial", 16, "bold"))
title_label.pack(pady=10)

# List of questions, including the new one for showing the original DataFrame
questions = [
    "Which are the top 5 countries by sales?",
    "What is the total sales?",
    "Which country has the highest sales?",
    "Which country has the lowest sales?",
    "List countries with sales over 5000.",
    "List countries with sales over 3000.",
    "Show the original DataFrame",
    "Which countries contribute 80% of the sales?"
]

# Increase font size
font_style = ("Arial", 12)

# Dropdown menu with increased width and font size
selected_question = tk.StringVar()
dropdown = ttk.Combobox(root, textvariable=selected_question, values=questions, width=50, font=font_style)
dropdown.pack(pady=20)
dropdown.set("Select a question")

# Frame for showing the results
table_frame = tk.Frame(root)
table_frame.pack(pady=10)

# Separate frame for summary information below the table
summary_frame = tk.Frame(root)
summary_frame.pack(pady=10)

# Function to clear the previous table and summary
def clear_table():
    # Clear all widgets in table_frame
    for widget in table_frame.winfo_children():
        widget.destroy()
    # Clear all widgets in summary_frame
    for widget in summary_frame.winfo_children():
        widget.destroy()        

# Function to get countries that make up 80% of sales
def countries_representing_80_percent():
     # Sort by sales in descending order
    sorted_df = sales_by_country.sort_values(by="sales", ascending=False)
    # Calculate total sales and 80% threshold
    total_sales = sorted_df["sales"].sum()
    threshold = 0.8 * total_sales

    # Select countries until cumulative sales reach 80% of total sales
    cumulative_sales = 0
    countries_80_percent = []

    for _, row in sorted_df.iterrows():
        cumulative_sales += row["sales"]
        countries_80_percent.append((row["country"], row["sales"]))
        if cumulative_sales >= threshold:
            break
    
    # Calculate the percentage these countries represent
    percentage_of_total = (cumulative_sales / total_sales) * 100

    return countries_80_percent, cumulative_sales, percentage_of_total, total_sales

# Function to display the response in a table format or wrap text if necessary
def get_response():
    question = selected_question.get()
    if question:
        clear_table()  # Clear previous results at the start
        
        try:
            if question == "Which are the top 5 countries by sales?":
                # Response for the first question
                response = agent.chat(question)
                if isinstance(response, pd.DataFrame):
                    # Sort the DataFrame by Sales column in descending order
                    response = response.sort_values(by='sales', ascending=False)
                    
                    # Create table headers
                    header_country = tk.Label(table_frame, text="Country", font=font_style, borderwidth=2, relief="groove", width=25)
                    header_country.grid(row=0, column=0)
                    header_sales = tk.Label(table_frame, text="Sales", font=font_style, borderwidth=2, relief="groove", width=10)
                    header_sales.grid(row=0, column=1)
                    
                    # Populate the table with DataFrame values
                    for i, (index, row) in enumerate(response.iterrows(), start=1):
                        country_label = tk.Label(table_frame, text=row['country'], font=font_style, borderwidth=2, relief="groove", width=25)
                        country_label.grid(row=i, column=0)
                        sales_label = tk.Label(table_frame, text=row['sales'], font=font_style, borderwidth=2, relief="groove", width=10)
                        sales_label.grid(row=i, column=1)

            elif question == "What is the total sales?":
                # For total sales, sum up the sales and list countries in alphabetical order
                total_sales = sales_by_country['sales'].sum()
                country_names = sorted(sales_by_country['country'].tolist())  # Sorted alphabetically
                
                result_label = tk.Label(table_frame, text=f"Total Sales: {total_sales}\nCountries (Alphabetical Order): {country_names}", font=font_style, justify="left", wraplength=450)
                result_label.pack()

            elif question == "Which country has the highest sales?":
                max_sales_row = sales_by_country.loc[sales_by_country['sales'].idxmax()]
                country_with_max_sales = max_sales_row['country']
                sales_value = max_sales_row['sales']
                
                result_label = tk.Label(table_frame, text=f"Country with Highest Sales: {country_with_max_sales}\nSales: {sales_value}", font=font_style, justify="left", wraplength=450)
                result_label.pack()
            
            elif question == "Which country has the lowest sales?":
                min_sales_row = sales_by_country.loc[sales_by_country['sales'].idxmin()]
                country_with_min_sales = min_sales_row['country']
                sales_value = min_sales_row['sales']
                
                result_label = tk.Label(table_frame, text=f"Country with Lowest Sales: {country_with_min_sales}\nSales: {sales_value}", font=font_style, justify="left", wraplength=450)
                result_label.pack()

            elif question == "List countries with sales over 5000.":
                response = sales_by_country[sales_by_country['sales'] > 5000].sort_values(by='sales', ascending=False)
                
                header_country = tk.Label(table_frame, text="Country", font=font_style, borderwidth=2, relief="groove", width=25)
                header_country.grid(row=0, column=0)
                header_sales = tk.Label(table_frame, text="Sales", font=font_style, borderwidth=2, relief="groove", width=10)
                header_sales.grid(row=0, column=1)
                
                for i, (index, row) in enumerate(response.iterrows(), start=1):
                    country_label = tk.Label(table_frame, text=row['country'], font=font_style, borderwidth=2, relief="groove", width=25)
                    country_label.grid(row=i, column=0)
                    sales_label = tk.Label(table_frame, text=row['sales'], font=font_style, borderwidth=2, relief="groove", width=10)
                    sales_label.grid(row=i, column=1)    

            elif question == "List countries with sales over 3000.":
                response = sales_by_country[sales_by_country['sales'] > 3000].sort_values(by='sales', ascending=False)
                
                header_country = tk.Label(table_frame, text="Country", font=font_style, borderwidth=2, relief="groove", width=25)
                header_country.grid(row=0, column=0)
                header_sales = tk.Label(table_frame, text="Sales", font=font_style, borderwidth=2, relief="groove", width=10)
                header_sales.grid(row=0, column=1)
                
                for i, (index, row) in enumerate(response.iterrows(), start=1):
                    country_label = tk.Label(table_frame, text=row['country'], font=font_style, borderwidth=2, relief="groove", width=25)
                    country_label.grid(row=i, column=0)
                    sales_label = tk.Label(table_frame, text=row['sales'], font=font_style, borderwidth=2, relief="groove", width=10)
                    sales_label.grid(row=i, column=1)

            elif question == "Show the original DataFrame":
                # Show the entire original DataFrame with scrollable content
                canvas = tk.Canvas(table_frame)
                scrollbar = tk.Scrollbar(table_frame, orient="vertical", command=canvas.yview)
                scrollable_frame = tk.Frame(canvas)

                scrollable_frame.bind(
                    "<Configure>",
                    lambda e: canvas.configure(scrollregion=canvas.bbox("all"))
                )

                canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
                canvas.configure(yscrollcommand=scrollbar.set)

                canvas.pack(side="left", fill="both", expand=True)
                scrollbar.pack(side="right", fill="y")

                header_country = tk.Label(scrollable_frame, text="Country", font=font_style, borderwidth=2, relief="groove", width=25)
                header_country.grid(row=0, column=0)
                header_sales = tk.Label(scrollable_frame, text="Sales", font=font_style, borderwidth=2, relief="groove", width=10)
                header_sales.grid(row=0, column=1)
                
                for i, (index, row) in enumerate(sales_by_country.iterrows(), start=1):
                    country_label = tk.Label(scrollable_frame, text=row['country'], font=font_style, borderwidth=2, relief="groove", width=25)
                    country_label.grid(row=i, column=0)
                    sales_label = tk.Label(scrollable_frame, text=row['sales'], font=font_style, borderwidth=2, relief="groove", width=10)
                    sales_label.grid(row=i, column=1)                  

            elif question == "Which countries contribute 80% of the sales?":
                countries_80, sum_sales, percentage, total_sales = countries_representing_80_percent()

                header_country = tk.Label(table_frame, text="Country", font=font_style, borderwidth=2, relief="groove", width=25)
                header_country.grid(row=0, column=0)
                header_sales = tk.Label(table_frame, text="Sales", font=font_style, borderwidth=2, relief="groove", width=10)
                header_sales.grid(row=0, column=1)
            
                for i, (country, sales) in enumerate(countries_80, start=1):
                    country_label = tk.Label(table_frame, text=country, font=font_style, borderwidth=2, relief="groove", width=25)
                    country_label.grid(row=i, column=0)
                    sales_label = tk.Label(table_frame, text=sales, font=font_style, borderwidth=2, relief="groove", width=10)
                    sales_label.grid(row=i, column=1)

                # Populate summary information below the table in summary_frame
                sum_sales_label = tk.Label(summary_frame, text=f"Sum of Sales for Countries contributing 80% or more of Total Sales: {sum_sales}", font=font_style)
                sum_sales_label.pack()
                percentage_label = tk.Label(summary_frame, text=f"% of Sum of Sales for Countries contributing 80% or more of Total Sales: {percentage:.2f}%", font=font_style)
                percentage_label.pack()
                total_sales_label = tk.Label(summary_frame, text=f"Total Sales (100%): {total_sales}", font=font_style)
                total_sales_label.pack()
        
        except Exception as e:
            error_label = tk.Label(table_frame, text=f"Error: {str(e)}", font=font_style, fg="red")
            error_label.pack()

# Button to get the answer with increased font size
submit_button = tk.Button(root, text="Get Answer", command=get_response, font=font_style)
submit_button.pack(pady=10)

root.mainloop()