-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassifier.py
More file actions
146 lines (122 loc) · 5.76 KB
/
classifier.py
File metadata and controls
146 lines (122 loc) · 5.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
classifier.py — BERT-based multi-modal forensic artifact classifier.
Case-type prompt re-weights attention across artifact streams.
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.linear_model import LogisticRegression
from config import cfg
from artifact_extractors import Artifact
# ── Case-type weight profiles ─────────────────────────────────────────────────
# Keys map to artifact_type; values are multipliers applied to suspicion scores.
CASE_PROFILES: Dict[str, Dict[str, float]] = {
"financial_fraud": {
"browser": 2.5, # financial site visits, cookies
"metadata": 1.8, # spreadsheets, email archives
"registry": 1.0,
"network": 1.5,
},
"data_theft": {
"metadata": 2.0, # large file transfers
"registry": 2.5, # USB mount events
"network": 2.5, # cloud upload anomalies
"browser": 1.0,
},
"malware": {
"registry": 2.5, # persistence keys
"metadata": 2.0, # dropped executables
"network": 2.0, # C2 beaconing
"browser": 1.0,
},
"default": {
"metadata": 1.0,
"registry": 1.0,
"browser": 1.0,
"network": 1.0,
},
}
# File extensions that are high-value per case type
CASE_EXT_BOOST: Dict[str, List[str]] = {
"financial_fraud": [".xlsx", ".xls", ".csv", ".pst", ".ost", ".mbox", ".pdf"],
"data_theft": [".zip", ".rar", ".7z", ".tar", ".gz", ".db", ".sql"],
"malware": [".exe", ".dll", ".bat", ".ps1", ".vbs", ".js", ".lnk"],
}
@dataclass
class ClassificationResult:
artifact: Artifact
suspicion_score: float # 0–1
label: str # "benign" | "suspicious" | "malicious"
feature_vector: np.ndarray
class ForensicClassifier:
"""
Two-stage classifier:
1. BERT encodes the artifact's text summary → embedding.
2. Logistic regression head maps [embedding + numeric features] → suspicion score.
Case-type multipliers re-weight the final score.
"""
_MODEL_NAME = cfg.bert_model
def __init__(self, case_type: str = "default"):
self.case_type = case_type if case_type in CASE_PROFILES else "default"
self.weights = CASE_PROFILES[self.case_type]
self.ext_boost = set(CASE_EXT_BOOST.get(self.case_type, []))
self.tokenizer = AutoTokenizer.from_pretrained(self._MODEL_NAME)
self.bert = AutoModel.from_pretrained(self._MODEL_NAME)
self.bert.eval()
# Lightweight head — in production, fine-tune on labelled forensic datasets
self._head = LogisticRegression(max_iter=1000)
self._head_fitted = False
# ── Public API ────────────────────────────────────────────────────────────
def fit(self, artifacts: List[Artifact], labels: List[int]) -> None:
"""Fine-tune the logistic head on labelled artifacts (0=benign, 1=suspicious)."""
X = np.vstack([self._featurize(a) for a in artifacts])
self._head.fit(X, labels)
self._head_fitted = True
def classify(self, artifact: Artifact) -> ClassificationResult:
fv = self._featurize(artifact)
if self._head_fitted:
prob = self._head.predict_proba(fv.reshape(1, -1))[0, 1]
else:
prob = self._heuristic_score(artifact)
# Apply case-type weight
weight = self.weights.get(artifact.artifact_type, 1.0)
if artifact.features.get("ext") in self.ext_boost:
weight *= 1.5
score = min(prob * weight, 1.0)
label = "malicious" if score > cfg.threshold_malicious else (
"suspicious" if score > cfg.threshold_suspicious else "benign"
)
return ClassificationResult(artifact, score, label, fv)
def classify_batch(self, artifacts: List[Artifact]) -> List[ClassificationResult]:
return [self.classify(a) for a in artifacts]
# ── Internal ──────────────────────────────────────────────────────────────
def _bert_embed(self, text: str) -> np.ndarray:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=cfg.bert_max_length)
with torch.no_grad():
out = self.bert(**inputs)
return out.last_hidden_state[:, 0, :].squeeze().numpy() # [CLS] token
def _featurize(self, a: Artifact) -> np.ndarray:
embed = self._bert_embed(a.raw[:512])
numeric = np.array([
float(a.features.get("size", 0)) / 1e6,
float(bool(a.features.get("is_deleted", False))),
float(bool(a.features.get("anomaly", False))),
float(bool(a.features.get("is_financial", False))),
float(bool(a.features.get("possible_timestomp", False))),
float(a.features.get("visit_count", 0)) / 100,
], dtype=np.float32)
return np.concatenate([embed, numeric])
def _heuristic_score(self, a: Artifact) -> float:
"""Rule-based fallback when no labelled data is available."""
score = 0.1
f = a.features
if f.get("is_deleted"): score += 0.2
if f.get("anomaly"): score += 0.35
if f.get("possible_timestomp"): score += 0.3
if f.get("is_financial"): score += 0.15
if f.get("type") == "persistence": score += 0.4
if f.get("type") == "usb_mount": score += 0.2
return min(score, 1.0)