오지's blog

airflow에서 sftp로 파일 여러개 다운받으려고 할때 본문

개발노트/airflow

airflow에서 sftp로 파일 여러개 다운받으려고 할때

오지구영ojjy90 2024. 4. 19. 14:09
728x90
반응형

https://sftp_multiple_files_download_operator.py/

import os
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
# from airflow.contrib.hooks import SSHHook
from airflow.providers.ssh.hooks.ssh import SSHHook
from typing import Any

class SFTPMultipleFilesDownloadOperator(BaseOperator):
    template_fields = ('local_directory', 'remote_filename_pattern', 'remote_host')

    def __init__(
            self,
            *,
            ssh_hook=None,
            ssh_conn_id=None,
            remote_host=None,
            local_directory=None,
            remote_filepath=None,
            remote_filename_pattern=None,
            filetype=None,
            confirm=True,
            create_intermediate_dirs=False,
            **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.ssh_hook = ssh_hook
        self.ssh_conn_id = ssh_conn_id
        self.remote_host = remote_host
        self.local_directory = local_directory
        self.filetype = filetype
        self.remote_filepath = remote_filepath
        self.remote_filename_pattern = remote_filename_pattern
        self.confirm = confirm
        # self.create_intermediate_dirs = create_intermediate_dirs

    def execute(self, context: Any) -> str:
        file_msg = None
        try:
            if self.ssh_conn_id:
                if self.ssh_hook and isinstance(self.ssh_hook, SSHHook):
                    self.log.info("ssh_conn_id is ignored when ssh_hook is provided.")
                else:
                    self.log.info(
                        "ssh_hook is not provided or invalid. Trying ssh_conn_id to create SSHHook."
                    )
                    self.ssh_hook = SSHHook(ssh_conn_id=self.ssh_conn_id)

            if not self.ssh_hook:
                raise AirflowException("Cannot operate without ssh_hook or ssh_conn_id.")

            if self.remote_host is not None:
                self.log.info(
                    "remote_host is provided explicitly. "
                    "It will replace the remote_host which was defined "
                    "in ssh_hook or predefined in connection of ssh_conn_id."
                )
                self.ssh_hook.remote_host = self.remote_host

            with self.ssh_hook.get_conn() as ssh_client:
                sftp_client = ssh_client.open_sftp()
                all_files = sftp_client.listdir(path=self.remote_filepath)
                print(all_files)
                self.log.info(f'Found {len(all_files)} files on server')
                timestamp = context['ds_nodash']
                filename_pattern = self.remote_filename_pattern
                # filename_pattern = self.remote_filename_pattern + timestamp
                # # fetch all CSV files for the run date that match the filename pattern
                matching_files = [f for f in all_files
                                  if f.find(filename_pattern) != -1]
                print(matching_files)
                # if file type is specified filter matching files for the file type
                if self.filetype is not None:
                    matching_files = [filename for filename in matching_files
                                      if filename[-len(self.filetype):] == self.filetype]
                self.log.info(f'Found {len(matching_files)} files with name including {filename_pattern}')

                # matching_filesfullpath = [os.path.join(self.remote_filepath,filename) for filename in matching_files]
                # matching_filesfullpath = matching_filesfullpath[:2]
                matching_files=matching_files[:2]
                for f in matching_files:
                    self.log.info(f"Starting to transfer from /{f} to {self.local_directory}/{f}")
                    sftp_client.get(f'/{os.path.join(self.remote_filepath,f)}', f'{self.local_directory}/{f}')

        except Exception as e:
            raise AirflowException(f"Error while transferring {file_msg}, error: {str(e)}")

        return self.local_directory


#
# def _make_intermediate_dirs(sftp_client, remote_directory) -> None:
#     """
#     Create all the intermediate directories in a remote host
#
#     :param sftp_client: A Paramiko SFTP client.
#     :param remote_directory: Absolute Path of the directory containing the file
#     :return:
#     """
#     if remote_directory == '/':
#         sftp_client.chdir('/')
#         return
#     if remote_directory == '':
#         return
#     try:
#         sftp_client.chdir(remote_directory)
#     except OSError:
#         dirname, basename = os.path.split(remote_directory.rstrip('/'))
#         _make_intermediate_dirs(sftp_client, dirname)
#         sftp_client.mkdir(basename)
#         sftp_client.chdir(basename)
#         return

 

 

 

dag.py

t4 = SFTPMultipleFilesDownloadOperator(
    task_id='sftp_multiple_download',
    ssh_conn_id='SFTP_CONN_ID_DATAHUB',
    local_directory='/home/ec2-user/data/2024-04-19/',
    remote_filepath=f'{os.path.join(".", "data_init")}',
    filetype='csv'
)

스택오버플로우내 소스가 있길래 참고하여 조금수정해서 개발하였다. 점점 airflow의 전문가가 되어가고 있다!

파일 여러개 다운 받을때 쓰레드를 이용하여 동시에 다운 받는 방법에 대해 공부해보고 싶어졌다.

 

 

reference.

https://stackoverflow.com/questions/67327170/airflow-customise-sftpoperator-to-download-multiple-files

 

AIRFLOW : Customise SFTPOperator to download multiple files

I'm trying to customise the SFTOperator take download multiple file from a server. I know that the original SFTPOperator only allow one file at a time. I copied the same code from source and I twer...

stackoverflow.com

 

'개발노트 > airflow' 카테고리의 다른 글

airflow에서 timezone 변경  (0) 2023.01.27
Comments