openai-sql/gui/main.py
2023-05-08 14:50:29 +02:00

451 lines
15 KiB
Python

import json
import os
import sys
from PyQt5.QtCore import QObject, QThread, pyqtSignal
from PyQt5.QtWidgets import QApplication, QMainWindow, QDialog, QTableWidgetItem, QTableWidget
import modules.aisql as aisql
from gui.apikey import Ui_ApiKey as ApiKeyForm
from gui.connection import Ui_Connection as ConnectionForm
from gui.gui.Window import Ui_MainWindow
configfolder = "config"
configfile = "config.json"
dbtypes = {
0: "PostgreSQL",
1: "MySQL",
2: "SQLite"
}
# Load DB Config from file but check if file exists and create if not
def load_config() -> tuple:
if not os.path.exists(os.path.join(configfolder, configfile)):
os.makedirs(configfolder)
with open(os.path.join(configfolder, configfile), "w") as f:
json.dump({
"database": {
"dbtype": 0,
"ip": "",
"port": "",
"user": "",
"password": "",
"database": ""
},
"apikey": ""
}, f, indent=4)
else:
with open(os.path.join(configfolder, configfile), "r") as f:
config = json.load(f)
type = config["database"]["dbtype"]
ip = config["database"]["ip"]
port = config["database"]["port"]
user = config["database"]["user"]
password = config["database"]["password"]
name = config["database"]["database"]
apikey = config["apikey"]
return type, ip, port, user, password, name, apikey
def save_config(dbtype, dbip, dbport, dbuser, dbpass, dbname, apikey) -> None:
with open(os.path.join(configfolder, configfile), "w") as f:
json.dump({
"database": {
"dbtype": dbtype,
"ip": dbip,
"port": dbport,
"user": dbuser,
"password": dbpass,
"database": dbname
},
"apikey": apikey
}, f, indent=4)
def connect_db() -> tuple:
db = None
tableschema = []
dbtype, dbip, dbport, dbuser, dbpass, dbname, apikey = load_config()
try:
if dbtype == 0:
import database.postgresql as pg
db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname)
tableschema = db.get_schema()
print('Database Connected')
return db, tableschema
except Exception as e:
print('No Database Connection')
print(e)
return db, tableschema
def test_connection(dbtype, dbip, dbport, dbuser, dbpass, dbname) -> bool:
db = None
try:
if dbtype == 0:
import database.postgresql as pg
db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname)
if db.test_connection():
return True
except Exception as e:
print(e)
return False
class Worker(QObject):
finished = pyqtSignal()
test = False
db = None
tableschema = []
result = ""
def __init__(self):
super().__init__()
def test_db_connection(self, dbtype, dbip, dbport, dbuser, dbpass, dbname):
connection = test_connection(dbtype, dbip, dbport, dbuser, dbpass, dbname)
if connection:
self.test = True
self.finished.emit()
def test_api_connection(self, apikey):
try:
ai = aisql.AI(apikey)
if ai.test_key():
self.test = True
except Exception as e:
print(e)
self.finished.emit()
def connect_db_worker(self):
self.db, self.tableschema = connect_db()
def translate(self, text):
# Load Config
self.connect_db_worker()
dbtype, dbip, dbport, dbuser, dbpassword, dbname, apikey = load_config()
ai = aisql.AI(apikey)
self.result = ai.humantosql(text, str(self.db), self.tableschema)
self.finished.emit()
def run_sql(self, sql):
print("Running SQL...")
db, tableschema = connect_db()
ai = aisql.AI(api_key=load_config()[6])
choice = ai.decide(sql)
print(choice)
if "fetchall".casefold() in choice.casefold():
self.result = db.fetchall(sql)
elif "fetchone".casefold() in choice.casefold():
self.result = db.fetchone(sql)
elif "execute".casefold() in choice.casefold():
self.result = db.execute(sql)
else:
print("error")
print(self.result)
self.finished.emit()
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setupUi()
def setupUi(self):
self.ui = Ui_MainWindow()
self.ui.setupUi(self)
self.tableschema = []
# Hide outputLabel
self.ui.outputLabel.hide()
# Open Connection Window
self.ui.actionConnect_DB.triggered.connect(self.open_connection)
# Open API Key Window
self.ui.actionConnect_API_Key.triggered.connect(self.open_apikey)
# Pressed Convert Button
# FIXME: UI still freezes when clicked, QThreads not working properly
self.ui.convertButton.clicked.connect(self.convert)
# Pressed Paste Button
self.ui.pasteButton.clicked.connect(self.on_paste_button_clicked)
# TODO: Functioning Execute Button
# Pressed Execute Button
self.ui.executeButton.clicked.connect(self.execute)
self.ui.outputTable.setColumnCount(2)
self.ui.outputTable.setHorizontalHeaderLabels(["Column", "Value"])
def try_to_connect(self):
# TODO: Rewrite to use a Worker
try:
self.db = connect_db()
print('Database Connected')
except Exception as e:
print('No Database Connection')
def open_connection(self):
self.connection_window = ConnectionWindow(self)
self.connection_window.show()
def on_paste_button_clicked(self):
self.ui.shellInput.setText(self.ui.statementOutput.toPlainText())
def open_apikey(self):
self.apikey_window = ApiKeyWindow(self)
self.apikey_window.show()
# TODO: Make this work
def execute(self):
self.ui.outputLabel.setText("Executing...")
self.ui.outputLabel.show()
self.ui.executeButton.setEnabled(False)
self.worker = Worker()
self.thread = QThread()
self.worker.moveToThread(self.thread)
self.worker.finished.connect(self.thread.quit)
self.worker.finished.connect(self.worker.deleteLater)
self.thread.finished.connect(self.thread.deleteLater)
self.thread.started.connect(lambda: self.worker.run_sql(self.ui.shellInput.toPlainText()))
self.worker.finished.connect(self.execute_finished)
self.thread.start()
def execute_finished(self):
self.data = self.worker.result
self.ui.outputLabel.setText("Finished!")
self.ui.outputLabel.show()
self.ui.executeButton.setEnabled(True)
def convert(self):
self.ui.outputLabel.setText("Converting...")
self.ui.outputLabel.show()
self.ui.convertButton.setEnabled(False)
self.worker = Worker()
self.thread = QThread()
self.worker.moveToThread(self.thread)
self.worker.finished.connect(self.thread.quit)
self.worker.finished.connect(self.worker.deleteLater)
self.thread.finished.connect(self.thread.deleteLater)
self.thread.started.connect(lambda: self.worker.translate(self.ui.textInput.text()))
self.worker.finished.connect(self.convert_finished)
self.thread.start()
def convert_finished(self):
self.ui.statementOutput.setText(self.worker.result)
self.ui.outputLabel.setText("Finished!")
self.ui.outputLabel.show()
self.ui.convertButton.setEnabled(True)
# TODO: Convert Function Threading
# FIXME: Still freezing UI even with threading
### Connection Window ###
class ConnectionWindow(QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.ui = ConnectionForm()
self.ui.setupUi(self)
self.ui.returnLabel.setText("")
# Load DB Config from file
with open(os.path.join(configfolder, configfile), "r") as f:
self.config = json.load(f)
self.dbtype = self.config["database"]["dbtype"]
self.dbip = self.config["database"]["ip"]
self.dbport = self.config["database"]["port"]
self.dbuser = self.config["database"]["user"]
self.dbpass = self.config["database"]["password"]
self.dbname = self.config["database"]["database"]
self.ui.dbtypeCombo.addItems(dbtypes.values())
self.ui.dbtypeCombo.setCurrentIndex(self.dbtype)
self.ui.ipInput.setText(self.dbip)
self.ui.portInput.setText(self.dbport)
self.ui.usernameInput.setText(self.dbuser)
self.ui.passwordInput.setText(self.dbpass)
self.ui.databaseInput.setText(self.dbname)
if (self.ui.ipInput or self.ui.portInput or self.ui.databaseInput) == "":
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
else:
self.ui.saveButton.setEnabled(True)
self.ui.testButton.setEnabled(True)
# Unlock Buttons if ip, port and database is not empty
self.ui.ipInput.textChanged.connect(self.on_text_changed)
self.ui.portInput.textChanged.connect(self.on_text_changed)
self.ui.databaseInput.textChanged.connect(self.on_text_changed)
# Pressed Save Button
self.ui.saveButton.clicked.connect(self.on_save_button_clicked)
# Pressed Test Button
self.ui.testButton.clicked.connect(self.start_db_test_thread)
def on_text_changed(self):
if self.ui.ipInput.text() == "" or self.ui.portInput.text() == "" or self.ui.databaseInput.text() == "":
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
else:
self.ui.saveButton.setEnabled(True)
self.ui.testButton.setEnabled(True)
def on_save_button_clicked(self):
self.dbtype = self.ui.dbtypeCombo.currentIndex()
self.dbip = self.ui.ipInput.text()
self.dbport = self.ui.portInput.text()
self.dbuser = self.ui.usernameInput.text()
self.dbpass = self.ui.passwordInput.text()
self.dbname = self.ui.databaseInput.text()
self.config["database"]["dbtype"] = self.dbtype
self.config["database"]["ip"] = self.dbip
self.config["database"]["port"] = self.dbport
self.config["database"]["user"] = self.dbuser
self.config["database"]["password"] = self.dbpass
self.config["database"]["database"] = self.dbname
with open(os.path.join(configfolder, configfile), "w") as f:
json.dump(self.config, f, indent=4)
self.close()
def start_db_test_thread(self):
self.ui.returnLabel.setText("Testing...")
self.ui.returnLabel.setStyleSheet("color: black;")
self.ui.testButton.setEnabled(False)
dbtype = self.ui.dbtypeCombo.currentIndex()
ip = self.ui.ipInput.text()
port = self.ui.portInput.text()
user = self.ui.usernameInput.text()
password = self.ui.passwordInput.text()
database = self.ui.databaseInput.text()
self.thread = QThread()
self.worker = Worker()
self.worker.moveToThread(self.thread)
self.thread.started.connect(lambda: self.worker.test_db_connection(dbtype, ip, port, user, password, database))
self.worker.finished.connect(self.thread.quit)
self.thread.finished.connect(self.thread_complete)
self.thread.start()
def thread_complete(self):
if self.worker.test:
self.ui.returnLabel.setText("Connection Success!")
self.ui.returnLabel.setStyleSheet("color: green;")
else:
self.ui.returnLabel.setText("Connection Failed!")
self.ui.returnLabel.setStyleSheet("color: red;")
self.ui.testButton.setEnabled(True)
self.worker.deleteLater()
self.thread.deleteLater()
### Api Key Window ###
class ApiKeyWindow(QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.ui = ApiKeyForm()
self.ui.setupUi(self)
self.ui.outputLabel.setText("")
# Load API Key from file
with open(os.path.join(configfolder, configfile), "r") as f:
self.config = json.load(f)
self.apikey = self.config["apikey"]
self.ui.apikeyInput.setText(self.apikey)
if self.apikey == "":
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
else:
self.ui.saveButton.setEnabled(True)
self.ui.testButton.setEnabled(True)
# Unlock Buttons if text is entered
self.ui.apikeyInput.textChanged.connect(self.on_text_changed)
# Pressed Save Button
self.ui.saveButton.clicked.connect(self.on_save_button_clicked)
# Pressed Test Button
self.ui.testButton.clicked.connect(self.start_api_test_thread)
def on_text_changed(self):
if self.ui.apikeyInput.text() == "":
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
else:
self.ui.saveButton.setEnabled(True)
self.ui.testButton.setEnabled(True)
def on_save_button_clicked(self):
self.apikey = self.ui.apikeyInput.text()
with open(os.path.join(configfolder, configfile), "w") as f:
self.config["apikey"] = self.apikey
json.dump(self.config, f)
# Close Window
self.close()
def start_api_test_thread(self):
self.ui.outputLabel.setText("Testing...")
self.ui.outputLabel.setStyleSheet("color: black;")
self.ui.testButton.setEnabled(False)
apikey = self.ui.apikeyInput.text()
self.thread = QThread()
self.worker = Worker()
self.worker.moveToThread(self.thread)
self.thread.started.connect(lambda: self.worker.test_api_connection(apikey))
self.worker.finished.connect(self.thread.quit)
self.thread.finished.connect(self.thread_complete)
self.thread.start()
def thread_complete(self):
if self.worker.test:
self.ui.outputLabel.setText("Connection Success!")
self.ui.outputLabel.setStyleSheet("color: green;")
else:
self.ui.outputLabel.setText("Connection Failed!")
self.ui.outputLabel.setStyleSheet("color: red;")
self.ui.outputLabel.setEnabled(True)
self.worker.deleteLater()
self.thread.deleteLater()
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())