import logging
from ace.config import settings
import pyodbc
from typing import Union
DATABASE_NAME = settings.DATABASE_NAME
DATABASE_HOST = settings.DATABASE_HOST
DATABASE_PORT = settings.DATABASE_PORT
DATABASE_USER = settings.DATABASE_USER
DATABASE_PASSWORD = settings.DATABASE_PASSWORD
DATABASE_DRIVER = settings.DATABASE_DRIVER
DATABASE_REQUEST_TIMEOUT = settings.DATABASE_REQUEST_TIMEOUT
DBMS_TYPE = settings.DBMS_TYPE
DATABASE_AUTH_METHOD = settings.DATABASE_AUTH_METHOD
[docs]
class ConnectionManager:
[docs]
def __init__(self):
self.conn = self.make_connection()
if settings.USE_DB_CONTEXT and settings.ACE_SERVICE_USER_ID:
self.set_context()
self.cursor: pyodbc.Cursor = self.conn.cursor()
[docs]
def make_connection(self) -> Union[pyodbc.Connection, None]:
pass
# Must implement on children class
[docs]
def fetchone(self):
return self.cursor.fetchone()
[docs]
def fetchall(self):
return self.cursor.fetchall()
[docs]
def execute(self, sql_query, *params):
return self.cursor.execute(sql_query, params)
[docs]
def close(self):
self.cursor.close()
self.conn.close()
return
[docs]
def set_context(self):
self.conn.cursor().execute("EXEC sp_set_session_context 'UserId', ?", settings.ACE_SERVICE_USER_ID)
[docs]
class SqlServerConnection(ConnectionManager):
[docs]
def make_connection(self):
connection_string = (
f"DRIVER={DATABASE_DRIVER};"
f"SERVER={DATABASE_HOST},{DATABASE_PORT};"
f"DATABASE={DATABASE_NAME};"
f"UID={DATABASE_USER};"
f"PWD={DATABASE_PASSWORD}"
)
conn = pyodbc.connect(connection_string)
timeout = DATABASE_REQUEST_TIMEOUT if isinstance(
DATABASE_REQUEST_TIMEOUT,
int
) and DATABASE_REQUEST_TIMEOUT > 0 else 30
conn.timeout = timeout
return conn
[docs]
class SqlServerConnectionADIntegrated(ConnectionManager):
[docs]
def make_connection(self):
"""
Creates a SQL Server connection using Active Directory Integrated authentication.
This method uses Windows/Azure AD authentication instead of SQL Server authentication.
Use tcp: prefix to force TCP/IP protocol (not Named Pipes)
"""
connection_string = (
f"DRIVER={DATABASE_DRIVER};"
f"SERVER={DATABASE_HOST},{DATABASE_PORT};"
f"DATABASE={DATABASE_NAME};"
f"Authentication=ActiveDirectoryIntegrated;"
f"Encrypt=yes;"
f"TrustServerCertificate=no;"
f"Connection Timeout=30"
)
conn = pyodbc.connect(connection_string)
timeout = DATABASE_REQUEST_TIMEOUT if isinstance(
DATABASE_REQUEST_TIMEOUT,
int
) and DATABASE_REQUEST_TIMEOUT > 0 else 30
conn.timeout = timeout
return conn
[docs]
def get_connection_class():
"""
Factory function to get the appropriate connection class based on DBMS_TYPE and DATABASE_AUTH_METHOD.
"""
# If SQL Server with AD Integrated authentication, use the AD Integrated class
if DBMS_TYPE == "SQL_SERVER" and DATABASE_AUTH_METHOD == "AD_INTEGRATED":
return SqlServerConnectionADIntegrated
else:
return SqlServerConnection
ace_connection = get_connection_class()