#!/usr/bin/env python
# utils/rth - Resilience test helper
#
# This source file is part of the Swift.org open source project
#
# Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See https://swift.org/LICENSE.txt for license information
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors

from __future__ import print_function

import argparse
import glob
import os
import pipes
import shlex
import shutil
import subprocess
import sys

VERBOSE = True


def verbose_print_command(command):
    if VERBOSE:
        print(" ".join(pipes.quote(c) for c in command))
        sys.stdout.flush()


class ResilienceTest(object):

    def __init__(self, target_build_swift, target_run, target_codesign,
                 target_nm, tmp_dir, test_dir, test_src, lib_prefix,
                 lib_suffix, additional_compile_flags,
                 backward_deployment, no_symbol_diff):
        self.target_build_swift = shlex.split(target_build_swift)
        self.target_run = shlex.split(target_run)
        self.target_codesign = shlex.split(target_codesign)
        self.target_nm = shlex.split(target_nm)
        self.tmp_dir = tmp_dir
        self.test_dir = test_dir
        self.test_src = test_src
        self.lib_prefix = lib_prefix
        self.lib_suffix = lib_suffix
        self.additional_compile_flags = shlex.split(additional_compile_flags)

        self.before_dir = os.path.join(self.tmp_dir, 'before')
        self.after_dir = os.path.join(self.tmp_dir, 'after')
        self.config_dir_map = {'BEFORE': self.before_dir,
                               'AFTER': self.after_dir}

        self.lib_src_name = os.path.basename(self.test_src)[5:]
        self.lib_name = self.lib_src_name[:-6]
        self.lib_src = os.path.join(self.test_dir, 'Inputs', self.lib_src_name)

        self.backward_deployment = backward_deployment
        self.no_symbol_diff = no_symbol_diff

    def run(self):
        self.set_up()
        self.compile_library()
        self.compare_symbols()
        self.compile_main()
        self.link()
        self.execute()
        return 0

    def set_up(self):
        shutil.rmtree(self.tmp_dir, ignore_errors=True)
        os.makedirs(self.after_dir)
        os.makedirs(self.before_dir)

    def is_apple_platform(self):
        return any('-apple-' in arg for arg in self.target_build_swift)

    def compile_library(self):
        for config in self.config_dir_map:
            lib_file = self.lib_prefix + self.lib_name + self.lib_suffix
            output_dylib = os.path.join(self.config_dir_map[config],
                                        lib_file)
            compiler_flags = ['-emit-library', '-emit-module',
                              '-swift-version', '4',
                              '-enable-library-evolution',
                              '-D', config,
                              self.lib_src,
                              '-o', output_dylib]
            if self.is_apple_platform():
                compiler_flags += ['-Xlinker',
                                   '-install_name',
                                   '-Xlinker',
                                   os.path.join('@rpath', lib_file)]

            command = self.target_build_swift + \
                self.additional_compile_flags + \
                compiler_flags
            verbose_print_command(command)
            returncode = subprocess.call(command)
            assert returncode == 0, str(command)

            codesign_cmd = self.target_codesign + [output_dylib]
            verbose_print_command(codesign_cmd)
            returncode = subprocess.call(codesign_cmd)
            assert returncode == 0, str(codesign_cmd)

            if not self.no_symbol_diff:
                # Global symbols only, don't display address or type, don't
                # display undefined symbols
                nm_cmd = self.target_nm + ["-gjU", output_dylib]
                verbose_print_command(nm_cmd)

                symbol_file = os.path.join(self.config_dir_map[config],
                                           self.lib_name + ".symbols")

                with open(symbol_file, "w") as handle:
                    returncode = subprocess.call(nm_cmd, stdout=handle)
                assert returncode == 0, str(nm_cmd)

    def compare_symbols(self):
        if self.no_symbol_diff:
            return

        before_symbol_file = os.path.join(self.config_dir_map["BEFORE"],
                                          self.lib_name + ".symbols")
        after_symbol_file = os.path.join(self.config_dir_map["AFTER"],
                                         self.lib_name + ".symbols")

        diff_cmd = ["diff",
                    "--changed-group-format=%<",
                    "--unchanged-group-format=",
                    before_symbol_file, after_symbol_file]
        try:
            subprocess.check_output(diff_cmd)
        except subprocess.CalledProcessError as e:
            assert len(e.output) == 0, "Removed symbols:\n" + e.output

    def compile_main(self):
        for config in self.config_dir_map:
            # If we're testing backward deployment, we only want to build
            # the app against the new version of the library.
            if self.backward_deployment and \
               config == "BEFORE":
                continue

            output_obj = os.path.join(self.config_dir_map[config], 'main.o')
            compiler_flags = ['-D', config, '-c', self.test_src,
                              '-I', self.config_dir_map[config],
                              '-o', output_obj]
            command = self.target_build_swift + \
                self.additional_compile_flags + \
                compiler_flags
            verbose_print_command(command)
            returncode = subprocess.call(command)
            assert returncode == 0, str(command)

    def configs(self):
        for config1 in self.config_dir_map:
            for config2 in self.config_dir_map:
                if self.backward_deployment and \
                   config2 == "BEFORE":
                    continue

                yield (config1, config2)

    def link(self):
        for config1, config2 in self.configs():
            config1_lower = config1.lower()
            config2_lower = config2.lower()
            output_obj = os.path.join(self.tmp_dir,
                                      config1_lower + '_' + config2_lower)
            if self.is_apple_platform():
                rpath_origin = '@executable_path'
            else:
                rpath_origin = '$ORIGIN'

            compiler_flags = [
                '-L', self.config_dir_map[config2],
                '-l' + self.lib_name,
                os.path.join(self.config_dir_map[config2],
                             'main.o'),
                '-Xlinker', '-rpath', '-Xlinker',
                os.path.join(rpath_origin,
                             os.path.relpath(self.config_dir_map[config1],
                                             self.tmp_dir)),
                '-o', output_obj
            ]

            if self.is_apple_platform():
                compiler_flags += ['-Xlinker', '-bind_at_load']

            command = self.target_build_swift + compiler_flags
            verbose_print_command(command)
            returncode = subprocess.call(command)
            assert returncode == 0, str(command)

            codesign_cmd = self.target_codesign + [output_obj]
            verbose_print_command(codesign_cmd)
            returncode = subprocess.call(codesign_cmd)
            assert returncode == 0, str(codesign_cmd)

    def execute(self):
        for config1, config2 in self.configs():
            config1_lower = config1.lower()
            config2_lower = config2.lower()
            output_obj = os.path.join(self.tmp_dir,
                                      config1_lower + '_' + config2_lower)
            tmp_dir_contents = glob.glob(
                os.path.join(self.tmp_dir, self.config_dir_map[config1], '*'))
            command = self.target_run + [output_obj] + tmp_dir_contents
            verbose_print_command(command)
            returncode = subprocess.call(command)
            assert returncode == 0, str(command)


def main():
    parser = argparse.ArgumentParser(description='Resilience test helper')
    parser.add_argument('--target-build-swift', required=True)
    parser.add_argument('--target-run', required=True)
    parser.add_argument('--target-codesign', default='echo')
    parser.add_argument('--target-nm', default='nm')
    parser.add_argument('--t', required=True)
    parser.add_argument('--S', required=True)
    parser.add_argument('--s', required=True)
    parser.add_argument('--lib-prefix', required=True)
    parser.add_argument('--lib-suffix', required=True)
    parser.add_argument('--additional-compile-flags', default='')
    parser.add_argument('--backward-deployment', default=False,
                        action='store_true')
    parser.add_argument('--no-symbol-diff', default=False,
                        action='store_true')

    args = parser.parse_args()

    resilience_test = ResilienceTest(args.target_build_swift, args.target_run,
                                     args.target_codesign, args.target_nm,
                                     args.t, args.S, args.s, args.lib_prefix,
                                     args.lib_suffix,
                                     args.additional_compile_flags,
                                     args.backward_deployment,
                                     args.no_symbol_diff)

    return resilience_test.run()


if __name__ == '__main__':
    exit(main())
