o
    j9:j~                     @  s  U d dl mZ d dlZd dlZd dlmZmZ d dlmZ d dl	m
Z
mZmZmZmZmZ d dlZd dlmZ d dlmZ d dlmZ d d	lmZmZ erYd d
lmZ d dlmZ eeZG dd deZ e!e!e"df edB f Z#de$d< G dd de%Z&edddG dd dZ'edddG dd de'Z(e'e(B Z)de$d< eddG dd dZ*edddG dd dZ+edddG dd  d Z,dxd&d'Z-dyd+d,Z.G d-d. d.Z/G d/d0 d0eZ0d1d2dzd5d6Z1d7d8 Z2	9d{d|d@dAZ3	9d{d}dCdDZ4eddG dEdF dFZ5d~dHdIZ6eddJddPdQZ7eddUdQZ7d1dJddWdQZ7dddXdYZ8e	ddd\d]Z9e	ddd_d]Z9	1dddad]Z9d1d1dbddjdkZ:dd1dbddndoZ;ddvdwZ<dS )    )annotationsN)	dataclassfield)Enum)castLiteraloverloadProtocolTYPE_CHECKING	TypeAlias)fx)_MeshLayout)DTensortree_flattentree_unflatten)
DeviceMesh)	Placementc                   @  s   e Zd ZdZddd	Zd
S )GetMeshCallbackzGCallback to create/retrieve a DeviceMesh from its cache key components.mesh_dim_namestuple[str, ...]mesh_layout_MeshLayout | Nonereturnr   c                 C     d S N )selfr   r   r   r   j/home/nk/hobo-godmode/plappi-mvp/.venv/lib/python3.10/site-packages/torch/distributed/pipelining/_utils.py__call__      zGetMeshCallback.__call__N)r   r   r   r   r   r   )__name__
__module____qualname____doc__r   r   r   r   r   r      s    r   .r   MeshCacheKeyc                   @  s   e Zd ZdZdS )PipeliningMetadataErrorz<Raised on metadata mismatches during pipeline communication.N)r!   r"   r#   r$   r   r   r   r   r&   ,   s    r&   T)frozenslotsc                   @  sT   e Zd ZU dZded< ded< ded< ded	< edddZdddZdddZdS )_TensorMetazTensor metadata for recv buffer allocation and validation.

    For plain tensors, these are the tensor's actual attributes.
    For DTensors, these are LOCAL shard attributes; global attributes
    are stored in :class:`_DTensorMeta`.
    
torch.Sizeshapetuple[int, ...]strideztorch.dtypedtypeboolrequires_gradtensortorch.Tensorr   c                 C  s,   t | tr	tdt| j|  | j| jdS )a  Create metadata from a plain tensor.

        Args:
            tensor: A plain ``torch.Tensor`` (not DTensor).

        Returns:
            Metadata capturing shape, stride, dtype, and requires_grad.

        Raises:
            TypeError: If ``tensor`` is a DTensor.
        zJExpected plain tensor, got DTensor. Use _DTensorMeta.from_dtensor instead.r+   r-   r.   r0   )
isinstancer   r&   r)   r+   r-   r.   r0   r1   r   r   r   from_tensor>   s   
z_TensorMeta.from_tensordevicetorch.device | strc                 C  s   t | |}|| j |S )zReconstruct a tensor on ``device`` from this metadata.

        Args:
            device: Target device for the tensor.

        Returns:
            An empty strided tensor on ``device``.
        )_make_tensor_from_metarequires_grad_r0   )r   r7   tr   r   r   	to_tensorV   s   
	z_TensorMeta.to_tensorother	list[str]c                 C  s   | |krg S g }| j |j kr|d| j  d|j   | j|jkr.|d| j d|j  | j|jkrA|d| j d|j  |S )zReturn field-by-field differences with ``other``.

        Args:
            other: Metadata to compare against.

        Returns:
            List of human-readable difference strings (empty if equal).
        zshape mismatch:  vs zstride mismatch: zdtype mismatch: )r+   appendr-   r.   r   r=   diffsr   r   r   get_diffc   s   	z_TensorMeta.get_diffN)r1   r2   r   r)   )r7   r8   r   r2   r=   r)   r   r>   )	r!   r"   r#   r$   __annotations__staticmethodr6   r<   rC   r   r   r   r   r)   0   s   
 
r)   c                   @  s   e Zd ZU dZedd dZded< eddZd	ed
< eddZded< eddZ	ded< eddZ
ded< ed%ddZed&ddZd'ddZd(d#d$ZdS ))_DTensorMetaa  DTensor metadata extending :class:`_TensorMeta` with distribution info.

    Inherited fields (shape, stride, etc.) are LOCAL shard attributes.
    Additional fields capture global shape and placement information
    needed to reconstruct a :class:`DTensor` via ``DTensor.from_local()``.

    The :class:`DeviceMesh` is **not** stored (not serializable for P2P);
    it is looked up from :class:`_MeshCache` using
    ``(mesh_dim_names, mesh_layout)`` as the key.
    c                   C  s
   t g S r   )torchSizer   r   r   r   <lambda>   s   
 z_DTensorMeta.<lambda>)default_factoryr*   global_shaper   )defaultr,   global_strideztuple[Placement, ...]
placementsr   r   Nr   r   dtensorr   r   c                 C  sJ   | j }t| jj| j | j| j| j|  | jj|j	rt
|j	nd|jd	S )zCreate metadata from a DTensor.

        Args:
            dtensor: The DTensor to extract metadata from.

        Returns:
            Metadata capturing both local and global attributes.
        r   )	r+   r-   r.   r0   rL   rN   rO   r   r   )device_meshrG   _local_tensorr+   r-   r.   r0   _specrO   r   tuple_layout)rP   rQ   r   r   r   from_dtensor   s   
z_DTensorMeta.from_dtensorr%   c                 C  s   | j | jfS )z<Cache key ``(mesh_dim_names, mesh_layout)`` for mesh lookup.)r   r   r   r   r   r   mesh_cache_key   s   z_DTensorMeta.mesh_cache_keyr7   r8   meshr   c              
   C  s4   t | |}tttj||| j| j| jdd| jS )zReconstruct a DTensor on ``device`` with placements.

        Args:
            device: Target device for the local tensor.
            mesh: The ``DeviceMesh`` to attach.

        Returns:
            A DTensor on ``device``.
        F)rQ   rO   r+   r-   	run_check)	r9   r   r   
from_localrO   rL   rN   r:   r0   )r   r7   rY   local_tensorr   r   r   
to_dtensor   s   

z_DTensorMeta.to_dtensorr=   r)   r>   c                 C  s   | |krg S t | |}t|trr| j|jkr$|d| j d|j  | j|jkr7|d| j d|j  | j|jkrJ|d| j d|j  | j|jkr]|d| j d|j  | j	|j	krp|d| j	 d|j	  |S |d |S )zReturn field-by-field differences, including DTensor-specific fields.

        Args:
            other: Metadata to compare against.

        Returns:
            List of human-readable difference strings (empty if equal).
        zglobal_shape mismatch: r?   zglobal_stride mismatch: zplacements mismatch: zmesh_dim_names mismatch: zmesh_layout mismatch: z!type: _DTensorMeta vs _TensorMeta)
r)   rC   r4   rG   rL   r@   rN   rO   r   r   rA   r   r   r   rC      s6   	

z_DTensorMeta.get_diff)rP   r   r   rG   )r   r%   )r7   r8   rY   r   r   r   rD   )r!   r"   r#   r$   r   rL   rE   rN   rO   r   r   rF   rV   propertyrX   r]   rC   r   r   r   r   rG   {   s"   
 
rG   
TensorMeta)r(   c                   @  s`   e Zd ZU dZdZded< dZded< dZded< dZded< dddZ	dddZ
dddZdS )
_StageMetazPConsolidated tensor metadata for a pipeline stage's forward and backward passes.Ntuple[TensorMeta, ...] | Noneinputsoutputs$tuple[TensorMeta | None, ...] | Noneinput_gradsoutput_gradsr   r/   c                 C  s"   t dd | j| j| j| jfD S )z)Check if any metadata field is populated.c                 s  s    | ]}|d uV  qd S r   r   ).0vr   r   r   	<genexpr>  s
    
z%_StageMeta.has_any.<locals>.<genexpr>)anyrb   rc   re   rf   rW   r   r   r   has_any  s   z_StageMeta.has_anyc                 C  s2   | j | jfD ]}|rtdd |D r dS qdS )z3Check if any input/output metadata is DTensor type.c                 s  s    | ]
}|rt |tV  qd S r   )r4   rG   rg   mr   r   r   ri     s    z*_StageMeta.has_dtensors.<locals>.<genexpr>TF)rb   rc   rj   )r   metasr   r   r   has_dtensors  s
   z_StageMeta.has_dtensorsc                 C  s   | j duo	| jduS )z-Check if forward metadata is fully populated.N)rb   rc   rW   r   r   r   is_complete_for_forward  s   z"_StageMeta.is_complete_for_forward)r   r/   )r!   r"   r#   r$   rb   rE   rc   re   rf   rk   ro   rp   r   r   r   r   r`     s   
 

r`   c                   @     e Zd ZU dZded< dS )_StageForwardMetazLForward metadata transmitted from stage *i* to stage *i+1* during inference.tuple[TensorMeta, ...]forward_metasNr!   r"   r#   r$   rE   r   r   r   r   rr   !  s   
 rr   c                   @  rq   )_StageBackwardMetau   Backward metadata transmitted from stage *i* to stage *i-1* during inference.

    Gradient placements may differ from forward activations
    (e.g., ``Replicate`` → ``Partial``).
    tuple[TensorMeta | None, ...]backward_metasNru   r   r   r   r   rv   (  s   
 rv   metar7   r8   r   r2   c                 C  s   t j| j| j| j|dS )zCreate a tensor from metadata.

    Args:
        meta: Metadata with shape, stride, and dtype.
        device: Target device for the tensor.

    Returns:
        Empty tensor preserving the exact memory layout.
    )sizer-   r.   r7   )rH   empty_stridedr+   r-   r.   )ry   r7   r   r   r   r9   5  s   r9   tensor_metasrs   tuple[_TensorMeta | None, ...]c                 C  s   t dd | D S )zDerive gradient metadata from tensor metadata.

    Returns metadata with the same shape/stride/dtype but ``requires_grad=False``.
    Entries where the source has ``requires_grad=False`` become ``None``.
    c                 s  s0    | ]}|j rt|j|j|jd dndV  qdS )Fr3   N)r0   r)   r+   r-   r.   rl   r   r   r   ri   R  s    
z%_derive_grad_metas.<locals>.<genexpr>)rT   )r|   r   r   r   _derive_grad_metasJ  s   r~   c                   @  sN   e Zd ZdZddddZdddZdddZdddZd ddZd!ddZ	dS )"
_MeshCachezCache for :class:`DeviceMesh` objects keyed by ``(mesh_dim_names, mesh_layout)``.

    Assumes all pipeline stages share the same rank tensor (true for
    TorchTitan-style frameworks where meshes derive from a common world).
    Nget_mesh_cbGetMeshCallback | Noner   Nonec                 C  s   i | _ || _d S r   )_cache_get_mesh_cb)r   r   r   r   r   __init__a  s   
z_MeshCache.__init__keyr%   r   c                 C  st   || j v r
| j | S |\}}| jdu rtd| d| d| ||}|du r3td| d| d|| j |< |S )a  Return a cached mesh, or create one via the callback.

        Args:
            key: Cache key ``(mesh_dim_names, mesh_layout)``.

        Returns:
            The ``DeviceMesh``.

        Raises:
            PipeliningMetadataError: If not cached and no callback provided.
        Nz+Mesh not found in cache for mesh_dim_names=z, mesh_layout=z`, and no get_mesh callback provided. Provide a get_mesh callback or use DTensors in static mode.z>Mesh lookup failed: callback returned None for mesh_dim_names=z6. Ensure all stages use meshes from the same universe.)r   r   r&   )r   r   r   r   rY   r   r   r   get_meshe  s(   



z_MeshCache.get_meshrY   c                 C  s   || j |< dS )zAdd a mesh to the cache.Nr   )r   r   rY   r   r   r   put  s   z_MeshCache.puttensorstuple[torch.Tensor | None, ...]c                 C  sT   |D ]%}t |tr'|j}|jrt|jnd}|j}||f}|| jvr'|| j|< qdS )zJExtract and cache meshes from any :class:`DTensor` instances in *tensors*.r   N)r4   r   rQ   r   rT   rU   r   )r   r   r1   rY   	dim_namesr   r   r   r   r   update_from_tensors  s   


z_MeshCache.update_from_tensorsr/   c                 C  s
   || j v S r   r   )r   r   r   r   r   __contains__     
z_MeshCache.__contains__intc                 C  s
   t | jS r   )lenr   rW   r   r   r   __len__  r   z_MeshCache.__len__r   )r   r   r   r   )r   r%   r   r   )r   r%   rY   r   r   r   )r   r   r   r   )r   r%   r   r/   )r   r   )
r!   r"   r#   r$   r   r   r   r   r   r   r   r   r   r   r   Z  s    

"

r   c                   @  s&   e Zd ZdZdZdZedd	d
ZdS )InferenceModea  Pipeline-level metadata inference mode, determined collectively across all PP ranks.

    The mode is set by the schedule (not individual stages) because
    ``has_backward`` is only known at schedule creation time and all
    stages must agree to avoid P2P hangs.

    .. attribute:: STATIC

        All stages have sufficient metadata; runtime inference is skipped.

    .. attribute:: DYNAMIC

        At least one stage requires runtime metadata inference.
    staticdynamicry   r`   stage_has_backwardr/   r   c                 C  s<   |  sdS | sdS |sdS |jdu s|jdu rdS dS )a'  Determine whether dynamic metadata inference is needed for a stage.

        Args:
            meta: Stage metadata from user-provided args.
            stage_has_backward: Whether a backward pass will be performed.

        Returns:
            ``True`` if dynamic inference is needed.
        TFN)rp   ro   re   rf   )clsry   r   r   r   r   needs_dynamic  s   zInferenceMode.needs_dynamicN)ry   r`   r   r/   r   r/   )r!   r"   r#   r$   STATICDYNAMICclassmethodr   r   r   r   r   r     s    r   Fdetachr   r/   c                C  s4   t | \}}|rdd |D }t||}||fS |S )a;  Flatten ``args`` into a list, optionally detaching tensors.

    Args:
        args: Nested arguments to flatten.
        detach: If ``True``, detach tensors while preserving ``requires_grad``.

    Returns:
        ``(new_args, flat_detached_args)`` when ``detach=True``;
        ``flat_args`` list otherwise.
    c                 S  s,   g | ]}t |tjr| |jn|qS r   )r4   rH   Tensorr   r:   r0   )rg   ar   r   r   
<listcomp>  s    
z flatten_args.<locals>.<listcomp>r   )argsr   	flat_argstreespecflat_detachednew_argsr   r   r   flatten_args  s   
r   c                 C  s   t | ddS )zHFlatten and detach. Deprecated: use ``flatten_args(args, detach=True)``.Tr   )r   )r   r   r   r   flatten_args_detach  s   r   looppp_sizer   
num_stagesstylestrdict[int, int]c                 C  s   i }|dkrt |D ]}||  ||< q
|S |dkrS||  dkr*td| d|  dd}t |D ] }|||< |d |  dkr?q0||  d dkrL|d7 }q0|d8 }q0|S td	| d
)z
    Compute the stage id to rank mapping for either a looped or V-style schedule.

    Most commonly num_stages == pp_size * 2, but this function can be used to
    compute the mapping for any number of stages per rank.
    r   rh   r   znum_stages z% must be evenly divisible by pp_size z for V schedules      zStyle z is not supported.)range
ValueError)r   r   r   mappingstage_index
rank_indexr   r   r   generate_stage_to_rank_mapping  s(   	

r   dict[int, list[int]]c                 C  sZ   t | ||}i }| D ]\}}||vrg ||< || | q| D ]}|  q$|S )a  
    Compute the rank to stage id mapping for either a looped or V-style schedule.

    This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank.

    Returns a dictionary mapping rank -> list of stage indices assigned to that rank.
    )r   itemsr@   valuessort)r   r   r   stage_to_rankrank_to_stagesstage_idrankstagesr   r   r   generate_rank_to_stage_mapping  s   

r   c                   @  s*   e Zd ZU dZded< ded< ded< dS )	PipeInfoz>
    Captures information for a pipeline (`Pipe` object).
    zfx.Graphgraphr   r   r/   has_loss_and_backwardNru   r   r   r   r   r   6  s
   
 r   r1   c                 C  s   t | tr
t| S t| S )a  Extract metadata from a tensor.

    Handles both plain Tensor and DTensor correctly: DTensors are
    dispatched to ``_DTensorMeta.from_dtensor`` which captures local
    shard attributes plus global shape/placement info, while plain
    tensors use ``_TensorMeta.from_tensor``.

    Args:
        tensor: A plain tensor or DTensor.

    Returns:
        ``_TensorMeta`` for plain tensors, ``_DTensorMeta`` for DTensors.
    )r4   r   rG   rV   r)   r6   r5   r   r   r   extract_tensor_metaF  s   


r   )
allow_noner   tuple[torch.Tensor, ...] | Noner   Literal[False]ra   c                C  r   r   r   r   r   r   r   r   extract_tensor_metasZ     r   &tuple[torch.Tensor | None, ...] | NoneLiteral[True]rd   c                C  r   r   r   r   r   r   r   r   b  r   Atuple[torch.Tensor | None, ...] | tuple[torch.Tensor, ...] | Nonec                C  s`   | du rdS g }d}| D ]}t |tjr|t| qd}|d q|s,|r,tdt|S )a  Extract metadata from a tuple of tensors.

    Args:
        tensors: Tuple of tensors (may include ``None`` when ``allow_none=True``).
        allow_none: If ``True``, preserve ``None`` elements (for gradients).

    Returns:
        Tuple of ``TensorMeta``, or ``None`` if ``tensors`` is ``None``.

    Raises:
        PipeliningMetadataError: If ``None`` found and ``allow_none=False``.
    NFTz_None values are not allowed in tensor metadata tuples. Use allow_none=True for optional values.)r4   rH   r   r@   r   r&   rT   )r   r   metas_with_nonehas_noner;   r   r   r   r   j  s   c                 C  s&   |r|   n| }t|tr| S |S )u  Convert a DTensor to its local shard, or return a plain tensor as-is.

    When ``detach=True``, the tensor is detached before conversion —
    this applies to both DTensors and plain tensors.

    Args:
        tensor: A tensor that may be a DTensor.
        detach: If ``True``, detach before ``to_local()`` to avoid
            redistribution during backward.

    Returns:
        The local tensor component.
    )r   r4   r   to_local)r1   r   maybe_detached_tensorr   r   r   to_local_if_dtensor  s   
r   r   Ctorch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor] | Nonec                 C  r   r   r   r   r   r   r   r   validate_and_normalize_to_tuple  r    r   Qtorch.Tensor | tuple[torch.Tensor | None, ...] | list[torch.Tensor | None] | Nonec                 C  r   r   r   r   r   r   r   r     s   torch.Tensor | tuple[torch.Tensor, ...] | tuple[torch.Tensor | None, ...] | list[torch.Tensor] | list[torch.Tensor | None] | Nonec                 C  s   | du rdS t | tjr| fS t | ttfrMt| D ]'\}}|du r-|s,td| dqt |tjsAtd| dt|j dqt | trKt| S | S tdt| j d)a  Normalize ``args`` to a tuple and validate that all elements are tensors.

    Args:
        args: A single tensor, tuple/list of tensors, or ``None``.
        allow_none: If ``True``, permit ``None`` elements (for gradients).

    Returns:
        Tuple of tensors, or ``None`` if ``args`` is ``None``.

    Raises:
        PipeliningMetadataError: On non-tensor values
            (or ``None`` when ``allow_none=False``).
    Nz
Stage arg[zF] is None. Stage args must be tensors. Use kwargs for optional values.z] has type zC. All stage args must be tensors. Use kwargs for non-tensor inputs.z<Stage args must be a tensor, tuple, or list of tensors, got .)	r4   rH   r   rT   list	enumerater&   typer!   )r   r   iargr   r   r   r     s*   
raise_on_mismatchwarn_on_mismatchdescexpectedactualtorch.Tensor | TensorMetar   r   r>   c                C  s   t |tjrt|}n|}t|t|urEdt|j dt|j g}|r1t|  d|d  |rCtj|  d|d  dt	dd |S |
|}|rm|rZt|  dd	| |rmtj|  d
d	| dt	dd |S )al  
    Compare expected metadata against actual tensor or metadata.

    This is the unified validation/comparison function that uses get_diff() from
    metadata classes. Works with both plain tensors and DTensors.

    For plain tensors: compares shape/stride/dtype/requires_grad.
    For DTensors: compares all properties including global shape and placements.

    Args:
        desc: Description for error/warning messages.
        expected: Expected tensor metadata (_TensorMeta or _DTensorMeta).
        actual: Actual tensor or metadata to compare against.
        raise_on_mismatch: If True, raise PipeliningMetadataError on mismatch.
        warn_on_mismatch: If True, issue a warning on mismatch.

    Returns:
        List of differences (empty if metadata matches).

    Raises:
        PipeliningMetadataError: If raise_on_mismatch=True and differences exist.
    ztype: expected , got : r   z: Metadata type mismatch. z.. Using dynamically inferred metadata instead.r   
stacklevelz; z: Metadata mismatch. )r4   rH   r   r   r   r!   r&   warningswarnUserWarningrC   join)r   r   r   r   r   actual_meta	type_diffrB   r   r   r   validate_metadata  s4   

r   rw   ,tuple[torch.Tensor | TensorMeta | None, ...]c             	   C  s"  t |t |kr)|  dt | dt | }|rt||r&tj|tdd |gS g }tt||ddD ]Z\}\}}	|du rC|	du rCq4|du sK|	du rz|  d| d	|du rWd
nd d|	du r`d
nd }|rjt||rttj|tdd || q4t|  d| d||	||d}
|	|
 q4|S )a2  Validate metadata for a tuple of tensors element-wise.

    Args:
        desc: Description prefix for error/warning messages.
        expected: Tuple of expected metadata (may include ``None`` for grads).
        actual: Tuple of actual tensors or metadata to compare against.
        raise_on_mismatch: If ``True``, raise on the first mismatch.
        warn_on_mismatch: If ``True``, issue warnings for mismatches.

    Returns:
        Aggregated list of difference strings.

    Raises:
        PipeliningMetadataError: If lengths differ or on mismatch.
    z: expected z tensors, got r   r   TstrictN[z]: expected r   metadatar   ]r   )
r   r&   r   r   r   r   zipr@   r   extend)r   r   r   r   r   msg	all_diffsr   expactrB   r   r   r   validate_tensors_metadata-  s>   
r   r   tuple[torch.Tensor, ...]gradsr   is_inputr   c           
      C  sZ  |rdnd}| d}| d}t |t |kr0td|  d| dt | d| dt | d	tt||d
dD ]q\}\}}	|js`|	dur`td|  d| d| d| d| dt|	j d|jr|	du rtjd|  d| d| d| d| dt	dd t
|tr|jr|	durt
|	tstd|  d| d| d| d| dt|	j dq9dS )u/  
    Validate the args↔grads contract for static mode.

    Enforces four rules for each (arg, grad) pair:
      1. len(args) must equal len(grads).
      2. If arg.requires_grad is False, grad must be None.
      3. If arg.requires_grad is True and grad is None, emit a warning
         (this is legal at pipeline boundaries but may indicate a bug).
      4. If arg is a DTensor with requires_grad=True and grad is not None,
         grad must also be a DTensor.

    Args:
        stage_index: The stage index for error messages.
        args: Tuple of forward tensors.
        grads: Tuple of gradient tensors (can include None).
        is_input: True for input_args/input_grads, False for output_args/output_grads.

    Raises:
        PipeliningMetadataError: If any hard rule (1, 2, or 4) is violated.
    inputoutput_args_gradszStage r   z	 length (z) does not match zo). Each forward tensor must have a corresponding gradient entry (use None for tensors that don't require grad).Tr   Nr   z] has requires_grad=False, but z] is not None (zE). Non-differentiable tensors must have None as their gradient entry.z] has requires_grad=True, but zT] is None. This is legal at pipeline boundaries but may indicate a missing gradient.r   r   z,] is a DTensor with requires_grad=True, but z] is za, expected DTensor or None. DTensor gradients may have different placements than forward tensors.)r   r&   r   r   r0   r   r!   r   r   r   r4   r   )
r   r   r   r   kind	args_name
grads_namer   r   gradr   r   r   'validate_static_arg_grad_correspondencef  sd   


r  )ry   r)   r7   r8   r   r2   )r|   rs   r   r}   )r   r/   )r   )r   r   r   r   r   r   r   r   )r   r   r   r   r   r   r   r   )r1   r2   r   r_   )r   r   r   r   r   ra   )r   r   r   r   r   rd   )r   r   r   r/   r   rd   )F)r1   r2   r   r/   r   r2   ).)r   r   r   r   r   r   )r   r   r   r   r   r   )r   r   r   r/   r   r   )r   r   r   r_   r   r   r   r/   r   r/   r   r>   )r   r   r   rw   r   r   r   r/   r   r/   r   r>   )
r   r   r   r   r   r   r   r/   r   r   )=
__future__r   loggingr   dataclassesr   r   enumr   typingr   r   r   r	   r
   r   rH   r   torch.distributed._mesh_layoutr   torch.distributed.tensorr   torch.utils._pytreer   r   torch.distributed.device_meshr   (torch.distributed.tensor.placement_typesr   	getLoggerr!   loggerr   rT   r   r%   rE   RuntimeErrorr&   r)   rG   r_   r`   rr   rv   r9   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r   r   r   r   <module>   s    
 

J 



H8#

$:J9