module-attribute  ¶
 MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
    [
        (
            "QKVParallelLinear",
            partition_qkv_parallel_linear,
        ),
        (
            "ColumnParallelLinear",
            partition_column_parallel_linear,
        ),
        (
            "RowParallelLinear",
            partition_row_parallel_linear,
        ),
    ]
)
 
  Bases: Module
Source code in vllm/distributed/tpu_distributed_utils.py
  
  Source code in vllm/distributed/tpu_distributed_utils.py
  
 _load_weights_from_qkv_linear(qkv_linear: Module)
Source code in vllm/distributed/tpu_distributed_utils.py
  
  Source code in vllm/distributed/tpu_distributed_utils.py
  
  Source code in vllm/distributed/tpu_distributed_utils.py
  
    
  Source code in vllm/distributed/tpu_distributed_utils.py
   
  Source code in vllm/distributed/tpu_distributed_utils.py
   
  Source code in vllm/distributed/tpu_distributed_utils.py
   
 shard_model(model: Module, mesh: Mesh) -> None
Recursively check a PyTorch model and apply appropriate sharding based on the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Parameters:
| Name | Type | Description | Default | 
|---|---|---|---|
| model | Module | torch.nn.Module to process | required | 
| mesh | Mesh | An XLA SPMD mesh object used for sharding | required |