o
    q::jE                     @  sh   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m
Z
 				ddddZG dd deZdS )    )annotations)AnyN)Base)expectmeanx
np.ndarraytargetweightnp.ndarray | None	reductionstrignore_index
int | Noneget_log_probbool | Nonereturnr   c                 C  s  | j }t|dkrtd|j }|d }|d }	tj| ddd}
t| |
 }|tj|ddd }t|}d }|du rAt|}d }|d urhtj	|tj
|tjddd}|d urgt||kd|jtjd}n|d urzt||kddjtjd}t|d	kr|||	d
f}||d
f}|j d }tj||ftjd}t|D ]"}t|D ]}|| | |kr|| || |  |  || |< qq|}t|d	kr||}|d ur|| }|dkr| |  }|du r||fS |S |dkrt|}n
|dkrt|}|r||fS |S )N   zUnsupported shaper   T)axiskeepdimsdtypeclip)mode      r   sum)shapelenRuntimeErrornpmaxexpr   logcopytakearrayint32whereastypefloat32reshapezerosranger   )r   r	   r
   r   r   r   input_shapetarget_shapeNCmax_xexp_xpinplog_probgather_weightDneg_gather_element_inputidloss r>   v/home/nk/hobo-godmode/plappi-mvp/.venv/lib/python3.10/site-packages/onnx/backend/test/case/node/softmaxcrossentropy.pysoftmaxcrossentropy   sf   


"


r@   c                   @  s  e Zd ZedHddZedHddZedHddZedHd	d
ZedHddZedHddZ	edHddZ
edHddZedHddZedHddZedHddZedHddZedHddZedHddZedHdd ZedHd!d"ZedHd#d$ZedHd%d&ZedHd'd(ZedHd)d*ZedHd+d,ZedHd-d.ZedHd/d0ZedHd1d2ZedHd3d4ZedHd5d6ZedHd7d8ZedHd9d:ZedHd;d<Z edHd=d>Z!edHd?d@Z"edHdAdBZ#edHdCdDZ$edHdEdFZ%dGS )ISoftmaxCrossEntropyLossr   Nonec                  C     d} t jjdddgdg| d}tjd tjdd	tj}tjj	dd	d
dtj
}t||dd}t|||g|gdd d S )NnonerA   r   yzinputsoutputsr   r   r      r   highsizer   test_sce_nonerH   rI   nameonnxhelper	make_noder!   randomseedrandr*   r+   randintint64r@   r   r   noder   labelsscer>   r>   r?   export_softmaxcrossentropy_noned      z7SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_nonec                  C     d} t jjdddgddg| d}tjd tjd	d
tj}tjj	dd
ddtj
}t||ddd\}}t|||g||gdd d S )NrD   rA   r   rE   rF   r7   rG   r   r   rJ   rK   rL   Tr   r   test_sce_none_log_probrQ   rS   r   r]   r   r^   r=   r7   r>   r>   r?   (export_softmaxcrossentropy_none_log_prob|   &   

z@SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_none_log_probc                  C  s   d} t jjdg ddg| d}tjd tjddtj}tjj	ddd	d
tj
}tjg dtjd}t|||dd}t||||g|gdd d S )NrD   rA   r   rE   wrF   rG   r   r   rJ   rK   rL   ?gffffff?g?rk   rk   r   r
   r   test_sce_none_weightsrQ   rT   rU   rV   r!   rW   rX   rY   r*   r+   rZ   r[   r'   r@   r   r   r]   r   r^   weightsr_   r>   r>   r?   'export_softmaxcrossentropy_none_weights   s$   
z?SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_none_weightsc                  C  s   d} t jjdg dddg| d}tjd tjdd	tj}tjj	dd	d
dtj
}tjg dtjd}t|||ddd\}}t||||g||gdd d S )NrD   rA   rh   rF   r7   rG   r   r   rJ   rK   rL   rj   r   Tr
   r   r   test_sce_none_weights_log_probrQ   rn   r   r]   r   r^   rp   r=   r7   r>   r>   r?   0export_softmaxcrossentropy_none_weights_log_prob   s(   


zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_none_weights_log_probc                  C  rC   )Nr   rA   r   rE   rF   rG   r   r   rJ   rK   rL   rO   test_sce_sumrQ   rS   r\   r>   r>   r?   export_softmaxcrossentropy_sum   ra   z6SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_sumc                  C  rb   )Nr   rA   r   rE   rF   r7   rG   r   r   rJ   rK   rL   Trc   test_sce_sum_log_probrQ   rS   re   r>   r>   r?   'export_softmaxcrossentropy_sum_log_prob   rg   z?SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_sum_log_probc                  C  s~   d} t jjdddgdg| d}tjd tjdd	tj}tjj	dd	d
dtj
}t||}t|||g|gdd d S )Nr   rA   r   rE   rF   rG   r   r   rJ   rK   rL   test_sce_meanrQ   rS   r\   r>   r>   r?   export_softmaxcrossentropy_mean  s   
z7SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_meanc                  C  s   d} t jjdddgddg| d}tjd tjd	d
tj}tjj	dd
ddtj
}t||dd\}}t|||g||gdd d S )Nr   rA   r   rE   rF   r7   rG   r   r   rJ   rK   rL   Tr   test_sce_mean_log_probrQ   rS   re   r>   r>   r?   (export_softmaxcrossentropy_mean_log_prob(  s"   
z@SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_log_probc                  C  s   d} t jjdddgdg| d}tjd tjdd	d
tj}tjj	dd	ddtj
}t||}t|||g|gdd d S )Nr   rA   r   rE   rF   rG   r   r   rJ   r   r   r   rL   test_sce_mean_3drQ   rS   )r   r]   r   rE   r_   r>   r>   r?   "export_softmaxcrossentropy_mean_3dE  s   
z:SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_3dc                  C  s   d} t jjdddgddg| d}tjd tjd	d
dtj}tjj	dd
ddtj
}t||dd\}}t|||g||gdd d S )Nr   rA   r   rE   rF   r7   rG   r   r   rJ   r   r   rL   Tr|   test_sce_mean_3d_log_probrQ   rS   )r   r]   r   rE   r=   r7   r>   r>   r?   +export_softmaxcrossentropy_mean_3d_log_prob]  s"   
zCSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_3d_log_probc                  C  s   d} t jjdg ddg| d}tjd tjddtj}tjj	ddd	d
tj
}tjg dtjd}t|||d}t||||g|gdd d S )Nr   rA   rh   rF   rG   r   r   rJ   rK   rL   rj   r   )r
   test_sce_mean_weightrQ   rn   ro   r>   r>   r?   'export_softmaxcrossentropy_mean_weightsz  s$   
z?SoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weightsc                  C  s   d} t jjdg dddg| d}tjd tjdd	tj}tjj	dd	d
dtj
}tjg dtjd}t|||dd\}}t||||g||gdd d S )Nr   rA   rh   rF   r7   rG   r   r   rJ   rK   rL   rj   r   T)r
   r   test_sce_mean_weight_log_probrQ   rn   rt   r>   r>   r?   0export_softmaxcrossentropy_mean_weights_log_prob  s(   

zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_log_probc                  C  s   d} t d}tjjdg ddg| |d}t jd t jddt j	}t jj
ddd	d
t j}t d|d< t jg dt j	d}t||||d}t||||g|gdd d S )Nr   r   rA   rh   rF   rH   rI   r   r   r   rJ   rK   rL   rj   r   r
   r   test_sce_mean_weight_iirQ   r!   r[   rT   rU   rV   rW   rX   rY   r*   r+   rZ   r'   r@   r   r   r   r]   r   r^   rp   r_   r>   r>   r?   *export_softmaxcrossentropy_mean_weights_ii  s*   
	
zBSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_iic                  C  s   d} t d}tjjdg dddg| |d}t jd t jdd	t j	}t jj
dd	d
dt j}t d|d< t jg dt j	d}t||||dd\}}t||||g||gdd d S )Nr   r   rA   rh   rF   r7   r   r   rJ   rK   rL   rj   r   Tr
   r   r    test_sce_mean_weight_ii_log_probrQ   r   r   r   r]   r   r^   rp   r=   r7   r>   r>   r?   3export_softmaxcrossentropy_mean_weights_ii_log_prob  s.   
	


zKSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_log_probc                  C  s   d} t d}tjjdddgdg| |d}t jd t jd	d
t j	}t jj
dd
ddt j}t d|d< t|||d}t|||g|gdd d S )Nr   r   rA   r   rE   rF   r   r   r   rJ   rK   rL   r   test_sce_mean_no_weight_iirQ   r!   r[   rT   rU   rV   rW   rX   rY   r*   r+   rZ   r@   r   r   r   r]   r   r^   r_   r>   r>   r?   -export_softmaxcrossentropy_mean_no_weights_ii  s"   
	
zESoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_iic                  C  s   d} t d}tjjdddgddg| |d}t jd	 t jd
dt j	}t jj
d	dddt j}t d|d	< t|||dd\}}t|||g||gdd d S )Nr   r   rA   r   rE   rF   r7   r   r   r   rJ   rK   rL   Tr   r   #test_sce_mean_no_weight_ii_log_probrQ   r   r   r   r]   r   r^   r=   r7   r>   r>   r?   6export_softmaxcrossentropy_mean_no_weights_ii_log_prob  s,   
	

zNSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_log_probc                  C  s   d} t d}tjjdg ddg| |d}t jd t jdd	d
t j	}t jj
dd	ddt j}t d|d d< t jg dt j	d}t||||d}t||||g|gdd d S )Nr   r   rA   rh   rF   r   r   r   rJ   r   r   rL   g?g333333?g333333?g?g      ?r   r   test_sce_mean_weight_ii_3drQ   r   r   r>   r>   r?   -export_softmaxcrossentropy_mean_weights_ii_3d;  s*   
	
zESoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_3dc                  C  s   d} t d}tjjdg dddg| |d}t jd t jd	d
dt j	}t jj
dd
ddt j}t d|d d< t jg dt j	d}t||||dd\}}t||||g||gdd d S )Nr   r   rA   rh   rF   r7   r   r   r   rJ   r   r   rL   r   r   Tr   #test_sce_mean_weight_ii_3d_log_probrQ   r   r   r>   r>   r?   6export_softmaxcrossentropy_mean_weights_ii_3d_log_prob\  s.   
	


zNSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_3d_log_probc                  C  s   d} t d}tjjdddgdg| |d}t jd t jd	d
dt j	}t jj
dd
ddt j}t d|d d< t|||d}t|||g|gdd d S )Nr   r   rA   r   rE   rF   r   r   r   rJ   r   rL   r   test_sce_mean_no_weight_ii_3drQ   r   r   r>   r>   r?   0export_softmaxcrossentropy_mean_no_weights_ii_3d  s(   
	
zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_3dc                  C  s   d} t d}tjjdddgddg| |d}t jd	 t jd
ddt j	}t jj
d	dddt j}t d|d	 d	< t|||dd\}}t|||g||gdd d S )Nr   r   rA   r   rE   rF   r7   r   r   r   rJ   r   rL   Tr   &test_sce_mean_no_weight_ii_3d_log_probrQ   r   r   r>   r>   r?   9export_softmaxcrossentropy_mean_no_weights_ii_3d_log_prob  s,   
	

zQSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_3d_log_probc                  C  s   d} t d}tjjdg ddg| |d}t jd t jdd	dd
t j	}t jj
dd	ddt j}t d|d d d< t jg dt j	d}t||| ||d}t||||g|gdd d S )Nr   r   rA   rh   rF   r   r   r   rJ      r   r   r   rL   r   r   )r   r
   r   test_sce_mean_weight_ii_4drQ   r   r   r>   r>   r?   -export_softmaxcrossentropy_mean_weights_ii_4d  s.   
	

zESoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_4dc                  C  s   d} t d}tjjdg dddg| |d}t jd t jd	d
ddt j	}t jj
dd
ddt j}t d|d d d< t jg dt j	d}t||| ||dd\}}t||||g||gdd d S )Nr   r   rA   rh   rF   r7   r   r   r   rJ   r   r   rL   r   r   T)r   r
   r   r   #test_sce_mean_weight_ii_4d_log_probrQ   r   r   r>   r>   r?   6export_softmaxcrossentropy_mean_weights_ii_4d_log_prob  s8   
	


zNSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_weights_ii_4d_log_probc                  C  s   d} t d}tjjdddgdg| |d}t jd t jd	d
ddt j	}t jj
dd
ddt j}t d|d d d< t||| |d}t|||g|gdd d S )Nr   r   rA   r   rE   rF   r   r   r   rJ   r   r   rL   r   r   test_sce_mean_no_weight_ii_4drQ   r   r   r>   r>   r?   0export_softmaxcrossentropy_mean_no_weights_ii_4d  s,   
	
zHSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_4dc                  C  s   d} t d}tjjdddgddg| |d}t jd	 t jd
dddt j	}t jj
d	dddt j}t d|d	 d	 d	< t||| |dd\}}t|||g||gdd d S )Nr   r   rA   r   rE   rF   r7   r   r   r   rJ   r   r   rL   Tr   r   r   &test_sce_mean_no_weight_ii_4d_log_probrQ   r   r   r>   r>   r?   9export_softmaxcrossentropy_mean_no_weights_ii_4d_log_prob.  s,   
	


zQSoftmaxCrossEntropyLoss.export_softmaxcrossentropy_mean_no_weights_ii_4d_log_probc               	   C  s   d} t jjdg ddg| d}d\}}}}}}}tjd tj|||||||tj}	tjj	d|||||||fdtj
}
tj|tj}t|	|
|| d	}t||	|
|g|gd
d d S )Nr   rA   rh   rF   rG   r   rJ      r   rJ   r      r   rL   rl   !test_sce_NCd1d2d3d4d5_mean_weightrQ   rS   )r   r]   r1   r2   dim1dim2dim3dim4dim5r   r^   r
   r_   r>   r>   r?   .export_input_shape_is_NCd1d2d3d4d5_mean_weightP  s.    
zFSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_mean_weightc               	   C  s   d} t jjdg dddg| d}d\}}}}}}}tjd tj|||||||tj}	tjj	d|||||||fd	tj
}
tj|tj}t|	|
|| d
d\}}t||	|
|g||gdd d S )Nr   rA   rh   rF   r7   rG   r   r   rL   Trr   *test_sce_NCd1d2d3d4d5_mean_weight_log_probrQ   rS   )r   r]   r1   r2   r   r   r   r   r   r   r^   r
   r=   r7   r>   r>   r?   7export_input_shape_is_NCd1d2d3d4d5_mean_weight_log_probl  s2    


zOSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_mean_weight_log_probc               	   C  s   d} t jjdddgdg| d}d\}}}}}}}tjd tj|||||||tj}	tjj	d|||||||fd	tj
}
t|	|
| d
}t||	|
g|gdd d S )NrD   rA   r   rE   rF   rG   r   r   rL   rO   $test_sce_NCd1d2d3d4d5_none_no_weightrQ   rS   )r   r]   r1   r2   r   r   r   r   r   r   r^   r_   r>   r>   r?   1export_input_shape_is_NCd1d2d3d4d5_none_no_weight  s,    
zISoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_none_no_weightc               	   C  s   d} t jjdddgddg| d}d\}}}}}}}tjd	 tj|||||||tj}	tjj	d	|||||||fd
tj
}
t|	|
| dd\}}t||	|
g||gdd d S )NrD   rA   r   rE   rF   r7   rG   r   r   rL   Trc   -test_sce_NCd1d2d3d4d5_none_no_weight_log_probrQ   rS   )r   r]   r1   r2   r   r   r   r   r   r   r^   r=   r7   r>   r>   r?   :export_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob  s0    

zRSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_probc            
      C  s   d} t d}tjjdg ddg| |d}d\}}}t jd t j|||t j	}t jj
d|||fd	t j}d|d d< t j|t j	}t|||| |d
}	t||||g|	gdd d S )Nr   r   rA   rh   rF   r   r   rJ   r   r   rL   r
   r   r   %test_sce_NCd1_mean_weight_negative_iirQ   r   )
r   r   r]   r1   r2   r   r   r^   r
   r_   r>   r>   r?   2export_input_shape_is_NCd1_mean_weight_negative_ii  s0   



zJSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1_mean_weight_negative_iic                  C  s   d} t d}tjjdg dddg| |d}d\}}}t jd	 t j|||t j	}t jj
d	|||fd
t j}d|d	 d	< t j|t j	}t|||| |dd\}	}
t||||g|	|
gdd d S )Nr   r   rA   rh   rF   r7   r   r   r   rL   Tr
   r   r   r   .test_sce_NCd1_mean_weight_negative_ii_log_probrQ   r   )r   r   r]   r1   r2   r   r   r^   r
   r=   r7   r>   r>   r?   ;export_input_shape_is_NCd1_mean_weight_negative_ii_log_prob  s:   


	
zSSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1_mean_weight_negative_ii_log_probc                  C  s   d} t d}tjjdddgdg| |d}d\}}}}}t jd	 t j|||||t j	}t jj
d	|||||fd
t j}	d|	d	 d	 d	 d	< t||	| |d}
t|||	g|
gdd d S )NrD   rA   r   rE   rF   r   r   rJ   r   r   rJ   r   rL   r   ,test_sce_NCd1d2d3_none_no_weight_negative_iirQ   r   )r   r   r]   r1   r2   r   r   r   r   r^   r_   r>   r>   r?   9export_input_shape_is_NCd1d2d3_none_no_weight_negative_ii  s2   

zQSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_none_no_weight_negative_iic                  C  s   d} t d}tjjdddgddg| |d}d	\}}}}}t jd
 t j|||||t j	}t jj
d
|||||fdt j}	d|	d
 d
 d
 d
< t||	| |dd\}
}t|||	g|
|gdd d S )NrD   r   rA   r   rE   rF   r7   r   r   r   rL   Tr   5test_sce_NCd1d2d3_none_no_weight_negative_ii_log_probrQ   r   )r   r   r]   r1   r2   r   r   r   r   r^   r=   r7   r>   r>   r?   Bexport_input_shape_is_NCd1d2d3_none_no_weight_negative_ii_log_prob%  s2   



zZSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_none_no_weight_negative_ii_log_probc            	      C  s   d} t d}tjjdg ddg| |d}d\}}t jd t j||t j	}t jj
d||d	t j}d|d< t j|t j	}t|||| |d
}t||||g|gdd d S )Nr   
   rA   rh   rF   r   r   rJ   r   rL   r   $test_sce_NCd1d2d3_sum_weight_high_iirQ   r   )	r   r   r]   r1   r2   r   r^   r
   r_   r>   r>   r?   1export_input_shape_is_NCd1d2d3_sum_weight_high_iiE  s0   


zISoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_sum_weight_high_iic            
      C  s   d} t d}tjjdg dddg| |d}d\}}t jd	 t j||t j	}t jj
d	||d
t j}d|d	< t j|t j	}t|||| |dd\}}	t||||g||	gdd d S )Nr   r   rA   rh   rF   r7   r   r   r   rL   Tr   -test_sce_NCd1d2d3_sum_weight_high_ii_log_probrQ   r   )
r   r   r]   r1   r2   r   r^   r
   r=   r7   r>   r>   r?   :export_input_shape_is_NCd1d2d3_sum_weight_high_ii_log_probd  s:   

	
zRSoftmaxCrossEntropyLoss.export_input_shape_is_NCd1d2d3_sum_weight_high_ii_log_probN)r   rB   )&__name__
__module____qualname__staticmethodr`   rf   rq   ru   rw   ry   r{   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r>   r>   r>   r?   rA   c   s     "! "!"'!!#rA   )Nr   NN)r   r   r	   r   r
   r   r   r   r   r   r   r   r   r   )
__future__r   typingr   numpyr!   rT   onnx.backend.test.case.baser   onnx.backend.test.case.noder   r@   rA   r>   r>   r>   r?   <module>   s   T