#!/usr/bin/env python
# remote-run - Runs a command on another machine, for testing -----*- python -*-
#
# This source file is part of the Swift.org open source project
#
# Copyright (c) 2018 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See https://swift.org/LICENSE.txt for license information
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
#
# ----------------------------------------------------------------------------

from __future__ import print_function

import argparse
import os
import posixpath
import subprocess
import sys

def quote(arg):
    return repr(arg)

class CommandRunner(object):
    def __init__(self):
        self.verbose = False
        self.dry_run = False

    @staticmethod
    def _dirnames(files):
        return sorted(set(posixpath.dirname(f) for f in files))

    def popen(self, command, **kwargs):
        if self.verbose:
            print(' '.join(command), file=sys.stderr)
        if self.dry_run:
            return None
        return subprocess.Popen(command, **kwargs)

    def send(self, local_to_remote_files):
        # Prepare the remote directory structure.
        # FIXME: This could be folded into the sftp connection below.
        dirs_to_make = self._dirnames(local_to_remote_files.viewvalues())
        self.run_remote(['/bin/mkdir', '-p'] + dirs_to_make)

        # Send the local files.
        sftp_commands = ("-put {0} {1}".format(quote(local_file), 
                                               quote(remote_file))
                         for local_file, remote_file
                         in local_to_remote_files.viewitems())
        self.run_sftp(sftp_commands)

    def fetch(self, local_to_remote_files):
        # Prepare the local directory structure.
        dirs_to_make = self._dirnames(local_to_remote_files.viewkeys())
        mkdir_command = ['/bin/mkdir', '-p'] + dirs_to_make
        if self.verbose:
            print(' '.join(mkdir_command), file=sys.stderr)
        if not self.dry_run:
            subprocess.check_call(mkdir_command)

        # Fetch the remote files.
        sftp_commands = ("-get {0} {1}".format(quote(remote_file), 
                                               quote(local_file))
                         for local_file, remote_file
                         in local_to_remote_files.viewitems())
        self.run_sftp(sftp_commands)

    def run_remote(self, command, remote_env={}):
        env_strings = ['{0}={1}'.format(k,v) for k,v in remote_env.viewitems()]
        remote_invocation = self.remote_invocation(
            ['/usr/bin/env'] + env_strings + command)
        remote_proc = self.popen(remote_invocation, stdin=subprocess.PIPE,
                                 stdout=None, stderr=None)
        if self.dry_run:
            return
        _, _ = remote_proc.communicate()
        if remote_proc.returncode:
            # FIXME: We may still want to fetch the output files to see what
            # went wrong.
            sys.exit(remote_proc.returncode)
    
    def run_sftp(self, commands):
        sftp_proc = self.popen(self.sftp_invocation(), stdin=subprocess.PIPE,
                               stdout=subprocess.PIPE, stderr=None)
        concatenated_commands = '\n'.join(commands)
        if self.verbose:
            print(concatenated_commands, file=sys.stderr)
        if self.dry_run:
            return
        _, _ = sftp_proc.communicate(concatenated_commands)
        if sftp_proc.returncode:
            sys.exit(sftp_proc.returncode)

class RemoteCommandRunner(CommandRunner):
    def __init__(self, host, identity_path, ssh_options):
        if ':' in host:
            (self.remote_host, self.port) = host.rsplit(':', 1)
        else:
            self.remote_host = host
            self.port = None
        self.identity_path = identity_path
        self.ssh_options = ssh_options

    def common_options(self, port_flag):
        port_option = [port_flag, self.port] if self.port else []
        identity_option = (
            ['-i', self.identity_path] if self.identity_path else [])
        # Interleave '-o' with each custom option.
        # From https://stackoverflow.com/a/8168526,
        # with explanatory help from
        # https://spapas.github.io/2016/04/27/python-nested-list-comprehensions/
        extra_options = [arg for option in self.ssh_options
                             for arg in ["-o", option]]
        return port_option + identity_option + extra_options

    def remote_invocation(self, command):
        return (['/usr/bin/ssh', '-n'] +
                self.common_options(port_flag='-p') +
                [self.remote_host, '--'] +
                [quote(arg) for arg in command])

    def sftp_invocation(self):
        return (['/usr/bin/sftp', '-b', '-', '-q'] +
                self.common_options(port_flag='-P') +
                [self.remote_host])

class LocalCommandRunner(CommandRunner):
    def __init__(self, sftp_server_path):
        self.sftp_server_path = sftp_server_path

    def remote_invocation(self, command):
        return command

    def sftp_invocation(self):
        return ['/usr/bin/sftp', '-b', '-', '-q', '-D', self.sftp_server_path]

def find_transfers(args, source_prefix, dest_prefix):
    if source_prefix.endswith(posixpath.sep):
        source_prefix = source_prefix[:-len(posixpath.sep)]
    return dict((arg, dest_prefix + arg[len(source_prefix):])
                for arg in args if arg.startswith(source_prefix))

def collect_remote_env(local_env=os.environ, prefix='REMOTE_RUN_CHILD_'):
    return dict((key[len(prefix):], value)
                for key, value in local_env.items() if key.startswith(prefix))

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-v', '--verbose', action='store_true', dest='verbose',
                        help='print commands as they are run')
    parser.add_argument('-n', '--dry-run', action='store_true', dest='dry_run',
                        help="print the commands that would have been run, but "
                             "don't actually run them")

    parser.add_argument('--remote-dir', required=True, metavar='PATH',
                        help='(required) a writable temporary path on the '
                             'remote machine')
    parser.add_argument('--input-prefix',
                        help='arguments matching this prefix will be uploaded')
    parser.add_argument('--output-prefix',
                        help='arguments matching this prefix will be both '
                             'uploaded and downloaded')
    parser.add_argument('--remote-input-prefix', default='input',
                        help='input arguments use this prefix on the remote '
                             'machine')
    parser.add_argument('--remote-output-prefix', default='output',
                        help='output arguments use this prefix on the remote '
                             'machine')

    parser.add_argument('-i', '--identity', dest='identity', metavar='FILE',
                        help='an SSH identity file (private key) to use')
    parser.add_argument('-o', '--ssh-option', action='append', default=[],
                        dest='ssh_options', metavar='OPTION',
                        help='extra SSH config options (man ssh_config)')
    parser.add_argument('--debug-as-local', metavar='/PATH/TO/SFTP-SERVER',
                        help='run commands locally instead of over SSH, for '
                             'debugging purposes. The "host" argument is '
                             'omitted.')

    parser.add_argument('host',
                        help='the host to connect to, in the form '
                             '[user@]host[:port]')
    parser.add_argument('command', nargs=argparse.REMAINDER,
                        help='the command to run', metavar='command...')
    args = parser.parse_args()

    if args.debug_as_local:
        runner = LocalCommandRunner(args.debug_as_local)
        args.command.insert(0, args.host)
        del args.host
    else:
        runner = RemoteCommandRunner(args.host, args.identity, args.ssh_options)
    runner.dry_run = args.dry_run
    runner.verbose = args.verbose or args.dry_run

    assert not args.remote_dir == '/'

    upload_files = dict()
    download_files = dict()
    remote_test_specific_dir = None
    if args.input_prefix:
        assert not args.remote_input_prefix.startswith("..")
        remote_dir = posixpath.join(args.remote_dir, args.remote_input_prefix)
        input_files = find_transfers(args.command, args.input_prefix, 
                                     remote_dir)
        assert not any(upload_files.has_key(f) for f in input_files)
        upload_files.update(input_files)
    if args.output_prefix:
        assert not args.remote_output_prefix.startswith("..")
        remote_dir = posixpath.join(args.remote_dir, args.remote_output_prefix)
        test_files = find_transfers(args.command, args.output_prefix, 
                                    remote_dir)
        assert not any(upload_files.has_key(f) for f in test_files)
        upload_files.update(test_files)
        assert not any(download_files.has_key(f) for f in test_files)
        download_files.update(test_files)
        remote_test_specific_dir = remote_dir

    if remote_test_specific_dir:
        assert remote_test_specific_dir.startswith(args.remote_dir)
        runner.run_remote(['/bin/rm', '-rf', remote_test_specific_dir])
    if upload_files:
        runner.send(upload_files)

    remote_env = collect_remote_env()

    translated_command = [upload_files.get(arg, download_files.get(arg, arg))
                          for arg in args.command]
    runner.run_remote(translated_command, remote_env)

    if download_files:
        runner.fetch(download_files)

if __name__ == "__main__":
    main()
