#!python

##########################################################################
#
#   ABCGAN asset downloader
#
#   The purpose of this program is to allow end users to directly 
#   download components of the ABCGAN repo, such as the tutorials.
#
#   2022-01-20  Todd Valentic
#               Initial implementation
#
#   2022-04-22  Todd Valentic
#               Fix repo url 
#
#   2023-07-29  Todd Valentic
#               Add support for downloading models and data from LFS
#
##########################################################################

import sys
import io 
import fnmatch
import os
import pathlib
import argparse
import shutil
import tarfile

from abcgan import persist
import requests
from tqdm import tqdm

#-------------------------------------------------------------------------
# Command Processors 
#-------------------------------------------------------------------------

def log(args, msg):
    
    if not args.quiet:
        print(msg)

def filter_members(tar, match_path='*', strip_components=None):
    
    for member in tar.getmembers():

        if fnmatch.fnmatch(member.path, match_path):

            if strip_components:
                path = pathlib.Path(member.path)
                path = path.relative_to(*path.parts[:strip_components])
                member.path = path

            yield member

def download_file(url, filename, timeout=None):

        filename.parent.mkdir(parents=True, exist_ok=True)

        with requests.get(url, stream=True, timeout=timeout) as r:
            r.raise_for_status()
            file_size = int(r.headers.get('Content-Length', 0))

            desc = "(Unknown file size)" if file_size == 0 else ""
           
            with tqdm.wrapattr(r.raw, "read", total=file_size, desc=desc) as r_raw:
                with filename.open('wb') as f:
                    shutil.copyfileobj(r_raw, f)

def download_files(args, url, srcpath, destpath):

    lfs_signature = "version https://git-lfs.github.com/spec/v1"

    for root, dirs, files in os.walk(srcpath):   
        for name in files:
            filename = pathlib.Path(root, name)
            destname = destpath / name 
            with filename.open('r', encoding='utf-8') as f:
                try:
                    line = f.readline()
                except UnicodeDecodeError:
                    continue

            if lfs_signature not in line:
                continue

            kw = {'repo': args.repo, 'root': root, 'name':name}
            url = url.format(**kw)
            log(args, f"Download {url} -> {destname}") 
            download_file(url, destname)

def download_model_files(args):
    """Download model files from the LFS tracking files"""

    log(args, f"Downloading model files")

    dstpath = pathlib.Path('tutorials', 'models')
    url = 'https://github.com/{repo}/raw/main/src/abcgan/models/{name}'

    download_files(args, url, persist.dir_path, dstpath)

def download_data_files(args):
    """Download data files from the LFS tracking files"""

    log(args, f"Downloading data files")

    srcpath = 'tutorials' 
    dstpath = pathlib.Path('tutorials')
    url = 'https://github.com/{repo}/raw/main/{root}/{name}'

    download_files(args, url, srcpath, dstpath)


def ProcessDownload(args): 

    log(args, f'Downloading files from repo {args.repo}')

    url = f'https://github.com/{args.repo}/archive/refs/heads/main.tar.gz'

    result = requests.get(url, timeout=args.timeout) 

    byte_stream = io.BytesIO(result.content)
    
    log(args, f'Extracting {args.asset}')

    match = f'*/{args.asset}/*'

    with tarfile.open(fileobj=byte_stream) as tar:
        members = filter_members(tar, match_path=match, strip_components=1)
        tar.extractall(members=members)

    if args.asset == 'tutorials' and args.models:
        download_model_files(args)
        download_data_files(args)

    log(args, 'Finished')

    return True

#-------------------------------------------------------------------------
# Command Maps
#-------------------------------------------------------------------------

Commands = {
    'download':        ProcessDownload,
    }

if __name__ == '__main__':

    desc = 'ABCGAN Asset Downloader'
    parser = argparse.ArgumentParser(description=desc)
    subparsers = parser.add_subparsers(dest='command', required=True)

    download_parser = subparsers.add_parser('download')

    download_parser.add_argument('asset',choices=['tutorials','docs'])

    parser.add_argument('-r','--repo',dest='repo',
                        default='sri-geospace/atmosense-abcgan',
                        help='Github repo (default sri-geospace/atmosense-abcgan)')

    parser.add_argument('-t','--timeout',dest='timeout',
                        default=5,type=int,metavar='SECS',
                        help='Timeout in seconds (default 5)')

    parser.add_argument('-q','--quiet',dest='quiet',
                        action='store_true',
                        help='Do not display output messages')

    parser.add_argument('-m','--models',dest='models',
                        action='store_true',
                        help='Download models and data for tutorials')


    #-- Process -----------------------------------------------------------

    args = parser.parse_args()

    result = Commands[args.command](args) 

    sys.exit(0 if result else 1)

