import os
import shutil
import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeVar
import pandas as pd
from algomancy_utils import Logger
from .datasource import DataClassification, BASEDATASOURCE
from .etl import ETLFactory, ETLConstructionError, ETLResult
from .schema import Schema, FileExtension
from .validator import ValidationSequence
from .file import File, CSVFile, JSONFile, XLSXFile
E = TypeVar("E", bound=ETLFactory)
# Filenames the SessionManager / ScenarioManager write into a session folder
# that are not DataSource payloads. The data scanner must skip these or
# ``from_json`` will fail on every restart.
_RESERVED_SESSION_FILENAMES = frozenset({"meta.json", "scenarios.json"})
[docs]
class DataManager(ABC):
"""
Handles all data-related operations: loading, deriving, deleting, and storing datasets.
"""
def __init__(
self,
etl_factory: type[E],
schemas: List[Schema],
save_type: str,
data_object_type: type[BASEDATASOURCE],
logger: Logger | None = None,
) -> None:
self.logger = logger
self._etl_factory = etl_factory(schemas, self.logger)
self._schemas = schemas
self._data: Dict[str, BASEDATASOURCE] = {}
self._save_type = save_type
self._data_object_type: type[BASEDATASOURCE] = data_object_type
@property
def data_object_type(self):
return self._data_object_type
[docs]
@abstractmethod
def startup(self):
raise NotImplementedError
# Utility
[docs]
def log(self, message: str):
if self.logger:
self.logger.log(message)
# Accessors
[docs]
def get_data_keys(self) -> List[str]:
return list(self._data.keys())
[docs]
def get_data(self, data_key: str) -> BASEDATASOURCE | None:
return self._data.get(data_key)
[docs]
def set_data(self, data_key: str, data: BASEDATASOURCE) -> None:
self._data[data_key] = data
# Derive/Delete
[docs]
def derive_data(self, existing_key: str, derived_key: str) -> None:
assert existing_key in self.get_data_keys(), f"Data '{existing_key}' not found."
assert derived_key not in self.get_data_keys(), (
f"Data '{derived_key}' already exists."
)
self._data[derived_key] = self.get_data(existing_key).derive(derived_key)
self.log(f"Derived data '{derived_key}' derived from '{existing_key}'.")
[docs]
def add_data_source(self, data_source: BASEDATASOURCE) -> None:
# Add to the data dictionary
self._data[str(data_source.name)] = data_source
self.log(f"Loaded DataSource '{data_source.name}' from {self._save_type} file.")
[docs]
@abstractmethod
def delete_data(
self, data_key: str, prevent_masterdata_removal: bool = False
) -> None:
raise NotImplementedError
[docs]
@staticmethod
def check_existence_of_files(file_name_to_path: List[Tuple[str, str]]) -> None:
for file, path in file_name_to_path:
if not os.path.exists(path):
raise ETLConstructionError(f"File at path '{path}' does not exist.")
[docs]
def prepare_files(
self,
file_items_with_content: List[Tuple[str, str, str]] = None,
file_items_with_path: List[Tuple[str, str]] = None,
) -> Dict[str, File]:
if file_items_with_content:
return self._prepare_files_from_content(file_items_with_content)
elif file_items_with_path:
return self._prepare_files_from_path(file_items_with_path)
else:
raise ETLConstructionError("No file data provided.")
@staticmethod
def _add_to_files(files, name, extension, content=None, path=None) -> None:
assert path or content, "Either path or content must be provided."
if extension == FileExtension.CSV.lower():
files[name] = CSVFile(name=name, content=content, path=path)
elif extension == FileExtension.JSON.lower():
files[name] = JSONFile(name=name, content=content, path=path)
elif extension == FileExtension.XLSX.lower():
files[name] = XLSXFile(name=name, content=content, path=path)
else:
raise ETLConstructionError(f"Unsupported file type: '{extension}'.")
def _schema_extension(self, name: str) -> Optional[str]:
"""Return the schema-declared extension for ``name`` (lower-cased).
Falls back to ``None`` if no schema matches; callers handle that case.
"""
for schema in self._schemas:
if schema.file_name() == name:
return str(schema.extension()).lower()
return None
def _prepare_files_from_content(
self,
file_items: List[Tuple[str, str, str]] = None,
) -> Dict[str, File]:
files: Dict[str, File] = {}
for name, extension, content in file_items:
# Prefer the schema-declared extension when one is available, so
# callers do not have to derive the extension from the filename.
ext = self._schema_extension(name) or extension
self._add_to_files(files, name, ext, content=content)
return files
def _prepare_files_from_path(
self, file_items: List[Tuple[str, str]]
) -> Dict[str, File]:
self.check_existence_of_files(file_items)
files: Dict[str, File] = {}
for name, path in file_items:
# Dispatch by schema-declared extension when available; only fall
# back to the path suffix when no schema matches the logical name.
ext = self._schema_extension(name)
if ext is None:
ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
if not ext:
raise ETLConstructionError(
f"Cannot determine extension for '{name}'. "
"Declare _EXTENSION on its Schema or rename the file."
)
self._add_to_files(files, name, ext, path=path)
return files
[docs]
def etl_data(self, files: Dict[str, File], dataset_name: str) -> ETLResult:
"""Run the ETL pipeline for ``dataset_name`` and store the result.
Args:
files: Mapping of logical file names to ``File`` objects.
dataset_name: Logical name for the resulting dataset.
Returns:
ETLResult: structured outcome. Inspect ``result.status`` to tell
success from failure and ``result.validation_result.messages``
for details.
Raises:
ETLConstructionError: If pipeline construction fails.
Exception: Programmer errors from user-supplied components are
allowed to propagate unchanged.
"""
etl = self._etl_factory.build_pipeline(dataset_name, files, self.logger)
self.log(f"ETL pipeline for dataset '{dataset_name}' created.")
result = etl.run()
if result.is_success:
self._data[dataset_name] = result.datasource
if self.logger:
self.logger.success(
f"ETL pipeline for dataset '{dataset_name}' completed."
)
else:
if self.logger:
self.logger.error(
f"ETL pipeline for dataset '{dataset_name}' failed: "
f"{result.validation_result.counts_by_severity if result.validation_result else 'unknown'}"
)
return result
[docs]
def create_validation_sequence(self) -> ValidationSequence:
return self._etl_factory.create_validation_sequence()
[docs]
class StatelessDataManager(DataManager):
def __init__(
self,
etl_factory: type[ETLFactory],
schemas: List[Schema],
save_type: str,
data_object_type: type[BASEDATASOURCE],
logger: Logger | None = None,
):
super().__init__(etl_factory, schemas, save_type, data_object_type, logger)
self._data: Dict[str, BASEDATASOURCE] = {}
[docs]
def startup(self):
# Stateless data manager does not need to perform any additional actions on startup
pass
[docs]
def delete_data(
self, data_key: str, prevent_masterdata_removal: bool = False
) -> None:
assert data_key in self.get_data_keys(), f"Data '{data_key}' not found."
# note: responsibility for checking scenario usage resides in callers
del self._data[data_key]
self.log(f"Data '{data_key}' deleted.")
[docs]
class StatefulDataManager(DataManager):
def __init__(
self,
etl_factory: type[ETLFactory],
schemas: List[Schema],
data_folder: str,
save_type: str,
data_object_type: type[BASEDATASOURCE],
logger: Logger | None = None,
):
warnings.warn(
"StatefulDataManager is deprecated and will be removed in a future release. "
"Use DatabaseDataManager (algomancy-data[database]) for persistent storage, "
"or StatelessDataManager for in-memory-only usage.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(etl_factory, schemas, save_type, data_object_type, logger)
self._data_folder = data_folder
self._data: Dict[str, BASEDATASOURCE] = {} # Loading
self.startup_errors: List[Tuple[str, Exception]] = []
[docs]
def startup(self) -> None:
"""Load persisted data sources from the data folder.
Each item is loaded independently; if a single file/directory fails
to load it is logged and skipped, and any partial in-memory state
for that item is rolled back so the manager remains consistent.
Other items continue to load. Failures are surfaced through the
configured logger; ``self.startup_errors`` collects them so callers
can inspect what happened.
"""
self.startup_errors: List[Tuple[str, Exception]] = []
try:
self._load_data_from_data_folder()
self.log(f"Data folder '{self._data_folder}' loaded.")
except Exception as exc:
# Hitting this branch indicates a defect in the loader itself;
# individual file failures are handled inside the loop.
self.startup_errors.append((self._data_folder, exc))
if self.logger:
self.logger.error(f"Data load on startup failed: {exc}")
self.logger.log_traceback(exc)
raise
[docs]
def load_data_from_file(self, file_name: str, root: str | None = None) -> None:
if root is None:
root = self._data_folder
# Retrieve files from directory
file_path = os.path.join(root, file_name)
if self._save_type == "json":
# Read the file content as text
with open(file_path, "r", encoding="utf-8") as f:
json_string = f.read()
data_source = self.data_object_type.from_json(json_string)
else:
raise Exception(f"Unsupported save type: {self._save_type}")
self.add_data_source(data_source)
def _load_data_from_data_folder(self) -> None:
"""
Loads all parquet files from the data folder and creates DataSource objects.
Each parquet file is expected to be a serialized DataSource.
"""
import os
# Check if the folder exists
if not os.path.exists(self._data_folder):
self.logger.warning(f"Data folder '{self._data_folder}' does not exist.")
return
# List all files in the data folder
items = os.listdir(self._data_folder)
for item in items:
item_path = os.path.join(self._data_folder, item)
# If it's a file, try to load it as a file of the appropriate type
if os.path.isfile(item_path):
# Framework-owned metadata (session id, scenario state)
# lives alongside DataSource files; skip it silently so a
# restart doesn't error out on every meta.json.
if item in _RESERVED_SESSION_FILENAMES:
continue
# Verify that item is of the appropriate data format
if not item.endswith(f".{self._save_type}"):
if self.logger:
self.logger.warning(
f"Skipping file '{item_path}' because it is not a {self._save_type} file."
)
continue
pre_keys = set(self._data.keys())
try:
self.load_data_from_file(item)
except Exception as exc:
# Roll back any partial keys that this load added so the
# manager is never left in an undefined state.
for added in set(self._data.keys()) - pre_keys:
del self._data[added]
if self.logger:
self.logger.error(
f"Failed to load file '{item_path}' as a DataSource"
)
self.logger.log_traceback(exc)
self.startup_errors.append((item_path, exc))
# If it's a directory, run ETL
elif os.path.isdir(item_path):
pre_keys = set(self._data.keys())
try:
self.load_data_from_dir(item)
except Exception as exc:
for added in set(self._data.keys()) - pre_keys:
del self._data[added]
if self.logger:
self.logger.error(
f"Failed to load directory '{item_path}' as a DataSource: {exc}"
)
if not isinstance(exc, ETLConstructionError):
self.logger.log_traceback(exc)
self.startup_errors.append((item_path, exc))
[docs]
def load_data_from_dir(self, directory: str, root: str | None = None) -> None:
if root is None:
root = self._data_folder
# Retrieve files from directory
dataset_name = directory
dataset_path = os.path.join(root, directory)
files = os.listdir(dataset_path)
# Compile the file-items
file_items_with_path = [
(file.split(".")[0], os.path.join(dataset_path, file)) for file in files
]
# Run ETL
prepared_files = self.prepare_files(file_items_with_path=file_items_with_path)
result = self.etl_data(prepared_files, dataset_name)
if result.is_failure:
# Surface failure to the outer startup loop so it can roll back.
raise ETLConstructionError(
f"ETL for dataset '{dataset_name}' failed: "
f"{result.validation_result.counts_by_severity if result.validation_result else 'unknown'}"
)
[docs]
def delete_data(
self, data_key: str, prevent_masterdata_removal: bool = False
) -> None:
assert data_key in self.get_data_keys(), f"Data '{data_key}' not found."
# note: responsibility for checking scenario usage resides in callers
# Delete files if applicable
if self._data[data_key].is_master_data():
directory = os.path.join(self._data_folder, data_key)
if os.path.isdir(directory):
shutil.rmtree(directory)
elif os.path.isfile(directory):
os.remove(directory)
del self._data[data_key]
self.log(f"Data '{data_key}' deleted.")
# Store new dataset to data folder (as CSVs) and keep in memory
[docs]
def store_data(
self, dataset_name: str, data: Dict[str, pd.DataFrame], USE_OLD_VERSION=True
) -> None:
if not USE_OLD_VERSION:
raise NotImplementedError
else:
import os as _os
target_dir = _os.path.join(self._data_folder, dataset_name)
if _os.path.exists(target_dir):
raise Exception(
f"Directory '{dataset_name}' already exists in '{self._data_folder}'"
)
_os.makedirs(target_dir)
# Write each DataFrame to a CSV named after its key
for key, df in data.items():
file_path = _os.path.join(target_dir, f"{key}.csv")
df.to_csv(file_path, index=False, sep=";")
# Also keep in memory
ds = self.data_object_type(
name=dataset_name, ds_type=DataClassification.DERIVED_DATA
)
for key, df in data.items():
ds.add_table(key, df)
self._data[dataset_name] = ds
self.log(f"Stored dataset '{dataset_name}' to disk and memory.")
[docs]
def store_data_source_as_json(
self, dataset_name: str, allow_overwrite: bool = False
):
import os as _os
file_name = _os.path.join(self._data_folder, f"{dataset_name}.json")
if _os.path.exists(file_name) and not allow_overwrite:
raise Exception(
f"Directory '{dataset_name}' already exists in '{self._data_folder}'"
)
# Retrieve datasource
data_source = self.get_data(dataset_name)
# check existence
assert data_source is not None, f"Data source '{dataset_name}' not found."
# Serialize it to parquet bytes
json_content = data_source.to_json()
# store bytes in file
with open(file_name, "wb") as f:
f.write(json_content.encode("utf-8"))