11import logging
22import sys
3+
34if sys .version_info >= (3 , 0 , 0 ):
45 from urllib .parse import urlparse , uses_netloc
56else :
1314from aws_xray_sdk .ext .util import unwrap
1415
1516
16- def _sql_meta (instance , args ):
17+ def _sql_meta (engine_instance , args ):
1718 try :
1819 metadata = {}
19- url = urlparse (str (instance .engine .url ))
20+ url = urlparse (str (engine_instance .engine .url ))
2021 # Add Scheme to uses_netloc or // will be missing from url.
2122 uses_netloc .append (url .scheme )
2223 if url .password is None :
@@ -29,17 +30,20 @@ def _sql_meta(instance, args):
2930 metadata ['url' ] = parts .geturl ()
3031 name = host_info
3132 metadata ['user' ] = url .username
32- metadata ['database_type' ] = instance .engine .name
33+ metadata ['database_type' ] = engine_instance .engine .name
3334 try :
34- version = getattr (instance .dialect , '{}_version' .format (instance .engine .driver ))
35+ version = getattr (engine_instance .dialect , '{}_version' .format (engine_instance .engine .driver ))
3536 version_str = '.' .join (map (str , version ))
36- metadata ['driver_version' ] = "{}-{}" .format (instance .engine .driver , version_str )
37+ metadata ['driver_version' ] = "{}-{}" .format (engine_instance .engine .driver , version_str )
3738 except AttributeError :
38- metadata ['driver_version' ] = instance .engine .driver
39- if instance .dialect .server_version_info is not None :
40- metadata ['database_version' ] = '.' .join (map (str , instance .dialect .server_version_info ))
39+ metadata ['driver_version' ] = engine_instance .engine .driver
40+ if engine_instance .dialect .server_version_info is not None :
41+ metadata ['database_version' ] = '.' .join (map (str , engine_instance .dialect .server_version_info ))
4142 if xray_recorder .stream_sql :
42- metadata ['sanitized_query' ] = str (args [0 ])
43+ try :
44+ metadata ['sanitized_query' ] = str (args [0 ])
45+ except Exception :
46+ logging .getLogger (__name__ ).exception ('Error getting the sanitized query' )
4347 except Exception :
4448 metadata = None
4549 name = None
@@ -48,7 +52,15 @@ def _sql_meta(instance, args):
4852
4953
5054def _xray_traced_sqlalchemy_execute (wrapped , instance , args , kwargs ):
51- name , sql = _sql_meta (instance , args )
55+ return _process_request (wrapped , instance , args , kwargs )
56+
57+
58+ def _xray_traced_sqlalchemy_session (wrapped , instance , args , kwargs ):
59+ return _process_request (wrapped , instance .bind , args , kwargs )
60+
61+
62+ def _process_request (wrapped , engine_instance , args , kwargs ):
63+ name , sql = _sql_meta (engine_instance , args )
5264 if sql is not None :
5365 subsegment = xray_recorder .begin_subsegment (name , namespace = 'remote' )
5466 else :
@@ -75,6 +87,12 @@ def patch():
7587 _xray_traced_sqlalchemy_execute
7688 )
7789
90+ wrapt .wrap_function_wrapper (
91+ 'sqlalchemy.orm.session' ,
92+ 'Session.execute' ,
93+ _xray_traced_sqlalchemy_session
94+ )
95+
7896
7997def unpatch ():
8098 """
@@ -84,3 +102,4 @@ def unpatch():
84102 _PATCHED_MODULES .discard ('sqlalchemy_core' )
85103 import sqlalchemy
86104 unwrap (sqlalchemy .engine .base .Connection , 'execute' )
105+ unwrap (sqlalchemy .orm .session .Session , 'execute' )
0 commit comments