Skip to content

Add function mapping feature #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions asm2vec/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,39 @@ def __init__(self, name, index):
def __str__(self):
return self.name

class FunctionMapper:
def __init__(self, functions=[]):
self.size = 0
self.map = {}
for fn in functions:
self.map[self.size] = {'name': fn.meta['name'], 'binary': fn.meta['file']}
self.size += 1
def add_fn(self):
self.map[self.size] = {'name': fn.meta['name'], 'binary': fn.meta['file']}
self.size += 1
def load_state_dict(self, sd):
self.size = sd['size']
self.map = sd['map_embds_to_fn']
def state_dict(self):
return {'size': self.size, 'map_embds_to_fn' : self.map}
def update(self, functions):
for fn in functions:
self.map[self.size] = {'name': fn.meta['name'], 'binary': fn.meta['file']}
self.size += 1

class Binary:
def __init__(self, name, hash_code):
self.name = name
self.hash_code = hash_code
self.functions = []
def add_function(fn):
self.functions.append(fn)
def load_state_dict(self, sd):
self.name = sd['name']
self.functions = sd['functions']
def state_dict(self):
return {'name': self.name, 'functions': self.functions}

class Tokens:
def __init__(self, name_to_index=None, tokens=None):
self.name_to_index = name_to_index or {}
Expand Down
1 change: 1 addition & 0 deletions asm2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@ def forward(self, inp, pos, neg):
def predict(self, inp, pos):
device, batch_size = inp.device, inp.shape[0]
v = self.v(inp)
print(v)
probs = torch.bmm(self.embeddings_r(torch.arange(self.embeddings_r.num_embeddings).repeat(batch_size, 1).to(device)), v).squeeze(dim=2)
return softmax(probs)
14 changes: 11 additions & 3 deletions asm2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch.utils.data import DataLoader, Dataset
from pathlib import Path
from .datatype import Tokens, Function, Instruction
from .datatype import Tokens, Function, Instruction, FunctionMapper
from .model import ASM2VEC

class AsmDataset(Dataset):
Expand Down Expand Up @@ -32,6 +32,7 @@ def load_data(paths, limit=None):
break
with open(filename) as f:
fn = Function.load(f.read())
#print(f'function {fn.meta["name"]} at file -> {fn.meta["file"]}')
functions.append(fn)
tokens.add(fn.tokens())

Expand All @@ -42,13 +43,16 @@ def preprocess(functions, tokens):
for i, fn in enumerate(functions):
for seq in fn.random_walk():
for j in range(1, len(seq) - 1):
#print(f"[+] preprocess : fn {i}, metadata = {fn.meta}, seq : {[i] + seq[j].tokens()}")
x.append([i] + [tokens[token].index for token in seq[j-1].tokens() + seq[j+1].tokens()])
y.append([tokens[token].index for token in seq[j].tokens()])
print(torch.tensor(x)[:, 0])
return torch.tensor(x), torch.tensor(y)

def train(
functions,
tokens,
fn_mapper,
model=None,
embedding_size=100,
batch_size=1024,
Expand Down Expand Up @@ -94,6 +98,7 @@ def train(
callback({
'model': model,
'tokens': tokens,
'fn_mapper': fn_mapper,
'epoch': epoch,
'time': time.time() - start,
'loss': loss_sum / loss_count,
Expand All @@ -102,7 +107,7 @@ def train(

return model

def save_model(path, model, tokens):
def save_model(path, model, tokens, fn_mapper):
torch.save({
'model_params': (
model.embeddings.num_embeddings,
Expand All @@ -111,16 +116,19 @@ def save_model(path, model, tokens):
),
'model': model.state_dict(),
'tokens': tokens.state_dict(),
'function_mapper': fn_mapper.state_dict()
}, path)

def load_model(path, device='cpu'):
checkpoint = torch.load(path, map_location=device)
tokens = Tokens()
tokens.load_state_dict(checkpoint['tokens'])
fn_mapper = FunctionMapper()
fn_mapper.load_state_dict(checkpoint['function_mapper'])
model = ASM2VEC(*checkpoint['model_params'])
model.load_state_dict(checkpoint['model'])
model = model.to(device)
return model, tokens
return model, tokens, fn_mapper

def show_probs(x, y, probs, tokens, limit=None, pretty=False):
if pretty:
Expand Down
20 changes: 20 additions & 0 deletions scripts/extract_vectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import click
import asm2vec
import json

@click.command()
@click.option('-m', '--model', 'mpath', help='model path', required=True)
@click.option('-c', '--device', default='auto', help='hardware device to be used: cpu / cuda / auto', show_default=True)
@click.option('-o', '--output', 'out', help='output file path that contains vectors', required=True)
def cli(mpath, device, out):
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# load model, tokens
model, token, fn_mapper = asm2vec.utils.load_model(mpath, device=device)
with open(out, 'w+') as f:
json.dump({'embeddings': model.to('cpu').embeddings_f.weight.data.numpy().tolist(), 'function_mapper': fn_mapper.state_dict()}, f)

if __name__ == '__main__':
cli()
7 changes: 5 additions & 2 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,27 @@ def cli(ipath, opath, mpath, limit, embedding_size, batch_size, epochs, neg_samp
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if mpath:
model, tokens = asm2vec.utils.load_model(mpath, device=device)
model, tokens, fn_mapper = asm2vec.utils.load_model(mpath, device=device)
functions, tokens_new = asm2vec.utils.load_data(ipath, limit=limit)
tokens.update(tokens_new)
fn_mapper.update(functions)
model.update(len(functions), tokens.size())
else:
model = None
functions, tokens = asm2vec.utils.load_data(ipath, limit=limit)
fn_mapper = asm2vec.datatype.FunctionMapper(functions)

def callback(context):
progress = f'{context["epoch"]} | time = {context["time"]:.2f}, loss = {context["loss"]:.4f}'
if context["accuracy"]:
progress += f', accuracy = {context["accuracy"]:.4f}'
print(progress)
asm2vec.utils.save_model(opath, context["model"], context["tokens"])
asm2vec.utils.save_model(opath, context["model"], context["tokens"], context['fn_mapper'])

model = asm2vec.utils.train(
functions,
tokens,
fn_mapper,
model=model,
embedding_size=embedding_size,
batch_size=batch_size,
Expand Down