File pytorch-rocm-do-not-use-aotriton-if-not-required.patch of Package python-torch
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index d84d94176..a34e5ef79 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -246,8 +246,8 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
return false;
}
-#endif
return true;
+#endif
}
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
@@ -485,6 +485,21 @@ bool check_cudnn_layout(sdp_params const& params, bool debug) {
}
bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
+#if USE_ROCM
+#if USE_AOTRITON
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
+ if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
+ auto dprops = at::cuda::getCurrentDeviceProperties();
+ if (debug) {
+ TORCH_WARN(
+ "Mem Efficient attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
+ }
+ return false;
+ }
+#else
+ return false;
+#endif
+#else
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -500,6 +515,7 @@ bool check_cudnn_hardware_support(sdp_params const& params, bool debug) {
return false;
}
return true;
+#endif
}
bool check_for_nested_inputs(sdp_params const& params, bool debug) {