o
    m9:jT                     @  s:  U d Z ddlmZ ddlZddlZddlmZ ddlmZm	Z	m
Z
 ddlZddlZddlmZmZ e
r?ddlmZ ddlmZ d	d
lmZ d	dlmZmZmZmZmZ d	dlmZmZ dae  Z!e" Z#g a$de%d< e&ej'ddd Z(e a)dZ*de%d< dqddZ+e+ rej'j,Z,ej'j-Z.ej'j/Z0nedZ,drddZ.drdd Z0ed	d!dsd"d#Z1dqd$d%Z2dtdud(d)Z3dqd*d+Z4d,d- Z5dvd/d0Z6dvd1d2Z7dvd3d4Z8G d5d6 d6Z9G d7d dZ:G d8d9 d9e:Z;dwd;d<Z<dxdyd>d?Z=eddxdzdAdBZ>	dxd{dCdDZ?dsdEdFZ@d|dIdJZAd}dLdMZBG dNdO dOZCd~dRdPZDdvdSdTZEddVdWZFdxddXdYZGdxdd[d\ZHdxdwd]d^ZIdd`daZJddbdcZKddedfZL	gdddidjZMdddkdlZNd	dmlOmPZPmQZQmRZRmSZSmTZTmUZUmVZVmWZWmXZXmYZYmZZZm[Z[m\Z\m]Z]m^Z^m_Z_m`Z` d	dnlambZbmcZcmdZdmeZemfZfmgZgmhZhmiZimjZj e	doekelelf Zmg dpZndS )z
This package introduces support for the XPU backend, specifically tailored for
Intel GPU optimization.

This package is lazily initialized, so you can always import it, and use
:func:`is_available()` to determine if your system supports XPU.
    )annotationsN)	lru_cache)AnyNewTypeTYPE_CHECKING)_dummy_type_LazySeedTracker)Callable)Device   )_get_device_index)graphgraph_pool_handleis_current_stream_capturingmake_graphed_callablesXPUGraph)EventStreamFz*list[tuple[Callable[[], None], list[str]]]_queued_calls_xpu_isInBadForkc                   C  s   dS NF r   r   r   Y/home/nk/hobo-godmode/plappi-mvp/.venv/lib/python3.10/site-packages/torch/xpu/__init__.py<lambda>,   s    r   r   ztuple[torch._C.Generator]default_generatorsreturnboolc                   C  s   t jjS )z(Return true if compile with XPU support.)torch_C_has_xpur   r   r   r   _is_compiled1   s   r    _XpuDevicePropertiesdeviceintc                 C     t dNz(PyTorch was compiled without XPU supportNotImplementedErrorr"   r   r   r   _exchange_device>      r)   c                 C  r$   r%   r&   r(   r   r   r   _maybe_exchange_deviceA   r*   r+   )maxsizec                   C  s   t  sdS tj S )z*Return the number of XPU device available.r   )r    r   r   _xpu_getDeviceCountr   r   r   r   device_countE   s   
r.   c                   C  s
   t  dkS )z7Return a bool indicating if XPU is currently available.r   )r.   r   r   r   r   is_availableM   s   
r/   Tincluding_emulationc                 C  s   t  sdS | ptj jS )zKReturn a bool indicating if the current XPU device supports dtype bfloat16.F)r/   r   xpuget_device_propertieshas_bfloat16_conversions)r0   r   r   r   is_bf16_supportedS   s
   
r4   c                   C  s   t  sdS tj jS )zGReturn a bool indicating if the current XPU device supports dtype tf32.F)r/   r   r1   r2   'has_subgroup_matrix_multiply_accumulater   r   r   r   is_tf32_supported]   s   r6   c                   C  s   t ot  S )z8Return whether PyTorch's XPU state has been initialized.)_initialized_is_in_bad_forkr   r   r   r   is_initializedh   s   r9   Nonec                 K  sf   t  r|   d S |ddrt| t  d S |ddr(t| t  d S t| t f d S )Nseed_allFseed)	r9   get_lazy_seed_trackerqueue_seed_all	tracebackformat_stack
queue_seedr   append)callablekwargsr   r   r   
_lazy_callm   s   
rF   c                   C  s
   t   dS )zInitialize PyTorch's XPU state.
    This is a Python API about lazy initialization that avoids initializing
    XPU until the first time it is accessed. Does nothing if the XPU state is
    already initialized.
    N)
_lazy_initr   r   r   r   init{   s   
rH   c                  C  s  t  sttdr
d S tq t  r	 W d    d S t r tdt s'tdtj	
  dt_tdd t D  z1tD ]'\} }z|   W q> tye } zdt| dd	| }t||d }~ww W ttd nttd w daW d    d S 1 sw   Y  d S )
Nis_initializingzuCannot re-initialize XPU in forked subprocess. To use XPU with multiprocessing, you must use the 'spawn' start methodz#Torch not compiled with XPU enabledTc                 s  s    | ]}|r|V  qd S Nr   ).0callsr   r   r   	<genexpr>   s    z_lazy_init.<locals>.<genexpr>z5XPU call failed lazily at initialization with error: z'

XPU call was originally invoked at:

 )r9   hasattr_tls_initialization_lockr8   RuntimeErrorr    AssertionErrorr   r   	_xpu_initrI   r   extendr>   	get_calls	Exceptionstrjoindelattrr7   )queued_callorig_tracebackemsgr   r   r   rG      s>   



"rG   c                   @  s(   e Zd ZdddZdd ZdddZdS )_DeviceGuardindexr#   r   r:   c                 C  s   || _ d| _d S N)idxprev_idx)selfr`   r   r   r   __init__   s   
z_DeviceGuard.__init__c                 C     t j| j| _d S rJ   r   r1   r)   rc   rd   re   r   r   r   	__enter__      z_DeviceGuard.__enter__typer   valuer@   c                 C     t j| j| _dS r   r   r1   r+   rd   rc   re   rl   rm   r@   r   r   r   __exit__      z_DeviceGuard.__exit__N)r`   r#   r   r:   rl   r   rm   r   r@   r   )__name__
__module____qualname__rf   rj   rq   r   r   r   r   r_      s    
r_   c                   @  s,   e Zd ZdZdddZdd ZdddZdS )r"   zContext-manager that changes the selected device.

    Args:
        device (torch.device or int or str): device index to select. It's a no-op if
            this argument is a negative integer or ``None``.
    r   r   r:   c                 C  s   t |dd| _d| _d S )NToptionalrb   )r   rc   rd   )re   r"   r   r   r   rf      s   
zdevice.__init__c                 C  rg   rJ   rh   ri   r   r   r   rj      rk   zdevice.__enter__rl   rm   r@   c                 C  rn   r   ro   rp   r   r   r   rq      rr   zdevice.__exit__N)r"   r   r   r:   rs   )rt   ru   rv   __doc__rf   rj   rq   r   r   r   r   r"      s
    
c                      s"   e Zd ZdZd fddZ  ZS )	device_ofa  Context-manager that changes the current device to that of given object.

    You can use both tensors and storages as arguments. If a given object is
    not allocated on a XPU, this is a no-op.

    Args:
        obj (Tensor or Storage): object allocated on the selected device.
    r   r:   c                   s"   |j r| nd}t | d S ra   )is_xpu
get_devicesuperrf   )re   objrc   	__class__r   r   rf      s   zdevice_of.__init__r   r:   )rt   ru   rv   ry   rf   __classcell__r   r   r   r   rz      s    	rz   r
   c                 C  s*   t   t| } | dkrtj|  dS dS )zSet the current device.

    Args:
        device (torch.device or int or str): selected device. This function is a
            no-op if this argument is negative.
    r   N)rG   r   r   r   _xpu_setDevicer(   r   r   r   
set_device   s
   r   rX   c                 C  s
   t | jS )a  Get the name of a device.

    Args:
        device (torch.device or int or str, optional): device for which to
            return the name. This function is a no-op if this argument is a
            negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`,
            if :attr:`device` is ``None`` (default).

    Returns:
        str: the name of the device
    )r2   namer(   r   r   r   get_device_name   s   
r   dict[str, Any]c                   s:   t |  tttttdtttf fddt	 D S )a  Get the xpu capability of a device.

    Args:
        device (torch.device or int or str, optional): device for which to
            return the device capability. This function is a no-op if this
            argument is a negative integer. It uses the current device, given by
            :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
            (default).

    Returns:
        dict[str, Any]: the xpu capability dictionary of the device
    Nc                   s0   i | ]}| d stt | r|qS )__)
startswith
isinstancegetattr)rK   keypropsserializable_typesrm   r   r   
<dictcomp>	  s    z)get_device_capability.<locals>.<dictcomp>)
r2   r#   floatr   rX   rl   listtupledictdirr(   r   r   r   get_device_capability   s
   r   c                 C  s   t   t| dd} t| S )a  Get the properties of a device. Returns _XpuDeviceProperties containing the following device properties:

    - ``name`` (str): device name.
    - ``platform_name`` (str): SYCL platform name.
    - ``vendor`` (str): device vendor.
    - ``device_id`` (int): device identifier (product ID).
    - ``driver_version`` (str): driver version.
    - ``version`` (str): runtime version.
    - ``max_compute_units`` (int): number of parallel compute units.
    - ``gpu_eu_count`` (int): number of EUs (Execution Unit).
    - ``max_work_group_size``: (int): maximum number of work-items permitted in a work-group.
    - ``max_num_sub_groups`` (int): maximum number of sub-groups supported in a work-group.
    - ``memory_clock_rate`` (int) maximum clock rate of device's global memory in MHz.
    - ``memory_bus_width`` (int) maximum bus width between device and memory in bits.
    - ``sub_group_sizes``: (list[int]): a list of supported sub-group sizes.
    - ``local_mem_size`` (int): device local memory capacity that can be allocated per work-group in bytes.
    - ``has_fp16`` (bool): whether float16 dtype is supported.
    - ``has_fp64`` (bool): whether float64 dtype is supported.
    - ``has_atomic64`` (bool): whether 64-bit atomic operations are supported.
    - ``has_bfloat16_conversions`` (bool): whether bfloat16 conversions are supported.
    - ``has_subgroup_matrix_multiply_accumulate`` (bool): whether DPAS (Dot Product Accumulate Systolic) is supported.
    - ``has_subgroup_matrix_multiply_accumulate_tensor_float32`` (bool): whether DPAS with tf32 inputs is supported.
    - ``has_subgroup_2d_block_io`` (bool): whether 2D block I/O for efficient matrix multiplication is supported.
    - ``total_memory`` (int): device global memory in bytes.
    - ``gpu_subslice_count`` (int): number of subslice.
    - ``architecture`` (int): device architecture identifier (experimental).
    - ``type`` (str): device type, e.g. 'cpu', 'gpu', accelerator', 'host', 'unknown'.
    - ``uuid`` (Any): device UUID (Universal Unique ID), 16 bytes.

    Args:
        device (torch.device or int or str): device for which to return the
            properties of the device.

    Returns:
        _XpuDeviceProperties: the properties of the device
    Trw   )rG   r   _get_device_propertiesr(   r   r   r   r2     s   'r2   c                   C  s   t   tj S )z0Return the index of a currently selected device.)rG   r   r   _xpu_getDevicer   r   r   r   current_device=  s   
r   int | str | torch.devicetorch.devicec                 C  s2   t | trt| } | S t | trtd| } | S )zReturn the torch.device type object from the passed in device.

    Args:
        device (torch.device or int or str): selected device.
    r1   )r   rX   r   r"   r#   r(   r   r   r   _get_deviceC  s   


r   peerc                 C  s,   t   t| dd} t|dd}tj| |S )a/  Query whether a device can access a peer device's memory.

    Args:
        device (torch.device or int or str): selected device.
        peer (torch.device or int or str): peer device to query access to.

    Returns:
        bool: ``True`` if ``device`` can access ``peer``, ``False`` otherwise.
    Trw   )rG   r   r   r   _xpu_canDeviceAccessPeer)r"   r   r   r   r   can_device_access_peerP  s   
r   c                   @  s6   e Zd ZU dZded< dddZd	d
 ZdddZdS )StreamContexta  Context-manager that selects a given stream.

    All XPU kernels queued within its context will be enqueued on a selected
    stream.

    Args:
        Stream (Stream): selected stream. This manager is a no-op if it's
            ``None``.
    .. note:: Streams are per-device.
    torch.xpu.Stream | None
cur_streamstreamr   r:   c                 C  s*   || _ td d| _| jd u rd| _d S d S )NTrb   )r   r   rc   )re   r   r   r   r   rf   n  s
   

zStreamContext.__init__c                 C  s   | j }|d u s| jdkrd S tjd | _| jj|jkr9t|j tj|j| _W d    n1 s4w   Y  tj| d S ra   )	r   rc   r   r1   current_streamsrc_prev_streamr"   dst_prev_stream
set_stream)re   r   r   r   r   rj   t  s   zStreamContext.__enter__rl   r   rm   r@   c                 C  sJ   | j }|d u s| jdkrd S | jj|jkrtj| j tj| j d S ra   )r   rc   r   r"   r   r1   r   r   )re   rl   rm   r@   r   r   r   r   rq     s   zStreamContext.__exit__N)r   r   r   r:   rs   )rt   ru   rv   ry   __annotations__rf   rj   rq   r   r   r   r   r   `  s   
 
r   r   r   c                 C  s   t | S )zWrap around the Context-manager StreamContext that selects a given stream.

    Arguments:
        stream (Stream): selected stream. This manager is a no-op if it's ``None``.
    )r   r   r   r   r   r     s   c                 C  s   t jj| ||d dS )a  set stream specified by the stream id, device index and device type

    Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
          device_index (int): selected device index.
          device_type (int): selected device type.
    	stream_iddevice_indexdevice_typeN)r   r   _xpu_setStreamr   r   r   r   _set_stream_by_id  s
   
r   r   c                 C  s*   | du rdS t   t| j| j| jd dS )a  Set the current stream. This is a wrapper API to set the stream.
        Usage of this function is discouraged in favor of the ``stream``
        context manager.

    Args:
        stream (Stream): selected stream. This function is a no-op
            if this argument is ``None``.
    Nr   )rG   r   r   r   r   r   r   r   r   r     s   	
r   c                 C  s4   t   tjt| dd}t|d |d |d dS )aR  Return the currently selected :class:`Stream` for a given device.

    Args:
        device (torch.device or int, optional): selected device. Returns
            the currently selected :class:`Stream` for the current device, given
            by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
            (default).
    Trw   r   r      r   )rG   r   r   _xpu_getCurrentStreamr   r   )r"   
streamdatar   r   r   r     s   	
r   data_ptrc                 C  s6   t   tj| t|dd}t|d |d |d dS )a;  Return a :class:`Stream` from an external SYCL queue.

    This function is used to wrap SYCL queue created in other libraries in order
    to facilitate data exchange and multi-library interactions.

    .. note:: This function doesn't manage the queue life-cycle, it is the user
       responsibility to keep the referenced queue alive while this returned stream is
       being used. The different SYCL queue pointers will result in distinct
       :class:`Stream` objects, even if the SYCL queues they dereference are equivalent.

    Args:
        data_ptr(int): Integer representation of the `sycl::queue*` value passed externally.
        device(torch.device or int, optional): the device where the queue was originally created.
            It is the user responsibility to ensure the device is specified correctly.
    Trw   r   r   r   r   )rG   r   r   _xpu_getStreamFromExternalr   r   )r   r"   r   r   r   r   get_stream_from_external  s   r   c                 C  s   t   t| dd} tj| S )a*  Wait for all kernels in all streams on a XPU device to complete.

    Args:
        device (torch.device or int, optional): device for which to synchronize.
            It uses the current device, given by :func:`~torch.xpu.current_device`,
            if :attr:`device` is ``None`` (default).
    Trw   )rG   r   r   r   _xpu_synchronizer(   r   r   r   synchronize  s   r   	list[str]c                  C  s(   t  sg S tj } | du rg S |  S )z<Return list XPU architectures this library was compiled for.N)r    r   r   _xpu_getArchFlagssplit)
arch_flagsr   r   r   get_arch_list  s   
r   c                  C  s0   t  } t| dkrdS dddd | D  S )zIReturn XPU AOT(ahead-of-time) build flags this library was compiled with.r   rN   z-device ,c                 s  s    | ]}|V  qd S rJ   r   )rK   archr   r   r   rM     s    z$get_gencode_flags.<locals>.<genexpr>)r   lenrY   )	arch_listr   r   r   get_gencode_flags  s   r   torch._C.Generatorc                 C  s    | j }|du r
t }tjj| S )zuReturn the XPU Generator object for the given device.

    Args:
        device (torch.device): selected device.
    N)r`   r   r   r1   r   )r"   rc   r   r   r   _get_generator  s   r   r1   offsetc                   s$   t | d fdd}t| dS )a$  Set the random number generator state offset of the specified GPU.

    Args:
        offset (int): The desired offset
        device (torch.device or int, optional): The device to set the RNG state.
            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
    r   r:   c                    s   t  } |  d S rJ   )r   
set_offset)default_generatorfinal_devicer   r   r   cb  s   z!_set_rng_state_offset.<locals>.cbNr   )r   rF   )r   r"   r   r   r   r   _set_rng_state_offset  s   
r   c                 C  s   t   t| }t|}| S )aL  Return the random number generator state offset of the specified GPU.

    Args:
        device (torch.device or int, optional): The device to return the RNG state offset of.
            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).

    .. warning::
        This function eagerly initializes XPU.
    )rG   r   r   
get_offset)r"   r   r   r   r   r   _get_rng_state_offset  s   
r   )change_current_allocatorempty_cacheget_per_process_memory_fractionmax_memory_allocatedmax_memory_reservedmem_get_infomemory_allocatedmemory_reservedmemory_snapshotmemory_statsmemory_stats_as_nested_dictMemPoolreset_accumulated_memory_statsreset_peak_memory_statsset_per_process_memory_fractionuse_mem_poolXPUPluggableAllocator)	get_rng_stateget_rng_state_allinitial_seedmanual_seedmanual_seed_allr<   r;   set_rng_stateset_rng_state_all_POOL_HANDLE)9r   r   r   r   r   r   r   r   r   r   r"   rz   r.   r   r   r   r   r2   r   r   r   r   r   r   r   rH   r   r/   r4   r   r9   r6   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r<   r;   r   r   r   r   r   r   streamsr   )r   r   )r"   r#   r   r#   )r   r#   )T)r0   r   r   r   r   )r"   r
   r   r:   rJ   )r"   r
   r   rX   )r"   r
   r   r   )r"   r
   r   r!   )r"   r   r   r   )r"   r
   r   r
   r   r   )r   r   r   r   )r   r   r   r:   )r"   r
   r   r   )r   r#   r"   r
   r   r   )r   r   )r   rX   )r"   r   r   r   )r1   )r   r#   r"   r   r   r:   )r"   r   r   r#   )ory   
__future__r   	threadingr@   	functoolsr   typingr   r   r   r   torch._Ctorch._utilsr   r   collections.abcr	   torch.typesr
   _utilsr   graphsr   r   r   r   r   r   r   r   r7   localrP   LockrQ   r   r   r   r   r8   r>   r   r    r!   _xpu_exchangeDevicer)   _xpu_maybeExchangeDevicer+   r.   r/   r4   r6   r9   rF   rH   rG   r_   r"   rz   r   r   r   r2   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   memoryr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   randomr   r   r   r   r   r<   r;   r   r   r   r#   r   __all__r   r   r   r   <module>   s   










	(

,


+
	




L,