openai-sql/gui/main.py
crennis 07410e66ff Fixed Test Connection button,
Implemented Threading
2023-05-05 11:18:17 +02:00

326 lines
11 KiB
Python

import json
import os
import sys
from PyQt5.QtCore import QObject, QThread, pyqtSignal
from PyQt5.QtWidgets import QApplication, QMainWindow, QDialog
import modules.aisql as aisql
from gui.apikey import Ui_ApiKey as ApiKeyForm
from gui.connection import Ui_Connection as ConnectionForm
from gui.gui.Window import Ui_MainWindow
configfolder = "config"
configfile = "config.json"
dbtypes = {
0: "PostgreSQL",
1: "MySQL",
2: "SQLite"
}
# Load DB Config from file but check if file exists and create if not
def load_config():
if not os.path.exists(os.path.join(configfolder, configfile)):
os.makedirs(configfolder)
with open(os.path.join(configfolder, configfile), "w") as f:
json.dump({
"database": {
"dbtype": 0,
"ip": "",
"port": "",
"user": "",
"password": "",
"database": ""
},
"apikey": ""
}, f, indent=4)
else:
with open(os.path.join(configfolder, configfile), "r") as f:
config = json.load(f)
type = config["database"]["dbtype"]
ip = config["database"]["ip"]
port = config["database"]["port"]
user = config["database"]["user"]
password = config["database"]["password"]
name = config["database"]["database"]
apikey = config["apikey"]
return type, ip, port, user, password, name, apikey
class Worker(QObject):
finished = pyqtSignal()
test = False
def __init__(self):
super().__init__()
def test_db_connection(self, dbtype, dbip, dbport, dbuser, dbpass, dbname):
try:
if dbtype == 0:
import database.postgresql as pg
self.db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname)
if self.db.test_connection():
self.test = True
except Exception as e:
print(e)
self.finished.emit()
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.ui = Ui_MainWindow()
self.ui.setupUi(self)
self.tableschema = []
# Hide outputLabel
self.ui.outputLabel.hide()
# Open Connection Window
self.ui.actionConnect_DB.triggered.connect(self.open_connection)
# Open API Key Window
self.ui.actionConnect_API_Key.triggered.connect(self.open_apikey)
# Pressed Convert Button
self.ui.convertButton.clicked.connect(self.on_convert_button_clicked)
# Pressed Paste Button
self.ui.pasteButton.clicked.connect(self.on_paste_button_clicked)
# Pressed Execute Button
self.ui.executeButton.clicked.connect(self.on_execute_button_clicked)
def try_to_connect(self):
dbtype, dbip, dbport, dbuser, dbpass, dbname, apikey = load_config()
try:
if dbtype == 0:
import database.postgresql as pg
self.db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname)
self.tableschema = self.db.get_schema()
print('Database Connected')
except Exception as e:
print('No Database Connection')
def open_connection(self):
self.connection_window = ConnectionWindow(self)
self.connection_window.show()
def open_apikey(self):
self.apikey_window = ApiKeyWindow(self)
self.apikey_window.show()
def on_convert_button_clicked(self):
self.ui.outputLabel.show()
self.ui.outputLabel.setText("Converting...")
self.load_config()
print("Convert Button Clicked")
ai = aisql.AI(self.apikey)
prompt = self.ui.textInput.text()
sql = ai.humantosql(prompt, dbtypes[self.dbtype], self.tableschema)
self.ui.statementOutput.setText(sql)
self.ui.outputLabel.setText("Converted!")
def on_paste_button_clicked(self):
self.ui.shellInput.setText(self.ui.statementOutput.toPlainText())
def on_execute_button_clicked(self):
print("Execute Button Clicked")
self.try_to_connect()
ai = aisql.AI(self.apikey)
self.ui.outputLabel.setText("Executing...")
self.load_config()
sql = self.ui.shellInput.toPlainText()
# Check what method to use
decision = ai.decide(sql)
if "fetchall".casefold() in decision.casefold():
fetch = self.db.fetchall(sql)
print(fetch)
elif "fetchone".casefold() in decision.casefold():
fetch = self.db.fetchone(sql)
print(fetch)
elif "fetchmany".casefold() in decision.casefold():
size = decision.split("=")[1].strip("]")
fetch = self.db.fetchmany(sql, int(size))
print(fetch)
elif "execute".casefold() in decision.casefold():
self.db.execute(sql)
elif "executemany".casefold() in decision.casefold():
size = decision.split("=")[1].strip("]")
self.db.executemany(sql, int(size))
self.ui.outputLabel.setText("Executed!")
### Connection Window ###
class ConnectionWindow(QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.ui = ConnectionForm()
self.ui.setupUi(self)
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
self.ui.returnLabel.setText("")
# Load DB Config from file
with open(os.path.join(configfolder, configfile), "r") as f:
self.config = json.load(f)
self.dbtype = self.config["database"]["dbtype"]
self.dbip = self.config["database"]["ip"]
self.dbport = self.config["database"]["port"]
self.dbuser = self.config["database"]["user"]
self.dbpass = self.config["database"]["password"]
self.dbname = self.config["database"]["database"]
self.ui.dbtypeCombo.addItems(dbtypes.values())
self.ui.dbtypeCombo.setCurrentIndex(self.dbtype)
self.ui.ipInput.setText(self.dbip)
self.ui.portInput.setText(self.dbport)
self.ui.usernameInput.setText(self.dbuser)
self.ui.passwordInput.setText(self.dbpass)
self.ui.databaseInput.setText(self.dbname)
# Unlock Buttons if ip, port and database is not empty
self.ui.ipInput.textChanged.connect(self.on_text_changed)
self.ui.portInput.textChanged.connect(self.on_text_changed)
self.ui.databaseInput.textChanged.connect(self.on_text_changed)
# Pressed Save Button
self.ui.saveButton.clicked.connect(self.on_save_button_clicked)
# Pressed Test Button
self.ui.testButton.clicked.connect(self.start_db_test_thread)
def on_text_changed(self):
if self.ui.ipInput.text() == "" or self.ui.portInput.text() == "" or self.ui.databaseInput.text() == "":
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
else:
self.ui.saveButton.setEnabled(True)
self.ui.testButton.setEnabled(True)
def on_save_button_clicked(self):
self.dbtype = self.ui.dbtypeCombo.currentIndex()
self.dbip = self.ui.ipInput.text()
self.dbport = self.ui.portInput.text()
self.dbuser = self.ui.usernameInput.text()
self.dbpass = self.ui.passwordInput.text()
self.dbname = self.ui.databaseInput.text()
self.config["database"]["dbtype"] = self.dbtype
self.config["database"]["ip"] = self.dbip
self.config["database"]["port"] = self.dbport
self.config["database"]["user"] = self.dbuser
self.config["database"]["password"] = self.dbpass
self.config["database"]["database"] = self.dbname
with open(os.path.join(configfolder, configfile), "w") as f:
json.dump(self.config, f, indent=4)
self.close()
def start_db_test_thread(self):
self.ui.returnLabel.setText("Testing...")
self.ui.returnLabel.setStyleSheet("color: black;")
self.ui.testButton.setEnabled(False)
dbtype = self.ui.dbtypeCombo.currentIndex()
ip = self.ui.ipInput.text()
port = self.ui.portInput.text()
user = self.ui.usernameInput.text()
password = self.ui.passwordInput.text()
database = self.ui.databaseInput.text()
self.thread = QThread()
self.worker = Worker()
self.worker.moveToThread(self.thread)
self.thread.started.connect(lambda: self.worker.test_db_connection(dbtype, ip, port, user, password, database))
self.worker.finished.connect(self.thread.quit)
self.thread.finished.connect(self.thread_complete)
self.thread.start()
def thread_complete(self):
if self.worker.test:
self.ui.returnLabel.setText("Connection Success!")
self.ui.returnLabel.setStyleSheet("color: green;")
else:
self.ui.returnLabel.setText("Connection Failed!")
self.ui.returnLabel.setStyleSheet("color: red;")
self.ui.testButton.setEnabled(True)
self.worker.deleteLater()
self.thread.deleteLater()
### Api Key Window ###
class ApiKeyWindow(QDialog):
def __init__(self, parent=None):
super().__init__(parent)
self.ui = ApiKeyForm()
self.ui.setupUi(self)
self.ui.outputLabel.setText("")
# Load API Key from file
with open(os.path.join(configfolder, configfile), "r") as f:
self.config = json.load(f)
self.apikey = self.config["apikey"]
self.ui.apikeyInput.setText(self.apikey)
if self.apikey == "":
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
else:
self.ui.saveButton.setEnabled(True)
self.ui.testButton.setEnabled(True)
# Unlock Buttons if text is entered
self.ui.apikeyInput.textChanged.connect(self.on_text_changed)
# Pressed Save Button
self.ui.saveButton.clicked.connect(self.on_save_button_clicked)
# Pressed Test Button
self.ui.testButton.clicked.connect(self.on_test_button_clicked)
def on_text_changed(self):
if self.ui.apikeyInput.text() == "":
self.ui.saveButton.setEnabled(False)
self.ui.testButton.setEnabled(False)
else:
self.ui.saveButton.setEnabled(True)
self.ui.testButton.setEnabled(True)
def on_save_button_clicked(self):
self.apikey = self.ui.apikeyInput.text()
with open(os.path.join(configfolder, configfile), "w") as f:
self.config["apikey"] = self.apikey
json.dump(self.config, f)
# Close Window
self.close()
def on_test_button_clicked(self):
test_key = self.ui.apikeyInput.text()
ai = aisql.AI(test_key)
if ai.test_key():
self.ui.outputLabel.setText("API Key is valid")
self.ui.outputLabel.setStyleSheet("color: green")
else:
self.ui.outputLabel.setText("API Key is invalid")
self.ui.outputLabel.setStyleSheet("color: red")
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())