o
    j9:j\                  	   @   s  d dl Z d dlmZ d dlmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZ ddlmZ dd	lmZ dd
lmZmZ ddlmZmZmZmZmZ ddlmZ ddlmZ ddlm Z  ddl!m"Z" ddl#m$Z$ erxd dl%m&Z& ddl'm(Z( d dl)m*Z* ej+Z+dedede,fddZ-de.e dede,fddZ/dej0ddddfdd Z1dej0de,fd!d"Z2d#e,dej3fd$d%Z4d#e,dej5fd&d'Z6e*d(d)d*d+e,d,e,ddfd-d.Z7e7j8d+e,d,e,ddfd/d0Z9eej:j;j<j= e*d1d)d*d+e,d,e,ddfd2d3Z>e>j8d+e,d,e,ddfd4d0Z9eej:j;j?j= e*d5d)d*d6e,d7e,ddfd8d9Z@e@j8d6e,d7e,ddfd:d0Z9eej:j;j@j= e*d;d)d*d6e,d7e,ddfd<d=ZAeAj8d6e,d7e,ddfd>d0Z9eej:j;jAj= e*d?d)d*d6e,ddfd@dAZBeBj8d6e,ddfdBd0Z9eej:j;jBj= e*dCd)d*dDeCdEe,ddfdFdGZDeDj8dDeCdEe,ddfdHd0Z9eej:j;jDj= e*dId)d*d7e,ddfdJdKZEeEj8d7e,ddfdLd0Z9eej:j;jEj= e*dMd)d*dNe,dOe,ddfdPdQZFeFj8d6e,d7e,ddfdRd0Z9eej:j;jFj= e*dSd)d*dTe,dUe,dVej+ddfdWdXZGeGj8dTe,dUe,dVej+ddfdYd0Z9eej:j;jGj= e*dZd)d*d[ej+d7e,ddfd\d]ZHeHj8d[ej+d7e,ddfd^d0Z9eej:j;jHj= G d_d` d`ZIG dadb dbe"ZJG dcdd ddeJZKG dedf dfeKZLG dgdh dheZMdS )i    N)Callable)AnyOptional)ConstDictVariable)TupleVariable)has_side_effectProxy   )graph_break_hints)create_call_function)TYPE_CHECKINGunimplemented)CURRENT_STREAM_INDEXget_external_object_by_indexregister_graph_created_objectregister_user_objectreset_user_object_trackingCurrentStreamSource   )VariableTrackerConstantVariable)FxTracebackAnnotateVariable)LazyVariableTracker)InstructionTranslator)	PyCodegen)	custom_opargskwargsreturnc                  O   *   t j| i |}t|ttg ti S N)torchEventr   EventVariable make_construct_in_graph_event_fnr   r   )r   r   event r(   f/home/nk/hobo-godmode/plappi-mvp/.venv/lib/python3.10/site-packages/torch/_dynamo/variables/streams.py	new_event&      r*   c                  O   r!   r"   )r#   Streamr   StreamVariable!make_construct_in_graph_stream_fnr   r   )r   r   streamr(   r(   r)   
new_stream0   r+   r0   devicecgr   c                    s2      fdd  t|   tdd d S )Nc                           tjjjdS Nstash_graph_created_objectload_import_fromr#   _dynamograph_bytecode_inputs__name__r(   r2   r(   r)   <lambda><       z)_codegen_current_stream.<locals>.<lambda>r   F)add_push_nullr   extend_outputr   )r1   r2   r(   r;   r)   _codegen_current_stream:   s
   
r@   c                    s   t j }t| fddS )Nc                    s
   t  |S r"   )r@   )_r2   r1   r(   r)   r<   H   s   
 z$get_current_stream.<locals>.<lambda>)r#   acceleratorcurrent_streamr   )r1   r/   r(   rB   r)   get_current_streamE   s   rE   indexc                 C   &   t | }t|tjsJ d|  |S )Nz3Fork/join stream expected a stream object at index )r   
isinstancer#   r,   )rF   r/   r(   r(   r)   _get_stream_by_indexL   
   rI   c                 C   rG   )Nz4Record/wait event expected an event object at index )r   rH   r#   r$   )rF   r'   r(   r(   r)   _get_event_by_indexT   rJ   rK   zstreams::forkr(   )mutates_args
from_indexto_indexc                 C      t jt| d S r"   r#   rC   
set_streamrI   rM   rN   r(   r(   r)   fork_stream\   s   rS   c                 C      d S r"   r(   rR   r(   r(   r)   rA   d      rA   zstreams::joinc                 C   rO   r"   rP   rR   r(   r(   r)   join_streamo   s   rV   c                 C   rT   r"   r(   rR   r(   r(   r)   rA   t   rU   zstreams::record_eventevent_indexstream_indexc                 C      t | }t|}|| d S r"   )rK   rI   recordrW   rX   r'   r/   r(   r(   r)   record_event      r\   c                 C   rT   r"   r(   rW   rX   r(   r(   r)   rA      rU   zstreams::wait_eventc                 C   rY   r"   )rK   rI   waitr[   r(   r(   r)   
wait_event   r]   r`   c                 C   rT   r"   r(   r^   r(   r(   r)   rA      rU   zstreams::synchronize_eventc                 C      t | }|  d S r"   )rK   synchronize)rW   r'   r(   r(   r)   synchronize_event      rc   c                 C   rT   r"   r(   )rW   r(   r(   r)   rA         zstreams::synchronize_devicedevice_typedevice_indexc                 C   s   t jt | | d S r"   )r#   rC   rb   r1   rf   rg   r(   r(   r)   synchronize_device   s   ri   c                 C   rT   r"   r(   rh   r(   r(   r)   rA      re   zstreams::synchronize_streamc                 C   ra   r"   )rI   rb   )rX   r/   r(   r(   r)   synchronize_stream   rd   rj   c                 C   rT   r"   r(   )rX   r(   r(   r)   rA      re   zstreams::wait_streamwaiting_stream_indexwaited_on_stream_indexc                 C   s   t | }t |}|| d S r"   )rI   wait_stream)rk   rl   waiting	waited_onr(   r(   r)   rm      r]   rm   c                 C   rT   r"   r(   r^   r(   r(   r)   rA      rU   zstreams::sync_deallocwait_event_indexsrc_stream_index
to_deallocc                 C   s   t jjj| | dS )a  An op which waits on an event and moves the last usage of to_dealloc
    after the wait, so that after the sync occurs, the deallocation or
    subsequent reuse of the tensor's memory will be guaranteed to happen
    after a side stream is finished using it.
    See https://docs.pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html#torch.Tensor.record_stream
    for more detailsN)r#   opsstreamsr`   defaultrp   rq   rr   r(   r(   r)   sync_dealloc   s   
rw   c                 C   rT   r"   r(   rv   r(   r(   r)   rA      s   zstreams::record_streamtensorc                 C   s   |  t| d S r"   )record_streamrI   rx   rX   r(   r(   r)   ry      s   ry   c                 C   rT   r"   r(   rz   r(   r(   r)   rA      rU   c                   @   sd   e Zd ZdZdddZddd	Zdd
dZddejdB ddfddZ	de
fddZdefddZdS )SymbolicStreamStatez)Track the currently entered stream if anyr    Nc                 C   s   ddl m} g }tj r:t  tj }||j}t||}|t	ks-J dt	 d| t
j||d}||_|g}t|| _d S )Nr	   r   z+Current stream must be registered at index z, got )source)r|   r   r#   rC   is_availabler   rD   r1   r   r   r   createuser_object_indexcollectionsdequecur_stream_stack)selfr   	cur_stackr/   r|   rF   
stream_varr(   r(   r)   __init__
  s$   





zSymbolicStreamState.__init__r/   r-   c                 C   s   | j | d S r"   )r   appendr   r/   r(   r(   r)   enter_stream&  s   z SymbolicStreamState.enter_streamc                 C   s   | j   d S r"   )r   popr   r(   r(   r)   exit_stream)     zSymbolicStreamState.exit_streamr1   c                 C   s4   |d urt | jD ]}|j|kr|  S q	| jd S )N)reversedr   r1   )r   r1   r/   r(   r(   r)   
cur_stream,  s   

zSymbolicStreamState.cur_streamc                 C   s   t | jdkS )Nr   )lenr   r   r(   r(   r)   in_stream_context4  r   z%SymbolicStreamState.in_stream_contextc                 C   s2   | j d }t|tr| st| S t|jS )zOGet a Python object id for the current stream without realizing lazy variables.r   )r   rH   r   is_realizedid
peek_valuevaluer   r(   r(   r)   cur_stream_id7  s   

z!SymbolicStreamState.cur_stream_id)r    N)r/   r-   r    Nr"   )r:   
__module____qualname____doc__r   r   r   r#   r1   r   boolr   intr   r(   r(   r(   r)   r{     s    


r{   c                	       s   e Zd ZdZedddddeeef dd fdd	Zd
e	d deddf fddZ
dddedef fddZdddedef fddZdefddZdefddZdddZ  ZS )StreamContextVariablez(This represents torch.cuda.StreamContexttxr   stream_to_enterr-   r   r    c                 K   s   t |fi |S r"   )r   )r   r   r   r(   r(   r)   r~   B  s
   zStreamContextVariable.creater/   Nc                    s,   || _ t jdd|  jid d| d S )Nr/   )target_valuesinitial_valuesr(   )r/   superr   
get_streamr   )r   r/   r   	__class__r(   r)   r   M  s   
zStreamContextVariable.__init__r   c                    s   |j |   t |S r"   )symbolic_stream_stater   r   r   enterr   r   r   r   r(   r)   r   U  s   zStreamContextVariable.enterc                    s   |j   t j|g|R  S r"   )r   r   r   exitr   r   r(   r)   r   ]  s   
zStreamContextVariable.exitc                 C      t jjS r"   )r#   cudaStreamContextr   r(   r(   r)   python_typee     z!StreamContextVariable.python_typec                 C      dS )NTr(   r   r(   r(   r)   supports_graph_breaksh     z+StreamContextVariable.supports_graph_breaksc                 C   s   | j sJ d| j S )Nz,Stream context should have a separate stream)r/   r   r(   r(   r)   r   k  s   z StreamContextVariable.get_streamr    r-   )r:   r   r   r   staticmethoddictstrr   r~   r   r   r   r   r   typer   r   r   r   __classcell__r(   r(   r   r)   r   ?  s<    

r   c                       s   e Zd ZdZejZ	d"dedejdedB de	ddf
 fdd	Z
defd
dZdefddZdddedee deeef def
 fddZdefddZdefddZdefddZd#ddZd$ddZedededeedgdf fd d!Z  ZS )%r-   z1Represents the device-agnostic torch.Stream classNproxyr   r   r   r    c                    sX   |d urd|j jv r|j jd |ksJ || _|| _|j| _|| _t jdi | d S )Nexample_valuer"   )nodemetar   r   r1   r   r   r   r   r   r   r   r   r   r(   r)   r   u  s   	zStreamVariable.__init__c                 C      t jS r"   )r#   r,   r   r(   r(   r)   r        zStreamVariable.python_typec                 C      | j S r"   r   r   r(   r(   r)   get_real_python_backed_value  r   z+StreamVariable.get_real_python_backed_valuer   r   namer   c                    s  t | j|sJ d| ddlm}m} ddlm} |dkr?|d }t|ts*J |j	
dtjjj|j| jfi  td S |d	krc|d }	t|	tsNJ |j	
dtjjj| j|	jfi  td S |d
krz|j	
dtjjj| jfi  td S |dkr|t||j	j
d|g|| g| |R  dS |dkrddlm}
 |j	t| j |rt|d tr|d }|j}|j}n| j }t|ttg ti }|j	
dtjjj|| jfi  |
||j	
dt|fi dS ||v r@t|dkr@|s@ddlm }m!} | j"r|| j"#|j$ |d }t|tst%&|t'S |j"r3| j"d us*J || j"#|j$ t%&||| | j|jS t( )||||S )Nzno stream method found named r	   )cmp_name_to_op_mappingproxy_args_kwargsr   wrap_fx_proxy_clsr`   r   call_functionrm   rb   querycall_method
target_clsr   r   r\   )wrap_fx_proxy)r   r   GuardBuilderinstall_guard)*hasattrr   utilsr   r   builderr   rH   r%   outputcreate_proxyr#   rs   rt   r`   r   r   r~   r-   rm   rj   r   'check_event_record_after_input_mutationr   r\   r   r&   r   r   r   r   guardsr   r   r|   
make_guardEQUALS_MATCHr   buildNotImplementedr   r   )r   r   r   r   r   r   r   r   	event_argother_streamr   	event_varr'   rW   r   r   otherr   r(   r)   r     s   





	zStreamVariable.call_methodc                 C   r   r"   r   r   r(   r(   r)   as_proxy  r   zStreamVariable.as_proxyc                 C   r   )Nztorch._Cr(   r   r(   r(   r)   module_name  r   zStreamVariable.module_namec                 C   r   )Nr,   r(   r   r(   r(   r)   fn_name  r   zStreamVariable.fn_namecodegenr   c                    s   | j rJ | jd ur&  fdd   | j  tdd d S d| j } jj	
|| j}  j|dd d S )Nc                      r3   )Nr   r6   r(   r   r(   r)   r<     r=   z,StreamVariable.reconstruct.<locals>.<lambda>r   F_stream_Tadd)r|   r   r>   append_outputcreate_load_constr?   r   r1   r   r   install_global_by_idr   create_load_globalr   r   prefixr   r(   r   r)   reconstruct  s   


zStreamVariable.reconstructc                 C   s   | S r"   r(   r   r(   r(   r)   r     r   zStreamVariable.get_streamc                        dt dddd f fdd}|S )NrF   r   r   r    c                    X      fdd    fdd      tdd  tdd d S )Nc                      r3   r4   r6   r(   r   r(   r)   r<     r=   zNStreamVariable.make_construct_in_graph_stream_fn.<locals>.fn.<locals>.<lambda>c                      r3   )Nbuild_streamr7   r#   r8   r   r:   r(   r   r(   r)   r<         
r	   Fr   r>   r?   r   rF   r   r   r   r   r)   fn     

z<StreamVariable.make_construct_in_graph_stream_fn.<locals>.fnr   r   r   r   r(   r   r)   r.        z0StreamVariable.make_construct_in_graph_stream_fnr"   r   r   r    Nr   )r:   r   r   r   r#   r,   _cpython_typer   r   r   r   r   r   objectr   r   listr   r   r   r   r   r   r   r   r   r   r   r   r.   r   r(   r(   r   r)   r-   p  sR    
e

r-   c                       sD   e Zd ZdZejjZdefddZ	ddde
ddf fd	d
Z  ZS )CudaStreamVariablezMRepresents torch.cuda.Stream, preserving device-specific type and attributes.r    c                 C   r   r"   )r#   r   r,   r   r(   r(   r)   r   /  r   zCudaStreamVariable.python_typer   r   r   r   c                    s~   ddl m} |dkr8ddlm}m} | jr|| j|j t| j	dr+|
| j	jS t| j	dr8|
| j	jS t ||S )Nr   r   cuda_streamr	   r   native_handle) r   r   r   r   r|   r   r   r   r   r~   r  r  r   var_getattr)r   r   r   r   r   r   r   r(   r)   r  2  s   zCudaStreamVariable.var_getattr)r:   r   r   r   r#   r   r,   r   r   r   r   r  r   r(   r(   r   r)   r   *  s
    "r   c                       s   e Zd ZdedejdedB deddf
 fddZde	fd	d
Z
defddZdddedee deeef def
ddZdefddZedddee deeef dedef fddZedededeedgdf fddZdddZ  ZS )r%   r   r   r   Nr   r    c                    sP   |d urd|j jv r|j jd |ksJ t jdi | || _|| _|| _d S )Nr   r(   )r   r   r   r   r   r   r   r   r   r(   r)   r   E  s   
zEventVariable.__init__c                 C   r   r"   )r#   r$   r   r(   r(   r)   r   S  r   zEventVariable.python_typec                 C   r   r"   r   r   r(   r(   r)   r   V  r   z*EventVariable.get_real_python_backed_valuer   r   r   r   c                 C   sZ  ddl m} ddlm} |dkr-t|||\}}|jdtj	j
j| j|fi  td S |dkrWt|||\}	}|jt|	j |jdtj	j
j| j|fi  td S |dkrn|jdtj	j
j| jfi  td S |d	kr|t||jjd
|g|| g| |R  dS t| jj dt| jj d| }
tdt|d|
 dg tjd d S )Nr	   )r   r   r   r_   r   rZ   rb   r   r   r   .zUnsupported event methodz#Dynamo doesn't support tracing the zC method. We currently support wait, record, synchronize, and query.)gb_typecontextexplanationhints)r   r   r   r   r%   _get_stream_argr   r   r#   rs   rt   r`   r   r   r~   r   r   r   r\   rc   r   r   r   r   r   r
   SUPPORTABLE)r   r   r   r   r   r   r   rA   rX   
stream_argmethod_namer(   r(   r)   r   Y  sj   
	
	
"	

zEventVariable.call_methodc                 C   r   r"   r   r   r(   r(   r)   r     r   zEventVariable.as_proxyr-   c                 C   sB   d}|r	|d }n|r| d}|s| j }||jfS ||jfS )a;  Returns (stream_variable, stream_index_for_op).

        The ambient current stream is registered at index 0 in the external
        object registry.  The inductor wrapper updates index 0 at runtime so
        that cudagraph capture sees the capture stream, not the stale
        trace-time default stream.
        Nr   r/   )getr   r   r   )r   r   r   r  r   r(   r(   r)   r
    s   




zEventVariable._get_stream_argr   c                    r   )NrF   r   r   r    c                    r   )Nc                      r3   r4   r6   r(   r   r(   r)   r<     r=   zLEventVariable.make_construct_in_graph_event_fn.<locals>.fn.<locals>.<lambda>c                      r3   )Nbuild_eventr   r(   r   r(   r)   r<     r   r	   Fr   r   r   r   r   r)   r     r   z:EventVariable.make_construct_in_graph_event_fn.<locals>.fnr   r   r(   r   r)   r&     r   z.EventVariable.make_construct_in_graph_event_fnr   c                 C   s8   | j rJ d}|jj|| j}||j|dd d S )N_eventTr   )r|   r   r   r   r   r   r   r   r(   r(   r)   r     s   
zEventVariable.reconstructr   )r:   r   r   r   r#   r$   r   r   r   r   r   r   r   r   r   r   r   r   r   r   tupler
  r   r   r   r&   r   r   r(   r(   r   r)   r%   D  sZ    

A

r%   )Nr   collections.abcr   typingr   r   r#   torch._dynamo.variables.dictsr   torch._dynamo.variables.listsr   torch.fxr   r   r  r
   bytecode_transformationr   excr   r   r9   r   r   r   r   r   r|   r   baser   constantr   ctx_managerr   lazyr   torch._dynamo.symbolic_convertr   r   r   torch._library.custom_opsr   Tensorr   r*   r  r0   r1   r@   rE   r,   rI   r$   rK   rS   register_fakerA   rs   rt   forkru   rV   joinr\   r`   rc   r   ri   rj   rm   rw   ry   r{   r   r-   r   r%   r(   r(   r(   r)   <module>   s   











81 ;