#!/usr/bin/env python
# blockifyasm ----- Split disassembly into basic blocks ---------*- python -*-
#
# This source file is part of the Swift.org open source project
#
# Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
# Licensed under Apache License v2.0 with Runtime Library Exception
#
# See https://swift.org/LICENSE.txt for license information
# See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
#
# ----------------------------------------------------------------------------
#
# Splits a disassembled function from lldb into basic blocks.
# 
# Useful to show the control flow graph of a disassembled function.
# The control flow graph can the be viewed with the viewcfg utility:
#
# (lldb) disassemble 
#    <copy-paste output to file.s>
# $ blockifyasm < file.s | viewcfg
#
# ----------------------------------------------------------------------------

from __future__ import print_function

import re
import sys
from collections import defaultdict


def help():
    print("""\
Usage:

blockifyasm [-<n>] < file

-<n>: only match <n> significant digits of relative branch addresses
""")


def main():

  addr_len = 16
  if len(sys.argv) >= 2:
    m = re.match('^-([0-9]+)$', sys.argv[1])
    if m:
      addr_len = int(m.group(1))
    else:
      help()
      return

  lines = []
  block_starts = {}

  branch_re1 = re.compile('^\s[-\s>]*0x.*:\s.* 0x([0-9a-f]+)\s*;\s*<[+-]')
  branch_re2 = re.compile('^\s[-\s>]*0x.*:\s+tb.* 0x([0-9a-f]+)\s*(;.*)?')
  inst_re = re.compile('^\s[-\s>]*0x([0-9a-f]+)[\s<>0-9+-]*:\s+([a-z0-9.]+)\s')
  non_fall_through_insts = [ 'b', 'ret', 'brk', 'jmp', 'retq', 'ud2' ]

  def get_branch_addr(line):
    bm = branch_re1.match(line)
    if bm:
      return bm.group(1)[-addr_len:]
    bm = branch_re2.match(line)
    if bm:
      return bm.group(1)[-addr_len:]
    return None

  def print_function():
    if not lines:
      return
    predecessors = defaultdict(list)
    block_num = -1
    next_is_block = True
    prev_is_fallthrough = False

    # Collect predecessors for all blocks
    for line in lines:
      m = inst_re.match(line)
      assert m, "non instruction line in function"
      addr = m.group(1)[-addr_len:]
      inst = m.group(2)
      if next_is_block or addr in block_starts:
        if prev_is_fallthrough:
          predecessors[addr].append(block_num)

        block_num += 1
        block_starts[addr] = block_num
        next_is_block = False

      prev_is_fallthrough = True
      br_addr = get_branch_addr(line)
      if br_addr:
        next_is_block = True
        predecessors[br_addr].append(block_num)

      prev_is_fallthrough = (not inst in non_fall_through_insts)

    # Print the function with basic block labels
    print('{')
    for line in lines:
      m = inst_re.match(line)
      if m:
        addr = m.group(1)[-addr_len:]
        if addr in block_starts:
          blockstr = 'bb' + str(block_starts[addr]) + ':'
          if predecessors[addr]:
            print(blockstr + ' ' * (55 - len(blockstr)) + '; preds = ', end='')
            print(', '.join('bb' + str(pred) for pred in predecessors[addr]))
          else:
            print(blockstr)
      
      br_addr = get_branch_addr(line)
      if br_addr and block_starts[br_addr] >= 0:
        line = re.sub(';\s<[+-].*', '; bb' + str(block_starts[br_addr]), line)

      print(line, end='')
    print('}')

  # Read disassembly code from stdin
  for line in sys.stdin:
    # let the line with the instruction pointer begin with a space
    line = re.sub('^-> ', ' ->', line)

    if inst_re.match(line):
      lines.append(line)
      br_addr = get_branch_addr(line)
      if br_addr:
        if len(br_addr) < addr_len:
          addr_len = len(br_addr)
        block_starts[br_addr] = -1
    else:
      print_function()
      lines = []
      block_starts = {}
      print(line, end='')

  print_function()

if __name__ == '__main__':
    main()
