diff --git a/tests/uFTB-with-ftq/env/executor.py b/tests/uFTB-with-ftq/env/executor.py index 6ca14c3..821b1d0 100644 --- a/tests/uFTB-with-ftq/env/executor.py +++ b/tests/uFTB-with-ftq/env/executor.py @@ -3,13 +3,16 @@ from .config import * import os os.sys.path.append(UTILS_PATH) -from BRTParser import BRTParser +from BRTParser import BRTParser, RandomBPTTrace class Executor: """Get program real execution instruction flow.""" def __init__(self, filename, reset_vector=0x80000000): - self._executor = BRTParser().fetch(filename) + if str(os.getenv("RANDOM_BPT")).lower() in ["1", "true"]: + self._executor = RandomBPTTrace().gen(start_address=reset_vector, pc_range_size=100_000) + else: + self._executor = BRTParser().fetch(filename) self._current_branch = next(self._executor) self._current_pc = reset_vector diff --git a/tests/uFTB-with-ftq/env/ftq.py b/tests/uFTB-with-ftq/env/ftq.py index 3e7d2b1..b452e55 100644 --- a/tests/uFTB-with-ftq/env/ftq.py +++ b/tests/uFTB-with-ftq/env/ftq.py @@ -38,7 +38,7 @@ class PredictionStatistician: summary_str += "[Conditional Branches]\n" cond_branches_total = sum([record[0] for record in self.cond_branches_list.values()]) cond_branches_correct = sum([record[1] for record in self.cond_branches_list.values()]) - summary_str += f"Total: {cond_branches_total}, Correct: {cond_branches_correct}, Accuracy: {cond_branches_correct / cond_branches_total}\n" + summary_str += f"Total: {cond_branches_total}, Correct: {cond_branches_correct}, Accuracy: {cond_branches_correct / max(1,cond_branches_total)}\n" for pc, record in self.cond_branches_list.items(): summary_str += f"PC: {hex(pc)}\tTotal: {record[0]}\tCorrect: {record[1]}\tAccuracy: {record[1] / record[0]}\n" @@ -46,14 +46,14 @@ class PredictionStatistician: summary_str += "[Jump Branches]\n" jmp_branches_total = sum([record[1] for record in self.jmp_branches_list.values()]) jmp_branches_correct = sum([record[2] for record in self.jmp_branches_list.values()]) - summary_str += f"Total: {jmp_branches_total}, Correct: {jmp_branches_correct}, Accuracy: {jmp_branches_correct / jmp_branches_total}\n" + summary_str += f"Total: {jmp_branches_total}, Correct: {jmp_branches_correct}, Accuracy: {jmp_branches_correct / max(1,jmp_branches_total)}\n" for pc, record in self.jmp_branches_list.items(): - summary_str += f"PC: {hex(pc)}\tType: {record[0]}\tTotal: {record[1]}\tCorrect: {record[2]}\tAccuracy: {record[2] / record[1]}\n" + summary_str += f"PC: {hex(pc)}\tType: {record[0]}\tTotal: {record[1]}\tCorrect: {record[2]}\tAccuracy: {record[2] / max(1,record[1])}\n" summary_str += "[All Branches]\n" total = cond_branches_total + jmp_branches_total correct = cond_branches_correct + jmp_branches_correct - summary_str += f"Total: {total}, Correct: {correct}, Accuracy: {correct / total}\n" + summary_str += f"Total: {total}, Correct: {correct}, Accuracy: {correct /max(1, total)}\n" info(summary_str) diff --git a/utils/BRTParser/__init__.py b/utils/BRTParser/__init__.py index 4756afb..8d9d5b2 100644 --- a/utils/BRTParser/__init__.py +++ b/utils/BRTParser/__init__.py @@ -3,6 +3,9 @@ import os from .util import * import time +import random +import bisect + class BRTParser: @@ -129,3 +132,95 @@ class BRTParser: elif ".CALL" in k: all_cal += data["taken"] print("%5d %8s %8d %8d %8d (%d checks, ins.ret - ins.call = %d)\n" % (len(keys), "ALL", count, taken, notaken, taken + notaken, all_ret - all_cal)) + + +class RandomBPTTrace(object): + def __init__(self) -> None: + self.branch_type = [ + "C.J", + "C.JR", + "C.CALL", + "C.RET", + "C.JALR", + "P.JAL", + "P.CALL", + "P.RET", + "*.CBR", + "I.JAL", + "I.JALR", + "I.CALL", + "I.RET", + ] + self.branch_list = [] + + def gen(self, start_address=None, pc_range_size=None, br_count=None, max_repeat=100, br_max_count=1000000, max_yield=1e9, seed=None, address_width=39, branch_type=None, min_gap=0x100): + max_address = 2 ** address_width - 1 + if seed is not None: + random.seed(seed) + if start_address is None: + start_address = random.randint(0, max_address - min_gap) + if pc_range_size is None: + pc_range_size = max(min_gap, random.randint(start_address + min_gap, max_address) - start_address) + if br_count is None: + br_count = max(1, random.randint(1, int(br_max_count/2)) % int(pc_range_size/2)) + br_list = [] + pc_list = [] + tg_list = [] + rp_list = [] + br_types_cp = [] + br_types_nm = [] + ins_size = 4 + br_types = branch_type if branch_type is not None else self.branch_type + for br in br_types: + if br.startswith("C."): + br_types_cp.append(br) + ins_size = 2 + else: + br_types_nm.append(br) + def gen_target(): + addr = random.randint(start_address, start_address + pc_range_size) + return addr - (addr % ins_size) + for pc in sorted(set([pc - (pc%ins_size) for pc in random.sample(range(start_address, start_address + pc_range_size), br_count)])): + pc_list.append(pc) + rp_list.append(max_repeat) + tg_list.append(gen_target()) + if ins_size == 2: + if pc % 4 != 0: + br_list.append(random.choice(br_types_cp)) # must be compress + else: + br_list.append(random.choice(br_types)) # can be compress or normal + else: + br_list.append(random.choice(br_types_nm)) # must be normal + pc_index = 0 + pc_index_max = len(pc_list) + rt_yeild = 0 + while True: + if pc_index >= pc_index_max: + break + repeat = rp_list[pc_index] + if repeat <= 0: + pc_index += 1 + continue + rp_list[pc_index] -= 1 + br_t = br_list[pc_index] + pc = pc_list[pc_index] + taken = True + target = tg_list[pc_index] + if random.randint(0, 100) < random.randint(0, 100): + target = gen_target() + if br_t == "*.CBR": + if random.randint(0, 100) < random.randint(0, 100): + taken = False + if taken: + if target > pc: + pc_index = bisect.bisect_left(pc_list[pc_index:], target) + pc_index + else: + pc_index = bisect.bisect_left(pc_list[:pc_index], target) + else: + pc_index += 1 + data = {"pc": pc, "index": rt_yeild, "target": target, "taken": taken, "type": br_t} + yield data + rt_yeild += 1 + if rt_yeild > max_yield: + break + return None