"""Contains I/O helpers for loading pickles and worms."""
__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"
import pickle
import random
import string
import json
# import py2neo
# Restricting unpickling
# ALLOWED_MODULES = [
# "cedne",
# "networkx"
# "numpy",
# "numpy.core",
# "numpy.core.multiarray",
# "numpy.core.numeric"
# ]
# ALLOWED_CLASSES = {
# # Safe NumPy internals needed for unpickling arrays
# ("numpy.core.multiarray", "_reconstruct"),
# ("numpy.core.numeric", "_frombuffer"),
# ("numpy", "dtype"),
# ("numpy", "ndarray"),
# ("numpy", "float64"),
# ("numpy", "int64"),
# ("builtins", "set"),
# ("builtins", "frozenset"),
# ("builtins", "slice"),
# }
ALLOWED_CLASSES = {
# numpy internals
("numpy.core.multiarray", "_reconstruct"),
("numpy.core.numeric", "_frombuffer"),
("numpy", "dtype"),
("numpy", "ndarray"),
("numpy", "float64"),
("numpy", "int64"),
# builtins
("builtins", "set"),
("builtins", "frozenset"),
("builtins", "slice"),
}
ALLOWED_MODULE_PREFIXES = [
"cedne",
"networkx",
# CeDNe Worms commonly carry pandas-backed attribute tables (neurotransmitter,
# neuropeptide, transcriptome). Pandas's pickle protocol pulls in many internal
# classes (Series, DataFrame, BlockManager, several Index types) and the set
# shifts across pandas minor versions, so an explicit class allowlist would be
# brittle. We keep the wildcard for now but pair it with the DENY rules below
# to block the known expression-evaluation gadgets (pandas.eval / pandas.query)
# that would otherwise execute arbitrary Python during unpickling. A full
# explicit pandas type allowlist is still on the backlog.
"pandas",
]
# Denies override the allow-prefix above. Used to block callable gadgets that
# could trigger arbitrary code execution if reached via a crafted pickle.
DENIED_MODULE_PREFIXES = (
"pandas.core.computation", # pandas.eval lives here; whole subsystem evaluates expressions
"pandas.io.pickle", # to_pickle/read_pickle helpers — refuse chaining via unpickle
)
DENIED_QUALIFIED_NAMES = {
("pandas", "eval"), # top-level re-export of pandas.core.computation.eval.eval
("pandas", "query"), # similar expression-evaluation entry point
}
class RestrictedUnpickler(pickle.Unpickler):
def find_class(self, module, name):
# Deny known dangerous callables first, even if they fall under an allowed prefix.
if (module, name) in DENIED_QUALIFIED_NAMES:
raise pickle.UnpicklingError(
f"global '{module}.{name}' is denied (expression-evaluation gadget)"
)
if any(
module == prefix or module.startswith(prefix + ".")
for prefix in DENIED_MODULE_PREFIXES
):
raise pickle.UnpicklingError(
f"global '{module}.{name}' is denied (module under {DENIED_MODULE_PREFIXES})"
)
# Allow fully qualified safe classes
if (module, name) in ALLOWED_CLASSES:
return getattr(__import__(module, fromlist=[name]), name)
# Allow all classes from whitelisted module prefixes (like cedne.*)
if any(
module == prefix or module.startswith(prefix + ".")
for prefix in ALLOWED_MODULE_PREFIXES
):
return getattr(__import__(module, fromlist=[name]), name)
# Otherwise, reject
raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")
# class RestrictedUnpickler(pickle.Unpickler):
# """
# A custom unpickler that restricts the loading of certain modules and classes.
# """
# def find_class(self, module, name):
# # Allow all functions and classes from the allowed modules and their submodules
# if any(module == allowed_module or module.startswith(allowed_module + ".") for allowed_module in ALLOWED_MODULES):
# return getattr(__import__(module, fromlist=[name]), name)
# if module == "builtins" and name in ["set", "frozenset"]:
# return getattr(__import__(module), name)
# raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")
# class RestrictedUnpickler(pickle.Unpickler):
# def find_class(self, module, name):
# if (module, name) in ALLOWED_CLASSES:
# return getattr(__import__(module, fromlist=[name]), name)
# raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")
[docs]
def load_pickle(file):
"""Loading restricted pickles."""
return RestrictedUnpickler(file).load()
[docs]
def load_worm(file_path):
"""
Load a Worm object from a pickle file.
Args:
file_path (str): The path to the pickle file.
Returns:
Worm: The loaded Worm object.
"""
try:
from .animal import Worm
with open(file_path, "rb") as pickle_file:
# return pickle.load(pickle_file)
w = load_pickle(pickle_file)
if not isinstance(w, Worm):
raise TypeError(f"Expected Worm object, got {type(w)}")
return w
except FileNotFoundError as exc:
raise FileNotFoundError(f"File {file_path} not found.") from exc
except pickle.UnpicklingError as exc:
raise pickle.UnpicklingError(f"Failed to unpickle {file_path}.") from exc
[docs]
def generate_random_string(length: int = 8) -> str:
"""
Generates a random string of given length.
Args:
length (int): The length of the string to generate.
Returns:
str: A random string of the specified length.
"""
characters = string.ascii_letters + string.digits
return "".join(random.choice(characters) for _ in range(length))
class NetworkWriter:
"""Writes the network for saving."""
def __init__(self):
"""
Initializes the NetworkWriter object.
Placeholder for future functionality."""
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASS = "password"
OUTPUT_JSON = "generated_metadata.json"
# Connect to Neo4j database
# graph_db = py2neo.Graph(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
def write(self, w):
json_data = self.generate_json(w)
output_filename = "model/" + w.name + ".json"
with open(output_filename, "w", encoding="utf-8") as f:
json.dump(json_data, f, indent=4)
def generate_json(self, w):
"""Generate a JSON-compatible dictionary from a Worm object.
Delegates per-network serialization to ``NervousSystem.to_dict()``
so that this writer stays in sync with the canonical serialization.
Args:
w: A Worm object with a ``networks`` dict of NervousSystem instances.
Returns:
dict: JSON-serializable representation of the worm and its networks.
"""
json_data = {
"model_name": w.name,
"version": getattr(w, "version", None),
"created_by": getattr(w, "author", None),
"networks": {},
}
for net_name, network in w.networks.items():
if hasattr(network, "to_dict"):
json_data["networks"][net_name] = network.to_dict()
return json_data