class TablePrinter:
    def __init__(
        self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int]
    ):
        self.row_cls = row_cls
        self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
        self.column_widths = column_widths
        assert set(self.column_widths.keys()) == set(self.fieldnames)
    def print_table(self, rows: list[dataclasses.dataclass]):
        self._print_header()
        self._print_line()
        for row in rows:
            self._print_row(row)
    def _print_header(self):
        for i, f in enumerate(self.fieldnames):
            last = i == len(self.fieldnames) - 1
            col_width = self.column_widths[f]
            print(
                trim_string_back(f, col_width).ljust(col_width),
                end=" | " if not last else "\n",
            )
    def _print_row(self, row):
        assert isinstance(row, self.row_cls)
        for i, f in enumerate(self.fieldnames):
            last = i == len(self.fieldnames) - 1
            col_width = self.column_widths[f]
            val = getattr(row, f)
            val_str = ""
            if isinstance(val, str):
                val_str = trim_string_back(val, col_width).ljust(col_width)
            elif type(val) in [float, int]:
                val_str = f"{float(val):>.2f}".rjust(col_width)
            else:
                val_str = f"{val}".rjust(col_width)
            print(val_str, end=" | " if not last else "\n")
    def _print_line(self):
        total_col_width = 0
        for column_width in self.column_widths.values():
            total_col_width += column_width
        print("=" * (total_col_width + 3 * (len(self.column_widths) - 1)))