Skip to content

Commit 0e2b4b7

Browse files
committed
Add parameter in sample.py to write vector files for each anomaly type
1 parent 305ec02 commit 0e2b4b7

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

sample.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
parser.add_argument("--train_ratio", default=0.01, help="fraction of normal data used for training", type=float)
1010
parser.add_argument("--time_window", default=None, help="size of the fixed time window in seconds (setting this parameter replaces session-based with window-based grouping)", type=float)
1111
parser.add_argument("--sample_ratio", default=1.0, help="fraction of data sampled from normal and anomalous events", type=float)
12-
parser.add_argument("--sorting", default="random", help="sorting mode", type=str, choices=['random', 'chronological'])
12+
parser.add_argument("--sorting", default="random", help="sorting mode: pick sequences randomly (random) or only pick the first ones (chronological)", type=str, choices=['random', 'chronological'])
13+
parser.add_argument("--anomaly_types", default="False", help="set to True to additionally create sequence files for each anomaly type (files are named <dataset>_test_abnormal_<anomaly>", type=str, choices=['True', 'False'])
1314

1415
params = vars(parser.parse_args())
1516
source = params["data_dir"]
1617
train_ratio = params["train_ratio"]
1718
tw = params["time_window"]
1819
sample_ratio = params["sample_ratio"]
1920
sorting = params["sorting"]
21+
output_anomaly_types = params["anomaly_types"]
2022

2123
if source in ['adfa_verazuo', 'hdfs_xu', 'hdfs_loghub', 'openstack_loghub', 'openstack_parisakalaki', 'hadoop_loghub', 'awsctd_djpasco'] and tw is not None:
2224
# Only BGL and Thunderbird should be used with time-window based grouping
@@ -124,7 +126,7 @@ def do_sample(source, train_ratio, sorting="random", tw=None):
124126
train.write(str(seq_id) + ',' + ' '.join([str(event) for event in event_list]) + '\n')
125127
else:
126128
test_norm.write(str(seq_id) + ',' + ' '.join([str(event) for event in event_list]) + '\n')
127-
elif label == "Anomaly":
129+
elif label == "Anomaly" or output_anomaly_types == "False":
128130
for seq_id, event_list in seq_id_dict.items():
129131
test_abnormal.write(str(seq_id) + ',' + ' '.join([str(event) for event in event_list]) + '\n')
130132
else:

0 commit comments

Comments
 (0)