Source code for cedne.core.io

"""Contains I/O helpers for loading pickles and worms."""

__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"

import pickle
import random
import string
import json
# import py2neo

# Restricting unpickling
# ALLOWED_MODULES = [
#     "cedne",
#     "networkx"
#     "numpy",
#     "numpy.core",
#     "numpy.core.multiarray",
#     "numpy.core.numeric"
# ]

# ALLOWED_CLASSES = {
#     # Safe NumPy internals needed for unpickling arrays
#     ("numpy.core.multiarray", "_reconstruct"),
#     ("numpy.core.numeric", "_frombuffer"),
#     ("numpy", "dtype"),
#     ("numpy", "ndarray"),
#     ("numpy", "float64"),
#     ("numpy", "int64"),
#     ("builtins", "set"),
#     ("builtins", "frozenset"),
#     ("builtins", "slice"),
# }

ALLOWED_CLASSES = {
    # numpy internals
    ("numpy.core.multiarray", "_reconstruct"),
    ("numpy.core.numeric", "_frombuffer"),
    ("numpy", "dtype"),
    ("numpy", "ndarray"),
    ("numpy", "float64"),
    ("numpy", "int64"),
    # builtins
    ("builtins", "set"),
    ("builtins", "frozenset"),
    ("builtins", "slice"),
}

ALLOWED_MODULE_PREFIXES = [
    "cedne",
    "networkx"
]

class RestrictedUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        # Allow fully qualified safe classes
        if (module, name) in ALLOWED_CLASSES:
            return getattr(__import__(module, fromlist=[name]), name)
        
        # Allow all classes from whitelisted module prefixes (like cedne.*)
        if any(module == prefix or module.startswith(prefix + ".") for prefix in ALLOWED_MODULE_PREFIXES):
            return getattr(__import__(module, fromlist=[name]), name)
        
        # Otherwise, reject
        raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")
    
# class RestrictedUnpickler(pickle.Unpickler):
#     """
#     A custom unpickler that restricts the loading of certain modules and classes.
#     """
#     def find_class(self, module, name):
#         # Allow all functions and classes from the allowed modules and their submodules
#         if any(module == allowed_module or module.startswith(allowed_module + ".") for allowed_module in ALLOWED_MODULES):
#             return getattr(__import__(module, fromlist=[name]), name)
#         if module == "builtins" and name in ["set", "frozenset"]:
#             return getattr(__import__(module), name)
#         raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")

# class RestrictedUnpickler(pickle.Unpickler):
#     def find_class(self, module, name):
#         if (module, name) in ALLOWED_CLASSES:
#             return getattr(__import__(module, fromlist=[name]), name)
#         raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")

[docs] def load_pickle(file): ''' Loading restricted pickles.''' return RestrictedUnpickler(file).load()
[docs] def load_worm(file_path): """ Load a Worm object from a pickle file. Args: file_path (str): The path to the pickle file. Returns: Worm: The loaded Worm object. """ try: from .animal import Worm with open(file_path, 'rb') as pickle_file: # return pickle.load(pickle_file) w= load_pickle(pickle_file) if not isinstance(w, Worm): raise TypeError(f"Expected Worm object, got {type(w)}") return w except FileNotFoundError as exc: raise FileNotFoundError(f"File {file_path} not found.") from exc except pickle.UnpicklingError as exc: raise pickle.UnpicklingError(f"Failed to unpickle {file_path}.") from exc
[docs] def generate_random_string(length: int = 8) -> str: """ Generates a random string of given length. Args: length (int): The length of the string to generate. Returns: str: A random string of the specified length. """ characters = string.ascii_letters + string.digits return ''.join(random.choice(characters) for _ in range(length))
class NetworkWriter: ''' Writes the network for saving.''' def __init__(self): """ Initializes the NetworkWriter object. Placeholder for future functionality.""" NEO4J_URI = "bolt://localhost:7687" NEO4J_USER = "neo4j" NEO4J_PASS = "password" OUTPUT_JSON = "generated_metadata.json" # Connect to Neo4j database # graph_db = py2neo.Graph(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS)) def write(self, w): json_data = self.generate_json(w) output_filename = 'model/' + w.name + '.json' with open(output_filename, "w", encoding='utf-8') as f: json.dump(json_data, f, indent=4) def generate_json(self, w): json_data = { "model_name": w.name, "version": w.version, "created_by": w.author, "date_created": "2025-02-10", "networks": {}, "neurons": {}, "connections": [], "neo4j_query": "MATCH (n:Neuron)-[r:SYNAPSE]->(m:Neuron) WHERE n.type = 'Sensory' RETURN n, r, m" } # Store neurons and their properties for nn in w.networks: for n in nn.neurons: json_data["neurons"][n] = { #"data": } # Store edges (synapses) for u, v, data in nn.edges(data=True): json_data["connections"].append({ "source": u.name, "target": v.name, "weight": data.get("weight", 1.0), #"neurotransmitters": data.get("neurotransmitters", []) }) return json_data