def load_weights_and_online_quantize(
    model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
) -> set[str]:
    # online quantization, right now only enabled for
    # torchao
    # R1, R2, R3, R4 in the Notes
    # TODO: Add fp8 support
    assert model_config.quantization == "torchao", (
        "online quantization is only enabled for torchao currently"
    )
    # TODO: use create_weights to restore the weights to original state
    # Step R1: First restore the quantized weights to original bfloat16
    # weights, with original metadata (shape, dtype, device)
    # and attributes, so that bfloat16 weights can be loaded properly
    existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
    named_modules = dict(model.named_modules(remove_duplicate=False))
    model_device = None
    # Step R2: recover the parameter to the state before first loading
    for name, d in model.original_weights_rebuild_keys.items():
        _shape = d["shape"]
        _dtype = d["dtype"]
        _device = d["device"]
        if model_device is not None:
            assert model_device == _device, (
                "Expecting all weights "
                "to be in the same device for now, got both: "
                f"{model_device} and {_device}"
            )
        else:
            model_device = _device
        if name in existing_param_names:
            module_name, weight_name = name.rsplit(".", 1)
            module = named_modules[module_name]
            setattr(
                module,
                weight_name,
                torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
            )
    # recorded_weight_attr is
    # {"weight_name": {"weight_attr_key": attr}}
    # e.g.
    # {
    #   {
    #     "layer.0.weight": {
    #       "weight_loader": weight_loader_function_object,
    #       "input_dim": 0, ...
    #     },
    #     "layer.1.weight": ...,
    #    }
    # }
    for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
        for attr_name, attr in weight_attr_dict.items():
            module_name, weight_name = full_weight_name.rsplit(".", 1)
            module = named_modules[module_name]
            weight = getattr(module, weight_name)
            if not hasattr(weight, attr_name):
                setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
    # Step I1: reload bfloat16 / high precision weights
    loaded_weights = model.load_weights(
        model_loader.get_all_weights(model_config, model)
    )
    # Step I2: online quantize the weights
    # manually process weights after loading
    model.process_weights_after_loading_already_called = False
    process_weights_after_loading(model, model_config, model_device)
    model.process_weights_after_loading_already_called = True
    return loaded_weights