diff --git a/mvt/common/encoding.py b/mvt/common/encoding.py new file mode 100644 index 0000000..c96d56f --- /dev/null +++ b/mvt/common/encoding.py @@ -0,0 +1,27 @@ +import json + + +class CustomJSONEncoder(json.JSONEncoder): + """ + Custom JSON encoder to handle non-standard types. + + Some modules are storing non-UTF-8 bytes in their results dictionaries. + This causes exceptions when the results are being encoded as JSON. + + Of course this means that when MVT is run via `check-iocs` with existing + results, the encoded version will be loaded back into the dictionary. + Modules should ensure they encode anything that needs to be compared + against an indicator in a JSON-friendly type. + """ + + def default(self, o): + if isinstance(o, bytes): + # If it's utf-8, try and use that first + try: + return o.decode("utf-8") + except UnicodeError: + # Otherwise use a hex representation for any byte type + return "0x" + o.hex() + + # For all other types try to use the string representation. + return str(o) diff --git a/mvt/common/module.py b/mvt/common/module.py index b7a63fc..55d54f3 100644 --- a/mvt/common/module.py +++ b/mvt/common/module.py @@ -4,13 +4,13 @@ # https://license.mvt.re/1.1/ import csv +import json import logging import os import re from typing import Any, Dict, List, Optional, Union -import simplejson as json - +from .encoding import CustomJSONEncoder from .utils import exec_or_profile @@ -103,7 +103,7 @@ class MVTModule: results_json_path = os.path.join(self.results_path, results_file_name) with open(results_json_path, "w", encoding="utf-8") as handle: try: - json.dump(self.results, handle, indent=4, default=str) + json.dump(self.results, handle, indent=4, cls=CustomJSONEncoder) except Exception as exc: self.log.error( "Unable to store results of module %s to file %s: %s", @@ -116,7 +116,7 @@ class MVTModule: detected_file_name = f"{name}_detected.json" detected_json_path = os.path.join(self.results_path, detected_file_name) with open(detected_json_path, "w", encoding="utf-8") as handle: - json.dump(self.detected, handle, indent=4, default=str) + json.dump(self.detected, handle, indent=4, cls=CustomJSONEncoder) def serialize(self, record: dict) -> Union[dict, list, None]: raise NotImplementedError diff --git a/tests/common/test_encoding.py b/tests/common/test_encoding.py new file mode 100644 index 0000000..74c9274 --- /dev/null +++ b/tests/common/test_encoding.py @@ -0,0 +1,30 @@ +import json +from datetime import datetime + +from mvt.common.encoding import CustomJSONEncoder + + +class TestCustomJSONEncoder: + def test__normal_input(self): + assert json.dumps({"a": "b"}, cls=CustomJSONEncoder) == '{"a": "b"}' + + def test__datetime_object(self): + assert ( + json.dumps( + {"timestamp": datetime(2023, 11, 13, 12, 21, 49, 727467)}, + cls=CustomJSONEncoder, + ) + == '{"timestamp": "2023-11-13 12:21:49.727467"}' + ) + + def test__bytes_non_utf_8(self): + assert ( + json.dumps({"identifier": b"\xa8\xa9"}, cls=CustomJSONEncoder) + == '{"identifier": "0xa8a9"}' + ) + + def test__bytes_valid_utf_8(self): + assert ( + json.dumps({"name": "家".encode()}, cls=CustomJSONEncoder) + == '{"name": "\\u5bb6"}' + )