File 0001-disable-use-of-aotriton.patch of Package python-torch
From 33d48f71db7530f00dbd8cff281b65aa8b355b2a Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Tue, 19 Mar 2024 11:32:37 -0400
Subject: [PATCH] disable use of aotriton
---
aten/src/ATen/native/transformers/cuda/sdp_utils.cpp | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index 96b839820efd..2d3dd0cb4b0f 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -21,9 +21,11 @@
#include <cmath>
#include <functional>
+#ifdef USE_FLASH_ATTENTION
#if USE_ROCM
#include <aotriton/flash.h>
#endif
+#endif
/**
* Note [SDPA Runtime Dispatch]
@@ -183,6 +185,7 @@ bool check_sm_version(cudaDeviceProp * dprops) {
}
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
+#ifdef USE_FLASH_ATTENTION
// Check that the gpu is capable of running flash attention
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
@@ -211,6 +214,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
#endif
return true;
+#else
+ return false;
+#endif
}
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
--
2.44.0