Source code for algomancy_data.transformer
"""Transformation primitives for ETL pipelines.
Defines the abstract ``Transformer`` contract and a few simple concrete
transformers, as well as a ``TransformationSequence`` to compose multiple
transformers into a single pipeline step.
"""
from abc import ABC, abstractmethod
import pandas as pd
from typing import Dict, List, Optional, Sequence, Tuple, Type
from algomancy_utils import Logger
from copy import deepcopy
from .relations import Relation, merge_relations, resolve_relations_from_schemas
from .schema import Schema
from .validator import ValidationMessage, ValidationSeverity
[docs]
class Transformer(ABC):
"""Base class for a transformation step operating on tabular data.
Subclasses implement ``transform`` and can mutate the provided mapping of
DataFrames in-place or return a new mapping where applicable.
Attributes:
messages: ValidationMessages produced by this transformer during its
most recent ``transform`` invocation. The ETL pipeline collects
these from each transformer in the sequence and folds them into
the run's ``ValidationResult`` so they surface via
``ETLResult.messages``.
"""
def __init__(self, name: str = "Abstract Transformer", logger=None) -> None:
self.name = name
self._logger = logger
self.messages: List[ValidationMessage] = []
[docs]
@abstractmethod
def transform(self, data: dict[str, pd.DataFrame]) -> None:
"""Apply the transformation to the provided data.
Args:
data: Mapping from table name to pandas DataFrame. Implementations
may mutate this mapping in place or create/replace entries.
"""
pass
[docs]
def fill_empty(data: pd.DataFrame) -> pd.DataFrame:
"""Forward-fill missing values across columns in a single row.
Args:
data: DataFrame to fill.
Returns:
DataFrame with values forward-filled along axis=1.
"""
return data.ffill(axis=1)
[docs]
def drop_empty(data: pd.DataFrame) -> pd.DataFrame:
"""Drop rows containing any NA values.
Args:
data: Input DataFrame.
Returns:
DataFrame without rows containing NA values.
"""
return data.dropna()
[docs]
class NoopTransformer(Transformer):
"""Transformer that returns the input data unchanged."""
def __init__(self, logger=None) -> None:
super().__init__(name="No Operation Transformer", logger=logger)
[docs]
def transform(self, data: dict[str, pd.DataFrame]) -> dict[str, pd.DataFrame]:
if self._logger:
self._logger.log("No operation transformer called")
return data
[docs]
class CleanTransformer(Transformer):
"""Basic cleanup: drop NA rows and normalize column names to lowercase."""
def __init__(self, logger=None) -> None:
super().__init__(name="Standard Transformer", logger=logger)
[docs]
def transform(self, data: dict[str, pd.DataFrame]) -> None:
if self._logger:
self._logger.log("Cleaning dataframes (dropna, lowercase columns)")
for name, df in data.items():
df = df.dropna()
df.columns = [c.lower().strip() for c in df.columns]
[docs]
class JoinTransformer(Transformer):
"""Join two input tables and write the result to a new table key.
Attributes:
left: Name of the left table to join.
right: Name of the right table to join.
on: Column name to join on.
output: Key under which the merged table is stored.
"""
def __init__(
self, left: str, right: str, on: str, output: str, logger=None
) -> None:
super().__init__(name="Join transformer", logger=logger)
self.left = left
self.right = right
self.on = on
self.output = output
[docs]
def transform(self, data: dict[str, pd.DataFrame]) -> None:
if self._logger:
self._logger.log(
f"Joining '{self.left}' and '{self.right}' on '{self.on}' into '{self.output}'"
)
merged = data[self.left].merge(data[self.right], on=self.on)
data[self.output] = merged
[docs]
class CascadeDropTransformer(Transformer):
"""Drop rows whose declared foreign-key relations are unsatisfied.
Reads relations from supplied schemas (default source of truth) and
optionally merges ``extra_relations`` on top. Iterates to fixpoint, on
each pass applying:
1. **Orphan-child drop** (always on) — drop child rows whose FK tuple is
not in the parent's referenced column set.
2. **Required-child parent drop** — for relations with
``parent_requires_child=True``: drop parent rows whose PK doesn't
appear in any child's FK column.
Aggregated ``ValidationMessage``s are emitted with
:class:`ValidationSeverity.ERROR` — one per ``(table, rule, relation)``
with the dropped row count.
Args:
schemas: Schemas whose ``Column.foreign_key`` declarations supply the
default relation set.
extra_relations: Additional or override relations; override wins on
matching ``(child_table, child_cols)``.
snapshot: Optional :class:`CascadeSnapshot` paired transformer. Used
for partial-loss detection (see :class:`CascadeSnapshot`).
name: Override the transformer's display name.
logger: Optional logger.
"""
def __init__(
self,
schemas: Optional[Sequence[Type[Schema]]] = None,
extra_relations: Optional[Sequence[Relation]] = None,
snapshot: Optional["CascadeSnapshot"] = None,
name: str = "Cascade drop transformer",
logger=None,
) -> None:
super().__init__(name=name, logger=logger)
base: List[Relation] = (
list(resolve_relations_from_schemas(schemas)) if schemas else []
)
self.relations: List[Relation] = merge_relations(
base, list(extra_relations or [])
)
self.snapshot = snapshot
[docs]
def transform(self, data: dict[str, pd.DataFrame]) -> None:
self.messages = []
# Accumulators: {(table, code, fk_label): dropped_count}
drops: Dict[Tuple[str, str, str], int] = {}
while True:
any_drop = False
for relation in self.relations:
if (
relation.child_table not in data
or relation.parent_table not in data
):
continue
# --- Orphan-child drop ---
child_df = data[relation.child_table]
parent_df = data[relation.parent_table]
if (
not child_df.empty
and all(c in child_df.columns for c in relation.child_cols)
and all(p in parent_df.columns for p in relation.parent_cols)
):
parent_keys = set(_row_tuples(parent_df, relation.parent_cols))
child_keys = _row_tuples(child_df, relation.child_cols)
# Treat any row with NA in FK as "no reference" → keep it.
mask_has_value = (
child_df[list(relation.child_cols)].notna().all(axis=1)
)
mask_match = pd.Series(
[k in parent_keys for k in child_keys],
index=child_df.index,
)
mask_keep = (~mask_has_value) | mask_match
dropped = int((~mask_keep).sum())
if dropped > 0:
data[relation.child_table] = child_df[mask_keep].reset_index(
drop=True
)
key = (
relation.child_table,
"CASCADE_ORPHAN_DROP",
_relation_label(relation),
)
drops[key] = drops.get(key, 0) + dropped
any_drop = True
# --- Required-child parent drop ---
if relation.parent_requires_child:
child_df = data[relation.child_table]
parent_df = data[relation.parent_table]
if not parent_df.empty and all(
p in parent_df.columns for p in relation.parent_cols
):
referenced_keys: set = set()
if not child_df.empty and all(
c in child_df.columns for c in relation.child_cols
):
child_mask = (
child_df[list(relation.child_cols)].notna().all(axis=1)
)
referenced_keys = set(
_row_tuples(child_df[child_mask], relation.child_cols)
)
parent_keys = _row_tuples(parent_df, relation.parent_cols)
mask_keep = pd.Series(
[k in referenced_keys for k in parent_keys],
index=parent_df.index,
)
dropped = int((~mask_keep).sum())
if dropped > 0:
data[relation.parent_table] = parent_df[
mask_keep
].reset_index(drop=True)
key = (
relation.parent_table,
"CASCADE_REQUIRED_CHILD_DROP",
_relation_label(relation),
)
drops[key] = drops.get(key, 0) + dropped
any_drop = True
# --- Partial-loss parent drop (only when paired with snapshot) ---
if self.snapshot is not None:
partial_drops = self._apply_partial_loss(data)
if partial_drops:
any_drop = True
for key, n in partial_drops.items():
drops[key] = drops.get(key, 0) + n
if not any_drop:
break
# Emit aggregated messages
for (table, code, fk_label), dropped_count in drops.items():
self.messages.append(
ValidationMessage(
ValidationSeverity.ERROR,
f"{dropped_count} row(s) dropped from '{table}' "
f"by {code} on relation {fk_label}",
table=table,
code=code,
)
)
if self._logger:
self._logger.log(
f"[CascadeDrop] {dropped_count} row(s) dropped from "
f"'{table}' by {code} on {fk_label}"
)
def _apply_partial_loss(
self, data: dict[str, pd.DataFrame]
) -> Dict[Tuple[str, str, str], int]:
"""Drop parents whose child-count fell below the snapshot baseline.
Only relations with ``track_partial_loss=True`` are considered. A
parent is dropped when its current referenced-child count is lower
than the count captured in the snapshot but still > 0 (the 0 case is
already covered by required-child drop or by orphan downstream).
"""
assert self.snapshot is not None
partial: Dict[Tuple[str, str, str], int] = {}
for relation in self.relations:
if not relation.track_partial_loss:
continue
baseline = self.snapshot.counts_for(relation)
if baseline is None:
continue
if relation.parent_table not in data:
continue
parent_df = data[relation.parent_table]
if parent_df.empty or not all(
p in parent_df.columns for p in relation.parent_cols
):
continue
child_df = data.get(relation.child_table)
if child_df is None or child_df.empty:
current_counts: Dict[Tuple, int] = {}
elif not all(c in child_df.columns for c in relation.child_cols):
current_counts = {}
else:
mask = child_df[list(relation.child_cols)].notna().all(axis=1)
current_counts = (
child_df[mask].groupby(list(relation.child_cols)).size().to_dict()
)
# Normalize single-col dict keys to tuples
if relation.child_cols and len(relation.child_cols) == 1:
current_counts = {(k,): v for k, v in current_counts.items()}
parent_keys = _row_tuples(parent_df, relation.parent_cols)
mask_keep = []
for key in parent_keys:
base_n = baseline.get(key, 0)
cur_n = current_counts.get(key, 0)
if base_n > 0 and 0 < cur_n < base_n:
mask_keep.append(False)
else:
mask_keep.append(True)
mask_keep_s = pd.Series(mask_keep, index=parent_df.index)
dropped = int((~mask_keep_s).sum())
if dropped > 0:
data[relation.parent_table] = parent_df[mask_keep_s].reset_index(
drop=True
)
key = (
relation.parent_table,
"CASCADE_PARTIAL_LOSS_DROP",
_relation_label(relation),
)
partial[key] = partial.get(key, 0) + dropped
return partial
[docs]
class CascadeSnapshot(Transformer):
"""Captures referenced-child counts for partial-loss cascade detection.
A read-only transformer that, for every relation flagged
``track_partial_loss=True``, records the number of referencing children
per parent row. Paired with :class:`CascadeDropTransformer` (passed via
its ``snapshot=`` argument) to enable the partial-loss drop rule.
Place this transformer **before** any drop-capable transformer so it
captures the pre-cleanup baseline.
Args:
schemas: Schemas whose foreign-key declarations supply the relation
set. Only relations with ``track_partial_loss=True`` are tracked.
extra_relations: Additional or override relations.
logger: Optional logger.
"""
def __init__(
self,
schemas: Optional[Sequence[Type[Schema]]] = None,
extra_relations: Optional[Sequence[Relation]] = None,
logger=None,
) -> None:
super().__init__(name="Cascade snapshot transformer", logger=logger)
base: List[Relation] = (
list(resolve_relations_from_schemas(schemas)) if schemas else []
)
self.relations: List[Relation] = [
r
for r in merge_relations(base, list(extra_relations or []))
if r.track_partial_loss
]
self._counts: Dict[Tuple[str, Tuple[str, ...]], Dict[Tuple, int]] = {}
[docs]
def transform(self, data: dict[str, pd.DataFrame]) -> None:
self.messages = []
self._counts = {}
for relation in self.relations:
if relation.child_table not in data or relation.parent_table not in data:
continue
child_df = data[relation.child_table]
parent_df = data[relation.parent_table]
if child_df.empty or not all(
c in child_df.columns for c in relation.child_cols
):
continue
mask = child_df[list(relation.child_cols)].notna().all(axis=1)
counts = child_df[mask].groupby(list(relation.child_cols)).size().to_dict()
if relation.child_cols and len(relation.child_cols) == 1:
counts = {(k,): v for k, v in counts.items()}
# Initialise parents not present in child to 0
parent_keys = (
_row_tuples(parent_df, relation.parent_cols)
if all(p in parent_df.columns for p in relation.parent_cols)
else []
)
full = {k: 0 for k in parent_keys}
full.update(counts)
self._counts[relation.key] = full
[docs]
def counts_for(self, relation: Relation) -> Optional[Dict[Tuple, int]]:
"""Return captured ``{parent_key_tuple: count}`` for the relation, or None."""
return self._counts.get(relation.key)
def _row_tuples(df: pd.DataFrame, cols: Tuple[str, ...]) -> List[Tuple]:
"""Return a list of column-value tuples (one per row) for the given columns."""
if not cols:
return []
if len(cols) == 1:
return [(v,) for v in df[cols[0]].tolist()]
return [tuple(row) for row in df[list(cols)].itertuples(index=False, name=None)]
def _relation_label(relation: Relation) -> str:
"""Compact human-readable label for a relation, used in message text."""
child = f"{relation.child_table}.{','.join(relation.child_cols)}"
parent = f"{relation.parent_table}.{','.join(relation.parent_cols)}"
return f"{child}->{parent}"
[docs]
class OptionalColumnGuard(Transformer):
"""Materialise missing optional columns using each ``Column.default``.
Injects missing optional columns into the corresponding DataFrame in-place,
using ``Column.default`` and coercing to the declared dtype. Downstream
code can then assume the full schema is present.
Attributes:
_schemas: Schemas whose optional columns may be injected.
"""
def __init__(self, schemas: List, logger=None) -> None:
super().__init__(name="OptionalColumnGuard", logger=logger)
self._schemas = schemas
[docs]
def transform(self, data: dict[str, pd.DataFrame]) -> None:
from .validator import _schema_table_map
table_map = _schema_table_map(self._schemas)
for table_name, schema in table_map.items():
if table_name not in data:
continue
df = data[table_name]
cols = schema.columns()
for col_name, col in cols.items():
if not col.optional or col_name in df.columns:
continue
df[col_name] = col.default
try:
df[col_name] = df[col_name].astype(col.dtype)
except ValueError, TypeError:
# Default may not be coercible (e.g. None for non-nullable
# numerics); leave dtype as-is and let SchemaValidator flag.
pass
if self._logger:
self._logger.log(
f"Injected optional column '{col_name}' with default into {table_name}."
)
[docs]
class TransformationSequence:
"""A sequence of transformers executed in order."""
def __init__(
self, transformers: List[Transformer] = None, logger: Logger = None
) -> None:
self._logger = logger
self._transformers = transformers or []
self._completed = False
[docs]
def add_transformer(self, transformer: Transformer) -> None:
"""Append a single transformer to the sequence."""
self._transformers.append(transformer)
[docs]
def add_transformers(self, transformers: List[Transformer]) -> None:
"""Append multiple transformers to the sequence."""
for transformer in transformers:
self.add_transformer(transformer)
[docs]
def run_transformation(
self, data: dict[str, pd.DataFrame]
) -> dict[str, pd.DataFrame]:
"""Run all transformers sequentially on a deepcopy of ``data``.
Args:
data: Mapping of tables to DataFrames.
Returns:
dict[str, pd.DataFrame]: Transformed copy of the input mapping.
"""
transformed_data = deepcopy(data)
for transformer in self._transformers:
transformer.transform(transformed_data)
return transformed_data
[docs]
def collect_messages(self) -> List[ValidationMessage]:
"""Aggregate ``ValidationMessage``s produced by all transformers.
Returns messages produced during the most recent
:meth:`run_transformation` invocation, in transformer order.
"""
out: List[ValidationMessage] = []
for t in self._transformers:
out.extend(getattr(t, "messages", []) or [])
return out