Skip to content

Commit 55c00fd

Browse files
authored
Merge pull request #153 from DCC-Lab/feature/postgresql-database
Add PostgresqlDatabase class
2 parents ea08827 + 56c297d commit 55c00fd

File tree

3 files changed

+443
-1
lines changed

3 files changed

+443
-1
lines changed

dcclab/database/database.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class Column(NamedTuple):
8686
class Engine(Enum):
8787
mysql = "mysql"
8888
sqlite3 = "sqlite3"
89+
postgresql = "postgresql"
8990

9091
class Database:
9192
def __init__(self, databaseURL, usePassword=True, writePermission=False):
@@ -785,6 +786,260 @@ def insert(self, table: str, entry: dict):
785786
table, keys, values)
786787
self.execute(statement)
787788

789+
class PostgresqlDatabase:
790+
def __init__(self, databaseURL, usePassword=True, writePermission=False):
791+
self.writePermission = writePermission
792+
self.databaseURL = databaseURL
793+
self.connection = None
794+
self.cursor = None
795+
self.databaseEngine = None
796+
797+
self.database = None
798+
self.port = None
799+
self.pgPassword = None
800+
self.usePassword = usePassword
801+
self.server = None
802+
803+
self.databaseEngine, self.sshUser, self.sshHost, self.pgHost, self.port, self.pgUser, self.pgPassword, self.database = self.parseURL(databaseURL)
804+
805+
self.connect()
806+
807+
def showDatabaseInfo(self):
808+
pass
809+
810+
def parseURL(self, url):
811+
# Standard URL formats:
812+
# postgresql://user[:password]@host[:port]/database
813+
# postgresql+ssh://sshuser@sshhost/pguser[:password]@pghost[:port]/database
814+
parsed = parse.urlparse(url)
815+
816+
if parsed.scheme in ('postgresql', 'postgres'):
817+
engine = Engine.postgresql
818+
pgUser = parsed.username
819+
pgPassword = parsed.password
820+
pgHost = parsed.hostname
821+
pgPort = parsed.port or 5432
822+
database = parsed.path.lstrip('/')
823+
if not pgUser or not pgHost or not database:
824+
raise ValueError("Incomplete postgresql URL: {0}. Use postgresql://user@host[:port]/database".format(url))
825+
return (engine, None, None, pgHost, pgPort, pgUser, pgPassword, database)
826+
827+
if parsed.scheme in ('postgresql+ssh', 'postgres+ssh'):
828+
engine = Engine.postgresql
829+
sshUser = parsed.username
830+
sshHost = parsed.hostname
831+
# Path contains: /pguser[:password]@pghost[:port]/database
832+
match = re.match(r'^/([^:@]+)(?::([^@]*))?@([^:/]+)(?::(\d+))?/(.+)$', parsed.path)
833+
if match is None:
834+
raise ValueError("Incomplete postgresql+ssh URL: {0}. Use postgresql+ssh://sshuser@sshhost/pguser@pghost/database".format(url))
835+
pgUser = match.group(1)
836+
pgPassword = match.group(2)
837+
pgHost = match.group(3)
838+
pgPort = int(match.group(4)) if match.group(4) else 5432
839+
database = match.group(5)
840+
return (engine, sshUser, sshHost, pgHost, pgPort, pgUser, pgPassword, database)
841+
842+
raise ValueError("Unrecognized URL scheme '{0}' in: {1}. Use postgresql://user@host[:port]/database or postgresql+ssh://sshuser@sshhost/pguser@pghost/database".format(parsed.scheme, url))
843+
844+
def __enter__(self):
845+
return self
846+
847+
def __exit__(self, exc_type, exc_val, exc_tb):
848+
self.disconnect()
849+
850+
def connect(self):
851+
try:
852+
if not self.isConnected:
853+
import psycopg2
854+
import psycopg2.extras
855+
856+
if self.pgPassword is not None:
857+
pwd = self.pgPassword
858+
elif self.usePassword is True:
859+
import keyring
860+
861+
if self.sshHost is not None:
862+
serviceName = "postgresql-{0}-ssh-{1}".format(self.pgHost, self.sshHost)
863+
else:
864+
serviceName = "postgresql-{0}".format(self.pgHost)
865+
866+
pwd = keyring.get_password(serviceName, self.pgUser)
867+
if pwd is None and self.sshHost is not None:
868+
pwd = keyring.get_password("postgresql-{0}".format(self.pgHost), self.pgUser)
869+
if pwd is None:
870+
raise Exception(""" Set the password in the system password manager on the command line with:
871+
{2} -m keyring set {0} {1}
872+
Or provide the password in the URL: postgresql://user:password@host/database""".format(serviceName, self.pgUser, sys.executable))
873+
else:
874+
pwd = None
875+
876+
actualPgHost = self.pgHost
877+
if self.sshHost is not None:
878+
from dcclab.utils import Cafeine
879+
self.server = Cafeine(username=self.sshUser)
880+
self.port = self.server.startTunnel(ssh_host=self.sshHost, remote_bind_address=self.pgHost, remote_port=self.port)
881+
actualPgHost = "127.0.0.1"
882+
883+
self.connection = psycopg2.connect(
884+
host=actualPgHost,
885+
port=self.port,
886+
dbname=self.database,
887+
user=self.pgUser,
888+
password=pwd
889+
)
890+
891+
self.cursor = self.connection.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
892+
893+
except Exception as err:
894+
if self.connection is not None:
895+
self.connection.close()
896+
self.cursor = None
897+
if self.server is not None:
898+
self.server.stopMySQLTunnel()
899+
raise(err)
900+
901+
def disconnect(self):
902+
if self.isConnected:
903+
self.commit()
904+
self.connection.close()
905+
self.connection = None
906+
self.cursor = None
907+
908+
def enforceForeignKeys(self):
909+
pass # Foreign keys are always enforced in PostgreSQL
910+
911+
def disableForeignKeys(self):
912+
self.execute("SET session_replication_role = 'replica'")
913+
914+
def beginTransaction(self):
915+
if self.isConnected:
916+
self.execute('BEGIN')
917+
918+
def endTransaction(self):
919+
if self.isConnected:
920+
self.execute('COMMIT')
921+
922+
def rollbackTransaction(self):
923+
if self.isConnected:
924+
self.execute('ROLLBACK')
925+
926+
@property
927+
def isConnected(self):
928+
return self.connection is not None
929+
930+
def commit(self):
931+
if self.isConnected:
932+
self.connection.commit()
933+
934+
def rollback(self):
935+
if self.isConnected:
936+
self.connection.rollback()
937+
938+
def execute(self, statement, bindings=None):
939+
if self.isConnected:
940+
try:
941+
if bindings is None:
942+
self.cursor.execute(statement)
943+
else:
944+
self.cursor.execute(statement, bindings)
945+
946+
return statement
947+
948+
except Exception as err:
949+
raise(err)
950+
951+
return None
952+
953+
def executeSelectOne(self, statement, bindings=None):
954+
self.execute(statement, bindings)
955+
singleRecord = self.fetchOne()
956+
keys = list(singleRecord.keys())
957+
if len(keys) == 1:
958+
return singleRecord[keys[0]]
959+
else:
960+
return None
961+
962+
def executeSelectFetchInt(self, statement, bindings=None):
963+
return int(self.executeSelectOne(statement, bindings))
964+
965+
def executeSelectFetchOneRow(self, statement, bindings = None):
966+
self.execute(statement, bindings)
967+
return dict(self.fetchOne())
968+
969+
def executeSelectFetchOneField(self, statement, bindings = None):
970+
self.execute(statement, bindings)
971+
rows = self.fetchAll()
972+
values = []
973+
974+
for row in rows:
975+
value = list(row.values())[0]
976+
values.append(value)
977+
978+
return values
979+
980+
def fetchAll(self):
981+
if self.isConnected:
982+
return self.cursor.fetchall()
983+
984+
def fetchOne(self):
985+
if self.isConnected:
986+
return self.cursor.fetchone()
987+
988+
@property
989+
def tables(self) -> list:
990+
self.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name")
991+
rows = self.fetchAll()
992+
return [row['table_name'] for row in rows]
993+
994+
def columns(self, table) -> list:
995+
self.execute('SELECT * FROM "{}" LIMIT 0'.format(table))
996+
columns = [description[0] for description in self.cursor.description]
997+
return columns
998+
999+
def select(self, table, columns='*', condition=None):
1000+
if condition is None:
1001+
self.execute("SELECT {0} FROM {1}".format(columns, table))
1002+
else:
1003+
self.execute("SELECT {0} FROM {1} WHERE {2}".format(
1004+
columns, table, condition))
1005+
return self.fetchAll()
1006+
1007+
def createTable(self, metadata: dict):
1008+
if self.isConnected:
1009+
for table, keys in metadata.items():
1010+
statement = 'CREATE TABLE IF NOT EXISTS "{}" ('.format(table)
1011+
attributes = []
1012+
for key, keyType in keys.items():
1013+
attributes.append('{} {}'.format(key, keyType))
1014+
statement += ",".join(attributes) + ")"
1015+
self.execute(statement)
1016+
1017+
def createSimpleTable(self, name, columns):
1018+
if self.isConnected:
1019+
statement = f'CREATE TABLE IF NOT EXISTS "{name}" '
1020+
1021+
colStatements = []
1022+
for c in columns:
1023+
colStatements.append(f"{c.name} {c.type.value} {c.constraint.value}")
1024+
1025+
statement += '(' + ','.join(colStatements) + ')'
1026+
self.execute(statement)
1027+
1028+
def dropTable(self, table: str):
1029+
if self.isConnected:
1030+
statement = 'DROP TABLE IF EXISTS "{}"'.format(table)
1031+
self.execute(statement)
1032+
1033+
def insert(self, table: str, entry: dict):
1034+
if self.isConnected:
1035+
keys = ', '.join('"{}"'.format(k) for k in entry.keys())
1036+
placeholders = ', '.join(['%s'] * len(entry))
1037+
values = list(entry.values())
1038+
statement = 'INSERT INTO "{}" ({}) VALUES ({})'.format(
1039+
table, keys, placeholders)
1040+
self.execute(statement, values)
1041+
1042+
7881043
if __name__ == "__main__":
7891044
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
7901045
db = Database("mysql+ssh://dcclab@cafeine3.crulrg.ulaval.ca/dcclab@cafeine3.crulrg.ulaval.ca/labdata")

0 commit comments

Comments
 (0)