File pytorch-rocm-do-not-use-aotriton-if-not-required.patch of Package python-torch-oldstable
--- a/CMakeLists.txt 2024-05-29 17:15:01.000000000 +0200
+++ b/CMakeLists.txt 2024-08-26 22:03:34.930907505 +0200
@@ -771,7 +771,13 @@
USE_MEM_EFF_ATTENTION
"Enable memory-efficient attention for scaled dot product attention.\
Will be disabled if not supported by the platform" ON
- "USE_CUDA" OFF)
+ "USE_CUDA OR USE_ROCM" OFF)
+
+if(USE_ROCM)
+ if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
+ include(cmake/External/aotriton.cmake)
+ endif()
+endif()
if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
--- a/cmake/Dependencies.cmake 2024-05-29 17:15:01.000000000 +0200
+++ b/cmake/Dependencies.cmake 2024-08-26 22:02:16.051084133 +0200
@@ -1334,11 +1334,6 @@
else()
message(STATUS "Disabling Kernel Assert for ROCm")
endif()
-
- include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
- if(USE_CUDA)
- caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
- endif()
else()
caffe2_update_option(USE_ROCM OFF)
endif()
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp 2024-08-25 14:27:15.042669638 +0200
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp 2024-08-26 21:41:25.003719366 +0200
@@ -22,7 +22,10 @@
#include <functional>
#if USE_ROCM
+#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
#include <aotriton/flash.h>
+#define USE_AOTRITON 1
+#endif
#endif
/**
@@ -187,6 +190,7 @@
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
+#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -197,6 +201,9 @@
return false;
}
#else
+ return false;
+#endif
+#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
if (debug) {
@@ -217,6 +224,21 @@
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
+#if USE_ROCM
+#if USE_AOTRITON
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
+ auto dprops = at::cuda::getCurrentDeviceProperties();
+ if (debug) {
+ TORCH_WARN(
+ "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
+ }
+ return false;
+ }
+#else
+ return false;
+#endif
+#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm90>(dprops)) {
if (debug) {
@@ -230,6 +252,7 @@
return false;
}
return true;
+#endif
}
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(