diff --git a/ntad/__init__.py b/ntad/__init__.py new file mode 100644 index 0000000..1fbc595 --- /dev/null +++ b/ntad/__init__.py @@ -0,0 +1,4 @@ +""" +NTAD - Network Threat Attack Discovery +MITRE ATT&CK 知识图谱查询模块 +""" diff --git a/ntad/answer_generator.py b/ntad/answer_generator.py new file mode 100644 index 0000000..2c0e0f2 --- /dev/null +++ b/ntad/answer_generator.py @@ -0,0 +1,321 @@ +""" +LLM 回答生成模块 +将 BM25 检索结果 + Neo4j 图查询结果整合,用千问 API 生成自然语言回答 +""" + +import os + +from dotenv import load_dotenv +from openai import OpenAI + +load_dotenv() + +# 系统提示词 +SYSTEM_PROMPT = """你是 **NTAD 智能攻防问答系统**,一个基于 MITRE ATT&CK 知识图谱的智能安全问答助手。 + +## 自我介绍(当用户问"你是谁"、"你是什么"、"介绍一下"等问题时): + +我是 NTAD 智能攻防问答系统,基于全球最权威的 MITRE ATT&CK 网络威胁知识库构建。系统将 25,000+ 个安全实体(攻击技术、威胁组织、恶意软件、防御措施等)构建为知识图谱,存储在 Neo4j 图数据库中,支持自然语言提问和智能检索。我能帮您查询攻击技术详情、追踪威胁组织活动、获取防御建议,实现从问题到结构化回答的端到端智能分析。 + +## 回答规则: + +1. **引用来源**:回答中必须引用相关的 attack_id(如 T1059、TA0002) +2. **结构化输出**:使用清晰的层次结构(标题、列表、表格) +3. **专业但易懂**:技术术语要解释,让非安全专业人员也能理解 +4. **实用建议**:在解释攻击技术的同时,提供防御建议 +5. **承认不确定性**:如果检索结果不够充分,明确告知用户 + +## 回答格式: + +- 使用 Markdown 格式 +- 包含"相关技术"或"相关战术"部分,列出 attack_id 和名称 +- 如果有防御建议,单独列出 +- 末尾附上 MITRE ATT&CK 参考链接 + +## 注意事项: +- 不要编造不在检索结果中的技术 +- 如果置信度较低,在回答开头说明 +- 中文回答,专业术语保留英文""" + + +def _get_client() -> tuple: + """获取千问 API 客户端""" + api_key = os.getenv("QWEN_API_KEY") + base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + model = os.getenv("QWEN_MODEL", "qwen-plus") + + if not api_key or api_key == "your_api_key_here": + raise ValueError("请在 .env 文件中配置 QWEN_API_KEY") + + return OpenAI(api_key=api_key, base_url=base_url), model + + +def _format_search_context(search_results: list, graph_data: dict = None) -> str: + """ + 将检索结果格式化为 LLM 上下文(带去重和 token 预算) + + Args: + search_results: BM25 检索结果 + graph_data: Neo4j 图查询结果 + + Returns: + 格式化的上下文文本 + """ + # 全局 token 预算(约 6000 中文字符) + max_chars = 6000 + current_chars = 0 + context_parts = [] + + def _can_append(text: str) -> bool: + """检查是否超出预算""" + nonlocal current_chars + if current_chars + len(text) > max_chars: + return False + current_chars += len(text) + return True + + # 收集图查询结果中的 attack_id,用于 BM25 去重 + graph_ids = set() + if graph_data: + for key in ["techniques", "tactics", "groups", "software", "mitigations", "campaigns"]: + for item in graph_data.get(key, []): + graph_ids.add(item["attack_id"]) + + # BM25 检索结果(排除已在图结果中的实体) + if search_results: + filtered = [r for r in search_results[:5] if r["attack_id"] not in graph_ids] + if filtered: + header = "## BM25 检索结果\n" + if _can_append(header): + context_parts.append(header) + for i, r in enumerate(filtered, 1): + entry = ( + f"{i}. **{r['attack_id']}** - {r['name']} (置信度: {r['score']:.2f})\n" + f" 类型: {r['type']}\n" + f" 描述: {r['snippet']}\n" + f" 链接: {r['url']}\n" + ) + if not _can_append(entry): + break + context_parts.append(entry) + + # Neo4j 图查询结果 + if graph_data: + # 技术详情 + if graph_data.get("techniques"): + header = "\n## 技术详情\n" + if _can_append(header): + context_parts.append(header) + for tech in graph_data["techniques"][:15]: + entry = ( + f"### {tech['attack_id']} - {tech['name']}\n" + f"- 平台: {', '.join(tech.get('platforms', []))}\n" + f"- 描述: {tech.get('description', '')[:200]}\n" + ) + if not _can_append(entry): + break + context_parts.append(entry) + + # 关联战术 + if graph_data.get("tactics"): + header = "\n## 关联战术\n" + if _can_append(header): + context_parts.append(header) + for tac in graph_data["tactics"][:10]: + entry = f"- **{tac['attack_id']}** - {tac['name']}\n" + if not _can_append(entry): + break + context_parts.append(entry) + + # 关联组织 + if graph_data.get("groups"): + header = "\n## 关联威胁组织\n" + if _can_append(header): + context_parts.append(header) + for grp in graph_data["groups"][:10]: + aliases = ", ".join(grp.get("aliases", [])[:3]) + entry = ( + f"- **{grp['attack_id']}** - {grp['name']}" + f"{f' (别名: {aliases})' if aliases else ''}\n" + ) + if not _can_append(entry): + break + context_parts.append(entry) + + # 关联软件 + if graph_data.get("software"): + header = "\n## 关联软件\n" + if _can_append(header): + context_parts.append(header) + for sw in graph_data["software"][:15]: + sw_type = "恶意软件" if sw.get("type") == "malware" else "工具" + entry = ( + f"- **{sw['attack_id']}** - {sw['name']} [{sw_type}]\n" + f" 描述: {sw.get('description', '')[:150]}\n" + ) + if not _can_append(entry): + break + context_parts.append(entry) + + # 缓解措施 + if graph_data.get("mitigations"): + header = "\n## 缓解措施\n" + if _can_append(header): + context_parts.append(header) + for mit in graph_data["mitigations"][:10]: + entry = ( + f"- **{mit['attack_id']}** - {mit['name']}\n" + f" {mit.get('description', '')[:150]}\n" + ) + if not _can_append(entry): + break + context_parts.append(entry) + + # 攻击活动 + if graph_data.get("campaigns"): + header = "\n## 关联攻击活动\n" + if _can_append(header): + context_parts.append(header) + for camp in graph_data["campaigns"][:10]: + entry = f"- **{camp['attack_id']}** - {camp['name']}\n" + if not _can_append(entry): + break + context_parts.append(entry) + + # 关系(优先级最低,放在最后) + if graph_data.get("relationships"): + header = "\n## 关系\n" + if _can_append(header): + context_parts.append(header) + for rel in graph_data["relationships"][:30]: + entry = f"- {rel['from']} → [{rel['type']}] → {rel['to']} ({rel.get('to_name', '')})\n" + if not _can_append(entry): + break + context_parts.append(entry) + total = len(graph_data["relationships"]) + if total > 30: + overflow = f"\n(共 {total} 条关系,仅展示前 30 条)\n" + if _can_append(overflow): + context_parts.append(overflow) + + return "\n".join(context_parts) + + +def _build_messages(user_question: str, search_results: list, graph_data: dict = None, history: list = None) -> list: + """构建 LLM 消息列表""" + context = _format_search_context(search_results, graph_data) + messages = [{"role": "system", "content": SYSTEM_PROMPT}] + if history: + messages.extend(history[-6:]) + user_message = f"""用户问题: {user_question} + +以下是检索到的相关信息: + +{context} + +请基于以上信息回答用户的问题。""" + messages.append({"role": "user", "content": user_message}) + return messages + + +def generate_answer( + user_question: str, + search_results: list, + graph_data: dict = None, + history: list = None, +) -> str: + """ + 整合检索结果,生成最终回答 + + Args: + user_question: 用户原始问题 + search_results: BM25 检索结果 + graph_data: Neo4j 图查询结果 + history: 对话历史 + + Returns: + 自然语言回答 + """ + client, model = _get_client() + messages = _build_messages(user_question, search_results, graph_data, history) + + try: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0.3, + max_tokens=2000, + ) + return response.choices[0].message.content.strip() + + except Exception as e: + fallback = f"⚠️ LLM 回答生成失败: {e}\n\n以下是检索结果摘要:\n\n" + for r in search_results[:3]: + fallback += f"- **{r['attack_id']}** - {r['name']} (置信度: {r['score']:.2f})\n" + fallback += f" {r['snippet'][:100]}...\n\n" + return fallback + + +def generate_answer_stream( + user_question: str, + search_results: list, + graph_data: dict = None, + history: list = None, +): + """ + 流式生成回答(逐 token 返回) + + Yields: + 每次产出的文本片段 + """ + client, model = _get_client() + messages = _build_messages(user_question, search_results, graph_data, history) + + try: + stream = client.chat.completions.create( + model=model, + messages=messages, + temperature=0.3, + max_tokens=2000, + stream=True, + ) + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + except Exception as e: + yield f"⚠️ LLM 回答生成失败: {e}\n\n以下是检索结果摘要:\n\n" + for r in search_results[:3]: + yield f"- **{r['attack_id']}** - {r['name']} (置信度: {r['score']:.2f})\n" + yield f" {r['snippet'][:100]}...\n\n" + + +def generate_low_confidence_answer(user_question: str, search_results: list) -> str: + """ + 低置信度时的回答 + + 当 BM25 检索置信度 < 0.7 时,返回检索结果并提示信息不足 + """ + if not search_results: + return ( + "抱歉,我没有找到与您问题相关的 MITRE ATT&CK 技术信息。\n\n" + "建议您:\n" + "1. 尝试使用更具体的关键词(如具体的攻击技术名称或 attack_id)\n" + "2. 参考 MITRE ATT&CK 官网: https://attack.mitre.org/\n" + "3. 描述您想了解的攻击场景或技术特点" + ) + + answer = "🔍 检索结果置信度较低,以下是最接近的信息:\n\n" + + for i, r in enumerate(search_results[:5], 1): + answer += f"**{i}. {r['attack_id']}** - {r['name']} (置信度: {r['score']:.2f})\n" + answer += f" {r['snippet'][:150]}...\n\n" + + answer += ( + "\n💡 以上结果可能不完全匹配您的问题。建议:\n" + "1. 提供更具体的关键词\n" + "2. 明确您想了解的攻击技术或战术阶段\n" + "3. 使用 attack_id(如 T1059)进行精确查询" + ) + + return answer diff --git a/ntad/bm25_search.py b/ntad/bm25_search.py new file mode 100644 index 0000000..e8c84b8 --- /dev/null +++ b/ntad/bm25_search.py @@ -0,0 +1,353 @@ +""" +BM25 关键词检索模块 +对 enterprise-attack.json 建立 BM25 索引,支持关键词检索 +""" + +import json +import os +import pickle +import re +from typing import Optional + +import jieba +from rank_bm25 import BM25Okapi + +# 路径配置 +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +DATA_DIR = os.path.join(BASE_DIR, "data") +STIX_FILE = os.path.join(DATA_DIR, "enterprise-attack.json") +INDEX_FILE = os.path.join(DATA_DIR, "bm25_index.pkl") + + +def _is_revoked_or_deprecated(obj: dict) -> bool: + """检查对象是否已撤销或废弃""" + return obj.get("revoked", False) or obj.get("x_mitre_deprecated", False) + + +def _extract_external_id(obj: dict) -> Optional[str]: + """从对象的 external_references 中提取 MITRE ID""" + refs = obj.get("external_references", []) + if refs and refs[0].get("source_name") == "mitre-attack": + return refs[0].get("external_id") + return None + + +def _extract_url(obj: dict) -> Optional[str]: + """从对象的 external_references 中提取 MITRE URL""" + refs = obj.get("external_references", []) + if refs and refs[0].get("source_name") == "mitre-attack": + return refs[0].get("url") + return None + + +def _tokenize(text: str) -> list: + """ + 中英文混合分词 + - 英文按空格和标点分割,转小写 + - 中文用 jieba 分词 + """ + # 英文部分:按非字母数字分割 + text = text.lower() + tokens = [] + + # 分离英文和中文 + parts = re.findall(r'[a-z0-9]+|[一-鿿]+', text) + for part in parts: + if re.match(r'[a-z0-9]+', part): + # 英文直接作为 token + tokens.append(part) + else: + # 中文用 jieba 分词 + tokens.extend(jieba.lcut(part)) + + return [t for t in tokens if len(t) > 1] # 过滤单字符 + + +def rrf_fusion(multi_results: list, k: int = 60) -> list: + """ + Reciprocal Rank Fusion (RRF) 多路结果融合 + + Args: + multi_results: 多路搜索结果列表,每路是 search() 的返回格式 + k: RRF 参数,控制排名靠后的结果权重衰减速度 + + Returns: + 融合后的结果列表,按 RRF 分数降序,score 字段更新为 RRF 分数 + """ + # attack_id → (最佳结果, RRF 累计分数) + score_map = {} + result_map = {} + + for results in multi_results: + for rank, r in enumerate(results): + aid = r["attack_id"] + rrf_score = 1.0 / (k + rank + 1) + if aid in score_map: + score_map[aid] += rrf_score + else: + score_map[aid] = rrf_score + result_map[aid] = r + + # 按 RRF 分数降序排列 + sorted_ids = sorted(score_map.keys(), key=lambda x: score_map[x], reverse=True) + + fused = [] + for aid in sorted_ids: + r = result_map[aid].copy() + r["score"] = round(score_map[aid], 4) + fused.append(r) + + return fused + + +def _preprocess_objects(data: dict) -> list: + """ + 预处理 STIX 对象,提取用于检索的字段 + 返回文档列表,每个文档包含: + - attack_id: MITRE ID (如 T1059) + - name: 名称 + - description: 描述(截取前1500字符) + - type: 实体类型 + - platforms: 平台列表 + - url: MITRE URL + - search_text: 用于检索的拼接文本(含关联实体名称) + """ + documents = [] + + # STIX type -> doc type 映射 + type_map = { + "attack-pattern": "technique", + "x-mitre-tactic": "tactic", + "intrusion-set": "group", + "malware": "malware", + "tool": "tool", + "course-of-action": "mitigation", + "campaign": "campaign", + } + + # 第一遍:构建 STIX ID → (attack_id, name) 查找表 + id_lookup = {} + for obj in data.get("objects", []): + if _is_revoked_or_deprecated(obj): + continue + obj_type = obj.get("type", "") + if obj_type not in type_map: + continue + attack_id = _extract_external_id(obj) + if not attack_id: + continue + id_lookup[obj["id"]] = (attack_id, obj.get("name", "")) + + # 第二遍:从 STIX 关系中提取关联实体名称 + related_names = {} # attack_id → set of related entity names + for obj in data.get("objects", []): + if obj.get("type") != "relationship": + continue + src = id_lookup.get(obj.get("source_ref")) + tgt = id_lookup.get(obj.get("target_ref")) + if src and tgt: + # 给源实体添加目标实体名称 + related_names.setdefault(src[0], set()).add(tgt[1]) + # 给目标实体添加源实体名称 + related_names.setdefault(tgt[0], set()).add(src[1]) + + # 第三遍:构建文档 + for obj in data.get("objects", []): + if _is_revoked_or_deprecated(obj): + continue + + obj_type = obj.get("type", "") + if obj_type not in type_map: + continue + + attack_id = _extract_external_id(obj) + if not attack_id: + continue + + name = obj.get("name", "") + description = obj.get("description", "")[:1500] + url = _extract_url(obj) or "" + doc_type = type_map[obj_type] + platforms = obj.get("x_mitre_platforms", []) + aliases = obj.get("aliases", []) or obj.get("x_mitre_aliases", []) + + # 拼接检索文本 + search_text = f"{attack_id} {name} {description}" + if platforms: + search_text += " " + " ".join(platforms) + if aliases: + search_text += " " + " ".join(aliases) + + # 添加关联实体名称(限制数量避免膨胀) + related = related_names.get(attack_id, set()) + if related: + search_text += " " + " ".join(list(related)[:30]) + + documents.append({ + "attack_id": attack_id, + "name": name, + "description": description, + "type": doc_type, + "platforms": platforms, + "url": url, + "search_text": search_text, + }) + + return documents + + +class BM25Search: + """BM25 检索引擎""" + + def __init__(self): + self.documents: list = [] + self.corpus_tokens: list = [] + self.bm25: Optional[BM25Okapi] = None + + def build_index(self, data_path: str = STIX_FILE) -> None: + """ + 从 enterprise-attack.json 构建 BM25 索引 + """ + if not os.path.exists(data_path): + raise FileNotFoundError(f"数据文件不存在: {data_path}") + + print("正在构建 BM25 索引...") + with open(data_path, "r", encoding="utf-8") as f: + data = json.load(f) + + # 预处理文档 + self.documents = _preprocess_objects(data) + print(f" 文档数量: {len(self.documents)}") + + # 分词 + print(" 正在分词...") + self.corpus_tokens = [_tokenize(doc["search_text"]) for doc in self.documents] + + # 构建 BM25 索引 + self.bm25 = BM25Okapi(self.corpus_tokens) + print(" ✅ BM25 索引构建完成") + + def save_index(self, index_path: str = INDEX_FILE) -> None: + """持久化索引到文件""" + os.makedirs(os.path.dirname(index_path), exist_ok=True) + with open(index_path, "wb") as f: + pickle.dump({ + "documents": self.documents, + "corpus_tokens": self.corpus_tokens, + }, f) + print(f" 索引已保存: {index_path}") + + def load_index(self, index_path: str = INDEX_FILE) -> bool: + """从文件加载索引,返回是否成功""" + if not os.path.exists(index_path): + return False + try: + with open(index_path, "rb") as f: + data = pickle.load(f) + self.documents = data["documents"] + self.corpus_tokens = data["corpus_tokens"] + self.bm25 = BM25Okapi(self.corpus_tokens) + print(f"✅ BM25 索引已加载 ({len(self.documents)} 个文档)") + return True + except Exception as e: + print(f"⚠️ 索引加载失败: {e}") + return False + + def search(self, query: str, top_k: int = 10) -> list: + """ + BM25 关键词检索 + + Args: + query: 用户查询文本 + top_k: 返回前 N 个结果 + + Returns: + 结果列表,每个元素包含: + - attack_id, name, description, type, platforms, url + - score: 归一化后的置信度 [0, 1] + - snippet: 描述片段 + """ + if not self.bm25: + raise RuntimeError("索引未初始化,请先调用 build_index() 或 load_index()") + + # 分词 + tokenized_query = _tokenize(query) + if not tokenized_query: + return [] + + # 计算 BM25 原始分数 + scores = self.bm25.get_scores(tokenized_query) + + # 归一化分数到 [0, 1](用于结果排序,不用于置信度判断) + max_score = max(scores) if max(scores) > 0 else 1.0 + normalized_scores = scores / max_score + + # 获取 top_k 结果(按原始分数降序) + top_indices = sorted( + range(len(scores)), + key=lambda i: scores[i], + reverse=True + )[:top_k] + + results = [] + for idx in top_indices: + if normalized_scores[idx] < 0.01: # 过滤极低分 + continue + doc = self.documents[idx] + results.append({ + "attack_id": doc["attack_id"], + "name": doc["name"], + "description": doc["description"], + "type": doc["type"], + "platforms": doc["platforms"], + "url": doc["url"], + "score": round(float(normalized_scores[idx]), 4), + "snippet": doc["description"][:200] + "..." if len(doc["description"]) > 200 else doc["description"], + }) + + return results + + def get_index_stats(self) -> dict: + """返回索引统计信息""" + counts = {} + for d in self.documents: + counts[d["type"]] = counts.get(d["type"], 0) + 1 + return {"total_documents": len(self.documents), **counts} + + +def build_and_save(data_path: str = STIX_FILE, index_path: str = INDEX_FILE): + """构建索引并保存(供命令行调用)""" + engine = BM25Search() + engine.build_index(data_path) + engine.save_index(index_path) + stats = engine.get_index_stats() + print(f"\n索引统计: {stats}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="BM25 检索索引管理") + parser.add_argument("action", choices=["build", "search"], + help="build=构建索引, search=测试搜索") + parser.add_argument("--query", default="", help="搜索关键词(search 模式)") + parser.add_argument("--top-k", type=int, default=5, help="返回结果数量") + args = parser.parse_args() + + if args.action == "build": + build_and_save() + elif args.action == "search": + engine = BM25Search() + if not engine.load_index(): + print("索引不存在,正在构建...") + engine.build_index() + engine.save_index() + + if args.query: + results = engine.search(args.query, top_k=args.top_k) + print(f"\n搜索: {args.query}") + print(f"结果 ({len(results)} 条):\n") + for r in results: + print(f" [{r['score']:.4f}] {r['attack_id']} - {r['name']} ({r['type']})") + print(f" {r['snippet'][:100]}...") + print() diff --git a/ntad/data_fetcher.py b/ntad/data_fetcher.py new file mode 100644 index 0000000..db0843b --- /dev/null +++ b/ntad/data_fetcher.py @@ -0,0 +1,63 @@ +""" +MITRE ATT&CK 数据获取模块 +从 MITRE 官方 GitHub 下载 Enterprise ATT&CK 的 STIX 2.0 数据文件 +""" + +import os +import time +import requests + +# 数据源 URL +ENTERPRISE_ATTACK_URL = "https://raw.githubusercontent.com/mitre/cti/master/enterprise-attack/enterprise-attack.json" + +# 缓存配置 +DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") +CACHE_FILE = os.path.join(DATA_DIR, "enterprise-attack.json") +CACHE_MAX_AGE = 7 * 24 * 3600 # 7 天(秒) + + +def _is_cache_valid() -> bool: + """检查缓存文件是否存在且未过期""" + if not os.path.exists(CACHE_FILE): + return False + mtime = os.path.getmtime(CACHE_FILE) + age = time.time() - mtime + return age < CACHE_MAX_AGE + + +def fetch_enterprise_attack() -> dict: + """ + 获取 Enterprise ATT&CK STIX 数据 + + 优先使用本地缓存(7 天内有效),否则从 GitHub 下载最新版本。 + 返回解析后的 JSON 字典。 + """ + if _is_cache_valid(): + print("使用缓存数据") + with open(CACHE_FILE, "r", encoding="utf-8") as f: + return __import__("json").load(f) + + print("正在从 MITRE GitHub 下载最新数据...") + try: + response = requests.get(ENTERPRISE_ATTACK_URL, timeout=60) + response.raise_for_status() + except requests.exceptions.Timeout: + raise RuntimeError("网络请求超时,请检查网络连接后重试") + except requests.exceptions.RequestException as e: + raise RuntimeError(f"网络请求失败: {e}") + + os.makedirs(DATA_DIR, exist_ok=True) + try: + with open(CACHE_FILE, "w", encoding="utf-8") as f: + f.write(response.text) + except PermissionError: + raise RuntimeError(f"没有写入权限: {CACHE_FILE}") + + print("下载完成,已缓存到本地") + return response.json() + + +if __name__ == "__main__": + data = fetch_enterprise_attack() + print(f"数据类型: {data.get('type', '未知')}") + print(f"攻击对象数量: {len(data.get('objects', []))}") diff --git a/ntad/graph_builder.py b/ntad/graph_builder.py new file mode 100644 index 0000000..1530de6 --- /dev/null +++ b/ntad/graph_builder.py @@ -0,0 +1,577 @@ +""" +MITRE ATT&CK 知识图谱构建模块 +将 Enterprise ATT&CK 数据导入 Neo4j 图数据库 +""" + +import json +import os +import sys +from typing import Optional + +from neo4j import GraphDatabase + +# 默认连接配置 +DEFAULT_URI = "bolt://localhost:7687" +DEFAULT_USER = "neo4j" + +# 数据文件路径 +DATA_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data") +STIX_FILE = os.path.join(DATA_DIR, "enterprise-attack.json") + + +def _extract_external_id(obj: dict) -> Optional[str]: + """从对象的 external_references 中提取 MITRE ID(如 T1059、TA0001)""" + refs = obj.get("external_references", []) + if refs and refs[0].get("source_name") == "mitre-attack": + return refs[0].get("external_id") + return None + + +def _extract_url(obj: dict) -> Optional[str]: + """从对象的 external_references 中提取 MITRE URL""" + refs = obj.get("external_references", []) + if refs and refs[0].get("source_name") == "mitre-attack": + return refs[0].get("url") + return None + + +def _is_revoked_or_deprecated(obj: dict) -> bool: + """检查对象是否已撤销或废弃""" + return obj.get("revoked", False) or obj.get("x_mitre_deprecated", False) + + +class GraphBuilder: + """Neo4j 图数据库构建器""" + + def __init__(self, uri: str = DEFAULT_URI, user: str = DEFAULT_USER, password: str = ""): + if not password: + password = os.environ.get("NEO4J_PASSWORD", "") + if not password: + raise ValueError("请提供 Neo4j 密码(参数或 NEO4J_PASSWORD 环境变量)") + + try: + self.driver = GraphDatabase.driver(uri, auth=(user, password)) + self.driver.verify_connectivity() + print(f"✅ 已连接 Neo4j: {uri}") + except Exception as e: + raise ConnectionError(f"Neo4j 连接失败: {e}") + + def close(self): + if hasattr(self, "driver"): + self.driver.close() + + def _create_constraints(self): + """为所有节点类型创建唯一约束""" + constraints = [ + ("Technique", "attack_id"), + ("Tactic", "attack_id"), + ("Group", "attack_id"), + ("Malware", "attack_id"), + ("Tool", "attack_id"), + ("Mitigation", "attack_id"), + ("Campaign", "attack_id"), + ] + with self.driver.session() as session: + for label, prop in constraints: + query = f"CREATE CONSTRAINT IF NOT EXISTS FOR (n:{label}) REQUIRE n.{prop} IS UNIQUE" + session.run(query) + print(f" 创建约束: {label}.{prop}") + + def _parse_attack_data(self) -> dict: + """ + 解析 MITRE ATT&CK STIX 数据 + 返回所有节点类型和关系数据 + """ + if not os.path.exists(STIX_FILE): + raise FileNotFoundError(f"数据文件不存在: {STIX_FILE}\n请先运行 utils/data_fetcher.py 下载数据") + + print("正在解析 MITRE ATT&CK 数据...") + with open(STIX_FILE, "r", encoding="utf-8") as f: + data = json.load(f) + + techniques = [] + tactics = [] + groups = [] + malware_list = [] + tools = [] + mitigations = [] + campaigns = [] + relationships = [] + + # stix_id -> attack_id 映射,用于解析 relationship + stix_to_attack = {} + + for obj in data.get("objects", []): + obj_type = obj.get("type", "") + stix_id = obj.get("id", "") + + # relationship 单独收集(不排除 deprecated,但排除 revoked) + if obj_type == "relationship": + if not obj.get("revoked", False): + relationships.append(obj) + continue + + if _is_revoked_or_deprecated(obj): + continue + + attack_id = _extract_external_id(obj) + + if obj_type == "attack-pattern": + techniques.append(obj) + if attack_id: + stix_to_attack[stix_id] = attack_id + elif obj_type == "x-mitre-tactic": + tactics.append(obj) + if attack_id: + stix_to_attack[stix_id] = attack_id + elif obj_type == "intrusion-set": + groups.append(obj) + if attack_id: + stix_to_attack[stix_id] = attack_id + elif obj_type == "malware": + malware_list.append(obj) + if attack_id: + stix_to_attack[stix_id] = attack_id + elif obj_type == "tool": + tools.append(obj) + if attack_id: + stix_to_attack[stix_id] = attack_id + elif obj_type == "course-of-action": + mitigations.append(obj) + if attack_id: + stix_to_attack[stix_id] = attack_id + elif obj_type == "campaign": + campaigns.append(obj) + if attack_id: + stix_to_attack[stix_id] = attack_id + + print(f" 技术: {len(techniques)}") + print(f" 战术: {len(tactics)}") + print(f" 威胁组织: {len(groups)}") + print(f" 恶意软件: {len(malware_list)}") + print(f" 工具: {len(tools)}") + print(f" 缓解措施: {len(mitigations)}") + print(f" 攻击活动: {len(campaigns)}") + print(f" 关系: {len(relationships)}") + + return { + "techniques": techniques, + "tactics": tactics, + "groups": groups, + "malware": malware_list, + "tools": tools, + "mitigations": mitigations, + "campaigns": campaigns, + "relationships": relationships, + "stix_to_attack": stix_to_attack, + } + + def _import_tactics(self, tactics: list) -> int: + """批量导入战术节点""" + batch = [] + for t in tactics: + attack_id = _extract_external_id(t) + if not attack_id: + continue + batch.append({ + "attack_id": attack_id, + "name": t.get("name", ""), + "description": t.get("description", "")[:500], + "url": _extract_url(t) or "", + "stix_id": t.get("id", ""), + "shortname": t.get("x_mitre_shortname", ""), + }) + + query = """ + UNWIND $batch AS item + MERGE (t:Tactic {attack_id: item.attack_id}) + SET t.name = item.name, + t.description = item.description, + t.url = item.url, + t.stix_id = item.stix_id, + t.shortname = item.shortname + """ + + with self.driver.session() as session: + session.run(query, batch=batch) + + return len(batch) + + def _import_techniques(self, techniques: list) -> int: + """批量导入技术节点""" + batch = [] + for t in techniques: + attack_id = _extract_external_id(t) + if not attack_id: + continue + batch.append({ + "attack_id": attack_id, + "name": t.get("name", ""), + "description": t.get("description", "")[:500], + "url": _extract_url(t) or "", + "stix_id": t.get("id", ""), + "platforms": t.get("x_mitre_platforms", []), + "is_subtechnique": t.get("x_mitre_is_subtechnique", False), + }) + + query = """ + UNWIND $batch AS item + MERGE (t:Technique {attack_id: item.attack_id}) + SET t.name = item.name, + t.description = item.description, + t.url = item.url, + t.stix_id = item.stix_id, + t.platforms = item.platforms, + t.is_subtechnique = item.is_subtechnique + """ + + with self.driver.session() as session: + session.run(query, batch=batch) + + return len(batch) + + def _import_groups(self, groups: list) -> int: + """批量导入威胁组织节点""" + batch = [] + for g in groups: + attack_id = _extract_external_id(g) + if not attack_id: + continue + batch.append({ + "attack_id": attack_id, + "name": g.get("name", ""), + "description": g.get("description", "")[:500], + "url": _extract_url(g) or "", + "stix_id": g.get("id", ""), + "aliases": g.get("aliases", []), + }) + + query = """ + UNWIND $batch AS item + MERGE (g:Group {attack_id: item.attack_id}) + SET g.name = item.name, + g.description = item.description, + g.url = item.url, + g.stix_id = item.stix_id, + g.aliases = item.aliases + """ + + with self.driver.session() as session: + session.run(query, batch=batch) + + return len(batch) + + def _import_malware(self, malware_list: list) -> int: + """批量导入恶意软件节点""" + batch = [] + for m in malware_list: + attack_id = _extract_external_id(m) + if not attack_id: + continue + batch.append({ + "attack_id": attack_id, + "name": m.get("name", ""), + "description": m.get("description", "")[:500], + "url": _extract_url(m) or "", + "stix_id": m.get("id", ""), + "platforms": m.get("x_mitre_platforms", []), + "aliases": m.get("x_mitre_aliases", []), + }) + + query = """ + UNWIND $batch AS item + MERGE (m:Malware {attack_id: item.attack_id}) + SET m.name = item.name, + m.description = item.description, + m.url = item.url, + m.stix_id = item.stix_id, + m.platforms = item.platforms, + m.aliases = item.aliases + """ + + with self.driver.session() as session: + session.run(query, batch=batch) + + return len(batch) + + def _import_tools(self, tools: list) -> int: + """批量导入工具节点""" + batch = [] + for t in tools: + attack_id = _extract_external_id(t) + if not attack_id: + continue + batch.append({ + "attack_id": attack_id, + "name": t.get("name", ""), + "description": t.get("description", "")[:500], + "url": _extract_url(t) or "", + "stix_id": t.get("id", ""), + "platforms": t.get("x_mitre_platforms", []), + "aliases": t.get("x_mitre_aliases", []), + }) + + query = """ + UNWIND $batch AS item + MERGE (t:Tool {attack_id: item.attack_id}) + SET t.name = item.name, + t.description = item.description, + t.url = item.url, + t.stix_id = item.stix_id, + t.platforms = item.platforms, + t.aliases = item.aliases + """ + + with self.driver.session() as session: + session.run(query, batch=batch) + + return len(batch) + + def _import_mitigations(self, mitigations: list) -> int: + """批量导入缓解措施节点""" + batch = [] + for m in mitigations: + attack_id = _extract_external_id(m) + if not attack_id: + continue + batch.append({ + "attack_id": attack_id, + "name": m.get("name", ""), + "description": m.get("description", "")[:500], + "url": _extract_url(m) or "", + "stix_id": m.get("id", ""), + }) + + query = """ + UNWIND $batch AS item + MERGE (m:Mitigation {attack_id: item.attack_id}) + SET m.name = item.name, + m.description = item.description, + m.url = item.url, + m.stix_id = item.stix_id + """ + + with self.driver.session() as session: + session.run(query, batch=batch) + + return len(batch) + + def _import_campaigns(self, campaigns: list) -> int: + """批量导入攻击活动节点""" + batch = [] + for c in campaigns: + attack_id = _extract_external_id(c) + if not attack_id: + continue + batch.append({ + "attack_id": attack_id, + "name": c.get("name", ""), + "description": c.get("description", "")[:500], + "url": _extract_url(c) or "", + "stix_id": c.get("id", ""), + "first_seen": c.get("first_seen", ""), + "last_seen": c.get("last_seen", ""), + "aliases": c.get("aliases", []), + }) + + query = """ + UNWIND $batch AS item + MERGE (c:Campaign {attack_id: item.attack_id}) + SET c.name = item.name, + c.description = item.description, + c.url = item.url, + c.stix_id = item.stix_id, + c.first_seen = item.first_seen, + c.last_seen = item.last_seen, + c.aliases = item.aliases + """ + + with self.driver.session() as session: + session.run(query, batch=batch) + + return len(batch) + + def _import_technique_tactic_relations(self, techniques: list) -> int: + """导入技术-战术关系:(Technique)-[:BELONGS_TO]->(Tactic)""" + relations = [] + for t in techniques: + tech_id = _extract_external_id(t) + if not tech_id: + continue + for phase in t.get("kill_chain_phases", []): + if phase.get("kill_chain_name") == "mitre-attack": + tactic_shortname = phase.get("phase_name", "") + relations.append({ + "tech_id": tech_id, + "tactic_shortname": tactic_shortname, + }) + + query = """ + UNWIND $batch AS item + MATCH (tech:Technique {attack_id: item.tech_id}) + MATCH (tac:Tactic {shortname: item.tactic_shortname}) + MERGE (tech)-[:BELONGS_TO]->(tac) + """ + + with self.driver.session() as session: + session.run(query, batch=relations) + + return len(relations) + + def _import_stix_relationships(self, relationships: list, stix_to_attack: dict) -> dict: + """ + 导入 STIX 关系对象 + 将 source_ref/target_ref (STIX ID) 映射为 attack_id,创建 Neo4j 关系 + """ + # 要导入的关系类型及其对应的 Neo4j 关系名 + rel_type_map = { + "uses": "USES", + "mitigates": "MITIGATES", + "subtechnique-of": "SUBTECHNIQUE_OF", + "attributed-to": "ATTRIBUTED_TO", + "detects": "DETECTS", + } + + # 按关系类型分组 + grouped = {} + skipped = 0 + for rel in relationships: + rel_type = rel.get("relationship_type", "") + if rel_type not in rel_type_map: + continue + source_aid = stix_to_attack.get(rel.get("source_ref", "")) + target_aid = stix_to_attack.get(rel.get("target_ref", "")) + if not source_aid or not target_aid: + skipped += 1 + continue + neo4j_rel = rel_type_map[rel_type] + grouped.setdefault(neo4j_rel, []).append({ + "source": source_aid, + "target": target_aid, + }) + + # 逐类型导入 + counts = {} + for neo4j_rel, pairs in grouped.items(): + # 分批处理,每批 1000 条 + batch_size = 1000 + total_imported = 0 + for i in range(0, len(pairs), batch_size): + batch = pairs[i:i + batch_size] + # 用 UNWIND + MERGE 导入,source/target 可能是多种节点类型 + # 用通用的 MATCH,不限定节点标签 + query = f""" + UNWIND $batch AS item + MATCH (src {{attack_id: item.source}}) + MATCH (tgt {{attack_id: item.target}}) + MERGE (src)-[r:{neo4j_rel}]->(tgt) + """ + with self.driver.session() as session: + session.run(query, batch=batch) + total_imported += len(batch) + counts[neo4j_rel] = total_imported + print(f" {neo4j_rel}: {total_imported} 条") + + if skipped > 0: + print(f" 跳过 {skipped} 条无法映射的关系") + + return counts + + def build(self): + """ + 主构建流程:解析数据 → 创建约束 → 批量导入节点 → 建立关系 + """ + print("=" * 50) + print("开始构建 MITRE ATT&CK 知识图谱") + print("=" * 50) + + # 1. 解析数据 + data = self._parse_attack_data() + + # 2. 创建约束 + print("\n创建唯一约束...") + self._create_constraints() + + # 3. 导入所有节点 + print("\n--- 导入节点 ---") + + print("\n导入战术节点...") + imported = self._import_tactics(data["tactics"]) + print(f"已导入 {imported}/{len(data['tactics'])} 个战术") + + print("\n导入技术节点...") + imported = self._import_techniques(data["techniques"]) + print(f"已导入 {imported}/{len(data['techniques'])} 个技术") + + print("\n导入威胁组织节点...") + imported = self._import_groups(data["groups"]) + print(f"已导入 {imported}/{len(data['groups'])} 个威胁组织") + + print("\n导入恶意软件节点...") + imported = self._import_malware(data["malware"]) + print(f"已导入 {imported}/{len(data['malware'])} 个恶意软件") + + print("\n导入工具节点...") + imported = self._import_tools(data["tools"]) + print(f"已导入 {imported}/{len(data['tools'])} 个工具") + + print("\n导入缓解措施节点...") + imported = self._import_mitigations(data["mitigations"]) + print(f"已导入 {imported}/{len(data['mitigations'])} 个缓解措施") + + print("\n导入攻击活动节点...") + imported = self._import_campaigns(data["campaigns"]) + print(f"已导入 {imported}/{len(data['campaigns'])} 个攻击活动") + + # 4. 建立关系 + print("\n--- 建立关系 ---") + + print("\n建立技术-战术关系 (BELONGS_TO)...") + rel_count = self._import_technique_tactic_relations(data["techniques"]) + print(f"已创建 {rel_count} 条 BELONGS_TO 关系") + + print("\n导入 STIX 关系 (USES/MITIGATES/SUBTECHNIQUE_OF/ATTRIBUTED_TO/DETECTS)...") + counts = self._import_stix_relationships( + data["relationships"], data["stix_to_attack"] + ) + + # 5. 汇总 + print("\n" + "=" * 50) + print("✅ 知识图谱构建完成!") + total_rels = sum(counts.values()) + rel_count + print(f" 节点: 战术 {len(data['tactics'])} | 技术 {len(data['techniques'])} | " + f"组织 {len(data['groups'])} | 恶意软件 {len(data['malware'])} | " + f"工具 {len(data['tools'])} | 缓解 {len(data['mitigations'])} | " + f"活动 {len(data['campaigns'])}") + print(f" 关系: BELONGS_TO {rel_count} | " + + " | ".join(f"{k} {v}" for k, v in counts.items())) + print(f" 总关系: {total_rels}") + print("=" * 50) + + +def main(): + """命令行入口""" + import argparse + + parser = argparse.ArgumentParser(description="MITRE ATT&CK 知识图谱构建工具") + parser.add_argument("--uri", default=DEFAULT_URI, help="Neo4j 连接地址 (默认: bolt://localhost:7687)") + parser.add_argument("--user", default=DEFAULT_USER, help="Neo4j 用户名 (默认: neo4j)") + parser.add_argument("--password", default="", help="Neo4j 密码(或设置 NEO4J_PASSWORD 环境变量)") + args = parser.parse_args() + + password = args.password or os.environ.get("NEO4J_PASSWORD", "") + + builder = None + try: + builder = GraphBuilder(uri=args.uri, user=args.user, password=password) + builder.build() + except (ConnectionError, FileNotFoundError, ValueError) as e: + print(f"❌ 错误: {e}", file=sys.stderr) + sys.exit(1) + finally: + if builder: + builder.close() + + +if __name__ == "__main__": + main() diff --git a/ntad/graph_query.py b/ntad/graph_query.py new file mode 100644 index 0000000..034a414 --- /dev/null +++ b/ntad/graph_query.py @@ -0,0 +1,892 @@ +""" +Neo4j 图查询模块 +根据 BM25 命中的 attack_id,用 Cypher 查询 Neo4j 获取关联实体和关系 +""" + +import os +from typing import Optional + +from dotenv import load_dotenv +from neo4j import GraphDatabase + +load_dotenv() + + +def _detect_entity_type(attack_id: str) -> str: + """根据 attack_id 前缀判断实体类型""" + if attack_id.startswith("TA"): + return "tactic" + if attack_id.startswith("T"): + return "technique" + if attack_id.startswith("G"): + return "group" + if attack_id.startswith("S"): + return "software" + if attack_id.startswith("M"): + return "mitigation" + if attack_id.startswith("C"): + return "campaign" + return "unknown" + + +class GraphQuery: + """Neo4j 图查询器""" + + def __init__(self, uri: str = "", user: str = "", password: str = ""): + self.uri = uri or os.getenv("NEO4J_URI", "bolt://localhost:7687") + self.user = user or os.getenv("NEO4J_USER", "neo4j") + self.password = password or os.getenv("NEO4J_PASSWORD", "") + + if not self.password: + raise ValueError("请提供 Neo4j 密码(参数或 NEO4J_PASSWORD 环境变量)") + + try: + self.driver = GraphDatabase.driver( + self.uri, auth=(self.user, self.password) + ) + self.driver.verify_connectivity() + print(f"✅ GraphQuery 已连接 Neo4j: {self.uri}") + except Exception as e: + raise ConnectionError(f"Neo4j 连接失败: {e}") + + def close(self): + if hasattr(self, "driver"): + self.driver.close() + + def query_related_entities(self, attack_ids: list) -> dict: + """ + 批量查询实体及其关联实体(多跳) + 支持所有实体类型:Technique, Tactic, Group, Software, Mitigation, Campaign + + Args: + attack_ids: attack_id 列表 + + Returns: + 包含所有相关节点和关系的字典 + """ + if not attack_ids: + return {"techniques": [], "tactics": [], "groups": [], "software": [], + "mitigations": [], "campaigns": [], "relationships": []} + + # 按类型分组 + by_type = {} + for aid in attack_ids: + t = _detect_entity_type(aid) + by_type.setdefault(t, []).append(aid) + + techniques = [] + tactics = [] + groups = [] + software = [] + mitigations = [] + campaigns = [] + relationships = [] + + # 收集需要二跳查询的 ID + tac_ids = set() + group_ids = set() + sw_ids = set() + mit_ids = set() + camp_ids = set() + + with self.driver.session() as session: + # === 查询命中的 Technique 及其关联 === + if by_type.get("technique"): + query = """ + MATCH (t:Technique) WHERE t.attack_id IN $ids + OPTIONAL MATCH (t)-[:BELONGS_TO]->(tac:Tactic) + OPTIONAL MATCH (grp:Group)-[:USES]->(t) + OPTIONAL MATCH (sw)-[:USES]->(t) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (mit:Mitigation)-[:MITIGATES]->(t) + OPTIONAL MATCH (camp:Campaign)-[:USES]->(t) + OPTIONAL MATCH (sub:Technique)-[:SUBTECHNIQUE_OF]->(t) + RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, + t.platforms AS platforms, t.url AS url, 'technique' AS etype, + collect(DISTINCT {id: tac.attack_id, n: tac.name}) AS rel_tactics, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_groups, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: mit.attack_id, n: mit.name}) AS rel_mits, + collect(DISTINCT {id: camp.attack_id, n: camp.name}) AS rel_camps, + collect(DISTINCT {id: sub.attack_id, n: sub.name}) AS rel_subs + """ + for r in session.run(query, ids=by_type["technique"]): + techniques.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "platforms": r["platforms"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_tactics"], "BELONGS_TO", tac_ids) + _add_rels(relationships, r["aid"], r["rel_groups"], "USES", group_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", sw_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_mits"], "MITIGATES", mit_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_camps"], "USES", camp_ids, reverse=True) + # 子技术关系 + for sub in r["rel_subs"]: + if sub["id"]: + relationships.append({"from": sub["id"], "to": r["aid"], + "type": "SUBTECHNIQUE_OF", "to_name": sub["n"]}) + + # === 查询命中的 Group 及其关联 === + if by_type.get("group"): + query = """ + MATCH (g:Group) WHERE g.attack_id IN $ids + OPTIONAL MATCH (g)-[:USES]->(tech:Technique) + OPTIONAL MATCH (g)-[:USES]->(sw) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (camp:Campaign)-[:ATTRIBUTED_TO]->(g) + RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, + g.aliases AS aliases, g.url AS url, 'group' AS etype, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: camp.attack_id, n: camp.name}) AS rel_camps + """ + for r in session.run(query, ids=by_type["group"]): + groups.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + for sw in r["rel_sw"]: + if sw["id"]: + relationships.append({"from": r["aid"], "to": sw["id"], "type": "USES", "to_name": sw["n"]}) + sw_ids.add(sw["id"]) + for camp in r["rel_camps"]: + if camp["id"]: + relationships.append({"from": camp["id"], "to": r["aid"], "type": "ATTRIBUTED_TO", "to_name": camp["n"]}) + camp_ids.add(camp["id"]) + + # === 查询命中的 Software 及其关联 === + if by_type.get("software"): + query = """ + MATCH (sw) WHERE sw.attack_id IN $ids AND (sw:Malware OR sw:Tool) + OPTIONAL MATCH (sw)-[:USES]->(tech:Technique) + OPTIONAL MATCH (grp:Group)-[:USES]->(sw) + RETURN sw.attack_id AS aid, sw.name AS name, sw.description AS desc, + sw.platforms AS platforms, sw.aliases AS aliases, sw.url AS url, + labels(sw) AS labels, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + for r in session.run(query, ids=by_type["software"]): + sw_type = "malware" if "Malware" in r["labels"] else "tool" + software.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], + "platforms": r["platforms"], "aliases": r["aliases"], "url": r["url"], "type": sw_type}) + for tech in r["rel_techs"]: + if tech["id"]: + relationships.append({"from": r["aid"], "to": tech["id"], "type": "USES", "to_name": tech["n"]}) + for grp in r["rel_grps"]: + if grp["id"]: + relationships.append({"from": grp["id"], "to": r["aid"], "type": "USES", "to_name": grp["n"]}) + group_ids.add(grp["id"]) + + # === 查询命中的 Tactic 及其关联 === + if by_type.get("tactic"): + query = """ + MATCH (tac:Tactic) WHERE tac.attack_id IN $ids + OPTIONAL MATCH (tech:Technique)-[:BELONGS_TO]->(tac) + RETURN tac.attack_id AS aid, tac.name AS name, tac.description AS desc, + tac.shortname AS shortname, tac.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(query, ids=by_type["tactic"]): + tactics.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "shortname": r["shortname"], "url": r["url"]}) + for tech in r["rel_techs"]: + if tech["id"]: + relationships.append({"from": tech["id"], "to": r["aid"], "type": "BELONGS_TO", "to_name": r["name"]}) + + # === 查询命中的 Mitigation 及其关联 === + if by_type.get("mitigation"): + query = """ + MATCH (m:Mitigation) WHERE m.attack_id IN $ids + OPTIONAL MATCH (m)-[:MITIGATES]->(tech:Technique) + RETURN m.attack_id AS aid, m.name AS name, m.description AS desc, m.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(query, ids=by_type["mitigation"]): + mitigations.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "url": r["url"]}) + for tech in r["rel_techs"]: + if tech["id"]: + relationships.append({"from": r["aid"], "to": tech["id"], "type": "MITIGATES", "to_name": tech["n"]}) + + # === 查询命中的 Campaign 及其关联 === + if by_type.get("campaign"): + query = """ + MATCH (c:Campaign) WHERE c.attack_id IN $ids + OPTIONAL MATCH (c)-[:USES]->(tech:Technique) + OPTIONAL MATCH (c)-[:USES]->(sw) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (c)-[:ATTRIBUTED_TO]->(grp:Group) + RETURN c.attack_id AS aid, c.name AS name, c.description AS desc, + c.aliases AS aliases, c.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + for r in session.run(query, ids=by_type["campaign"]): + campaigns.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + for tech in r["rel_techs"]: + if tech["id"]: + relationships.append({"from": r["aid"], "to": tech["id"], "type": "USES", "to_name": tech["n"]}) + for sw in r["rel_sw"]: + if sw["id"]: + relationships.append({"from": r["aid"], "to": sw["id"], "type": "USES", "to_name": sw["n"]}) + sw_ids.add(sw["id"]) + for grp in r["rel_grps"]: + if grp["id"]: + relationships.append({"from": r["aid"], "to": grp["id"], "type": "ATTRIBUTED_TO", "to_name": grp["n"]}) + group_ids.add(grp["id"]) + + # === 二跳查询:补充关联实体的详情 === + # 去掉已有实体的 ID + existing_tech_ids = {t["attack_id"] for t in techniques} + existing_group_ids = {g["attack_id"] for g in groups} + existing_sw_ids = {s["attack_id"] for s in software} + existing_mit_ids = {m["attack_id"] for m in mitigations} + existing_camp_ids = {c["attack_id"] for c in campaigns} + + # 补充战术详情 + if tac_ids: + tac_ids -= {t["attack_id"] for t in tactics} + if tac_ids: + for r in session.run("MATCH (t:Tactic) WHERE t.attack_id IN $ids RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, t.shortname AS sn, t.url AS url", ids=list(tac_ids)): + tactics.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], "shortname": r["sn"], "url": r["url"]}) + + # 补充组织详情 + new_group_ids = group_ids - existing_group_ids + if new_group_ids: + for r in session.run("MATCH (g:Group) WHERE g.attack_id IN $ids RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, g.aliases AS aliases, g.url AS url", ids=list(new_group_ids)): + groups.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + + # 补充软件详情 + new_sw_ids = sw_ids - existing_sw_ids + if new_sw_ids: + for r in session.run("MATCH (sw) WHERE sw.attack_id IN $ids AND (sw:Malware OR sw:Tool) RETURN sw.attack_id AS aid, sw.name AS name, sw.description AS desc, sw.platforms AS platforms, sw.aliases AS aliases, sw.url AS url, labels(sw) AS labels", ids=list(new_sw_ids)): + sw_type = "malware" if "Malware" in r["labels"] else "tool" + software.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], + "platforms": r["platforms"], "aliases": r["aliases"], "url": r["url"], "type": sw_type}) + + # 补充缓解措施详情 + new_mit_ids = mit_ids - existing_mit_ids + if new_mit_ids: + for r in session.run("MATCH (m:Mitigation) WHERE m.attack_id IN $ids RETURN m.attack_id AS aid, m.name AS name, m.description AS desc, m.url AS url", ids=list(new_mit_ids)): + mitigations.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], "url": r["url"]}) + + # 补充攻击活动详情 + new_camp_ids = camp_ids - existing_camp_ids + if new_camp_ids: + for r in session.run("MATCH (c:Campaign) WHERE c.attack_id IN $ids RETURN c.attack_id AS aid, c.name AS name, c.description AS desc, c.aliases AS aliases, c.url AS url", ids=list(new_camp_ids)): + campaigns.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + + return { + "techniques": techniques, + "tactics": tactics, + "groups": groups, + "software": software, + "mitigations": mitigations, + "campaigns": campaigns, + "relationships": relationships, + } + + def query_by_intent(self, attack_ids: list, query_focus: str = "general", entities: list = None) -> dict: + """ + 根据用户意图动态决定查询策略 + + Args: + attack_ids: BM25 命中的 attack_id 列表 + query_focus: 用户关心的关系类型 + entities: LLM 提取的原始实体 + + Returns: + 同 query_related_entities 的返回格式 + """ + if not attack_ids: + return {"techniques": [], "tactics": [], "groups": [], "software": [], + "mitigations": [], "campaigns": [], "relationships": []} + + by_type = {} + for aid in attack_ids: + t = _detect_entity_type(aid) + by_type.setdefault(t, []).append(aid) + + techniques, tactics, groups, software = [], [], [], [] + mitigations, campaigns, relationships = [], [], [] + + with self.driver.session() as session: + if query_focus == "detail": + self._query_detail(session, by_type, techniques, tactics, groups, + software, mitigations, campaigns, relationships) + elif query_focus == "uses": + self._query_uses(session, by_type, techniques, groups, software, + campaigns, relationships) + elif query_focus == "mitigates": + self._query_mitigates(session, by_type, techniques, mitigations, relationships) + elif query_focus == "belongs_to": + self._query_belongs_to(session, by_type, techniques, tactics, relationships) + elif query_focus == "attributed_to": + self._query_attributed_to(session, by_type, campaigns, groups, relationships) + elif query_focus == "tactics_of": + self._query_tactics_of(session, by_type, tactics, techniques, relationships) + else: + self._query_general(session, by_type, techniques, tactics, groups, + software, mitigations, campaigns, relationships) + + return { + "techniques": techniques, "tactics": tactics, "groups": groups, + "software": software, "mitigations": mitigations, "campaigns": campaigns, + "relationships": relationships, + } + + def _query_detail(self, session, by_type, techniques, tactics, groups, + software, mitigations, campaigns, relationships): + """查实体本身详情 + 所有直接关系(1跳)""" + if by_type.get("technique"): + q = """ + MATCH (t:Technique) WHERE t.attack_id IN $ids + OPTIONAL MATCH (t)-[:BELONGS_TO]->(tac:Tactic) + OPTIONAL MATCH (grp:Group)-[:USES]->(t) + OPTIONAL MATCH (sw)-[:USES]->(t) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (mit:Mitigation)-[:MITIGATES]->(t) + OPTIONAL MATCH (camp:Campaign)-[:USES]->(t) + OPTIONAL MATCH (sub:Technique)-[:SUBTECHNIQUE_OF]->(t) + RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, + t.platforms AS platforms, t.url AS url, + collect(DISTINCT {id: tac.attack_id, n: tac.name}) AS rel_tactics, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_groups, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: mit.attack_id, n: mit.name}) AS rel_mits, + collect(DISTINCT {id: camp.attack_id, n: camp.name}) AS rel_camps, + collect(DISTINCT {id: sub.attack_id, n: sub.name}) AS rel_subs + """ + tac_ids, group_ids, sw_ids, mit_ids, camp_ids = set(), set(), set(), set(), set() + for r in session.run(q, ids=by_type["technique"]): + techniques.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "platforms": r["platforms"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_tactics"], "BELONGS_TO", tac_ids) + _add_rels(relationships, r["aid"], r["rel_groups"], "USES", group_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", sw_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_mits"], "MITIGATES", mit_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_camps"], "USES", camp_ids, reverse=True) + for sub in r["rel_subs"]: + if sub["id"]: + relationships.append({"from": sub["id"], "to": r["aid"], + "type": "SUBTECHNIQUE_OF", "to_name": sub["n"]}) + self._fill_2hop(session, techniques, tactics, groups, software, mitigations, campaigns, + tac_ids, group_ids, sw_ids, mit_ids, camp_ids) + + if by_type.get("group"): + q = """ + MATCH (g:Group) WHERE g.attack_id IN $ids + OPTIONAL MATCH (g)-[:USES]->(tech:Technique) + OPTIONAL MATCH (g)-[:USES]->(sw) WHERE sw:Malware OR sw:Tool + RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, + g.aliases AS aliases, g.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw + """ + for r in session.run(q, ids=by_type["group"]): + groups.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", None) + + if by_type.get("software"): + q = """ + MATCH (sw) WHERE sw.attack_id IN $ids AND (sw:Malware OR sw:Tool) + OPTIONAL MATCH (sw)-[:USES]->(tech:Technique) + OPTIONAL MATCH (grp:Group)-[:USES]->(sw) + RETURN sw.attack_id AS aid, sw.name AS name, sw.description AS desc, + sw.platforms AS platforms, sw.aliases AS aliases, sw.url AS url, + labels(sw) AS labels, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + for r in session.run(q, ids=by_type["software"]): + sw_type = "malware" if "Malware" in r["labels"] else "tool" + software.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], + "platforms": r["platforms"], "aliases": r["aliases"], "url": r["url"], "type": sw_type}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_grps"], "USES", None, reverse=True) + + if by_type.get("tactic"): + q = """ + MATCH (tac:Tactic) WHERE tac.attack_id IN $ids + OPTIONAL MATCH (tech:Technique)-[:BELONGS_TO]->(tac) + RETURN tac.attack_id AS aid, tac.name AS name, tac.description AS desc, + tac.shortname AS shortname, tac.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(q, ids=by_type["tactic"]): + tactics.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "shortname": r["shortname"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "BELONGS_TO", None, reverse=True) + + if by_type.get("mitigation"): + q = """ + MATCH (m:Mitigation) WHERE m.attack_id IN $ids + OPTIONAL MATCH (m)-[:MITIGATES]->(tech:Technique) + RETURN m.attack_id AS aid, m.name AS name, m.description AS desc, m.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(q, ids=by_type["mitigation"]): + mitigations.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "MITIGATES", None) + + if by_type.get("campaign"): + q = """ + MATCH (c:Campaign) WHERE c.attack_id IN $ids + OPTIONAL MATCH (c)-[:USES]->(tech:Technique) + OPTIONAL MATCH (c)-[:ATTRIBUTED_TO]->(grp:Group) + RETURN c.attack_id AS aid, c.name AS name, c.description AS desc, + c.aliases AS aliases, c.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + for r in session.run(q, ids=by_type["campaign"]): + campaigns.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_grps"], "ATTRIBUTED_TO", None) + + def _query_uses(self, session, by_type, techniques, groups, software, campaigns, relationships): + """只查 USES 关系""" + if by_type.get("technique"): + q = """ + MATCH (t:Technique) WHERE t.attack_id IN $ids + OPTIONAL MATCH (grp:Group)-[:USES]->(t) + OPTIONAL MATCH (sw)-[:USES]->(t) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (sub:Technique)-[:SUBTECHNIQUE_OF]->(t) + RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, + t.platforms AS platforms, t.url AS url, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_groups, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: sub.attack_id, n: sub.name}) AS rel_subs + """ + group_ids, sw_ids = set(), set() + for r in session.run(q, ids=by_type["technique"]): + techniques.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "platforms": r["platforms"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_groups"], "USES", group_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", sw_ids, reverse=True) + for sub in r["rel_subs"]: + if sub["id"]: + relationships.append({"from": sub["id"], "to": r["aid"], + "type": "SUBTECHNIQUE_OF", "to_name": sub["n"]}) + self._fill_2hop(session, [], [], groups, software, [], [], + set(), group_ids, sw_ids, set(), set()) + + if by_type.get("group"): + q = """ + MATCH (g:Group) WHERE g.attack_id IN $ids + OPTIONAL MATCH (g)-[:USES]->(tech:Technique) + OPTIONAL MATCH (g)-[:USES]->(sw) WHERE sw:Malware OR sw:Tool + RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, + g.aliases AS aliases, g.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw + """ + tech_ids, sw_ids = set(), set() + for r in session.run(q, ids=by_type["group"]): + groups.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", tech_ids) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", sw_ids) + if tech_ids: + for r in session.run( + "MATCH (t:Technique) WHERE t.attack_id IN $ids " + "RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, " + "t.platforms AS platforms, t.url AS url", + ids=list(tech_ids)): + techniques.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "platforms": r["platforms"], "url": r["url"]}) + if sw_ids: + for r in session.run( + "MATCH (sw) WHERE sw.attack_id IN $ids AND (sw:Malware OR sw:Tool) " + "RETURN sw.attack_id AS aid, sw.name AS name, sw.description AS desc, " + "sw.platforms AS platforms, sw.aliases AS aliases, sw.url AS url, labels(sw) AS labels", + ids=list(sw_ids)): + sw_type = "malware" if "Malware" in r["labels"] else "tool" + software.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], + "platforms": r["platforms"], "aliases": r["aliases"], "url": r["url"], "type": sw_type}) + + if by_type.get("software"): + q = """ + MATCH (sw) WHERE sw.attack_id IN $ids AND (sw:Malware OR sw:Tool) + OPTIONAL MATCH (sw)-[:USES]->(tech:Technique) + OPTIONAL MATCH (grp:Group)-[:USES]->(sw) + RETURN sw.attack_id AS aid, sw.name AS name, sw.description AS desc, + sw.platforms AS platforms, sw.aliases AS aliases, sw.url AS url, + labels(sw) AS labels, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + group_ids = set() + for r in session.run(q, ids=by_type["software"]): + sw_type = "malware" if "Malware" in r["labels"] else "tool" + software.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], + "platforms": r["platforms"], "aliases": r["aliases"], "url": r["url"], "type": sw_type}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_grps"], "USES", group_ids, reverse=True) + if group_ids: + for r in session.run( + "MATCH (g:Group) WHERE g.attack_id IN $ids " + "RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, g.aliases AS aliases, g.url AS url", + ids=list(group_ids)): + groups.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + + if by_type.get("campaign"): + q = """ + MATCH (c:Campaign) WHERE c.attack_id IN $ids + OPTIONAL MATCH (c)-[:USES]->(tech:Technique) + OPTIONAL MATCH (c)-[:USES]->(sw) WHERE sw:Malware OR sw:Tool + RETURN c.attack_id AS aid, c.name AS name, c.description AS desc, + c.aliases AS aliases, c.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw + """ + for r in session.run(q, ids=by_type["campaign"]): + campaigns.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", None) + + def _query_mitigates(self, session, by_type, techniques, mitigations, relationships): + """只查 MITIGATES 关系""" + if by_type.get("technique"): + q = """ + MATCH (t:Technique) WHERE t.attack_id IN $ids + OPTIONAL MATCH (mit:Mitigation)-[:MITIGATES]->(t) + RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, + t.platforms AS platforms, t.url AS url, + collect(DISTINCT {id: mit.attack_id, n: mit.name}) AS rel_mits + """ + mit_ids = set() + for r in session.run(q, ids=by_type["technique"]): + techniques.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "platforms": r["platforms"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_mits"], "MITIGATES", mit_ids, reverse=True) + if mit_ids: + for r in session.run( + "MATCH (m:Mitigation) WHERE m.attack_id IN $ids " + "RETURN m.attack_id AS aid, m.name AS name, m.description AS desc, m.url AS url", + ids=list(mit_ids)): + mitigations.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "url": r["url"]}) + + if by_type.get("mitigation"): + q = """ + MATCH (m:Mitigation) WHERE m.attack_id IN $ids + OPTIONAL MATCH (m)-[:MITIGATES]->(tech:Technique) + RETURN m.attack_id AS aid, m.name AS name, m.description AS desc, m.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(q, ids=by_type["mitigation"]): + mitigations.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "MITIGATES", None) + + def _query_belongs_to(self, session, by_type, techniques, tactics, relationships): + """只查 BELONGS_TO 关系(技术→战术)""" + if by_type.get("technique"): + q = """ + MATCH (t:Technique) WHERE t.attack_id IN $ids + OPTIONAL MATCH (t)-[:BELONGS_TO]->(tac:Tactic) + RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, + t.platforms AS platforms, t.url AS url, + collect(DISTINCT {id: tac.attack_id, n: tac.name}) AS rel_tactics + """ + tac_ids = set() + for r in session.run(q, ids=by_type["technique"]): + techniques.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "platforms": r["platforms"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_tactics"], "BELONGS_TO", tac_ids) + if tac_ids: + for r in session.run( + "MATCH (t:Tactic) WHERE t.attack_id IN $ids " + "RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, t.shortname AS sn, t.url AS url", + ids=list(tac_ids)): + tactics.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "shortname": r["sn"], "url": r["url"]}) + + def _query_attributed_to(self, session, by_type, campaigns, groups, relationships): + """只查 ATTRIBUTED_TO 关系(活动→组织)""" + if by_type.get("campaign"): + q = """ + MATCH (c:Campaign) WHERE c.attack_id IN $ids + OPTIONAL MATCH (c)-[:ATTRIBUTED_TO]->(grp:Group) + RETURN c.attack_id AS aid, c.name AS name, c.description AS desc, + c.aliases AS aliases, c.url AS url, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + group_ids = set() + for r in session.run(q, ids=by_type["campaign"]): + campaigns.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_grps"], "ATTRIBUTED_TO", group_ids) + if group_ids: + for r in session.run( + "MATCH (g:Group) WHERE g.attack_id IN $ids " + "RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, g.aliases AS aliases, g.url AS url", + ids=list(group_ids)): + groups.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + + def _query_tactics_of(self, session, by_type, tactics, techniques, relationships): + """从战术出发查其下属技术(BELONGS_TO 反向)""" + if by_type.get("tactic"): + q = """ + MATCH (tac:Tactic) WHERE tac.attack_id IN $ids + OPTIONAL MATCH (tech:Technique)-[:BELONGS_TO]->(tac) + RETURN tac.attack_id AS aid, tac.name AS name, tac.description AS desc, + tac.shortname AS shortname, tac.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(q, ids=by_type["tactic"]): + tactics.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "shortname": r["shortname"], "url": r["url"]}) + for tech in r["rel_techs"]: + if tech["id"]: + techniques.append({"attack_id": tech["id"], "name": tech["n"], + "description": "", "platforms": [], "url": ""}) + relationships.append({"from": tech["id"], "to": r["aid"], + "type": "BELONGS_TO", "to_name": r["name"]}) + + def _query_general(self, session, by_type, techniques, tactics, groups, + software, mitigations, campaigns, relationships): + """通用查询:所有关系,1跳""" + if by_type.get("technique"): + q = """ + MATCH (t:Technique) WHERE t.attack_id IN $ids + OPTIONAL MATCH (t)-[:BELONGS_TO]->(tac:Tactic) + OPTIONAL MATCH (grp:Group)-[:USES]->(t) + OPTIONAL MATCH (sw)-[:USES]->(t) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (mit:Mitigation)-[:MITIGATES]->(t) + OPTIONAL MATCH (camp:Campaign)-[:USES]->(t) + OPTIONAL MATCH (sub:Technique)-[:SUBTECHNIQUE_OF]->(t) + RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, + t.platforms AS platforms, t.url AS url, + collect(DISTINCT {id: tac.attack_id, n: tac.name}) AS rel_tactics, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_groups, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: mit.attack_id, n: mit.name}) AS rel_mits, + collect(DISTINCT {id: camp.attack_id, n: camp.name}) AS rel_camps, + collect(DISTINCT {id: sub.attack_id, n: sub.name}) AS rel_subs + """ + tac_ids, group_ids, sw_ids, mit_ids, camp_ids = set(), set(), set(), set(), set() + for r in session.run(q, ids=by_type["technique"]): + techniques.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "platforms": r["platforms"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_tactics"], "BELONGS_TO", tac_ids) + _add_rels(relationships, r["aid"], r["rel_groups"], "USES", group_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", sw_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_mits"], "MITIGATES", mit_ids, reverse=True) + _add_rels(relationships, r["aid"], r["rel_camps"], "USES", camp_ids, reverse=True) + for sub in r["rel_subs"]: + if sub["id"]: + relationships.append({"from": sub["id"], "to": r["aid"], + "type": "SUBTECHNIQUE_OF", "to_name": sub["n"]}) + self._fill_2hop(session, techniques, tactics, groups, software, mitigations, campaigns, + tac_ids, group_ids, sw_ids, mit_ids, camp_ids) + + if by_type.get("group"): + q = """ + MATCH (g:Group) WHERE g.attack_id IN $ids + OPTIONAL MATCH (g)-[:USES]->(tech:Technique) + OPTIONAL MATCH (g)-[:USES]->(sw) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (camp:Campaign)-[:ATTRIBUTED_TO]->(g) + RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, + g.aliases AS aliases, g.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: camp.attack_id, n: camp.name}) AS rel_camps + """ + sw_ids, camp_ids = set(), set() + for r in session.run(q, ids=by_type["group"]): + groups.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", sw_ids) + _add_rels(relationships, r["aid"], r["rel_camps"], "ATTRIBUTED_TO", camp_ids, reverse=True) + + if by_type.get("software"): + q = """ + MATCH (sw) WHERE sw.attack_id IN $ids AND (sw:Malware OR sw:Tool) + OPTIONAL MATCH (sw)-[:USES]->(tech:Technique) + OPTIONAL MATCH (grp:Group)-[:USES]->(sw) + RETURN sw.attack_id AS aid, sw.name AS name, sw.description AS desc, + sw.platforms AS platforms, sw.aliases AS aliases, sw.url AS url, + labels(sw) AS labels, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + for r in session.run(q, ids=by_type["software"]): + sw_type = "malware" if "Malware" in r["labels"] else "tool" + software.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], + "platforms": r["platforms"], "aliases": r["aliases"], "url": r["url"], "type": sw_type}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_grps"], "USES", None, reverse=True) + + if by_type.get("tactic"): + q = """ + MATCH (tac:Tactic) WHERE tac.attack_id IN $ids + OPTIONAL MATCH (tech:Technique)-[:BELONGS_TO]->(tac) + RETURN tac.attack_id AS aid, tac.name AS name, tac.description AS desc, + tac.shortname AS shortname, tac.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(q, ids=by_type["tactic"]): + tactics.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "shortname": r["shortname"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "BELONGS_TO", None, reverse=True) + + if by_type.get("mitigation"): + q = """ + MATCH (m:Mitigation) WHERE m.attack_id IN $ids + OPTIONAL MATCH (m)-[:MITIGATES]->(tech:Technique) + RETURN m.attack_id AS aid, m.name AS name, m.description AS desc, m.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs + """ + for r in session.run(q, ids=by_type["mitigation"]): + mitigations.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "MITIGATES", None) + + if by_type.get("campaign"): + q = """ + MATCH (c:Campaign) WHERE c.attack_id IN $ids + OPTIONAL MATCH (c)-[:USES]->(tech:Technique) + OPTIONAL MATCH (c)-[:USES]->(sw) WHERE sw:Malware OR sw:Tool + OPTIONAL MATCH (c)-[:ATTRIBUTED_TO]->(grp:Group) + RETURN c.attack_id AS aid, c.name AS name, c.description AS desc, + c.aliases AS aliases, c.url AS url, + collect(DISTINCT {id: tech.attack_id, n: tech.name}) AS rel_techs, + collect(DISTINCT {id: sw.attack_id, n: sw.name}) AS rel_sw, + collect(DISTINCT {id: grp.attack_id, n: grp.name}) AS rel_grps + """ + sw_ids, group_ids = set(), set() + for r in session.run(q, ids=by_type["campaign"]): + campaigns.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + _add_rels(relationships, r["aid"], r["rel_techs"], "USES", None) + _add_rels(relationships, r["aid"], r["rel_sw"], "USES", sw_ids) + _add_rels(relationships, r["aid"], r["rel_grps"], "ATTRIBUTED_TO", group_ids) + + def _fill_2hop(self, session, techniques, tactics, groups, software, + mitigations, campaigns, tac_ids, group_ids, sw_ids, mit_ids, camp_ids): + """补充二跳关联实体的详情""" + existing_tech = {t["attack_id"] for t in techniques} + existing_group = {g["attack_id"] for g in groups} + existing_sw = {s["attack_id"] for s in software} + existing_mit = {m["attack_id"] for m in mitigations} + existing_camp = {c["attack_id"] for c in campaigns} + + tac_ids -= {t["attack_id"] for t in tactics} + if tac_ids: + for r in session.run( + "MATCH (t:Tactic) WHERE t.attack_id IN $ids " + "RETURN t.attack_id AS aid, t.name AS name, t.description AS desc, t.shortname AS sn, t.url AS url", + ids=list(tac_ids)): + tactics.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "shortname": r["sn"], "url": r["url"]}) + + new_group_ids = group_ids - existing_group + if new_group_ids: + for r in session.run( + "MATCH (g:Group) WHERE g.attack_id IN $ids " + "RETURN g.attack_id AS aid, g.name AS name, g.description AS desc, g.aliases AS aliases, g.url AS url", + ids=list(new_group_ids)): + groups.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + + new_sw_ids = sw_ids - existing_sw + if new_sw_ids: + for r in session.run( + "MATCH (sw) WHERE sw.attack_id IN $ids AND (sw:Malware OR sw:Tool) " + "RETURN sw.attack_id AS aid, sw.name AS name, sw.description AS desc, " + "sw.platforms AS platforms, sw.aliases AS aliases, sw.url AS url, labels(sw) AS labels", + ids=list(new_sw_ids)): + sw_type = "malware" if "Malware" in r["labels"] else "tool" + software.append({"attack_id": r["aid"], "name": r["name"], "description": r["desc"], + "platforms": r["platforms"], "aliases": r["aliases"], "url": r["url"], "type": sw_type}) + + new_mit_ids = mit_ids - existing_mit + if new_mit_ids: + for r in session.run( + "MATCH (m:Mitigation) WHERE m.attack_id IN $ids " + "RETURN m.attack_id AS aid, m.name AS name, m.description AS desc, m.url AS url", + ids=list(new_mit_ids)): + mitigations.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "url": r["url"]}) + + new_camp_ids = camp_ids - existing_camp + if new_camp_ids: + for r in session.run( + "MATCH (c:Campaign) WHERE c.attack_id IN $ids " + "RETURN c.attack_id AS aid, c.name AS name, c.description AS desc, c.aliases AS aliases, c.url AS url", + ids=list(new_camp_ids)): + campaigns.append({"attack_id": r["aid"], "name": r["name"], + "description": r["desc"], "aliases": r["aliases"], "url": r["url"]}) + + def get_stats(self) -> dict: + """获取数据库统计信息""" + with self.driver.session() as session: + stats = {} + for label in ["Technique", "Tactic", "Group", "Malware", "Tool", "Mitigation", "Campaign"]: + stats[label.lower() + "s"] = session.run( + f"MATCH (n:{label}) RETURN count(n) AS cnt" + ).single()["cnt"] + + for rel in ["BELONGS_TO", "USES", "MITIGATES", "SUBTECHNIQUE_OF", "ATTRIBUTED_TO"]: + stats[rel.lower()] = session.run( + f"MATCH ()-[r:{rel}]->() RETURN count(r) AS cnt" + ).single()["cnt"] + + return stats + + +def _add_rels(relationships, entity_id, related_list, rel_type, id_set, reverse=False): + """辅助函数:添加关系到列表并收集关联 ID""" + for item in related_list: + if not item["id"]: + continue + if reverse: + relationships.append({"from": item["id"], "to": entity_id, "type": rel_type, "to_name": item["n"]}) + else: + relationships.append({"from": entity_id, "to": item["id"], "type": rel_type, "to_name": item["n"]}) + if id_set is not None: + id_set.add(item["id"]) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Neo4j 图查询工具") + parser.add_argument("--password", default="", help="Neo4j 密码") + parser.add_argument("--attack-id", default="", help="查询实体 ID") + parser.add_argument("--stats", action="store_true", help="显示统计信息") + args = parser.parse_args() + + gq = GraphQuery(password=args.password) + + if args.stats: + stats = gq.get_stats() + print("=== 节点统计 ===") + for key, val in stats.items(): + if not any(r in key for r in ["belongs", "uses", "mitigates", "subtechnique", "attributed"]): + print(f" {key}: {val}") + print("\n=== 关系统计 ===") + for key, val in stats.items(): + if any(r in key for r in ["belongs", "uses", "mitigates", "subtechnique", "attributed"]): + print(f" {key}: {val}") + + if args.attack_id: + result = gq.query_related_entities([args.attack_id]) + for key in ["techniques", "tactics", "groups", "software", "mitigations", "campaigns"]: + items = result[key] + if items: + print(f"\n{key}: {len(items)} 个") + for item in items[:5]: + print(f" - {item['attack_id']} {item['name']}") + if result["relationships"]: + print(f"\n关系: {len(result['relationships'])} 条") + for rel in result["relationships"][:10]: + print(f" {rel['from']} → [{rel['type']}] → {rel['to']}") + + gq.close() diff --git a/ntad/qa_engine.py b/ntad/qa_engine.py new file mode 100644 index 0000000..896ba7e --- /dev/null +++ b/ntad/qa_engine.py @@ -0,0 +1,423 @@ +""" +问答引擎(主入口) +串联所有模块,实现完整的问答流程 +""" + +import os +import sys +from typing import Optional + +from dotenv import load_dotenv + +load_dotenv() + +# 将项目根目录加入 Python 路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from utils.bm25_search import BM25Search, STIX_FILE, INDEX_FILE, rrf_fusion +from utils.graph_query import GraphQuery +from utils.query_parser import parse_query +from utils.answer_generator import generate_answer, generate_answer_stream, generate_low_confidence_answer + + +class QAEngine: + """ + 问答引擎 + + 流程: + 1. LLM 意图识别 → 提取关键词和查询类型 + 2. 判断是否需要追问 + 3. BM25 关键词检索 + 4. 置信度判断 (> 0.75 → Neo4j 查询) + 5. LLM 整合回答 + """ + + def __init__( + self, + neo4j_password: str = "", + bm25_index_path: str = INDEX_FILE, + stix_data_path: str = STIX_FILE, + ): + """ + 初始化问答引擎 + + Args: + neo4j_password: Neo4j 密码(优先使用参数,其次环境变量) + bm25_index_path: BM25 索引文件路径 + stix_data_path: STIX 数据文件路径 + """ + self.history = [] # 对话历史 + self.stix_data_path = stix_data_path + + # 初始化 BM25 检索引擎 + print("正在初始化 BM25 检索引擎...") + self.bm25 = BM25Search() + if not self.bm25.load_index(bm25_index_path): + if os.path.exists(stix_data_path): + print("索引不存在,正在从原始数据构建...") + self.bm25.build_index(stix_data_path) + self.bm25.save_index(bm25_index_path) + else: + raise FileNotFoundError( + f"BM25 索引文件和原始数据文件均不存在: {bm25_index_path}, {stix_data_path}" + ) + + # 初始化 Neo4j 图查询 + print("正在连接 Neo4j...") + try: + self.graph = GraphQuery(password=neo4j_password) + except Exception as e: + print(f"⚠️ Neo4j 连接失败: {e}") + print("图查询功能将不可用,仅使用 BM25 检索") + self.graph = None + + print("✅ 问答引擎初始化完成\n") + + def _is_self_intro(self, user_input: str) -> bool: + """判断是否为自我介绍类问题""" + keywords = ["你是谁", "你是什么", "介绍一下", "自我介绍", "你叫什么", "你是什么系统", "你是什么平台"] + return any(k in user_input for k in keywords) + + def _self_intro_answer(self) -> str: + """返回自我介绍""" + return ( + "我是 **NTAD 智能攻防问答系统**,基于全球最权威的 MITRE ATT&CK 网络威胁知识库构建。\n\n" + "### 系统能力\n\n" + "| 维度 | 说明 |\n" + "|------|------|\n" + "| 数据源 | MITRE ATT&CK Enterprise(25,842 个 STIX 对象) |\n" + "| 知识图谱 | Neo4j 图数据库,7 类节点、6 类关系、18,000+ 边 |\n" + "| 检索引擎 | BM25Okapi 关键词检索 + RRF 多路融合 |\n" + "| 智能引擎 | 通义千问 qwen-max 意图识别 + 回答生成 |\n" + "| 查询路由 | 7 种意图动态图查询(detail/uses/mitigates 等) |\n\n" + "### 我能帮您\n\n" + "- 查询攻击技术详情(如:T1059 是什么?)\n" + "- 了解战术阶段下的技术(如:初始访问有哪些技术?)\n" + "- 获取防御建议(如:怎么防御 PowerShell 攻击?)\n" + "- 追踪威胁组织活动(如:APT28 用了哪些技术?)\n" + "- 分析恶意软件和工具(如:Cobalt Strike 利用了什么技术?)\n\n" + "请输入您的问题开始探索。" + ) + + def ask(self, user_input: str) -> str: + """ + 处理用户提问 + + Args: + user_input: 用户输入的文本 + + Returns: + 回答文本 + """ + # 自我介绍快捷处理 + if self._is_self_intro(user_input): + answer = self._self_intro_answer() + self.history.append({"role": "user", "content": user_input}) + self.history.append({"role": "assistant", "content": answer}) + return answer + + # 1. LLM 意图识别 + print("🔍 正在分析问题...") + try: + parsed = parse_query(user_input, self.history) + except Exception as e: + print(f"⚠️ 意图识别失败: {e},使用简单关键词检索") + parsed = { + "query_type": "general", + "query_focus": "general", + "entities": [], + "search_keywords": user_input.split(), + "search_keywords_variants": [], + "need_clarification": False, + "confidence": 0.5, + } + + print(f" 查询类型: {parsed.get('query_type', 'unknown')}") + print(f" 查询焦点: {parsed.get('query_focus', 'general')}") + print(f" 提取实体: {parsed.get('entities', [])}") + print(f" 检索关键词: {parsed.get('search_keywords', [])}") + + # 2. 判断是否需要追问 + if parsed.get("need_clarification", False): + clarification = parsed.get("clarification_question", "请提供更多信息") + self.history.append({"role": "user", "content": user_input}) + self.history.append({"role": "assistant", "content": clarification}) + return clarification + + # 3. BM25 多路检索 + RRF 融合 + keywords = parsed.get("search_keywords", []) + if not keywords: + keywords = user_input.split() + + # 主查询 + query_text = " ".join(keywords) + all_queries = [query_text] + + # 变体查询 + variants = parsed.get("search_keywords_variants", []) + for v in variants: + if v: + all_queries.append(" ".join(v)) + + print(f"🔍 BM25 检索: {query_text}(共 {len(all_queries)} 路查询)") + + # 多路搜索 + multi_results = [self.bm25.search(q, top_k=10) for q in all_queries] + + # RRF 融合 + if len(multi_results) > 1: + search_results = rrf_fusion(multi_results) + else: + search_results = multi_results[0] if multi_results else [] + + if search_results: + print(f" 找到 {len(search_results)} 条结果") + print(f" 最高置信度: {search_results[0]['score']:.4f}") + else: + print(" 未找到匹配结果") + + # 4. 置信度判断(使用 LLM 意图识别的 confidence,而非 BM25 分数) + llm_confidence = parsed.get("confidence", 0.5) + has_results = bool(search_results) + + if has_results and llm_confidence >= 0.5: + # 有检索结果且 LLM 置信度足够 → Neo4j 查询 + print(f"✅ BM25 命中 {len(search_results)} 条,LLM 置信度 {llm_confidence:.2f},执行图查询") + + graph_data = None + if self.graph: + # 收集要查询的 attack_id + attack_ids = [r["attack_id"] for r in search_results[:5]] + + # 如果有特定实体,优先查询 + entities = parsed.get("entities", []) + entity_ids = [e for e in entities if any(e.startswith(p) for p in ("T", "TA", "G", "S", "M", "C", "DET"))] + if entity_ids: + attack_ids = entity_ids + attack_ids + + attack_ids = list(dict.fromkeys(attack_ids)) # 去重保序 + + print(f"🔍 Neo4j 查询: {attack_ids}") + graph_data = self.graph.query_by_intent( + attack_ids, + query_focus=parsed.get("query_focus", "general"), + entities=parsed.get("entities", []), + ) + + # LLM 整合回答 + print("🤖 正在生成回答...") + answer = generate_answer( + user_question=user_input, + search_results=search_results, + graph_data=graph_data, + history=self.history, + ) + elif has_results: + # 有结果但 LLM 置信度低 → 返回检索结果 + 提示 + print(f"⚠️ LLM 置信度 {llm_confidence:.2f} < 0.5,返回检索结果") + answer = generate_low_confidence_answer(user_input, search_results) + else: + # 无检索结果 + print("⚠️ BM25 未找到匹配结果") + answer = generate_low_confidence_answer(user_input, []) + + # 更新对话历史 + self.history.append({"role": "user", "content": user_input}) + self.history.append({"role": "assistant", "content": answer}) + + # 限制历史长度 + if len(self.history) > 20: + self.history = self.history[-20:] + + return answer + + def ask_stream(self, user_input: str): + """ + 流式处理用户提问(逐 token 返回) + + Yields: + 每次产出的文本片段 + """ + # 自我介绍快捷处理(模拟逐字流式) + if self._is_self_intro(user_input): + answer = self._self_intro_answer() + self.history.append({"role": "user", "content": user_input}) + self.history.append({"role": "assistant", "content": answer}) + # 按行切分,逐行 yield 实现流式效果 + for line in answer.split("\n"): + yield line + "\n" + return + + # 1. LLM 意图识别 + print("🔍 正在分析问题...") + try: + parsed = parse_query(user_input, self.history) + except Exception as e: + print(f"⚠️ 意图识别失败: {e},使用简单关键词检索") + parsed = { + "query_type": "general", + "query_focus": "general", + "entities": [], + "search_keywords": user_input.split(), + "search_keywords_variants": [], + "need_clarification": False, + "confidence": 0.5, + } + + print(f" 查询类型: {parsed.get('query_type', 'unknown')}") + print(f" 查询焦点: {parsed.get('query_focus', 'general')}") + + # 2. 判断是否需要追问 + if parsed.get("need_clarification", False): + clarification = parsed.get("clarification_question", "请提供更多信息") + self.history.append({"role": "user", "content": user_input}) + self.history.append({"role": "assistant", "content": clarification}) + yield clarification + return + + # 3. BM25 多路检索 + RRF 融合 + keywords = parsed.get("search_keywords", []) + if not keywords: + keywords = user_input.split() + + query_text = " ".join(keywords) + all_queries = [query_text] + variants = parsed.get("search_keywords_variants", []) + for v in variants: + if v: + all_queries.append(" ".join(v)) + + print(f"🔍 BM25 检索: {query_text}(共 {len(all_queries)} 路查询)") + multi_results = [self.bm25.search(q, top_k=10) for q in all_queries] + if len(multi_results) > 1: + search_results = rrf_fusion(multi_results) + else: + search_results = multi_results[0] if multi_results else [] + + if search_results: + print(f" 找到 {len(search_results)} 条结果") + else: + print(" 未找到匹配结果") + + # 4. 置信度判断 + llm_confidence = parsed.get("confidence", 0.5) + has_results = bool(search_results) + + if has_results and llm_confidence >= 0.5: + print(f"✅ BM25 命中 {len(search_results)} 条,LLM 置信度 {llm_confidence:.2f},执行图查询") + + graph_data = None + if self.graph: + attack_ids = [r["attack_id"] for r in search_results[:5]] + entities = parsed.get("entities", []) + entity_ids = [e for e in entities if any(e.startswith(p) for p in ("T", "TA", "G", "S", "M", "C", "DET"))] + if entity_ids: + attack_ids = entity_ids + attack_ids + attack_ids = list(dict.fromkeys(attack_ids)) + + print(f"🔍 Neo4j 查询: {attack_ids}") + graph_data = self.graph.query_by_intent( + attack_ids, + query_focus=parsed.get("query_focus", "general"), + entities=parsed.get("entities", []), + ) + + # 流式生成回答 + print("🤖 正在生成回答...") + full_answer = "" + for chunk in generate_answer_stream( + user_question=user_input, + search_results=search_results, + graph_data=graph_data, + history=self.history, + ): + full_answer += chunk + yield chunk + + elif has_results: + print(f"⚠️ LLM 置信度 {llm_confidence:.2f} < 0.5,返回检索结果") + full_answer = generate_low_confidence_answer(user_input, search_results) + yield full_answer + else: + print("⚠️ BM25 未找到匹配结果") + full_answer = generate_low_confidence_answer(user_input, []) + yield full_answer + + # 更新对话历史 + self.history.append({"role": "user", "content": user_input}) + self.history.append({"role": "assistant", "content": full_answer}) + if len(self.history) > 20: + self.history = self.history[-20:] + + def clear_history(self): + """清空对话历史""" + self.history = [] + print("对话历史已清空") + + def get_stats(self) -> dict: + """获取引擎统计信息""" + stats = self.bm25.get_index_stats() + if self.graph: + try: + graph_stats = self.graph.get_stats() + stats.update(graph_stats) + except Exception: + pass + return stats + + def close(self): + """关闭连接""" + if self.graph: + self.graph.close() + + +def main(): + """命令行交互入口""" + import argparse + + parser = argparse.ArgumentParser(description="MITRE ATT&CK 知识图谱问答引擎") + parser.add_argument("--password", default="", help="Neo4j 密码") + args = parser.parse_args() + + password = args.password or os.getenv("NEO4J_PASSWORD", "") + + engine = None + try: + engine = QAEngine(neo4j_password=password) + + stats = engine.get_stats() + print(f"📊 引擎统计: {stats}\n") + print("=" * 50) + print("MITRE ATT&CK 知识图谱问答系统") + print("输入 'quit' 或 'exit' 退出") + print("输入 'clear' 清空对话历史") + print("=" * 50) + + while True: + try: + user_input = input("\n👤 您的问题: ").strip() + except (EOFError, KeyboardInterrupt): + break + + if not user_input: + continue + if user_input.lower() in ("quit", "exit"): + break + if user_input.lower() == "clear": + engine.clear_history() + continue + + print() + answer = engine.ask(user_input) + print(f"\n🤖 回答:\n{answer}") + + except Exception as e: + print(f"❌ 错误: {e}", file=sys.stderr) + sys.exit(1) + finally: + if engine: + engine.close() + + +if __name__ == "__main__": + main() diff --git a/ntad/query_parser.py b/ntad/query_parser.py new file mode 100644 index 0000000..0caacc8 --- /dev/null +++ b/ntad/query_parser.py @@ -0,0 +1,206 @@ +""" +LLM 意图识别模块 +使用千问 API 解析用户提问,提取查询意图和关键实体 +""" + +import json +import os +from typing import Optional + +from dotenv import load_dotenv +from openai import OpenAI + +load_dotenv() + +# 系统提示词 +SYSTEM_PROMPT = """你是一个 MITRE ATT&CK 网络安全知识图谱的查询助手。你的任务是解析用户的问题,提取查询意图和关键实体。 + +## 你需要输出一个 JSON 对象,包含以下字段: + +{ + "query_type": "technique | tactic | group | software | mitigation | campaign | general", + "query_focus": "detail | uses | mitigates | belongs_to | attributed_to | tactics_of | general", + "entities": ["提取的实体列表,如 T1059、PowerShell、APT28、Linux 等"], + "search_keywords": ["用于 BM25 检索的关键词列表(主查询)"], + "search_keywords_variants": [ + ["变体1:从不同角度或使用同义词的关键词"], + ["变体2:包含 attack_id 或更具体的术语"] + ], + "need_clarification": true/false, + "clarification_question": "如果需要追问,返回追问内容;否则为空字符串", + "confidence": 0.0~1.0, + "reasoning": "简要说明你的判断逻辑" +} + +## query_type 说明(用户主要关心哪类实体): +- technique: 查询攻击技术(如 "T1059 是什么"、"PowerShell 攻击") +- tactic: 查询战术阶段(如 "初始访问有哪些技术"、"横向移动") +- group: 查询威胁组织(如 "APT28 的信息"、"哪个组织用了这个技术") +- software: 查询恶意软件/工具(如 "CobaltStrike 是什么"、"用了哪些工具") +- mitigation: 查询防御/缓解措施(如 "怎么防御钓鱼"、"T1059 的缓解方法") +- campaign: 查询攻击活动(如 "最近的攻击活动") +- general: 通用问题(如 "什么是 ATT&CK") + +## query_focus 说明(用户关心什么关系): +- detail: 了解实体本身详情(如 "T1059 是什么"、"APT28 介绍") +- uses: 谁在用/用了什么(如 "APT28 用了哪些技术"、"这个软件利用了什么") +- mitigates: 防御/缓解(如 "怎么防御 PowerShell"、"T1059 的缓解措施") +- belongs_to: 归属关系(如 "T1059 属于哪个战术"、"这个技术在哪个阶段") +- attributed_to: 归属溯源(如 "这个活动是哪个组织干的") +- tactics_of: 战术下的技术列表(如 "初始访问有哪些技术") +- general: 通用/不明确 + +## 注意事项: +1. attack_id 格式:技术 T 开头(T1059、T1059.001),战术 TA 开头(TA0001),组织 G 开头(G0007),软件 S 开头(S0154),缓解 M 开头(M1032),活动 C 开头(C0001) +2. 如果用户问题过于模糊(如 "帮我查一下"),设置 need_clarification=true +3. search_keywords 应该是英文关键词,用于 BM25 检索 +4. 如果用户用中文提问,也要提取对应的英文关键词 +5. search_keywords_variants 是 search_keywords 的变体,用于多路检索提升召回率。生成 1-2 个变体,每个变体从不同角度选词(如同义词、更具体的 attack_id、上位/下位概念) +6. 只输出 JSON,不要输出其他内容 + +## 示例: + +用户: "T1059 是什么技术?" +输出: {"query_type":"technique","query_focus":"detail","entities":["T1059"],"search_keywords":["T1059","command","scripting","interpreter"],"search_keywords_variants":[["command","script","interpreter","execution"],["T1059.001","PowerShell","T1059.004","Unix shell"]],"need_clarification":false,"clarification_question":"","confidence":0.95,"reasoning":"用户询问特定技术 T1059 的详情"} + +用户: "初始访问阶段有哪些攻击技术?" +输出: {"query_type":"tactic","query_focus":"tactics_of","entities":["TA0001"],"search_keywords":["initial","access"],"search_keywords_variants":[["TA0001","phishing","drive-by","exploit"],["spearphishing","watering hole","public-facing"]],"need_clarification":false,"clarification_question":"","confidence":0.9,"reasoning":"用户询问初始访问战术下的技术"} + +用户: "PowerShell 相关的攻击怎么防御?" +输出: {"query_type":"mitigation","query_focus":"mitigates","entities":["T1059.001"],"search_keywords":["PowerShell","defense","mitigation"],"search_keywords_variants":[["T1059.001","constrained","language mode"],["scripting","execution","prevention"]],"need_clarification":false,"clarification_question":"","confidence":0.85,"reasoning":"用户询问 PowerShell 攻击的防御方法"} + +用户: "APT28 用了哪些攻击技术?" +输出: {"query_type":"group","query_focus":"uses","entities":["G0007"],"search_keywords":["APT28","attack","techniques"],"search_keywords_variants":[["G0007","Fancy Bear","Sofacy"],["APT28","malware","tools","credential"]],"need_clarification":false,"clarification_question":"","confidence":0.95,"reasoning":"用户询问 APT28 组织使用的攻击技术"} + +用户: "T1059 属于哪个战术阶段?" +输出: {"query_type":"technique","query_focus":"belongs_to","entities":["T1059"],"search_keywords":["T1059","tactic","phase"],"search_keywords_variants":[["command","scripting","interpreter","execution phase"]],"need_clarification":false,"clarification_question":"","confidence":0.95,"reasoning":"用户询问 T1059 归属的战术阶段"} + +用户: "帮我查一下" +输出: {"query_type":"general","query_focus":"general","entities":[],"search_keywords":[],"search_keywords_variants":[],"need_clarification":true,"clarification_question":"请问您想查询什么内容?例如:\n1. 特定攻击技术(如 T1059 PowerShell)\n2. 攻击战术阶段(如初始访问、横向移动)\n3. 威胁组织(如 APT28)\n4. 防御方法","confidence":0.1,"reasoning":"用户问题过于模糊,需要追问"}""" + + +def _get_client() -> OpenAI: + """获取千问 API 客户端""" + api_key = os.getenv("QWEN_API_KEY") + base_url = os.getenv("QWEN_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + model = os.getenv("QWEN_MODEL", "qwen-plus") + + if not api_key or api_key == "your_api_key_here": + raise ValueError("请在 .env 文件中配置 QWEN_API_KEY") + + return OpenAI(api_key=api_key, base_url=base_url), model + + +def parse_query(user_input: str, history: list = None) -> dict: + """ + 用千问 API 解析用户提问 + + Args: + user_input: 用户输入的文本 + history: 对话历史 [{"role": "user/assistant", "content": "..."}] + + Returns: + { + "query_type": "technique" | "tactic" | "group" | "software" | "mitigation" | "campaign" | "general", + "query_focus": "detail" | "uses" | "mitigates" | "belongs_to" | "attributed_to" | "tactics_of" | "general", + "entities": ["T1059", "PowerShell", ...], + "search_keywords": ["command", "scripting", ...], + "need_clarification": True/False, + "clarification_question": "...", + "confidence": 0.0~1.0, + "reasoning": "..." + } + """ + client, model = _get_client() + + messages = [{"role": "system", "content": SYSTEM_PROMPT}] + + # 添加对话历史 + if history: + messages.extend(history[-6:]) # 最近 3 轮对话 + + messages.append({"role": "user", "content": user_input}) + + try: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0.1, + max_tokens=500, + ) + + content = response.choices[0].message.content.strip() + + # 提取 JSON(处理可能的 markdown 代码块) + if content.startswith("```"): + content = content.split("```")[1] + if content.startswith("json"): + content = content[4:] + content = content.strip() + + result = json.loads(content) + + # 验证必要字段 + required_fields = ["query_type", "entities", "search_keywords", "need_clarification", "confidence"] + for field in required_fields: + if field not in result: + raise ValueError(f"缺少字段: {field}") + + # query_focus 默认值 + if "query_focus" not in result: + result["query_focus"] = "general" + + # search_keywords_variants 默认值 + if "search_keywords_variants" not in result: + result["search_keywords_variants"] = [] + + # 确保 confidence 在合理范围 + result["confidence"] = max(0.0, min(1.0, float(result.get("confidence", 0.5)))) + + return result + + except json.JSONDecodeError as e: + return { + "query_type": "general", + "query_focus": "general", + "entities": [], + "search_keywords": user_input.split(), + "search_keywords_variants": [], + "need_clarification": False, + "clarification_question": "", + "confidence": 0.3, + "reasoning": f"LLM 输出解析失败: {e}", + "raw_output": content if 'content' in dir() else "", + } + except Exception as e: + return { + "query_type": "general", + "query_focus": "general", + "entities": [], + "search_keywords": user_input.split(), + "search_keywords_variants": [], + "need_clarification": False, + "clarification_question": "", + "confidence": 0.3, + "reasoning": f"LLM 调用失败: {e}", + } + + +if __name__ == "__main__": + # 测试意图识别 + test_queries = [ + "T1059 是什么技术?", + "初始访问阶段有哪些攻击技术?", + "PowerShell 相关的攻击怎么防御?", + "帮我查一下", + "横向移动 Linux 平台", + ] + + for query in test_queries: + print(f"\n问题: {query}") + result = parse_query(query) + print(f" 类型: {result['query_type']} | 焦点: {result.get('query_focus', 'N/A')}") + print(f" 实体: {result['entities']}") + print(f" 关键词: {result['search_keywords']}") + print(f" 置信度: {result['confidence']}") + print(f" 需追问: {result['need_clarification']}")