File pytorch-optionally-use-hipblaslt-2-3-1.patch of Package python-torch-oldstable

diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index d534ec5a1..e815463f6 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -14,7 +14,7 @@
 #include <c10/util/irange.h>
 
 #ifdef USE_ROCM
-#if ROCM_VERSION >= 60000
+#ifdef HIPBLASLT
 #include <hipblaslt/hipblaslt-ext.hpp>
 #endif
 // until hipblas has an API to accept flags, we must use rocblas here
@@ -781,7 +781,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
   }
 }
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 
 #if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
 // only for rocm 5.7 where we first supported hipblaslt, it was difficult
@@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
 };
 } // namespace
 
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 template <typename Dtype>
 void gemm_and_bias(
     bool transpose_mat1,
@@ -1124,7 +1125,7 @@ template void gemm_and_bias(
     at::BFloat16* result_ptr,
     int64_t result_ld,
     GEMMAndBiasActivationEpilogue activation);
-
+#endif
 void scaled_gemm(
     char transa,
     char transb,
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index eb12bb350..068607467 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
 template <>
 void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 enum GEMMAndBiasActivationEpilogue {
   None,
   RELU,
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
index 4ec35f59a..e28dc4203 100644
--- a/aten/src/ATen/cuda/CUDAContextLight.h
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
@@ -9,7 +9,7 @@
 
 // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
 // added bf16 support
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 #include <cublasLt.h>
 #endif
 
@@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
 /* Handles */
 TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
 TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
 #endif
 
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 6913d2cd9..3d4276be3 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -29,7 +29,7 @@ namespace at::cuda {
 
 namespace {
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
 void createCublasLtHandle(cublasLtHandle_t *handle) {
   TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
 }
@@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
   return handle;
 }
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 cublasLtHandle_t getCurrentCUDABlasLtHandle() {
 #ifdef USE_ROCM
   c10::DeviceIndex device = 0;
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
index 3ba0d7612..dde1870cf 100644
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
@@ -11,7 +11,7 @@
 
 #include <ATen/cuda/tunable/GemmCommon.h>
 #ifdef USE_ROCM
-#if ROCM_VERSION >= 50700
+#ifdef HIPBLASLT
 #include <ATen/cuda/tunable/GemmHipblaslt.h>
 #endif
 #include <ATen/cuda/tunable/GemmRocblas.h>
@@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
     }
 #endif
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
     static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
     if (env == nullptr || strcmp(env, "1") == 0) {
       // disallow tuning of hipblaslt with c10::complex
@@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
     }
 #endif
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
     static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
     if (env == nullptr || strcmp(env, "1") == 0) {
       // disallow tuning of hipblaslt with c10::complex
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 29e5c5e3c..df56f3d7f 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -155,7 +155,7 @@ enum class Activation {
   GELU,
 };
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
 cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
   switch (a) {
     case Activation::None:
@@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {
 
 #ifdef USE_ROCM
 static bool isSupportedHipLtROCmArch(int index) {
+#if defined(HIPBLASLT)
     hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
     std::string device_arch = prop->gcnArchName;
     static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
@@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
         }
     }
     TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
+#endif
     return false;
 }
 #endif
@@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
   at::ScalarType scalar_type = self.scalar_type();
   c10::MaybeOwned<Tensor> self_;
   if (&result != &self) {
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
     // Strangely, if mat2 has only 1 row or column, we get
     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
     // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
@@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
     }
     self__sizes = self_->sizes();
   } else {
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
     useLtInterface = !disable_addmm_cuda_lt &&
         result.dim() == 2 && result.is_contiguous() &&
         isSupportedHipLtROCmArch(self.device().index()) &&
@@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
 
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
 
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
   if (useLtInterface) {
     AT_DISPATCH_FLOATING_TYPES_AND2(
         at::ScalarType::Half,
@@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
   at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
   at::native::resize_output(amax, {});
 
-#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
+#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT))
   cublasCommonArgs args(mat1, mat2, out);
   const auto out_dtype_ = args.result->scalar_type();
   TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
   TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
 #endif
 
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
+#if defined(USE_ROCM) && defined(HIPBLASLT)
   // rocm's hipblaslt does not yet support amax, so calculate separately
   auto out_float32 = out.to(kFloat);
   out_float32.abs_();
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index a96075245..d6149ef77 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1282,6 +1282,9 @@ if(USE_ROCM)
     if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
       list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
     endif()
+    if(hipblaslt_FOUND)
+      list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
+    endif()
     if(HIPBLASLT_CUSTOM_DATA_TYPE)
       list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
     endif()
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
index f6ca263c5..53eb0b63c 100644
--- a/cmake/public/LoadHIP.cmake
+++ b/cmake/public/LoadHIP.cmake
@@ -156,7 +156,7 @@ if(HIP_FOUND)
   find_package_and_print_version(rocblas REQUIRED)
   find_package_and_print_version(hipblas REQUIRED)
   if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
-    find_package_and_print_version(hipblaslt REQUIRED)
+    find_package_and_print_version(hipblaslt)
   endif()
   find_package_and_print_version(miopen REQUIRED)
   if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
@@ -191,7 +191,7 @@ if(HIP_FOUND)
   # roctx is part of roctracer
   find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
 
-  if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
+  if(hipblastlt_FOUND)
     # check whether hipblaslt is using its own datatype
     set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
     file(WRITE ${file} ""
openSUSE Build Service is sponsored by