diff --git a/gui/main.py b/gui/main.py index 52b7fcc..59cb62c 100644 --- a/gui/main.py +++ b/gui/main.py @@ -21,7 +21,7 @@ dbtypes = { } # Load DB Config from file but check if file exists and create if not -def load_config(): +def load_config() -> tuple: if not os.path.exists(os.path.join(configfolder, configfile)): os.makedirs(configfolder) with open(os.path.join(configfolder, configfile), "w") as f: @@ -49,27 +49,91 @@ def load_config(): return type, ip, port, user, password, name, apikey +def save_config(dbtype, dbip, dbport, dbuser, dbpass, dbname, apikey) -> None: + with open(os.path.join(configfolder, configfile), "w") as f: + json.dump({ + "database": { + "dbtype": dbtype, + "ip": dbip, + "port": dbport, + "user": dbuser, + "password": dbpass, + "database": dbname + }, + "apikey": apikey + }, f, indent=4) + +def connect_db() -> tuple: + db = None + tableschema = [] + dbtype, dbip, dbport, dbuser, dbpass, dbname, apikey = load_config() + try: + if dbtype == 0: + import database.postgresql as pg + db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname) + tableschema = db.get_schema() + + print('Database Connected') + return db, tableschema + except Exception as e: + print('No Database Connection') + print(e) + return db, tableschema + +def test_connection(dbtype, dbip, dbport, dbuser, dbpass, dbname) -> bool: + db = None + try: + if dbtype == 0: + import database.postgresql as pg + db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname) + if db.test_connection(): + return True + except Exception as e: + print(e) + return False + class Worker(QObject): finished = pyqtSignal() test = False + db = None + tableschema = [] def __init__(self): super().__init__() def test_db_connection(self, dbtype, dbip, dbport, dbuser, dbpass, dbname): + connection = test_connection(dbtype, dbip, dbport, dbuser, dbpass, dbname) + if connection: + self.test = True + self.finished.emit() + def test_api_connection(self, apikey): try: - if dbtype == 0: - import database.postgresql as pg - self.db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname) - - if self.db.test_connection(): + ai = aisql.AI(apikey) + if ai.test_key(): self.test = True except Exception as e: print(e) + self.finished.emit() + def connect_db_worker(self): + self.db, self.tableschema = connect_db() + + + def testing(self): + self.connect_db_worker() + print('test') + self.finished.emit() + +# TODO: Add a translation Worker and a SQL Worker + def translate(self, text): + print("Translating...") + self.finished.emit() + + def run_sql(self, sql): + print("Running SQL...") self.finished.emit() class MainWindow(QMainWindow): @@ -98,13 +162,9 @@ class MainWindow(QMainWindow): self.ui.executeButton.clicked.connect(self.on_execute_button_clicked) def try_to_connect(self): - dbtype, dbip, dbport, dbuser, dbpass, dbname, apikey = load_config() + # TODO: Rewrite to use a Worker try: - if dbtype == 0: - import database.postgresql as pg - self.db = pg.Postgres(dbip, dbport, dbuser, dbpass, dbname) - self.tableschema = self.db.get_schema() - + self.db = connect_db() print('Database Connected') except Exception as e: print('No Database Connection') @@ -113,50 +173,78 @@ class MainWindow(QMainWindow): self.connection_window = ConnectionWindow(self) self.connection_window.show() + def on_paste_button_clicked(self): + self.ui.shellInput.setText(self.ui.statementOutput.toPlainText()) def open_apikey(self): self.apikey_window = ApiKeyWindow(self) self.apikey_window.show() - def on_convert_button_clicked(self): - self.ui.outputLabel.show() - self.ui.outputLabel.setText("Converting...") - self.load_config() - print("Convert Button Clicked") - ai = aisql.AI(self.apikey) - prompt = self.ui.textInput.text() - sql = ai.humantosql(prompt, dbtypes[self.dbtype], self.tableschema) - self.ui.statementOutput.setText(sql) - self.ui.outputLabel.setText("Converted!") - - def on_paste_button_clicked(self): - self.ui.shellInput.setText(self.ui.statementOutput.toPlainText()) - def on_execute_button_clicked(self): - print("Execute Button Clicked") - self.try_to_connect() - ai = aisql.AI(self.apikey) self.ui.outputLabel.setText("Executing...") - self.load_config() - sql = self.ui.shellInput.toPlainText() - # Check what method to use - decision = ai.decide(sql) - if "fetchall".casefold() in decision.casefold(): - fetch = self.db.fetchall(sql) - print(fetch) - elif "fetchone".casefold() in decision.casefold(): - fetch = self.db.fetchone(sql) - print(fetch) - elif "fetchmany".casefold() in decision.casefold(): - size = decision.split("=")[1].strip("]") - fetch = self.db.fetchmany(sql, int(size)) - print(fetch) - elif "execute".casefold() in decision.casefold(): - self.db.execute(sql) - elif "executemany".casefold() in decision.casefold(): - size = decision.split("=")[1].strip("]") - self.db.executemany(sql, int(size)) + self.ui.outputLabel.show() + self.ui.executeButton.setEnabled(False) + + sql = self.ui.shellInput.toPlainText() + + self.thread = QThread() + self.worker = Worker() + self.worker.moveToThread(self.thread) + + self.thread.started.connect(self.worker.testing) + self.worker.finished.connect(self.execute_finish) + + self.thread.start() + + def execute_finish(self): + self.ui.outputLabel.setText("Finished!") + self.ui.outputLabel.show() + self.ui.executeButton.setEnabled(True) + print("finished") + + self.worker.deleteLater() + self.thread.deleteLater() + print("deleted") + + def on_convert_button_clicked(self): + self.ui.outputLabel.setText("Converting...") + self.ui.outputLabel.show() + self.ui.convertButton.setEnabled(False) + + # def start_db_test_thread(self): + # self.ui.returnLabel.setText("Testing...") + # self.ui.returnLabel.setStyleSheet("color: black;") + # self.ui.testButton.setEnabled(False) + # + # dbtype = self.ui.dbtypeCombo.currentIndex() + # ip = self.ui.ipInput.text() + # port = self.ui.portInput.text() + # user = self.ui.usernameInput.text() + # password = self.ui.passwordInput.text() + # database = self.ui.databaseInput.text() + # + # self.thread = QThread() + # self.worker = Worker() + # self.worker.moveToThread(self.thread) + # + # self.thread.started.connect( + # lambda: self.worker.test_db_connection(dbtype, ip, port, user, password, database)) + # self.worker.finished.connect(self.thread.quit) + # self.thread.finished.connect(self.thread_complete) + # + # self.thread.start() + # + # def thread_complete(self): + # if self.worker.test: + # self.ui.returnLabel.setText("Connection Success!") + # self.ui.returnLabel.setStyleSheet("color: green;") + # else: + # self.ui.returnLabel.setText("Connection Failed!") + # self.ui.returnLabel.setStyleSheet("color: red;") + # + # self.ui.testButton.setEnabled(True) + # self.worker.deleteLater() + # self.thread.deleteLater() - self.ui.outputLabel.setText("Executed!") ### Connection Window ### class ConnectionWindow(QDialog): @@ -288,7 +376,7 @@ class ApiKeyWindow(QDialog): self.ui.saveButton.clicked.connect(self.on_save_button_clicked) # Pressed Test Button - self.ui.testButton.clicked.connect(self.on_test_button_clicked) + self.ui.testButton.clicked.connect(self.start_api_test_thread) def on_text_changed(self): if self.ui.apikeyInput.text() == "": @@ -308,15 +396,36 @@ class ApiKeyWindow(QDialog): # Close Window self.close() - def on_test_button_clicked(self): - test_key = self.ui.apikeyInput.text() - ai = aisql.AI(test_key) - if ai.test_key(): - self.ui.outputLabel.setText("API Key is valid") - self.ui.outputLabel.setStyleSheet("color: green") + def start_api_test_thread(self): + self.ui.outputLabel.setText("Testing...") + self.ui.outputLabel.setStyleSheet("color: black;") + self.ui.testButton.setEnabled(False) + + apikey = self.ui.apikeyInput.text() + + self.thread = QThread() + self.worker = Worker() + self.worker.moveToThread(self.thread) + + self.thread.started.connect(lambda: self.worker.test_api_connection(apikey)) + self.worker.finished.connect(self.thread.quit) + self.thread.finished.connect(self.thread_complete) + + self.thread.start() + + def thread_complete(self): + if self.worker.test: + self.ui.outputLabel.setText("Connection Success!") + self.ui.outputLabel.setStyleSheet("color: green;") else: - self.ui.outputLabel.setText("API Key is invalid") - self.ui.outputLabel.setStyleSheet("color: red") + self.ui.outputLabel.setText("Connection Failed!") + self.ui.outputLabel.setStyleSheet("color: red;") + + self.ui.outputLabel.setEnabled(True) + self.worker.deleteLater() + self.thread.deleteLater() + + if __name__ == "__main__": app = QApplication(sys.argv)