import os
import paramiko
import logging
from typing import Union
from ace.config import settings
[docs]
class SFTPConnector:
[docs]
def __init__(
self,
host: str = settings.SFTP_HOST,
username: str = settings.SFTP_USERNAME,
password: str = settings.SFTP_PASSWORD,
private_key: str = settings.SFTP_PRIVATE_KEY,
port: Union[str, int] = settings.SFTP_PORT,
auth_method: str = settings.SFTP_AUTH_METHOD,
):
self.host = host
self.port = port
self.username = username
self.password = password
self.private_key = private_key
self.client: paramiko.client.SSHClient = paramiko.client.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
if auth_method not in ['password', 'private_key']:
raise ValueError("Invalid authentication method. Use 'password' or 'private_key'.")
self.auth_method = auth_method
if self.auth_method == 'private_key' and not self.private_key:
raise ValueError("Private key path must be provided for private key authentication.")
if self.auth_method == 'password' and not self.password:
raise ValueError("Password must be provided for password authentication.")
[docs]
def connect(self) -> None:
if self.auth_method != 'password':
return self.private_key_auth()
return self.password_auth()
[docs]
def disconnect(self):
if self.client:
self.client.close()
logging.info("Disconnected from SFTP server")
[docs]
def password_auth(self) -> None:
try:
logging.info("Connecting to SFTP server using password authentication")
self.client.connect(hostname=self.host, port=self.port, username=self.username, password=self.password, look_for_keys=False, allow_agent=False)
logging.info("Connected to SFTP server using password authentication")
except paramiko.AuthenticationException as e:
logging.error(f"Authentication failed: {e}")
[docs]
def private_key_auth(self) -> None:
try:
logging.info("Connecting to SFTP server using private key authentication")
pkey = paramiko.RSAKey.from_private_key_file(self.private_key)
self.client.connect(hostname=self.host, port=self.port, username=self.username, pkey=pkey)
logging.info("Connected to SFTP server using private key authentication")
except paramiko.AuthenticationException as e:
logging.error(f"Authentication failed: {e}")
def _remote_file_exists(sftp: paramiko.SFTPClient, remote_path: str) -> bool:
"""Return True if a file already exists at remote_path on the SFTP server."""
try:
sftp.stat(remote_path)
return True
except FileNotFoundError:
return False
[docs]
def upload(sftp_connector: SFTPConnector, file_obj, remote_path: str):
"""Upload a file to the SFTP server. Raises FileExistsError if a file already exists at remote_path."""
sftp_connector.connect()
try:
logging.info(f"Uploading file to {remote_path}")
sftp = sftp_connector.client.open_sftp()
if _remote_file_exists(sftp, remote_path):
raise FileExistsError(f"A file with the same name already exists at '{remote_path}'. Use upload_overwrite to replace it.")
file_obj.seek(0)
sftp.putfo(file_obj, remote_path)
logging.info(f"File uploaded successfully to {remote_path}")
sftp.close()
except paramiko.SSHException as e:
logging.error(f"SSH error: {e}")
finally:
sftp_connector.disconnect()
[docs]
def upload_overwrite(sftp_connector: SFTPConnector, file_obj, remote_path: str):
"""Upload a file to the SFTP server, overwriting any existing file at remote_path."""
sftp_connector.connect()
try:
logging.info(f"Uploading file to {remote_path} (overwrite enabled)")
sftp = sftp_connector.client.open_sftp()
file_obj.seek(0)
sftp.putfo(file_obj, remote_path)
logging.info(f"File uploaded successfully to {remote_path} (overwrite)")
sftp.close()
except paramiko.SSHException as e:
logging.error(f"SSH error: {e}")
finally:
sftp_connector.disconnect()
[docs]
def download(sftp_connector: SFTPConnector, file_name: str):
"""Download a file from the SFTP server and return its content as bytes."""
sftp_connector.connect()
try:
logging.info(f"Downloading file {file_name} from SFTP server")
sftp = sftp_connector.client.open_sftp()
# Download file as bytes
sftp_file = sftp.open(file_name, 'r')
logging.info(f"File {file_name} downloaded successfully from SFTP server")
file_content = sftp_file.read()
sftp.close()
sftp_connector.disconnect()
return file_content
except paramiko.SSHException as e:
logging.error(f"SSH error: {e}")
return None
except FileNotFoundError as e:
logging.error(f"File not found: {e}")
return None
[docs]
def move(sftp_connector: SFTPConnector, current_path: str, new_path: str):
"""Move a file on the SFTP server. Raises FileExistsError if a file already exists at new_path."""
sftp_connector.connect()
try:
logging.info(f"Moving file from {current_path} to {new_path}")
sftp = sftp_connector.client.open_sftp()
if _remote_file_exists(sftp, new_path):
raise FileExistsError(f"A file with the same name already exists at '{new_path}'. Use move_overwrite to replace it.")
sftp.rename(current_path, new_path)
logging.info(f"File moved successfully from {current_path} to {new_path}")
sftp.close()
except paramiko.SSHException as e:
logging.error(f"SSH error: {e}")
except FileNotFoundError as e:
logging.error(f"File not found: {e}")
finally:
sftp_connector.disconnect()
[docs]
def move_overwrite(sftp_connector: SFTPConnector, current_path: str, new_path: str):
"""Move a file on the SFTP server, overwriting any existing file at new_path."""
sftp_connector.connect()
try:
logging.info(f"Moving file from {current_path} to {new_path} (overwrite enabled)")
sftp = sftp_connector.client.open_sftp()
if _remote_file_exists(sftp, new_path):
sftp.remove(new_path)
logging.info(f"Existing file at '{new_path}' removed before move")
sftp.rename(current_path, new_path)
logging.info(f"File moved successfully from {current_path} to {new_path} (overwrite)")
sftp.close()
except paramiko.SSHException as e:
logging.error(f"SSH error: {e}")
except FileNotFoundError as e:
logging.error(f"File not found: {e}")
finally:
sftp_connector.disconnect()