def rocm_aiter_rotary_emb(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    head_size: int,
    rotary_dim: int,
    is_neox_style: bool,
):
    num_tokens = positions.numel()
    cos, sin = cos_sin_cache.chunk(2, dim=-1)
    query_shape = query.shape
    key_shape = key.shape
    rotate_style = 0 if is_neox_style else 1
    query = query.view(num_tokens, -1, head_size)
    key = key.view(num_tokens, -1, head_size)
    query_ = query[..., :rotary_dim]
    key_ = key[..., :rotary_dim]
    positions = positions.view(*query.shape[:1])
    torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
        positions,
        sin,
        cos,
        query_,
        key_,
        rotate_style,
        False,
    )
    query = query.view(query_shape)
    key = key.view(key_shape)