File pytorch-optionally-use-hipblaslt.patch of Package python-torch

diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index eea4a9f42..d57942ba0 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -14,7 +14,9 @@
 #include <c10/util/irange.h>
 
 #ifdef USE_ROCM
+#ifdef USE_HIPBLASLT
 #include <hipblaslt/hipblaslt-ext.hpp>
+#endif
 // until hipblas has an API to accept flags, we must use rocblas here
 #include <hipblas/hipblas.h>
 #include <rocblas/rocblas.h>
@@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) {
 static size_t _parseChosenWorkspaceSize() {
   const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
 #ifdef USE_ROCM
+#ifndef USE_HIPBLASLT
+  return 0;
+#endif
   if (!val) {
     // accept either env var
     val = getenv("HIPBLASLT_WORKSPACE_SIZE");
@@ -235,6 +240,7 @@ namespace at::cuda::blas {
   } while (0)
 
 
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 namespace {
 // Following the pattern of CuSparseDescriptor
 // Defined here for now because this is the only place cublas_lt interface is
@@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
 };
 } // namespace
 
-
 template <typename Dtype>
 inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
   cudaDataType_t abcType = CUDA_R_32F;
@@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
       " scaleType ",
       scaleType);
 }
-
+#endif
 
 template <typename Dtype>
 inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
@@ -608,10 +613,13 @@ void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
 template <>
 void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
 {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
   }
-  else {
+  else
+#endif
+  {
     bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
   }
 }
@@ -651,10 +659,13 @@ void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<fl
 template <>
 void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
 {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
   }
-  else {
+  else
+#endif
+  {
     bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
   }
 }
@@ -662,10 +673,13 @@ void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
 template <>
 void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
 {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
   }
-  else {
+  else
+#endif
+  {
     bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
   }
 }
@@ -781,11 +795,13 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
   }
 }
 
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 template <typename Dtype>
 inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
   // forward to bgemm implementation but set strides and batches to 0
   bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
 }
+#endif
 
 template <typename Dtype>
 inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
@@ -1008,10 +1024,13 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
 template <>
 void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
 {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
   }
-  else {
+  else
+#endif
+  {
     gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
   }
 }
@@ -1051,10 +1070,13 @@ void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<floa
 template <>
 void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
 {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
   }
-  else {
+  else
+#endif
+  {
     gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
   }
 }
@@ -1062,10 +1084,13 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
 template <>
 void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
 {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
     gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
   }
-  else {
+  else
+#endif
+  {
     gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
   }
 }
@@ -1177,7 +1202,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
   }
 }
 
-
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 template <typename Dtype>
 void gemm_and_bias(
     bool transpose_mat1,
@@ -1410,7 +1435,7 @@ void scaled_gemm(
     ScalarType result_dtype,
     void* amax_ptr,
     bool use_fast_accum) {
-#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
+#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   const auto computeType = CUBLAS_COMPUTE_32F;
   const auto scaleType = CUDA_R_32F;
   const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
@@ -1681,6 +1706,7 @@ void int8_gemm(
       " scaleType ",
       scaleType);
 }
+#endif
 
 template <>
 void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index 2c6cef95f..dfca27656 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -87,7 +87,7 @@ enum GEMMAndBiasActivationEpilogue {
   RELU,
   GELU,
 };
-
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 // NOTE: GELU activation is not supported prior to CUDA 11.4 and will
 // do nothing if passed in that case.
 template <typename Dtype>
@@ -142,7 +142,7 @@ void scaled_gemm(
     ScalarType result_dtype,
     void* amax_ptr,
     bool use_fast_accum);
-
+#endif
 #define CUDABLAS_BGEMM_ARGTYPES(Dtype)                                                        \
   char transa, char transb, int64_t m, int64_t n, int64_t k, at::opmath_type<Dtype> alpha,    \
       const Dtype *a, int64_t lda, int64_t stridea,                                           \
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
index dc33cb541..208816e2a 100644
--- a/aten/src/ATen/cuda/CUDAContextLight.h
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
@@ -9,7 +9,9 @@
 
 // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
 // added bf16 support
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
 #include <cublasLt.h>
+#endif
 
 #ifdef CUDART_VERSION
 #include <cusolverDn.h>
@@ -84,7 +86,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
 /* Handles */
 TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
 TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
 TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
+#endif
 
 TORCH_CUDA_CPP_API void clearCublasWorkspaces();
 
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 8eac525b3..abfdf7a23 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -29,7 +29,7 @@ namespace at::cuda {
 
 namespace {
 
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
 void createCublasLtHandle(cublasLtHandle_t *handle) {
   TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
 }
@@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() {
   return handle;
 }
 
-cublasLtHandle_t getCurrentCUDABlasLtHandle() {
 #ifdef USE_ROCM
+#if defined(USE_HIPBLASLT)
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
   c10::DeviceIndex device = 0;
   AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
 
@@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
 
   auto handle = myPoolWindow->reserve(device);
   return handle;
+}
+#endif
 #else
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
   return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
-#endif
 }
+#endif
 
 } // namespace at::cuda
diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp
index 5053f6693..31b41b275 100644
--- a/aten/src/ATen/cuda/tunable/Tunable.cpp
+++ b/aten/src/ATen/cuda/tunable/Tunable.cpp
@@ -39,9 +39,11 @@
 #include <rocm-core/rocm_version.h>
 #define ROCBLAS_BETA_FEATURES_API
 #include <rocblas/rocblas.h>
+#ifdef USE_HIPBLASLT
 #include <hipblaslt/hipblaslt.h>
 #include <hipblaslt/hipblaslt-ext.hpp>
 #endif
+#endif
 
 namespace at::cuda::tunable {
 
@@ -214,6 +216,7 @@ TuningResultsValidator::TuningResultsValidator() {
         [rocblas_version]() { return rocblas_version; },
         [rocblas_version](auto&& k) { return rocblas_version == k ? OK : FAIL; });
   }
+#ifdef USE_HIPBLASLT
   // hipblaslt
   {
     int version;
@@ -229,6 +232,7 @@ TuningResultsValidator::TuningResultsValidator() {
         [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
   }
 #endif
+#endif
 }
 
 std::unordered_map<std::string, std::string> TuningResultsValidator::GetAllValidators() const {
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
index 50a7344b0..a8da577aa 100644
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
@@ -11,7 +11,9 @@
 
 #include <ATen/cuda/tunable/GemmCommon.h>
 #ifdef USE_ROCM
+#ifdef USE_HIPBLASLT
 #include <ATen/cuda/tunable/GemmHipblaslt.h>
+#endif
 #include <ATen/cuda/tunable/GemmRocblas.h>
 #endif
 #include <ATen/cuda/tunable/StreamTimer.h>
@@ -80,6 +82,7 @@ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
     }
 };
 
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 template <typename T>
 class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
   public:
@@ -109,6 +112,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
       return OK;
     }
 };
+#endif
 
 template <typename T>
 inline bool IsZero(T v) {
@@ -203,7 +207,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
         this->RegisterOp(std::move(name), std::move(op));
       }
     }
-
+#ifdef USE_HIPBLASLT
     static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
     if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
       // disallow tuning of hipblaslt with c10::complex
@@ -215,6 +219,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
         }
       }
     }
+#endif
 #endif
   }
 
@@ -229,7 +234,7 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
   GemmAndBiasTunableOp() {
     this->RegisterOp(std::string("Default"), std::make_unique<DefaultGemmAndBiasOp<T>>());
 
-#ifdef USE_ROCM
+#if (defined(USE_ROCM) && defined(USE_HIPBLASLT))
     static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
     if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
       // disallow tuning of hipblaslt with c10::complex
@@ -262,7 +267,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
         this->RegisterOp(std::move(name), std::move(op));
       }
     }
-
+#ifdef USE_HIPBLASLT
     static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
     if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
       // disallow tuning of hipblaslt with c10::complex
@@ -274,6 +279,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
         }
       }
     }
+#endif
 #endif
   }
 
@@ -282,6 +288,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
   }
 };
 
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
 class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
  public:
@@ -303,5 +310,6 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
             "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
   }
 };
+#endif
 
 } // namespace at::cuda::tunable
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 991c7d2db..fa8f84cc4 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -178,6 +178,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
 }
 
 static bool getDisableAddmmCudaLt() {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
     static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
 #ifdef USE_ROCM
     // allow both CUDA and HIP env var names for ROCm builds
@@ -195,10 +196,14 @@ static bool getDisableAddmmCudaLt() {
     }
     return false;
 #endif
+#else
+    return true;
+#endif
 }
 
 #ifdef USE_ROCM
 static bool isSupportedHipLtROCmArch(int index) {
+#ifdef USE_HIPBLASLT
     hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
     std::string device_arch = prop->gcnArchName;
     static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
@@ -209,6 +214,7 @@ static bool isSupportedHipLtROCmArch(int index) {
         }
     }
     TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
+#endif
     return false;
 }
 #endif
@@ -274,6 +280,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
   at::ScalarType scalar_type = self.scalar_type();
   c10::MaybeOwned<Tensor> self_;
   if (&result != &self) {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
     // Strangely, if mat2 has only 1 row or column, we get
     // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
@@ -315,13 +322,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
             scalar_type != at::ScalarType::BFloat16));
 #endif
     }
+#endif
 #endif
     if (!useLtInterface) {
       self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
     }
     self__sizes = self_->sizes();
   } else {
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
     useLtInterface = !disable_addmm_cuda_lt &&
         result.dim() == 2 && result.is_contiguous() &&
         isSupportedHipLtROCmArch(self.device().index()) &&
@@ -373,6 +381,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
 
   if (useLtInterface) {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
 #if defined(USE_ROCM)
     AT_DISPATCH_FLOATING_TYPES_AND2(
         at::ScalarType::Half,
@@ -451,6 +460,9 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
               activation_epilogue
           );
         }});
+#endif
+#else
+    TORCH_CHECK(false, "Hit a stub path in addmm_out_cuda_impl");
 #endif
   } else
   {
@@ -824,7 +836,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result)
 
   TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous.");
 
-#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || defined(USE_ROCM)
+#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   cublasCommonArgs args(self, mat2, result);
 
   at::cuda::blas::int8_gemm(
@@ -860,6 +872,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
 }
 
 static bool _scaled_mm_allowed_device() {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
     auto dprops = at::cuda::getCurrentDeviceProperties();
 #ifdef USE_ROCM
     std::string device_arch = dprops->gcnArchName;
@@ -874,6 +887,9 @@ static bool _scaled_mm_allowed_device() {
 #else
     return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
 #endif
+#else
+    return false;
+#endif
 }
 
 namespace{
@@ -995,6 +1011,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
   // Check sizes
   bool allowed_device = _scaled_mm_allowed_device();
   TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
   TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
   TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
   TORCH_CHECK(
@@ -1182,7 +1199,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
         amax.data_ptr(),
         use_fast_accum);
   }
-
+#endif
   return out;
 }
 
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index e78305e0a..9b2591112 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1061,6 +1061,9 @@ if(USE_ROCM)
     if(HIP_NEW_TYPE_ENUMS)
       list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
     endif()
+    if(hipblast_FOUND)
+      list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT)
+    endif()
     add_definitions(-DROCM_VERSION=${ROCM_VERSION_DEV_INT})
     add_definitions(-DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
     message("TORCH_HIP_VERSION=${TORCH_HIP_VERSION} is added as a compiler defines")
@@ -1089,8 +1092,9 @@ if(USE_ROCM)
 
     set(Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
       hip::amdhip64 MIOpen hiprtc::hiprtc) # libroctx will be linked in with MIOpen
-    list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblaslt)
-
+    if(hipblast_FOUND)
+      list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipblaslt)
+    endif()
     list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
       roc::hipblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver)
 
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
index 1c0d3a203..a9fe05a2a 100644
--- a/cmake/public/LoadHIP.cmake
+++ b/cmake/public/LoadHIP.cmake
@@ -156,7 +156,7 @@ if(HIP_FOUND)
   find_package_and_print_version(hiprand REQUIRED)
   find_package_and_print_version(rocblas REQUIRED)
   find_package_and_print_version(hipblas REQUIRED)
-  find_package_and_print_version(hipblaslt REQUIRED)
+  find_package_and_print_version(hipblaslt)
   find_package_and_print_version(miopen REQUIRED)
   find_package_and_print_version(hipfft REQUIRED)
   find_package_and_print_version(hipsparse REQUIRED)
openSUSE Build Service is sponsored by