#!/usr/bin/env python ''' ''' __docformat__ = 'restructuredtext' __version__ = '$Id: $' import copy import parser import pprint import symbol import token class Inlineable: def __init__(self, name, stmts): self.name = name self.params = [] self.stmts = stmts class InlineableParam: def __init__(self, name): self.name = name self.count = 0 def __hash__(self): return hash(self.name) def __cmp__(self, other): if isinstance(other, InlineableParam): return cmp(self.name, other.name) else: return -1 def ast_names(tree): if type(tree) in (tuple, list): if type(tree[0]) == int: return [symbol.sym_name.get(tree[0], token.tok_name.get(tree[0], tree[0]))] + \ [ast_names(i) for i in tree[1:]] return tree def ast_print(tree): if hasattr(tree, 'totuple'): tree = tree.totuple() pprint.pprint(ast_names(tree)) def match(pattern, data, vars=None): # from python manual 18.1.6.2 if vars is None: vars = {} if type(pattern) == list: vars[pattern[0]] = data return 1, vars if type(pattern) != tuple: return (pattern == data), vars if len(data) < len(pattern): return 0, vars for pattern, data in map(None, pattern, data): if not pattern: break same, vars = match(pattern, data, vars) if not same: break return same, vars def replace(data, replacements): for i in range(len(data)): if type(data[i]) != list and data[i] in replacements: data[i] = replacements[data[i]] elif type(data[i]) in (list, tuple): replace(data[i], replacements) funcs = {} FUNC_PATTERN = ( symbol.funcdef, (token.NAME, 'def'), (token.NAME, ['func_name']), ['func_params'], (token.COLON, ':'), ['func_suite']) FPDEF_SIMPLE_PATTERN = ( symbol.fpdef, (token.NAME, ['param_name'])) FPDEF_TUPLE_PATTERN = ( symbol.fpdef, (token.LPAR, '('), (symbol.fplist, ['fp_list']), (token.RPAR, ')')) def find_funcs(tree): if type(tree) not in (tuple, list): return same, vars = match(FUNC_PATTERN, tree) if same: func_name = vars['func_name'] func_stmts = copy.deepcopy(vars['func_suite'][1:][2:-1]) func_params = vars['func_params'][2] toohard = False params = [] param_map = {} if func_params[0] == symbol.varargslist: for param in func_params[1:]: same, vars = match(FPDEF_SIMPLE_PATTERN, param) if same: param_name = vars['param_name'] params.append(InlineableParam(param_name)) param_map[param_name] = params[-1] same, vars = match(FPDEF_TUPLE_PATTERN, param) if same: print 'toohard: formal param is tuple' toohard = True break if param[0] in (token.STAR, token.DOUBLESTAR): print 'toohard: formal param has star or doublestar' toohard = True break if not toohard: replace_func_returns(func_stmts) replace_func_params(func_stmts, param_map) funcs[func_name] = Inlineable(func_name, func_stmts) funcs[func_name].params = params for i in tree[1:]: find_funcs(i) RETURN_PATTERN = ( symbol.stmt, (symbol.simple_stmt, (symbol.small_stmt, (symbol.flow_stmt, (symbol.return_stmt, (token.NAME, 'return'), ['value']))))) RETURN_REPLACE = [ symbol.stmt, [symbol.simple_stmt, [symbol.small_stmt, [symbol.expr_stmt, [symbol.testlist, [symbol.test, [symbol.and_test, [symbol.not_test, [symbol.comparison, [symbol.expr, [symbol.xor_expr, [symbol.and_expr, [symbol.shift_expr, [symbol.arith_expr, [symbol.term, [symbol.factor, [symbol.power, [symbol.atom, [token.NAME, '___RESULT_NAME']]]]]]]]]]]]]]], [token.EQUAL, '='], '___RESULT_VALUE']], [token.NEWLINE, '']]] def replace_func_returns(stmts): for i in range(len(stmts)): if type(stmts[i]) not in (list, tuple): continue same, vars = match(RETURN_PATTERN, stmts[i]) if same: stmts[i] = copy.deepcopy(RETURN_REPLACE) replace(stmts[i], {'___RESULT_VALUE': vars['value']}) else: replace_func_returns(stmts[i]) PARAM_PATTERN = ( symbol.atom, (token.NAME, ['name'])) PARAM_REPLACE = [ symbol.atom, [token.LPAR, '('], [symbol.testlist_gexp, '___PARAM_VALUE'], [token.RPAR, ')']] def replace_func_params(stmts, params): for i in range(len(stmts)): if type(stmts[i]) not in (list, tuple): continue same, vars = match(PARAM_PATTERN, stmts[i]) if same and vars['name'] in params: stmts[i] = copy.deepcopy(PARAM_REPLACE) name = vars['name'] replace(stmts[i], {'___PARAM_VALUE': params[name]}) params[name].count += 1 else: replace_func_params(stmts[i], params) FUNC_NAME_PATTERN = (symbol.atom, (token.NAME, ['func_name'])) def make_use_result(name): return (symbol.atom, (token.NAME, name)) _resultnum = 0 def make_result_name(): global _resultnum _resultnum += 1 return '___result_%05d' % _resultnum def inline_funcs(tree, stmt_container=None, stmt_index=0): if type(tree) not in (tuple, list): return stmt_index old_index = stmt_index if tree[0] in (symbol.file_input, symbol.suite): stmt_container = tree stmt_index = 1 i = 1 func_name = None while i < len(tree): found, vars = match(FUNC_NAME_PATTERN, tree[i]) if found: func_name = vars['func_name'] elif func_name and \ func_name in funcs and \ type(tree[i] in (list, tuple)) and \ tree[i][0] == symbol.trailer and \ tree[i][1][0] == token.LPAR: inlineable = funcs[func_name] func_apply_params = tree[i][2][1:] toohard = False param_dict = {} param_index = 0 for param in func_apply_params: if param[0] == symbol.argument: if len(param) > 2: print 'toohard: keyword arg or generator' toohard = True break else: formal = inlineable.params[param_index] param_index += 1 if formal.count == 1: param_dict[formal] = param[1] else: print 'toohard: param used more than once' toohard = True break elif param[0] in (token.STAR, token.DOUBLESTAR): print 'toohard: star or doublestar in arglist' toohard = True break if not toohard: del tree[i] i -= 1 result_name = make_result_name() tree[i] = make_use_result(result_name) param_dict['___RESULT_NAME'] = result_name for s in inlineable.stmts: s = copy.deepcopy(s) replace(s, param_dict) stmt_container.insert(stmt_index, s) stmt_index += 1 func_name = None else: stmt_index = inline_funcs(tree[i], stmt_container, stmt_index) func_name = None if stmt_container is tree: stmt_index += 1 i = stmt_index else: i += 1 if stmt_container is tree: return old_index return stmt_index def inline(source, debug_tree=False): ast = parser.suite(source).tolist() find_funcs(ast) inline_funcs(ast) if debug_tree: ast_print(ast) return parser.sequence2ast(ast).compile() test1 = ''' def add(p1, p2): return p1 + p2 t = 2 for i in range(1000000): t = add(t, 2) print t ''' import time code = compile(test1, '?', 'exec') start = time.time() exec(code) print 'No inlining: %0.2f' % (time.time() - start) code = inline(test1, debug_tree=False) start = time.time() exec(code) print 'Inlined: %0.2f' % (time.time() - start)