add random br trace

This commit is contained in:
yaozhicheng 2024-05-23 15:36:24 +08:00
parent 6ac3fa7ba2
commit 8287fc4e5c
3 changed files with 104 additions and 6 deletions

View File

@ -3,13 +3,16 @@ from .config import *
import os import os
os.sys.path.append(UTILS_PATH) os.sys.path.append(UTILS_PATH)
from BRTParser import BRTParser from BRTParser import BRTParser, RandomBPTTrace
class Executor: class Executor:
"""Get program real execution instruction flow.""" """Get program real execution instruction flow."""
def __init__(self, filename, reset_vector=0x80000000): def __init__(self, filename, reset_vector=0x80000000):
self._executor = BRTParser().fetch(filename) if str(os.getenv("RANDOM_BPT")).lower() in ["1", "true"]:
self._executor = RandomBPTTrace().gen(start_address=reset_vector, pc_range_size=100_000)
else:
self._executor = BRTParser().fetch(filename)
self._current_branch = next(self._executor) self._current_branch = next(self._executor)
self._current_pc = reset_vector self._current_pc = reset_vector

View File

@ -38,7 +38,7 @@ class PredictionStatistician:
summary_str += "[Conditional Branches]\n" summary_str += "[Conditional Branches]\n"
cond_branches_total = sum([record[0] for record in self.cond_branches_list.values()]) cond_branches_total = sum([record[0] for record in self.cond_branches_list.values()])
cond_branches_correct = sum([record[1] for record in self.cond_branches_list.values()]) cond_branches_correct = sum([record[1] for record in self.cond_branches_list.values()])
summary_str += f"Total: {cond_branches_total}, Correct: {cond_branches_correct}, Accuracy: {cond_branches_correct / cond_branches_total}\n" summary_str += f"Total: {cond_branches_total}, Correct: {cond_branches_correct}, Accuracy: {cond_branches_correct / max(1,cond_branches_total)}\n"
for pc, record in self.cond_branches_list.items(): for pc, record in self.cond_branches_list.items():
summary_str += f"PC: {hex(pc)}\tTotal: {record[0]}\tCorrect: {record[1]}\tAccuracy: {record[1] / record[0]}\n" summary_str += f"PC: {hex(pc)}\tTotal: {record[0]}\tCorrect: {record[1]}\tAccuracy: {record[1] / record[0]}\n"
@ -46,14 +46,14 @@ class PredictionStatistician:
summary_str += "[Jump Branches]\n" summary_str += "[Jump Branches]\n"
jmp_branches_total = sum([record[1] for record in self.jmp_branches_list.values()]) jmp_branches_total = sum([record[1] for record in self.jmp_branches_list.values()])
jmp_branches_correct = sum([record[2] for record in self.jmp_branches_list.values()]) jmp_branches_correct = sum([record[2] for record in self.jmp_branches_list.values()])
summary_str += f"Total: {jmp_branches_total}, Correct: {jmp_branches_correct}, Accuracy: {jmp_branches_correct / jmp_branches_total}\n" summary_str += f"Total: {jmp_branches_total}, Correct: {jmp_branches_correct}, Accuracy: {jmp_branches_correct / max(1,jmp_branches_total)}\n"
for pc, record in self.jmp_branches_list.items(): for pc, record in self.jmp_branches_list.items():
summary_str += f"PC: {hex(pc)}\tType: {record[0]}\tTotal: {record[1]}\tCorrect: {record[2]}\tAccuracy: {record[2] / record[1]}\n" summary_str += f"PC: {hex(pc)}\tType: {record[0]}\tTotal: {record[1]}\tCorrect: {record[2]}\tAccuracy: {record[2] / max(1,record[1])}\n"
summary_str += "[All Branches]\n" summary_str += "[All Branches]\n"
total = cond_branches_total + jmp_branches_total total = cond_branches_total + jmp_branches_total
correct = cond_branches_correct + jmp_branches_correct correct = cond_branches_correct + jmp_branches_correct
summary_str += f"Total: {total}, Correct: {correct}, Accuracy: {correct / total}\n" summary_str += f"Total: {total}, Correct: {correct}, Accuracy: {correct /max(1, total)}\n"
info(summary_str) info(summary_str)

View File

@ -3,6 +3,9 @@
import os import os
from .util import * from .util import *
import time import time
import random
import bisect
class BRTParser: class BRTParser:
@ -129,3 +132,95 @@ class BRTParser:
elif ".CALL" in k: elif ".CALL" in k:
all_cal += data["taken"] all_cal += data["taken"]
print("%5d %8s %8d %8d %8d (%d checks, ins.ret - ins.call = %d)\n" % (len(keys), "ALL", count, taken, notaken, taken + notaken, all_ret - all_cal)) print("%5d %8s %8d %8d %8d (%d checks, ins.ret - ins.call = %d)\n" % (len(keys), "ALL", count, taken, notaken, taken + notaken, all_ret - all_cal))
class RandomBPTTrace(object):
def __init__(self) -> None:
self.branch_type = [
"C.J",
"C.JR",
"C.CALL",
"C.RET",
"C.JALR",
"P.JAL",
"P.CALL",
"P.RET",
"*.CBR",
"I.JAL",
"I.JALR",
"I.CALL",
"I.RET",
]
self.branch_list = []
def gen(self, start_address=None, pc_range_size=None, br_count=None, max_repeat=100, br_max_count=1000000, max_yield=1e9, seed=None, address_width=39, branch_type=None, min_gap=0x100):
max_address = 2 ** address_width - 1
if seed is not None:
random.seed(seed)
if start_address is None:
start_address = random.randint(0, max_address - min_gap)
if pc_range_size is None:
pc_range_size = max(min_gap, random.randint(start_address + min_gap, max_address) - start_address)
if br_count is None:
br_count = max(1, random.randint(1, int(br_max_count/2)) % int(pc_range_size/2))
br_list = []
pc_list = []
tg_list = []
rp_list = []
br_types_cp = []
br_types_nm = []
ins_size = 4
br_types = branch_type if branch_type is not None else self.branch_type
for br in br_types:
if br.startswith("C."):
br_types_cp.append(br)
ins_size = 2
else:
br_types_nm.append(br)
def gen_target():
addr = random.randint(start_address, start_address + pc_range_size)
return addr - (addr % ins_size)
for pc in sorted(set([pc - (pc%ins_size) for pc in random.sample(range(start_address, start_address + pc_range_size), br_count)])):
pc_list.append(pc)
rp_list.append(max_repeat)
tg_list.append(gen_target())
if ins_size == 2:
if pc % 4 != 0:
br_list.append(random.choice(br_types_cp)) # must be compress
else:
br_list.append(random.choice(br_types)) # can be compress or normal
else:
br_list.append(random.choice(br_types_nm)) # must be normal
pc_index = 0
pc_index_max = len(pc_list)
rt_yeild = 0
while True:
if pc_index >= pc_index_max:
break
repeat = rp_list[pc_index]
if repeat <= 0:
pc_index += 1
continue
rp_list[pc_index] -= 1
br_t = br_list[pc_index]
pc = pc_list[pc_index]
taken = True
target = tg_list[pc_index]
if random.randint(0, 100) < random.randint(0, 100):
target = gen_target()
if br_t == "*.CBR":
if random.randint(0, 100) < random.randint(0, 100):
taken = False
if taken:
if target > pc:
pc_index = bisect.bisect_left(pc_list[pc_index:], target) + pc_index
else:
pc_index = bisect.bisect_left(pc_list[:pc_index], target)
else:
pc_index += 1
data = {"pc": pc, "index": rt_yeild, "target": target, "taken": taken, "type": br_t}
yield data
rt_yeild += 1
if rt_yeild > max_yield:
break
return None