From cee3212139b79f2615c423152dafbc3c9ad18298 Mon Sep 17 00:00:00 2001 From: Stefano De Rosa Date: Fri, 15 Jul 2022 21:43:18 +0200 Subject: [PATCH] Add function mapping feature --- asm2vec/datatype.py | 33 +++++++++++++++++++++++++++++++++ asm2vec/model.py | 1 + asm2vec/utils.py | 14 +++++++++++--- scripts/extract_vectors.py | 20 ++++++++++++++++++++ scripts/train.py | 7 +++++-- 5 files changed, 70 insertions(+), 5 deletions(-) create mode 100644 scripts/extract_vectors.py diff --git a/asm2vec/datatype.py b/asm2vec/datatype.py index a3cd39b..2330532 100644 --- a/asm2vec/datatype.py +++ b/asm2vec/datatype.py @@ -10,6 +10,39 @@ def __init__(self, name, index): def __str__(self): return self.name +class FunctionMapper: + def __init__(self, functions=[]): + self.size = 0 + self.map = {} + for fn in functions: + self.map[self.size] = {'name': fn.meta['name'], 'binary': fn.meta['file']} + self.size += 1 + def add_fn(self): + self.map[self.size] = {'name': fn.meta['name'], 'binary': fn.meta['file']} + self.size += 1 + def load_state_dict(self, sd): + self.size = sd['size'] + self.map = sd['map_embds_to_fn'] + def state_dict(self): + return {'size': self.size, 'map_embds_to_fn' : self.map} + def update(self, functions): + for fn in functions: + self.map[self.size] = {'name': fn.meta['name'], 'binary': fn.meta['file']} + self.size += 1 + +class Binary: + def __init__(self, name, hash_code): + self.name = name + self.hash_code = hash_code + self.functions = [] + def add_function(fn): + self.functions.append(fn) + def load_state_dict(self, sd): + self.name = sd['name'] + self.functions = sd['functions'] + def state_dict(self): + return {'name': self.name, 'functions': self.functions} + class Tokens: def __init__(self, name_to_index=None, tokens=None): self.name_to_index = name_to_index or {} diff --git a/asm2vec/model.py b/asm2vec/model.py index 301f3be..b382d01 100644 --- a/asm2vec/model.py +++ b/asm2vec/model.py @@ -39,5 +39,6 @@ def forward(self, inp, pos, neg): def predict(self, inp, pos): device, batch_size = inp.device, inp.shape[0] v = self.v(inp) + print(v) probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2) return softmax(probs) diff --git a/asm2vec/utils.py b/asm2vec/utils.py index 4f9aa25..75531be 100644 --- a/asm2vec/utils.py +++ b/asm2vec/utils.py @@ -3,7 +3,7 @@ import torch from torch.utils.data import DataLoader, Dataset from pathlib import Path -from .datatype import Tokens, Function, Instruction +from .datatype import Tokens, Function, Instruction, FunctionMapper from .model import ASM2VEC class AsmDataset(Dataset): @@ -32,6 +32,7 @@ def load_data(paths, limit=None): break with open(filename) as f: fn = Function.load(f.read()) + #print(f'function {fn.meta["name"]} at file -> {fn.meta["file"]}') functions.append(fn) tokens.add(fn.tokens()) @@ -42,13 +43,16 @@ def preprocess(functions, tokens): for i, fn in enumerate(functions): for seq in fn.random_walk(): for j in range(1, len(seq) - 1): + #print(f"[+] preprocess : fn {i}, metadata = {fn.meta}, seq : {[i] + seq[j].tokens()}") x.append([i] + [tokens[token].index for token in seq[j-1].tokens() + seq[j+1].tokens()]) y.append([tokens[token].index for token in seq[j].tokens()]) + print(torch.tensor(x)[:, 0]) return torch.tensor(x), torch.tensor(y) def train( functions, tokens, + fn_mapper, model=None, embedding_size=100, batch_size=1024, @@ -94,6 +98,7 @@ def train( callback({ 'model': model, 'tokens': tokens, + 'fn_mapper': fn_mapper, 'epoch': epoch, 'time': time.time() - start, 'loss': loss_sum / loss_count, @@ -102,7 +107,7 @@ def train( return model -def save_model(path, model, tokens): +def save_model(path, model, tokens, fn_mapper): torch.save({ 'model_params': ( model.embeddings.num_embeddings, @@ -111,16 +116,19 @@ def save_model(path, model, tokens): ), 'model': model.state_dict(), 'tokens': tokens.state_dict(), + 'function_mapper': fn_mapper.state_dict() }, path) def load_model(path, device='cpu'): checkpoint = torch.load(path, map_location=device) tokens = Tokens() tokens.load_state_dict(checkpoint['tokens']) + fn_mapper = FunctionMapper() + fn_mapper.load_state_dict(checkpoint['function_mapper']) model = ASM2VEC(*checkpoint['model_params']) model.load_state_dict(checkpoint['model']) model = model.to(device) - return model, tokens + return model, tokens, fn_mapper def show_probs(x, y, probs, tokens, limit=None, pretty=False): if pretty: diff --git a/scripts/extract_vectors.py b/scripts/extract_vectors.py new file mode 100644 index 0000000..2436bf7 --- /dev/null +++ b/scripts/extract_vectors.py @@ -0,0 +1,20 @@ +import torch +import click +import asm2vec +import json + +@click.command() +@click.option('-m', '--model', 'mpath', help='model path', required=True) +@click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True) +@click.option('-o', '--output', 'out', help='output file path that contains vectors', required=True) +def cli(mpath, device, out): + if device == 'auto': + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # load model, tokens + model, token, fn_mapper = asm2vec.utils.load_model(mpath, device=device) + with open(out, 'w+') as f: + json.dump({'embeddings': model.to('cpu').embeddings_f.weight.data.numpy().tolist(), 'function_mapper': fn_mapper.state_dict()}, f) + +if __name__ == '__main__': + cli() diff --git a/scripts/train.py b/scripts/train.py index 98391f4..bcb2636 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -19,24 +19,27 @@ def cli(ipath, opath, mpath, limit, embedding_size, batch_size, epochs, neg_samp device = 'cuda' if torch.cuda.is_available() else 'cpu' if mpath: - model, tokens = asm2vec.utils.load_model(mpath, device=device) + model, tokens, fn_mapper = asm2vec.utils.load_model(mpath, device=device) functions, tokens_new = asm2vec.utils.load_data(ipath, limit=limit) tokens.update(tokens_new) + fn_mapper.update(functions) model.update(len(functions), tokens.size()) else: model = None functions, tokens = asm2vec.utils.load_data(ipath, limit=limit) + fn_mapper = asm2vec.datatype.FunctionMapper(functions) def callback(context): progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}' if context["accuracy"]: progress += f', accuracy = {context["accuracy"]:.4f}' print(progress) - asm2vec.utils.save_model(opath, context["model"], context["tokens"]) + asm2vec.utils.save_model(opath, context["model"], context["tokens"], context['fn_mapper']) model = asm2vec.utils.train( functions, tokens, + fn_mapper, model=model, embedding_size=embedding_size, batch_size=batch_size,