Source code for codinglab.encoders.prefix_code_tree

"""
Prefix code tree data structure for the coding experiments library.

This module implements a prefix code tree (also known as a decoding tree
or code tree) used for efficient decoding of prefix codes. The tree
enables unique decoding of variable-length codes without separators
between codewords.
"""

# Module metadata
__author__ = "Mikhail Mikhailov"
__license__ = "MIT"
__version__ = "0.1.0"
__all__ = ["TreeNode", "PrefixCodeTree"]

from typing import Dict, Generic, Optional, List, Tuple
from graphviz import Digraph
from dataclasses import dataclass, field
from ..types import SourceChar, ChannelChar


[docs] @dataclass(kw_only=True) class TreeNode(Generic[ChannelChar, SourceChar]): """ Node in a prefix code tree. Each node in the tree represents either: - An internal node (value is None): has children for possible next symbols in a code sequence - A leaf node (value is not None): represents a complete code for a source symbol """ value: Optional[SourceChar] = None """Source symbol at this node, or None for internal nodes.""" children: Dict[ChannelChar, "TreeNode[ChannelChar, SourceChar]"] = field( default_factory=dict ) """Dictionary mapping channel symbols to child nodes."""
[docs] def leaf(self) -> bool: """ Check if this node is a leaf node. A leaf node has no children, meaning it represents a complete code for a source symbol. Returns: True if this node has no children (is a leaf), False otherwise (is an internal node) """ return len(self.children.items()) == 0
[docs] class PrefixCodeTree(Generic[ChannelChar, SourceChar]): """ Prefix code tree for efficient decoding. This tree data structure enables decoding of prefix codes by traversing from the root to leaves based on the input sequence. Each leaf corresponds to a source symbol, and the path from root to leaf defines the code for that symbol. """
[docs] def __init__( self, root: Optional[TreeNode[ChannelChar, SourceChar]] = None ) -> None: """ Initialize a prefix code tree. Args: root: Optional root node for the tree. If None, creates a new empty root node. """ self.root = TreeNode() if root is None else root """Root node of the prefix code tree."""
[docs] def insert_code(self, code: List[ChannelChar], symbol: SourceChar) -> None: """ Insert a code sequence into the tree. This method builds the tree by adding a path from the root corresponding to the code sequence, ending with a leaf node containing the source symbol. Args: code: Sequence of channel symbols representing the code symbol: Source symbol that this code represents Raises: ValueError: If the code conflicts with existing codes (violates prefix property) """ node = self.root # Traverse the tree following the code sequence for char in code: if char not in node.children.keys(): # Create new internal node if path doesn't exist node.children[char] = TreeNode() # Move to child node node = node.children[char] # Check for prefix violation: if we encounter a leaf # before finishing the code, this code is a prefix of # an existing code if node.value is not None: raise ValueError( f"Code prefix conflict: {code} is a prefix of existing code for {node.value}" ) # Check for prefix violation: if node has children, # an existing code is a prefix of this new code if node.children: raise ValueError( f"Code prefix conflict: existing code is a prefix of new code for {symbol}" ) # Set the leaf node value to the source symbol node.value = symbol
[docs] def decode( self, sequence: List[ChannelChar], position: int = 0 ) -> Tuple[Optional[SourceChar], int]: """ Decode a symbol from a sequence starting at the given position. This method traverses the tree from the root, following symbols from the sequence until it reaches a leaf node. It returns the decoded symbol and the new position in the sequence (after consuming the code). Args: sequence: List of channel symbols to decode position: Starting position in the sequence (default: 0) Returns: Tuple of (decoded_symbol, new_position) where: - decoded_symbol: Source symbol decoded, or None if decoding failed - new_position: Position in sequence after consuming the code Raises: ValueError: If the sequence is incomplete or contains symbols not in the tree """ node = self.root # Traverse the tree until we reach a leaf while not node.leaf(): if position >= len(sequence): raise ValueError( f"Cannot decode sequence {sequence} at position {position}: sequence incomplete" ) current_symbol = sequence[position] if current_symbol not in node.children: raise ValueError( f"Cannot decode sequence {sequence} at position {position}: symbol '{current_symbol}' not in tree" ) # Move to child node and advance position node = node.children[current_symbol] position += 1 # Return the leaf value and current position return node.value, position
[docs] def vizualize(self) -> Digraph: dot = Digraph() def add(n: TreeNode[ChannelChar, SourceChar], idx: str): dot.node(idx, str(n.value) if n.value else "") for char, child in n.children.items(): cidx = f"{idx}_{char}" dot.edge(idx, cidx, label=str(char)) add(child, cidx) add(self.root, "r") return dot