@@ -86,6 +86,7 @@ class Column(NamedTuple):
8686class Engine (Enum ):
8787 mysql = "mysql"
8888 sqlite3 = "sqlite3"
89+ postgresql = "postgresql"
8990
9091class Database :
9192 def __init__ (self , databaseURL , usePassword = True , writePermission = False ):
@@ -785,6 +786,260 @@ def insert(self, table: str, entry: dict):
785786 table , keys , values )
786787 self .execute (statement )
787788
789+ class PostgresqlDatabase :
790+ def __init__ (self , databaseURL , usePassword = True , writePermission = False ):
791+ self .writePermission = writePermission
792+ self .databaseURL = databaseURL
793+ self .connection = None
794+ self .cursor = None
795+ self .databaseEngine = None
796+
797+ self .database = None
798+ self .port = None
799+ self .pgPassword = None
800+ self .usePassword = usePassword
801+ self .server = None
802+
803+ self .databaseEngine , self .sshUser , self .sshHost , self .pgHost , self .port , self .pgUser , self .pgPassword , self .database = self .parseURL (databaseURL )
804+
805+ self .connect ()
806+
807+ def showDatabaseInfo (self ):
808+ pass
809+
810+ def parseURL (self , url ):
811+ # Standard URL formats:
812+ # postgresql://user[:password]@host[:port]/database
813+ # postgresql+ssh://sshuser@sshhost/pguser[:password]@pghost[:port]/database
814+ parsed = parse .urlparse (url )
815+
816+ if parsed .scheme in ('postgresql' , 'postgres' ):
817+ engine = Engine .postgresql
818+ pgUser = parsed .username
819+ pgPassword = parsed .password
820+ pgHost = parsed .hostname
821+ pgPort = parsed .port or 5432
822+ database = parsed .path .lstrip ('/' )
823+ if not pgUser or not pgHost or not database :
824+ raise ValueError ("Incomplete postgresql URL: {0}. Use postgresql://user@host[:port]/database" .format (url ))
825+ return (engine , None , None , pgHost , pgPort , pgUser , pgPassword , database )
826+
827+ if parsed .scheme in ('postgresql+ssh' , 'postgres+ssh' ):
828+ engine = Engine .postgresql
829+ sshUser = parsed .username
830+ sshHost = parsed .hostname
831+ # Path contains: /pguser[:password]@pghost[:port]/database
832+ match = re .match (r'^/([^:@]+)(?::([^@]*))?@([^:/]+)(?::(\d+))?/(.+)$' , parsed .path )
833+ if match is None :
834+ raise ValueError ("Incomplete postgresql+ssh URL: {0}. Use postgresql+ssh://sshuser@sshhost/pguser@pghost/database" .format (url ))
835+ pgUser = match .group (1 )
836+ pgPassword = match .group (2 )
837+ pgHost = match .group (3 )
838+ pgPort = int (match .group (4 )) if match .group (4 ) else 5432
839+ database = match .group (5 )
840+ return (engine , sshUser , sshHost , pgHost , pgPort , pgUser , pgPassword , database )
841+
842+ raise ValueError ("Unrecognized URL scheme '{0}' in: {1}. Use postgresql://user@host[:port]/database or postgresql+ssh://sshuser@sshhost/pguser@pghost/database" .format (parsed .scheme , url ))
843+
844+ def __enter__ (self ):
845+ return self
846+
847+ def __exit__ (self , exc_type , exc_val , exc_tb ):
848+ self .disconnect ()
849+
850+ def connect (self ):
851+ try :
852+ if not self .isConnected :
853+ import psycopg2
854+ import psycopg2 .extras
855+
856+ if self .pgPassword is not None :
857+ pwd = self .pgPassword
858+ elif self .usePassword is True :
859+ import keyring
860+
861+ if self .sshHost is not None :
862+ serviceName = "postgresql-{0}-ssh-{1}" .format (self .pgHost , self .sshHost )
863+ else :
864+ serviceName = "postgresql-{0}" .format (self .pgHost )
865+
866+ pwd = keyring .get_password (serviceName , self .pgUser )
867+ if pwd is None and self .sshHost is not None :
868+ pwd = keyring .get_password ("postgresql-{0}" .format (self .pgHost ), self .pgUser )
869+ if pwd is None :
870+ raise Exception (""" Set the password in the system password manager on the command line with:
871+ {2} -m keyring set {0} {1}
872+ Or provide the password in the URL: postgresql://user:password@host/database""" .format (serviceName , self .pgUser , sys .executable ))
873+ else :
874+ pwd = None
875+
876+ actualPgHost = self .pgHost
877+ if self .sshHost is not None :
878+ from dcclab .utils import Cafeine
879+ self .server = Cafeine (username = self .sshUser )
880+ self .port = self .server .startTunnel (ssh_host = self .sshHost , remote_bind_address = self .pgHost , remote_port = self .port )
881+ actualPgHost = "127.0.0.1"
882+
883+ self .connection = psycopg2 .connect (
884+ host = actualPgHost ,
885+ port = self .port ,
886+ dbname = self .database ,
887+ user = self .pgUser ,
888+ password = pwd
889+ )
890+
891+ self .cursor = self .connection .cursor (cursor_factory = psycopg2 .extras .RealDictCursor )
892+
893+ except Exception as err :
894+ if self .connection is not None :
895+ self .connection .close ()
896+ self .cursor = None
897+ if self .server is not None :
898+ self .server .stopMySQLTunnel ()
899+ raise (err )
900+
901+ def disconnect (self ):
902+ if self .isConnected :
903+ self .commit ()
904+ self .connection .close ()
905+ self .connection = None
906+ self .cursor = None
907+
908+ def enforceForeignKeys (self ):
909+ pass # Foreign keys are always enforced in PostgreSQL
910+
911+ def disableForeignKeys (self ):
912+ self .execute ("SET session_replication_role = 'replica'" )
913+
914+ def beginTransaction (self ):
915+ if self .isConnected :
916+ self .execute ('BEGIN' )
917+
918+ def endTransaction (self ):
919+ if self .isConnected :
920+ self .execute ('COMMIT' )
921+
922+ def rollbackTransaction (self ):
923+ if self .isConnected :
924+ self .execute ('ROLLBACK' )
925+
926+ @property
927+ def isConnected (self ):
928+ return self .connection is not None
929+
930+ def commit (self ):
931+ if self .isConnected :
932+ self .connection .commit ()
933+
934+ def rollback (self ):
935+ if self .isConnected :
936+ self .connection .rollback ()
937+
938+ def execute (self , statement , bindings = None ):
939+ if self .isConnected :
940+ try :
941+ if bindings is None :
942+ self .cursor .execute (statement )
943+ else :
944+ self .cursor .execute (statement , bindings )
945+
946+ return statement
947+
948+ except Exception as err :
949+ raise (err )
950+
951+ return None
952+
953+ def executeSelectOne (self , statement , bindings = None ):
954+ self .execute (statement , bindings )
955+ singleRecord = self .fetchOne ()
956+ keys = list (singleRecord .keys ())
957+ if len (keys ) == 1 :
958+ return singleRecord [keys [0 ]]
959+ else :
960+ return None
961+
962+ def executeSelectFetchInt (self , statement , bindings = None ):
963+ return int (self .executeSelectOne (statement , bindings ))
964+
965+ def executeSelectFetchOneRow (self , statement , bindings = None ):
966+ self .execute (statement , bindings )
967+ return dict (self .fetchOne ())
968+
969+ def executeSelectFetchOneField (self , statement , bindings = None ):
970+ self .execute (statement , bindings )
971+ rows = self .fetchAll ()
972+ values = []
973+
974+ for row in rows :
975+ value = list (row .values ())[0 ]
976+ values .append (value )
977+
978+ return values
979+
980+ def fetchAll (self ):
981+ if self .isConnected :
982+ return self .cursor .fetchall ()
983+
984+ def fetchOne (self ):
985+ if self .isConnected :
986+ return self .cursor .fetchone ()
987+
988+ @property
989+ def tables (self ) -> list :
990+ self .execute ("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name" )
991+ rows = self .fetchAll ()
992+ return [row ['table_name' ] for row in rows ]
993+
994+ def columns (self , table ) -> list :
995+ self .execute ('SELECT * FROM "{}" LIMIT 0' .format (table ))
996+ columns = [description [0 ] for description in self .cursor .description ]
997+ return columns
998+
999+ def select (self , table , columns = '*' , condition = None ):
1000+ if condition is None :
1001+ self .execute ("SELECT {0} FROM {1}" .format (columns , table ))
1002+ else :
1003+ self .execute ("SELECT {0} FROM {1} WHERE {2}" .format (
1004+ columns , table , condition ))
1005+ return self .fetchAll ()
1006+
1007+ def createTable (self , metadata : dict ):
1008+ if self .isConnected :
1009+ for table , keys in metadata .items ():
1010+ statement = 'CREATE TABLE IF NOT EXISTS "{}" (' .format (table )
1011+ attributes = []
1012+ for key , keyType in keys .items ():
1013+ attributes .append ('{} {}' .format (key , keyType ))
1014+ statement += "," .join (attributes ) + ")"
1015+ self .execute (statement )
1016+
1017+ def createSimpleTable (self , name , columns ):
1018+ if self .isConnected :
1019+ statement = f'CREATE TABLE IF NOT EXISTS "{ name } " '
1020+
1021+ colStatements = []
1022+ for c in columns :
1023+ colStatements .append (f"{ c .name } { c .type .value } { c .constraint .value } " )
1024+
1025+ statement += '(' + ',' .join (colStatements ) + ')'
1026+ self .execute (statement )
1027+
1028+ def dropTable (self , table : str ):
1029+ if self .isConnected :
1030+ statement = 'DROP TABLE IF EXISTS "{}"' .format (table )
1031+ self .execute (statement )
1032+
1033+ def insert (self , table : str , entry : dict ):
1034+ if self .isConnected :
1035+ keys = ', ' .join ('"{}"' .format (k ) for k in entry .keys ())
1036+ placeholders = ', ' .join (['%s' ] * len (entry ))
1037+ values = list (entry .values ())
1038+ statement = 'INSERT INTO "{}" ({}) VALUES ({})' .format (
1039+ table , keys , placeholders )
1040+ self .execute (statement , values )
1041+
1042+
7881043if __name__ == "__main__" :
7891044 sys .path .insert (0 , os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , '..' )))
7901045 db = Database ("mysql+ssh://dcclab@cafeine3.crulrg.ulaval.ca/dcclab@cafeine3.crulrg.ulaval.ca/labdata" )
0 commit comments