4c0dcdd8a4
Upstream issue that points to several commits: https://github.com/OpenMathLib/OpenBLAS/issues/4475 Downstream trackers: https://bugzilla.redhat.com/show_bug.cgi?id=2261415 https://issues.redhat.com/browse/RHEL-24745
753 lines
20 KiB
Diff
753 lines
20 KiB
Diff
This is a compilation of more upstream commits related to:
|
|
https://github.com/OpenMathLib/OpenBLAS/issues/4475
|
|
|
|
|
|
From 63004fa5f76ef1058975271314bc4591e7878726 Mon Sep 17 00:00:00 2001
|
|
From: Honza Horak <hhorak@redhat.com>
|
|
Date: Fri, 9 Feb 2024 09:49:41 +0100
|
|
Subject: [PATCH 1/6] Fix incompatible pointer type in BFLOAT16 mode
|
|
|
|
Upstream commit:
|
|
|
|
commit 68d354814f9f846338e1988c4f609c8add419012
|
|
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
|
|
Date: Sun Feb 4 01:14:22 2024 +0100
|
|
|
|
Fix incompatible pointer type in BFLOAT16 mode
|
|
---
|
|
interface/gemmt.c | 4 ++--
|
|
1 file changed, 2 insertions(+), 2 deletions(-)
|
|
|
|
diff --git a/interface/gemmt.c b/interface/gemmt.c
|
|
index 046432670..2fb9954ad 100644
|
|
--- a/interface/gemmt.c
|
|
+++ b/interface/gemmt.c
|
|
@@ -478,7 +478,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
|
#endif
|
|
// for alignment
|
|
buffer_size = (buffer_size + 3) & ~3;
|
|
- STACK_ALLOC(buffer_size, FLOAT, buffer);
|
|
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
|
|
|
|
#ifdef SMP
|
|
|
|
@@ -567,7 +567,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
|
#endif
|
|
// for alignment
|
|
buffer_size = (buffer_size + 3) & ~3;
|
|
- STACK_ALLOC(buffer_size, FLOAT, buffer);
|
|
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
|
|
|
|
#ifdef SMP
|
|
|
|
--
|
|
2.41.0
|
|
|
|
From edfd4f52f3f22344863c233411ae792fb12aa81b Mon Sep 17 00:00:00 2001
|
|
From: Honza Horak <hhorak@redhat.com>
|
|
Date: Fri, 9 Feb 2024 09:53:40 +0100
|
|
Subject: [PATCH 2/6] Separate the interface for SBGEMMT from GEMMT due to
|
|
differences in GEMV arguments
|
|
|
|
Upstream commit:
|
|
|
|
commit d4db6a9f16a5c82bbe1860f591cc731c4d83d7c8
|
|
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
|
|
Date: Tue Feb 6 22:23:47 2024 +0100
|
|
|
|
Separate the interface for SBGEMMT from GEMMT due to differences in GEMV arguments
|
|
---
|
|
interface/CMakeLists.txt | 1 +
|
|
interface/Makefile | 4 +-
|
|
interface/sbgemmt.c | 447 +++++++++++++++++++++++++++++++++++++++
|
|
3 files changed, 450 insertions(+), 2 deletions(-)
|
|
create mode 100644 interface/sbgemmt.c
|
|
|
|
diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt
|
|
index 4e082928b..3110f2e90 100644
|
|
--- a/interface/CMakeLists.txt
|
|
+++ b/interface/CMakeLists.txt
|
|
@@ -119,6 +119,7 @@ endif ()
|
|
if (BUILD_BFLOAT16)
|
|
GenerateNamedObjects("bf16dot.c" "" "sbdot" ${CBLAS_FLAG} "" "" true "BFLOAT16")
|
|
GenerateNamedObjects("gemm.c" "" "sbgemm" ${CBLAS_FLAG} "" "" true "BFLOAT16")
|
|
+ GenerateNamedObjects("gemmt.c" "" "sbgemmt" ${CBLAS_FLAG} "" "" true "BFLOAT16")
|
|
GenerateNamedObjects("sbgemv.c" "" "sbgemv" ${CBLAS_FLAG} "" "" true "BFLOAT16")
|
|
GenerateNamedObjects("tobf16.c" "SINGLE_PREC" "sbstobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
|
|
GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16")
|
|
diff --git a/interface/Makefile b/interface/Makefile
|
|
index 78335357b..d106ca568 100644
|
|
--- a/interface/Makefile
|
|
+++ b/interface/Makefile
|
|
@@ -1301,7 +1301,7 @@ xhpr2.$(SUFFIX) xhpr2.$(PSUFFIX) : zhpr2.c
|
|
ifeq ($(BUILD_BFLOAT16),1)
|
|
sbgemm.$(SUFFIX) sbgemm.$(PSUFFIX) : gemm.c ../param.h
|
|
$(CC) -c $(CFLAGS) $< -o $(@F)
|
|
-sbgemmt.$(SUFFIX) sbgemmt.$(PSUFFIX) : gemmt.c ../param.h
|
|
+sbgemmt.$(SUFFIX) sbgemmt.$(PSUFFIX) : sbgemmt.c ../param.h
|
|
$(CC) -c $(CFLAGS) $< -o $(@F)
|
|
endif
|
|
|
|
@@ -1932,7 +1932,7 @@ cblas_sgemmt.$(SUFFIX) cblas_sgemmt.$(PSUFFIX) : gemmt.c ../param.h
|
|
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
|
|
|
|
ifeq ($(BUILD_BFLOAT16),1)
|
|
-cblas_sbgemmt.$(SUFFIX) cblas_sbgemmt.$(PSUFFIX) : gemmt.c ../param.h
|
|
+cblas_sbgemmt.$(SUFFIX) cblas_sbgemmt.$(PSUFFIX) : sbgemmt.c ../param.h
|
|
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
|
|
endif
|
|
|
|
diff --git a/interface/sbgemmt.c b/interface/sbgemmt.c
|
|
new file mode 100644
|
|
index 000000000..759af4bfb
|
|
--- /dev/null
|
|
+++ b/interface/sbgemmt.c
|
|
@@ -0,0 +1,447 @@
|
|
+/*********************************************************************/
|
|
+/* Copyright 2024, The OpenBLAS Project. */
|
|
+/* All rights reserved. */
|
|
+/* */
|
|
+/* Redistribution and use in source and binary forms, with or */
|
|
+/* without modification, are permitted provided that the following */
|
|
+/* conditions are met: */
|
|
+/* */
|
|
+/* 1. Redistributions of source code must retain the above */
|
|
+/* copyright notice, this list of conditions and the following */
|
|
+/* disclaimer. */
|
|
+/* */
|
|
+/* 2. Redistributions in binary form must reproduce the above */
|
|
+/* copyright notice, this list of conditions and the following */
|
|
+/* disclaimer in the documentation and/or other materials */
|
|
+/* provided with the distribution. */
|
|
+/* */
|
|
+/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
|
|
+/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
|
|
+/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
|
|
+/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
|
|
+/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
|
|
+/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
|
|
+/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
|
|
+/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
|
|
+/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
|
|
+/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
|
|
+/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
|
|
+/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
|
|
+/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
|
|
+/* POSSIBILITY OF SUCH DAMAGE. */
|
|
+/* */
|
|
+/*********************************************************************/
|
|
+
|
|
+#include <stdio.h>
|
|
+#include <stdlib.h>
|
|
+#include "common.h"
|
|
+
|
|
+#define SMP_THRESHOLD_MIN 65536.0
|
|
+#define ERROR_NAME "SBGEMMT "
|
|
+
|
|
+#ifndef GEMM_MULTITHREAD_THRESHOLD
|
|
+#define GEMM_MULTITHREAD_THRESHOLD 4
|
|
+#endif
|
|
+
|
|
+#ifndef CBLAS
|
|
+
|
|
+void NAME(char *UPLO, char *TRANSA, char *TRANSB,
|
|
+ blasint * M, blasint * K,
|
|
+ FLOAT * Alpha,
|
|
+ IFLOAT * a, blasint * ldA,
|
|
+ IFLOAT * b, blasint * ldB, FLOAT * Beta, FLOAT * c, blasint * ldC)
|
|
+{
|
|
+
|
|
+ blasint m, k;
|
|
+ blasint lda, ldb, ldc;
|
|
+ int transa, transb, uplo;
|
|
+ blasint info;
|
|
+
|
|
+ char transA, transB, Uplo;
|
|
+ blasint nrowa, nrowb;
|
|
+ IFLOAT *buffer;
|
|
+ IFLOAT *aa, *bb;
|
|
+ FLOAT *cc;
|
|
+ FLOAT alpha, beta;
|
|
+
|
|
+ PRINT_DEBUG_NAME;
|
|
+
|
|
+ m = *M;
|
|
+ k = *K;
|
|
+
|
|
+ alpha = *Alpha;
|
|
+ beta = *Beta;
|
|
+
|
|
+ lda = *ldA;
|
|
+ ldb = *ldB;
|
|
+ ldc = *ldC;
|
|
+
|
|
+ transA = *TRANSA;
|
|
+ transB = *TRANSB;
|
|
+ Uplo = *UPLO;
|
|
+ TOUPPER(transA);
|
|
+ TOUPPER(transB);
|
|
+ TOUPPER(Uplo);
|
|
+
|
|
+ transa = -1;
|
|
+ transb = -1;
|
|
+ uplo = -1;
|
|
+
|
|
+ if (transA == 'N')
|
|
+ transa = 0;
|
|
+ if (transA == 'T')
|
|
+ transa = 1;
|
|
+
|
|
+ if (transA == 'R')
|
|
+ transa = 0;
|
|
+ if (transA == 'C')
|
|
+ transa = 1;
|
|
+
|
|
+ if (transB == 'N')
|
|
+ transb = 0;
|
|
+ if (transB == 'T')
|
|
+ transb = 1;
|
|
+
|
|
+ if (transB == 'R')
|
|
+ transb = 0;
|
|
+ if (transB == 'C')
|
|
+ transb = 1;
|
|
+
|
|
+ if (Uplo == 'U')
|
|
+ uplo = 0;
|
|
+ if (Uplo == 'L')
|
|
+ uplo = 1;
|
|
+ nrowa = m;
|
|
+ if (transa & 1) nrowa = k;
|
|
+ nrowb = k;
|
|
+ if (transb & 1) nrowb = m;
|
|
+
|
|
+ info = 0;
|
|
+
|
|
+ if (ldc < MAX(1, m))
|
|
+ info = 13;
|
|
+ if (ldb < MAX(1, nrowb))
|
|
+ info = 10;
|
|
+ if (lda < MAX(1, nrowa))
|
|
+ info = 8;
|
|
+ if (k < 0)
|
|
+ info = 5;
|
|
+ if (m < 0)
|
|
+ info = 4;
|
|
+ if (transb < 0)
|
|
+ info = 3;
|
|
+ if (transa < 0)
|
|
+ info = 2;
|
|
+ if (uplo < 0)
|
|
+ info = 1;
|
|
+
|
|
+ if (info != 0) {
|
|
+ BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
|
|
+ return;
|
|
+ }
|
|
+#else
|
|
+
|
|
+void CNAME(enum CBLAS_ORDER order, enum CBLAS_UPLO Uplo,
|
|
+ enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANSPOSE TransB, blasint m,
|
|
+ blasint k,
|
|
+ FLOAT alpha,
|
|
+ IFLOAT * A, blasint LDA,
|
|
+ IFLOAT * B, blasint LDB, FLOAT beta, FLOAT * c, blasint ldc)
|
|
+{
|
|
+ IFLOAT *aa, *bb;
|
|
+ FLOAT *cc;
|
|
+
|
|
+ int transa, transb, uplo;
|
|
+ blasint info;
|
|
+ blasint lda, ldb;
|
|
+ IFLOAT *a, *b;
|
|
+ XFLOAT *buffer;
|
|
+
|
|
+ PRINT_DEBUG_CNAME;
|
|
+
|
|
+ uplo = -1;
|
|
+ transa = -1;
|
|
+ transb = -1;
|
|
+ info = 0;
|
|
+
|
|
+ if (order == CblasColMajor) {
|
|
+ if (Uplo == CblasUpper) uplo = 0;
|
|
+ if (Uplo == CblasLower) uplo = 1;
|
|
+
|
|
+ if (TransA == CblasNoTrans)
|
|
+ transa = 0;
|
|
+ if (TransA == CblasTrans)
|
|
+ transa = 1;
|
|
+
|
|
+ if (TransA == CblasConjNoTrans)
|
|
+ transa = 0;
|
|
+ if (TransA == CblasConjTrans)
|
|
+ transa = 1;
|
|
+
|
|
+ if (TransB == CblasNoTrans)
|
|
+ transb = 0;
|
|
+ if (TransB == CblasTrans)
|
|
+ transb = 1;
|
|
+
|
|
+ if (TransB == CblasConjNoTrans)
|
|
+ transb = 0;
|
|
+ if (TransB == CblasConjTrans)
|
|
+ transb = 1;
|
|
+
|
|
+ a = (void *)A;
|
|
+ b = (void *)B;
|
|
+ lda = LDA;
|
|
+ ldb = LDB;
|
|
+
|
|
+ info = -1;
|
|
+
|
|
+ blasint nrowa;
|
|
+ blasint nrowb;
|
|
+ nrowa = m;
|
|
+ if (transa & 1) nrowa = k;
|
|
+ nrowb = k;
|
|
+ if (transb & 1) nrowb = m;
|
|
+
|
|
+ if (ldc < MAX(1, m))
|
|
+ info = 13;
|
|
+ if (ldb < MAX(1, nrowb))
|
|
+ info = 10;
|
|
+ if (lda < MAX(1, nrowa))
|
|
+ info = 8;
|
|
+ if (k < 0)
|
|
+ info = 5;
|
|
+ if (m < 0)
|
|
+ info = 4;
|
|
+ if (transb < 0)
|
|
+ info = 3;
|
|
+ if (transa < 0)
|
|
+ info = 2;
|
|
+ if (uplo < 0)
|
|
+ info = 1;
|
|
+ }
|
|
+
|
|
+ if (order == CblasRowMajor) {
|
|
+
|
|
+ a = (void *)B;
|
|
+ b = (void *)A;
|
|
+
|
|
+ lda = LDB;
|
|
+ ldb = LDA;
|
|
+
|
|
+ if (Uplo == CblasUpper) uplo = 0;
|
|
+ if (Uplo == CblasLower) uplo = 1;
|
|
+
|
|
+ if (TransB == CblasNoTrans)
|
|
+ transa = 0;
|
|
+ if (TransB == CblasTrans)
|
|
+ transa = 1;
|
|
+
|
|
+ if (TransB == CblasConjNoTrans)
|
|
+ transa = 0;
|
|
+ if (TransB == CblasConjTrans)
|
|
+ transa = 1;
|
|
+
|
|
+ if (TransA == CblasNoTrans)
|
|
+ transb = 0;
|
|
+ if (TransA == CblasTrans)
|
|
+ transb = 1;
|
|
+
|
|
+ if (TransA == CblasConjNoTrans)
|
|
+ transb = 0;
|
|
+ if (TransA == CblasConjTrans)
|
|
+ transb = 1;
|
|
+
|
|
+ info = -1;
|
|
+
|
|
+ blasint ncola;
|
|
+ blasint ncolb;
|
|
+
|
|
+ ncola = m;
|
|
+ if (transa & 1) ncola = k;
|
|
+ ncolb = k;
|
|
+
|
|
+ if (transb & 1) {
|
|
+ ncolb = m;
|
|
+ }
|
|
+
|
|
+ if (ldc < MAX(1,m))
|
|
+ info = 13;
|
|
+ if (ldb < MAX(1, ncolb))
|
|
+ info = 8;
|
|
+ if (lda < MAX(1, ncola))
|
|
+ info = 10;
|
|
+ if (k < 0)
|
|
+ info = 5;
|
|
+ if (m < 0)
|
|
+ info = 4;
|
|
+ if (transb < 0)
|
|
+ info = 2;
|
|
+ if (transa < 0)
|
|
+ info = 3;
|
|
+ if (uplo < 0)
|
|
+ info = 1;
|
|
+ }
|
|
+
|
|
+ if (info >= 0) {
|
|
+ BLASFUNC(xerbla) (ERROR_NAME, &info, sizeof(ERROR_NAME));
|
|
+ return;
|
|
+ }
|
|
+
|
|
+#endif
|
|
+ int buffer_size;
|
|
+ blasint i, j;
|
|
+
|
|
+#ifdef SMP
|
|
+ int nthreads;
|
|
+#endif
|
|
+
|
|
+
|
|
+#ifdef SMP
|
|
+ static int (*gemv_thread[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *,
|
|
+ BLASLONG, IFLOAT *, BLASLONG, FLOAT,
|
|
+ FLOAT *, BLASLONG, int) = {
|
|
+ sbgemv_thread_n, sbgemv_thread_t,
|
|
+ };
|
|
+#endif
|
|
+ int (*gemv[]) (BLASLONG, BLASLONG, FLOAT, IFLOAT *, BLASLONG,
|
|
+ IFLOAT *, BLASLONG, FLOAT, FLOAT *, BLASLONG) = {
|
|
+ SBGEMV_N, SBGEMV_T,};
|
|
+
|
|
+
|
|
+ if (m == 0)
|
|
+ return;
|
|
+
|
|
+ IDEBUG_START;
|
|
+
|
|
+ const blasint incb = ((transb & 1) == 0) ? 1 : ldb;
|
|
+
|
|
+ if (uplo == 1) {
|
|
+ for (i = 0; i < m; i++) {
|
|
+ j = m - i;
|
|
+
|
|
+ aa = a + i;
|
|
+ bb = b + i * ldb;
|
|
+ if (transa & 1) {
|
|
+ aa = a + lda * i;
|
|
+ }
|
|
+ if (transb & 1)
|
|
+ bb = b + i;
|
|
+ cc = c + i * ldc + i;
|
|
+
|
|
+#if 0
|
|
+ if (beta != ONE)
|
|
+ SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
|
|
+
|
|
+ if (alpha == ZERO)
|
|
+ continue;
|
|
+#endif
|
|
+
|
|
+ IDEBUG_START;
|
|
+
|
|
+ buffer_size = j + k + 128 / sizeof(FLOAT);
|
|
+#ifdef WINDOWS_ABI
|
|
+ buffer_size += 160 / sizeof(FLOAT);
|
|
+#endif
|
|
+ // for alignment
|
|
+ buffer_size = (buffer_size + 3) & ~3;
|
|
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
|
|
+
|
|
+#ifdef SMP
|
|
+
|
|
+ if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD)
|
|
+ nthreads = 1;
|
|
+ else
|
|
+ nthreads = num_cpu_avail(2);
|
|
+
|
|
+ if (nthreads == 1) {
|
|
+#endif
|
|
+
|
|
+ if (!(transa & 1))
|
|
+ (gemv[(int)transa]) (j, k, alpha, aa, lda,
|
|
+ bb, incb, beta, cc, 1);
|
|
+ else
|
|
+ (gemv[(int)transa]) (k, j, alpha, aa, lda,
|
|
+ bb, incb, beta, cc, 1);
|
|
+
|
|
+#ifdef SMP
|
|
+ } else {
|
|
+ if (!(transa & 1))
|
|
+ (gemv_thread[(int)transa]) (j, k, alpha, aa,
|
|
+ lda, bb, incb, beta, cc,
|
|
+ 1, nthreads);
|
|
+ else
|
|
+ (gemv_thread[(int)transa]) (k, j, alpha, aa,
|
|
+ lda, bb, incb, beta, cc,
|
|
+ 1, nthreads);
|
|
+
|
|
+ }
|
|
+#endif
|
|
+
|
|
+ STACK_FREE(buffer);
|
|
+ }
|
|
+ } else {
|
|
+
|
|
+ for (i = 0; i < m; i++) {
|
|
+ j = i + 1;
|
|
+
|
|
+ bb = b + i * ldb;
|
|
+ if (transb & 1) {
|
|
+ bb = b + i;
|
|
+ }
|
|
+ cc = c + i * ldc;
|
|
+
|
|
+#if 0
|
|
+ if (beta != ONE)
|
|
+ SCAL_K(l, 0, 0, beta, cc, 1, NULL, 0, NULL, 0);
|
|
+
|
|
+ if (alpha == ZERO)
|
|
+ continue;
|
|
+#endif
|
|
+ IDEBUG_START;
|
|
+
|
|
+ buffer_size = j + k + 128 / sizeof(FLOAT);
|
|
+#ifdef WINDOWS_ABI
|
|
+ buffer_size += 160 / sizeof(FLOAT);
|
|
+#endif
|
|
+ // for alignment
|
|
+ buffer_size = (buffer_size + 3) & ~3;
|
|
+ STACK_ALLOC(buffer_size, IFLOAT, buffer);
|
|
+
|
|
+#ifdef SMP
|
|
+
|
|
+ if (1L * j * k < 2304L * GEMM_MULTITHREAD_THRESHOLD)
|
|
+ nthreads = 1;
|
|
+ else
|
|
+ nthreads = num_cpu_avail(2);
|
|
+
|
|
+ if (nthreads == 1) {
|
|
+#endif
|
|
+
|
|
+ if (!(transa & 1))
|
|
+ (gemv[(int)transa]) (j, k, alpha, a, lda, bb,
|
|
+ incb, beta, cc, 1);
|
|
+ else
|
|
+ (gemv[(int)transa]) (k, j, alpha, a, lda, bb,
|
|
+ incb, beta, cc, 1);
|
|
+
|
|
+#ifdef SMP
|
|
+ } else {
|
|
+ if (!(transa & 1))
|
|
+ (gemv_thread[(int)transa]) (j, k, alpha, a, lda,
|
|
+ bb, incb, beta, cc, 1,
|
|
+ nthreads);
|
|
+ else
|
|
+ (gemv_thread[(int)transa]) (k, j, alpha, a, lda,
|
|
+ bb, incb, beta, cc, 1,
|
|
+ nthreads);
|
|
+ }
|
|
+#endif
|
|
+
|
|
+ STACK_FREE(buffer);
|
|
+ }
|
|
+ }
|
|
+
|
|
+ IDEBUG_END;
|
|
+
|
|
+ return;
|
|
+}
|
|
--
|
|
2.41.0
|
|
|
|
From 9a4c2d61a345866e4540f9d6da87eb881419b411 Mon Sep 17 00:00:00 2001
|
|
From: Honza Horak <hhorak@redhat.com>
|
|
Date: Fri, 9 Feb 2024 09:54:52 +0100
|
|
Subject: [PATCH 3/6] fix type conversion warnings
|
|
|
|
upstream commit:
|
|
|
|
commit fb99fc2e6e4ec8ecdcfffe1ca1aeb787464d2825
|
|
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
|
|
Date: Wed Feb 7 13:42:08 2024 +0100
|
|
|
|
fix type conversion warnings
|
|
---
|
|
test/compare_sgemm_sbgemm.c | 18 ++++++++++++++----
|
|
1 file changed, 14 insertions(+), 4 deletions(-)
|
|
|
|
diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c
|
|
index cf808b56d..4afa8bf93 100644
|
|
--- a/test/compare_sgemm_sbgemm.c
|
|
+++ b/test/compare_sgemm_sbgemm.c
|
|
@@ -81,6 +81,16 @@ float16to32 (bfloat16_bits f16)
|
|
return f32.v;
|
|
}
|
|
|
|
+float
|
|
+float32to16 (float32_bits f32)
|
|
+{
|
|
+ bfloat16_bits f16;
|
|
+ f16.bits.s = f32.bits.s;
|
|
+ f16.bits.e = f32.bits.e;
|
|
+ f16.bits.m = (uint32_t) f32.bits.m >> 16;
|
|
+ return f32.v;
|
|
+}
|
|
+
|
|
int
|
|
main (int argc, char *argv[])
|
|
{
|
|
@@ -108,16 +118,16 @@ main (int argc, char *argv[])
|
|
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
|
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
|
C[j * k + i] = 0;
|
|
- AA[j * k + i].v = *(uint32_t *) & A[j * k + i] >> 16;
|
|
- BB[j * k + i].v = *(uint32_t *) & B[j * k + i] >> 16;
|
|
+ AA[j * k + i].v = float32to16( A[j * k + i] );
|
|
+ BB[j * k + i].v = float32to16( B[j * k + i] );
|
|
CC[j * k + i] = 0;
|
|
DD[j * k + i] = 0;
|
|
}
|
|
}
|
|
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
|
|
&m, B, &k, &beta, C, &m);
|
|
- SBGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
|
|
- &m, BB, &k, &beta, CC, &m);
|
|
+ SBGEMM (&transA, &transB, &m, &n, &k, &alpha, (bfloat16*) AA,
|
|
+ &m, (bfloat16*)BB, &k, &beta, CC, &m);
|
|
for (i = 0; i < n; i++)
|
|
for (j = 0; j < m; j++)
|
|
if (fabs (CC[i * m + j] - C[i * m + j]) > 1.0)
|
|
--
|
|
2.41.0
|
|
|
|
From 5593507ddbd5d35d088cd4db6285de6b9d84a405 Mon Sep 17 00:00:00 2001
|
|
From: Honza Horak <hhorak@redhat.com>
|
|
Date: Fri, 9 Feb 2024 09:56:11 +0100
|
|
Subject: [PATCH 4/6] fix prototype for c/zaxpby
|
|
|
|
Upstream commit:
|
|
|
|
commit b3fa16345d83b723b8984b78dc6a2bb5d9f3d479
|
|
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
|
|
Date: Thu Feb 8 13:15:34 2024 +0100
|
|
|
|
fix prototype for c/zaxpby
|
|
---
|
|
common_interface.h | 4 ++--
|
|
1 file changed, 2 insertions(+), 2 deletions(-)
|
|
|
|
diff --git a/common_interface.h b/common_interface.h
|
|
index 318827920..1f6cb5f6d 100644
|
|
--- a/common_interface.h
|
|
+++ b/common_interface.h
|
|
@@ -764,8 +764,8 @@ xdouble BLASFUNC(qlamc3)(xdouble *, xdouble *);
|
|
|
|
void BLASFUNC(saxpby) (blasint *, float *, float *, blasint *, float *, float *, blasint *);
|
|
void BLASFUNC(daxpby) (blasint *, double *, double *, blasint *, double *, double *, blasint *);
|
|
-void BLASFUNC(caxpby) (blasint *, float *, float *, blasint *, float *, float *, blasint *);
|
|
-void BLASFUNC(zaxpby) (blasint *, double *, double *, blasint *, double *, double *, blasint *);
|
|
+void BLASFUNC(caxpby) (blasint *, void *, float *, blasint *, void *, float *, blasint *);
|
|
+void BLASFUNC(zaxpby) (blasint *, void *, double *, blasint *, void *, double *, blasint *);
|
|
|
|
void BLASFUNC(somatcopy) (char *, char *, blasint *, blasint *, float *, float *, blasint *, float *, blasint *);
|
|
void BLASFUNC(domatcopy) (char *, char *, blasint *, blasint *, double *, double *, blasint *, double *, blasint *);
|
|
--
|
|
2.41.0
|
|
|
|
From 42b30ed2c54034b2b1dbb15bb9e3e705e704b6a9 Mon Sep 17 00:00:00 2001
|
|
From: Honza Horak <hhorak@redhat.com>
|
|
Date: Fri, 9 Feb 2024 09:56:46 +0100
|
|
Subject: [PATCH 5/6] fix incompatible pointer types
|
|
|
|
Upstream commit:
|
|
|
|
commit 500ac4de5e20596d5cd797d745db97dd0a62ff86
|
|
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
|
|
Date: Thu Feb 8 13:18:34 2024 +0100
|
|
|
|
fix incompatible pointer types
|
|
---
|
|
interface/zaxpby.c | 4 +++-
|
|
1 file changed, 3 insertions(+), 1 deletion(-)
|
|
|
|
diff --git a/interface/zaxpby.c b/interface/zaxpby.c
|
|
index 3a4db7403..e5065270d 100644
|
|
--- a/interface/zaxpby.c
|
|
+++ b/interface/zaxpby.c
|
|
@@ -39,12 +39,14 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
#ifndef CBLAS
|
|
|
|
-void NAME(blasint *N, FLOAT *ALPHA, FLOAT *x, blasint *INCX, FLOAT *BETA, FLOAT *y, blasint *INCY)
|
|
+void NAME(blasint *N, void *VALPHA, FLOAT *x, blasint *INCX, void *VBETA, FLOAT *y, blasint *INCY)
|
|
{
|
|
|
|
blasint n = *N;
|
|
blasint incx = *INCX;
|
|
blasint incy = *INCY;
|
|
+ FLOAT* ALPHA = (FLOAT*) VALPHA;
|
|
+ FLOAT* BETA = (FLOAT*) VBETA;
|
|
|
|
#else
|
|
|
|
--
|
|
2.41.0
|
|
|
|
From 1c525b6e704523912a04fbd026300a2ff95341f3 Mon Sep 17 00:00:00 2001
|
|
From: Honza Horak <hhorak@redhat.com>
|
|
Date: Fri, 9 Feb 2024 15:29:17 +0100
|
|
Subject: [PATCH 6/6] fix sbgemm bfloat16 conversion errors introduced in PR
|
|
4488
|
|
|
|
Upstream commit:
|
|
|
|
commit e9f480111e1d5b6f69c8053f79375b0a4242712f
|
|
Author: Martin Kroeker <martin@ruby.chemie.uni-freiburg.de>
|
|
Date: Wed Feb 7 19:57:18 2024 +0100
|
|
|
|
fix sbgemm bfloat16 conversion errors introduced in PR 4488
|
|
---
|
|
test/compare_sgemm_sbgemm.c | 18 ++++++------------
|
|
1 file changed, 6 insertions(+), 12 deletions(-)
|
|
|
|
diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c
|
|
index 4afa8bf93..bc74233ab 100644
|
|
--- a/test/compare_sgemm_sbgemm.c
|
|
+++ b/test/compare_sgemm_sbgemm.c
|
|
@@ -81,16 +81,6 @@ float16to32 (bfloat16_bits f16)
|
|
return f32.v;
|
|
}
|
|
|
|
-float
|
|
-float32to16 (float32_bits f32)
|
|
-{
|
|
- bfloat16_bits f16;
|
|
- f16.bits.s = f32.bits.s;
|
|
- f16.bits.e = f32.bits.e;
|
|
- f16.bits.m = (uint32_t) f32.bits.m >> 16;
|
|
- return f32.v;
|
|
-}
|
|
-
|
|
int
|
|
main (int argc, char *argv[])
|
|
{
|
|
@@ -110,6 +100,8 @@ main (int argc, char *argv[])
|
|
float C[m * n];
|
|
bfloat16_bits AA[m * k], BB[k * n];
|
|
float DD[m * n], CC[m * n];
|
|
+ bfloat16 atmp,btmp;
|
|
+ blasint one=1;
|
|
|
|
for (j = 0; j < m; j++)
|
|
{
|
|
@@ -118,8 +110,10 @@ main (int argc, char *argv[])
|
|
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
|
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
|
|
C[j * k + i] = 0;
|
|
- AA[j * k + i].v = float32to16( A[j * k + i] );
|
|
- BB[j * k + i].v = float32to16( B[j * k + i] );
|
|
+ sbstobf16_(&one, &A[j*k+i], &one, &atmp, &one);
|
|
+ sbstobf16_(&one, &B[j*k+i], &one, &btmp, &one);
|
|
+ AA[j * k + i].v = atmp;
|
|
+ BB[j * k + i].v = btmp;
|
|
CC[j * k + i] = 0;
|
|
DD[j * k + i] = 0;
|
|
}
|
|
--
|
|
2.41.0
|
|
|