Files
SecureCheck/securecheck/storage.py
2026-04-05 18:56:26 +02:00

75 lines
2.6 KiB
Python

from __future__ import annotations
import json
from pathlib import Path
from .models import Scenario
class ScenarioStore:
def __init__(self, scenario_file: Path, builtin_scenarios: list[Scenario]) -> None:
self._scenario_file = scenario_file
self._builtin = {scenario.name: scenario for scenario in builtin_scenarios}
def _load_user_scenarios(self) -> dict[str, Scenario]:
if not self._scenario_file.exists():
return {}
raw = json.loads(self._scenario_file.read_text(encoding="utf-8"))
scenarios: dict[str, Scenario] = {}
for item in raw.get("scenarios", []):
scenario = Scenario(
name=item["name"],
description=item.get("description", ""),
task_keys=list(item.get("task_keys", [])),
builtin=False,
)
scenarios[scenario.name] = scenario
return scenarios
def list_all(self) -> list[Scenario]:
merged = dict(self._builtin)
merged.update(self._load_user_scenarios())
return sorted(merged.values(), key=lambda scenario: (scenario.builtin is False, scenario.name.lower()))
def get(self, name: str) -> Scenario | None:
return {scenario.name: scenario for scenario in self.list_all()}.get(name)
def save(self, scenario: Scenario) -> None:
scenarios = self._load_user_scenarios()
scenarios[scenario.name] = Scenario(
name=scenario.name,
description=scenario.description,
task_keys=scenario.task_keys,
builtin=False,
)
payload = {
"scenarios": [
{
"name": item.name,
"description": item.description,
"task_keys": item.task_keys,
}
for item in sorted(scenarios.values(), key=lambda s: s.name.lower())
]
}
self._scenario_file.write_text(json.dumps(payload, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
def delete(self, name: str) -> bool:
scenarios = self._load_user_scenarios()
if name not in scenarios:
return False
del scenarios[name]
payload = {
"scenarios": [
{
"name": item.name,
"description": item.description,
"task_keys": item.task_keys,
}
for item in sorted(scenarios.values(), key=lambda s: s.name.lower())
]
}
self._scenario_file.write_text(json.dumps(payload, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
return True