File 0001-Optionally-use-hipblaslt.patch of Package python-torch

From d77e05d90df006322cda021f1a8affdcc2c7eaef Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Fri, 23 Feb 2024 08:27:30 -0500
Subject: [PATCH] Optionally use hipblaslt

The hipblaslt package is not available on Fedora.
Instead of requiring the package, make it optional.
If it is found, define the preprocessor variable HIPBLASLT
Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks

Signed-off-by: Tom Rix <trix@redhat.com>
---
 aten/src/ATen/cuda/CUDABlas.cpp          |  7 ++++---
 aten/src/ATen/cuda/CUDABlas.h            |  2 +-
 aten/src/ATen/cuda/CUDAContextLight.h    |  4 ++--
 aten/src/ATen/cuda/CublasHandlePool.cpp  |  4 ++--
 aten/src/ATen/cuda/tunable/TunableGemm.h |  6 +++---
 aten/src/ATen/native/cuda/Blas.cpp       | 14 ++++++++------
 cmake/Dependencies.cmake                 |  3 +++
 cmake/public/LoadHIP.cmake               |  4 ++--
 8 files changed, 25 insertions(+), 19 deletions(-)

diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index d534ec5a178..e815463f630 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -14,7 +14,7 @@
 #include <c10/util/irange.h>
 
 #ifdef USE_ROCM
-#if ROCM_VERSION >= 60000
+#ifdef HIPBLASLT
 #include <hipblaslt/hipblaslt-ext.hpp>
 #endif
 // until hipblas has an API to accept flags, we must use rocblas here
@@ -781,7 +781,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
   }
 }
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 
 #if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
 // only for rocm 5.7 where we first supported hipblaslt, it was difficult
@@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
 };
 } // namespace
 
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 template <typename Dtype>
 void gemm_and_bias(
     bool transpose_mat1,
@@ -1124,7 +1125,7 @@ template void gemm_and_bias(
     at::BFloat16* result_ptr,
     int64_t result_ld,
     GEMMAndBiasActivationEpilogue activation);
-
+#endif
 void scaled_gemm(
     char transa,
     char transb,
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index eb12bb350c5..068607467dd 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
 template <>
 void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 enum GEMMAndBiasActivationEpilogue {
   None,
   RELU,
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
index 4ec35f59a21..e28dc42034f 100644
--- a/aten/src/ATen/cuda/CUDAContextLight.h
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
@@ -9,7 +9,7 @@
 
 // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
 // added bf16 support
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 #include <cublasLt.h>
 #endif
 
@@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
 /* Handles */
 TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
 TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
 #endif
 
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 6913d2cd95e..3d4276be372 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -29,7 +29,7 @@ namespace at::cuda {
 
 namespace {
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
 void createCublasLtHandle(cublasLtHandle_t *handle) {
   TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
 }
@@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
   return handle;
 }
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 cublasLtHandle_t getCurrentCUDABlasLtHandle() {
 #ifdef USE_ROCM
   c10::DeviceIndex device = 0;
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
index 3ba0d761277..dde1870cfbf 100644
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
@@ -11,7 +11,7 @@
 
 #include <ATen/cuda/tunable/GemmCommon.h>
 #ifdef USE_ROCM
-#if ROCM_VERSION >= 50700
+#ifdef HIPBLASLT
 #include <ATen/cuda/tunable/GemmHipblaslt.h>
 #endif
 #include <ATen/cuda/tunable/GemmRocblas.h>
@@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
     }
 #endif
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
     static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
     if (env == nullptr || strcmp(env, "1") == 0) {
       // disallow tuning of hipblaslt with c10::complex
@@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
     }
 #endif
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
     static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
     if (env == nullptr || strcmp(env, "1") == 0) {
       // disallow tuning of hipblaslt with c10::complex
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 29e5c5e3cf1..df56f3d7f1d 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -155,7 +155,7 @@ enum class Activation {
   GELU,
 };
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
   switch (a) {
     case Activation::None:
@@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {
 
 #ifdef USE_ROCM
 static bool isSupportedHipLtROCmArch(int index) {
+#if defined(HIPBLASLT)
     hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
     std::string device_arch = prop->gcnArchName;
     static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
@@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
         }
     }
     TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
+#endif
     return false;
 }
 #endif
@@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
   at::ScalarType scalar_type = self.scalar_type();
   c10::MaybeOwned<Tensor> self_;
   if (&result != &self) {
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
     // Strangely, if mat2 has only 1 row or column, we get
     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
@@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
     }
     self__sizes = self_->sizes();
   } else {
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
     useLtInterface = !disable_addmm_cuda_lt &&
         result.dim() == 2 && result.is_contiguous() &&
         isSupportedHipLtROCmArch(self.device().index()) &&
@@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
 
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
   if (useLtInterface) {
     AT_DISPATCH_FLOATING_TYPES_AND2(
         at::ScalarType::Half,
@@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
   at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
   at::native::resize_output(amax, {});
 
-#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
+#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT))
   cublasCommonArgs args(mat1, mat2, out);
   const auto out_dtype_ = args.result->scalar_type();
   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
   TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
 #endif
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
+#if defined(USE_ROCM) && defined(HIPBLASLT)
   // rocm's hipblaslt does not yet support amax, so calculate separately
   auto out_float32 = out.to(kFloat);
   out_float32.abs_();
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index b7ffbeb07dc..2b6c3678984 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1273,6 +1273,9 @@ if(USE_ROCM)
     if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
       list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
     endif()
+    if(hipblast_FOUND)
+      list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
+    endif()
     if(HIPBLASLT_CUSTOM_DATA_TYPE)
       list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
     endif()
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
index f6ca263c5e5..53eb0b63c1a 100644
--- a/cmake/public/LoadHIP.cmake
+++ b/cmake/public/LoadHIP.cmake
@@ -156,7 +156,7 @@ if(HIP_FOUND)
   find_package_and_print_version(rocblas REQUIRED)
   find_package_and_print_version(hipblas REQUIRED)
   if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
-    find_package_and_print_version(hipblaslt REQUIRED)
+    find_package_and_print_version(hipblaslt)
   endif()
   find_package_and_print_version(miopen REQUIRED)
   if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
@@ -191,7 +191,7 @@ if(HIP_FOUND)
   # roctx is part of roctracer
   find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
 
-  if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
+  if(hipblastlt_FOUND)
     # check whether hipblaslt is using its own datatype
     set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
     file(WRITE ${file} ""
-- 
2.43.2

openSUSE Build Service is sponsored by