diff mbox series

[6/8] target/arm: Prepare bfdotadd() callers for FEAT_EBF support

Message ID 20240730160306.2959745-7-peter.maydell@linaro.org
State Accepted
Headers show
Series target/arm: Implement FEAT_EBF16 | expand

Commit Message

Peter Maydell July 30, 2024, 4:03 p.m. UTC
We use bfdotadd() in four callsites for various helper functions. Currently
this all assumes that we have the FPCR.EBF=0 semantics. For FPCR.EBF=1
we will need to:
 * call a different routine to bfdotadd() because we need to do a
   fused multiply-add rather than separate multiply and add steps
 * use a different float_status that honours the FPCR rounding mode
   and denormal-flushing fields
 * pass in an extra float_status that has been set up to perform
   round-to-odd rounding

To prepare for this, refactor all the callsites so that instead of
   for (...) {
       x = bfdotadd(...);
   }

they are:
   float_status fpst, fpst_odd;
   if (is_ebf(env, &fpst, &fpst_odd)) {
       for (...) {
           x = bfdotadd_ebf(..., &fpst, &fpst_odd);
       }
   } else {
       for (...) {
           x = bfdotadd(..., &fpst);
       }
   }

For the moment the is_ebf() function always returns false, sets up
fpst for EBF=0 semantics and never sets up fpst_odd; bfdotadd_ebf()
will assert if called. We'll fill in the handling for EBF=1 in the
next commit.

This change should be a zero-behaviour-change refactor.

Signed-off-by: Peter Maydell <peter.maydell@linaro.org>
---
 target/arm/tcg/vec_internal.h |  37 ++++++++-
 target/arm/tcg/sme_helper.c   |  74 ++++++++++++------
 target/arm/tcg/vec_helper.c   | 141 +++++++++++++++++++++++++---------
 3 files changed, 192 insertions(+), 60 deletions(-)

Comments

Richard Henderson July 31, 2024, 1:43 a.m. UTC | #1
On 7/31/24 02:03, Peter Maydell wrote:
> We use bfdotadd() in four callsites for various helper functions. Currently
> this all assumes that we have the FPCR.EBF=0 semantics. For FPCR.EBF=1
> we will need to:
>   * call a different routine to bfdotadd() because we need to do a
>     fused multiply-add rather than separate multiply and add steps
>   * use a different float_status that honours the FPCR rounding mode
>     and denormal-flushing fields
>   * pass in an extra float_status that has been set up to perform
>     round-to-odd rounding
> 
> To prepare for this, refactor all the callsites so that instead of
>     for (...) {
>         x = bfdotadd(...);
>     }
> 
> they are:
>     float_status fpst, fpst_odd;
>     if (is_ebf(env, &fpst, &fpst_odd)) {
>         for (...) {
>             x = bfdotadd_ebf(..., &fpst, &fpst_odd);
>         }
>     } else {
>         for (...) {
>             x = bfdotadd(..., &fpst);
>         }
>     }
> 
> For the moment the is_ebf() function always returns false, sets up
> fpst for EBF=0 semantics and never sets up fpst_odd; bfdotadd_ebf()
> will assert if called. We'll fill in the handling for EBF=1 in the
> next commit.
> 
> This change should be a zero-behaviour-change refactor.
> 
> Signed-off-by: Peter Maydell<peter.maydell@linaro.org>
> ---
>   target/arm/tcg/vec_internal.h |  37 ++++++++-
>   target/arm/tcg/sme_helper.c   |  74 ++++++++++++------
>   target/arm/tcg/vec_helper.c   | 141 +++++++++++++++++++++++++---------
>   3 files changed, 192 insertions(+), 60 deletions(-)

Reviewed-by: Richard Henderson <richard.henderson@linaro.org>

r~
Richard Henderson July 31, 2024, 1:48 a.m. UTC | #2
On 7/31/24 02:03, Peter Maydell wrote:
> @@ -2790,7 +2790,7 @@ DO_MMLA_B(gvec_usmmla_b, do_usmmla_b)
>    * BFloat16 Dot Product
>    */
>   
> -float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
> +bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp)
>   {
>       /* FPCR is ignored for BFDOT and BFMMLA. */
>       float_status bf_status = {
> @@ -2800,29 +2800,50 @@ float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
>           .flush_inputs_to_zero = true,
>           .default_nan_mode = true,
>       };
> +
> +    *statusp = bf_status;
> +    return false;
> +}

Looking at the next patch, I think dropping the local variable is better.

   *statusp = (float_status){
       ...
   };


r~
Peter Maydell July 31, 2024, 12:32 p.m. UTC | #3
On Wed, 31 Jul 2024 at 02:48, Richard Henderson
<richard.henderson@linaro.org> wrote:
>
> On 7/31/24 02:03, Peter Maydell wrote:
> > @@ -2790,7 +2790,7 @@ DO_MMLA_B(gvec_usmmla_b, do_usmmla_b)
> >    * BFloat16 Dot Product
> >    */
> >
> > -float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
> > +bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp)
> >   {
> >       /* FPCR is ignored for BFDOT and BFMMLA. */
> >       float_status bf_status = {
> > @@ -2800,29 +2800,50 @@ float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
> >           .flush_inputs_to_zero = true,
> >           .default_nan_mode = true,
> >       };
> > +
> > +    *statusp = bf_status;
> > +    return false;
> > +}
>
> Looking at the next patch, I think dropping the local variable is better.
>
>    *statusp = (float_status){
>        ...
>    };

Yes, I agree; I've updated this patch and the next accordingly.

-- PMM
diff mbox series

Patch

diff --git a/target/arm/tcg/vec_internal.h b/target/arm/tcg/vec_internal.h
index 3ca1b94ccf9..094f5c169ca 100644
--- a/target/arm/tcg/vec_internal.h
+++ b/target/arm/tcg/vec_internal.h
@@ -223,13 +223,46 @@  int64_t do_sqrdmlah_d(int64_t, int64_t, int64_t, bool, bool);
  * bfdotadd:
  * @sum: addend
  * @e1, @e2: multiplicand vectors
+ * @fpst: floating-point status to use
  *
  * BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
  * The @e1 and @e2 operands correspond to the 32-bit source vector
  * slots and contain two Bfloat16 values each.
  *
- * Corresponds to the ARM pseudocode function BFDotAdd.
+ * Corresponds to the ARM pseudocode function BFDotAdd, specialized
+ * for the FPCR.EBF == 0 case.
  */
-float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2);
+float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst);
+/**
+ * bfdotadd_ebf:
+ * @sum: addend
+ * @e1, @e2: multiplicand vectors
+ * @fpst: floating-point status to use
+ * @fpst_odd: floating-point status to use for round-to-odd operations
+ *
+ * BFloat16 2-way dot product of @e1 & @e2, accumulating with @sum.
+ * The @e1 and @e2 operands correspond to the 32-bit source vector
+ * slots and contain two Bfloat16 values each.
+ *
+ * Corresponds to the ARM pseudocode function BFDotAdd, specialized
+ * for the FPCR.EBF == 1 case.
+ */
+float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2,
+                     float_status *fpst, float_status *fpst_odd);
+
+/**
+ * is_ebf:
+ * @env: CPU state
+ * @statusp: pointer to floating point status to fill in
+ * @oddstatusp: pointer to floating point status to fill in for round-to-odd
+ *
+ * Determine whether a BFDotAdd operation should use FPCR.EBF = 0
+ * or FPCR.EBF = 1 semantics. On return, has initialized *statusp
+ * and *oddstatusp to suitable float_status arguments to use with either
+ * bfdotadd() or bfdotadd_ebf().
+ * Returns true for EBF = 1, false for EBF = 0. (The caller should use this
+ * to decide whether to call bfdotadd() or bfdotadd_ebf().)
+ */
+bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp);
 
 #endif /* TARGET_ARM_VEC_INTERNAL_H */
diff --git a/target/arm/tcg/sme_helper.c b/target/arm/tcg/sme_helper.c
index f172225b2f2..e3fbfa98fa5 100644
--- a/target/arm/tcg/sme_helper.c
+++ b/target/arm/tcg/sme_helper.c
@@ -1086,32 +1086,62 @@  void HELPER(sme_bfmopa)(CPUARMState *env, void *vza, void *vzn, void *vzm,
     intptr_t row, col, oprsz = simd_maxsz(desc);
     uint32_t neg = simd_data(desc) * 0x80008000u;
     uint16_t *pn = vpn, *pm = vpm;
+    float_status fpst, fpst_odd;
 
-    for (row = 0; row < oprsz; ) {
-        uint16_t prow = pn[H2(row >> 4)];
-        do {
-            void *vza_row = vza + tile_vslice_offset(row);
-            uint32_t n = *(uint32_t *)(vzn + H1_4(row));
+    if (is_ebf(env, &fpst, &fpst_odd)) {
+        for (row = 0; row < oprsz; ) {
+            uint16_t prow = pn[H2(row >> 4)];
+            do {
+                void *vza_row = vza + tile_vslice_offset(row);
+                uint32_t n = *(uint32_t *)(vzn + H1_4(row));
 
-            n = f16mop_adj_pair(n, prow, neg);
+                n = f16mop_adj_pair(n, prow, neg);
 
-            for (col = 0; col < oprsz; ) {
-                uint16_t pcol = pm[H2(col >> 4)];
-                do {
-                    if (prow & pcol & 0b0101) {
-                        uint32_t *a = vza_row + H1_4(col);
-                        uint32_t m = *(uint32_t *)(vzm + H1_4(col));
+                for (col = 0; col < oprsz; ) {
+                    uint16_t pcol = pm[H2(col >> 4)];
+                    do {
+                        if (prow & pcol & 0b0101) {
+                            uint32_t *a = vza_row + H1_4(col);
+                            uint32_t m = *(uint32_t *)(vzm + H1_4(col));
 
-                        m = f16mop_adj_pair(m, pcol, 0);
-                        *a = bfdotadd(*a, n, m);
-                    }
-                    col += 4;
-                    pcol >>= 4;
-                } while (col & 15);
-            }
-            row += 4;
-            prow >>= 4;
-        } while (row & 15);
+                            m = f16mop_adj_pair(m, pcol, 0);
+                            *a = bfdotadd_ebf(*a, n, m, &fpst, &fpst_odd);
+                        }
+                        col += 4;
+                        pcol >>= 4;
+                    } while (col & 15);
+                }
+                row += 4;
+                prow >>= 4;
+            } while (row & 15);
+        }
+    } else {
+        for (row = 0; row < oprsz; ) {
+            uint16_t prow = pn[H2(row >> 4)];
+            do {
+                void *vza_row = vza + tile_vslice_offset(row);
+                uint32_t n = *(uint32_t *)(vzn + H1_4(row));
+
+                n = f16mop_adj_pair(n, prow, neg);
+
+                for (col = 0; col < oprsz; ) {
+                    uint16_t pcol = pm[H2(col >> 4)];
+                    do {
+                        if (prow & pcol & 0b0101) {
+                            uint32_t *a = vza_row + H1_4(col);
+                            uint32_t m = *(uint32_t *)(vzm + H1_4(col));
+
+                            m = f16mop_adj_pair(m, pcol, 0);
+                            *a = bfdotadd(*a, n, m, &fpst);
+                        }
+                        col += 4;
+                        pcol >>= 4;
+                    } while (col & 15);
+                }
+                row += 4;
+                prow >>= 4;
+            } while (row & 15);
+        }
     }
 }
 
diff --git a/target/arm/tcg/vec_helper.c b/target/arm/tcg/vec_helper.c
index 77efb5f47d8..baf04a0561b 100644
--- a/target/arm/tcg/vec_helper.c
+++ b/target/arm/tcg/vec_helper.c
@@ -2790,7 +2790,7 @@  DO_MMLA_B(gvec_usmmla_b, do_usmmla_b)
  * BFloat16 Dot Product
  */
 
-float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
+bool is_ebf(CPUARMState *env, float_status *statusp, float_status *oddstatusp)
 {
     /* FPCR is ignored for BFDOT and BFMMLA. */
     float_status bf_status = {
@@ -2800,29 +2800,50 @@  float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2)
         .flush_inputs_to_zero = true,
         .default_nan_mode = true,
     };
+
+    *statusp = bf_status;
+    return false;
+}
+
+float32 bfdotadd(float32 sum, uint32_t e1, uint32_t e2, float_status *fpst)
+{
     float32 t1, t2;
 
     /*
      * Extract each BFloat16 from the element pair, and shift
      * them such that they become float32.
      */
-    t1 = float32_mul(e1 << 16, e2 << 16, &bf_status);
-    t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, &bf_status);
-    t1 = float32_add(t1, t2, &bf_status);
-    t1 = float32_add(sum, t1, &bf_status);
+    t1 = float32_mul(e1 << 16, e2 << 16, fpst);
+    t2 = float32_mul(e1 & 0xffff0000u, e2 & 0xffff0000u, fpst);
+    t1 = float32_add(t1, t2, fpst);
+    t1 = float32_add(sum, t1, fpst);
 
     return t1;
 }
 
+float32 bfdotadd_ebf(float32 sum, uint32_t e1, uint32_t e2,
+                     float_status *fpst, float_status *fpst_odd)
+{
+    g_assert_not_reached();
+}
+
 void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va,
                         void *envp, uint32_t desc)
 {
+    CPUARMState *env = envp;
     intptr_t i, opr_sz = simd_oprsz(desc);
     float32 *d = vd, *a = va;
     uint32_t *n = vn, *m = vm;
+    float_status fpst, fpst_odd;
 
-    for (i = 0; i < opr_sz / 4; ++i) {
-        d[i] = bfdotadd(a[i], n[i], m[i]);
+    if (is_ebf(env, &fpst, &fpst_odd)) {
+        for (i = 0; i < opr_sz / 4; ++i) {
+            d[i] = bfdotadd_ebf(a[i], n[i], m[i], &fpst, &fpst_odd);
+        }
+    } else {
+        for (i = 0; i < opr_sz / 4; ++i) {
+            d[i] = bfdotadd(a[i], n[i], m[i], &fpst);
+        }
     }
     clear_tail(d, opr_sz, simd_maxsz(desc));
 }
@@ -2830,18 +2851,30 @@  void HELPER(gvec_bfdot)(void *vd, void *vn, void *vm, void *va,
 void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm,
                             void *va, void *envp, uint32_t desc)
 {
+    CPUARMState *env = envp;
     intptr_t i, j, opr_sz = simd_oprsz(desc);
     intptr_t index = simd_data(desc);
     intptr_t elements = opr_sz / 4;
     intptr_t eltspersegment = MIN(16 / 4, elements);
     float32 *d = vd, *a = va;
     uint32_t *n = vn, *m = vm;
+    float_status fpst, fpst_odd;
 
-    for (i = 0; i < elements; i += eltspersegment) {
-        uint32_t m_idx = m[i + H4(index)];
+    if (is_ebf(env, &fpst, &fpst_odd)) {
+        for (i = 0; i < elements; i += eltspersegment) {
+            uint32_t m_idx = m[i + H4(index)];
 
-        for (j = i; j < i + eltspersegment; j++) {
-            d[j] = bfdotadd(a[j], n[j], m_idx);
+            for (j = i; j < i + eltspersegment; j++) {
+                d[j] = bfdotadd_ebf(a[j], n[j], m_idx, &fpst, &fpst_odd);
+            }
+        }
+    } else {
+        for (i = 0; i < elements; i += eltspersegment) {
+            uint32_t m_idx = m[i + H4(index)];
+
+            for (j = i; j < i + eltspersegment; j++) {
+                d[j] = bfdotadd(a[j], n[j], m_idx, &fpst);
+            }
         }
     }
     clear_tail(d, opr_sz, simd_maxsz(desc));
@@ -2850,40 +2883,76 @@  void HELPER(gvec_bfdot_idx)(void *vd, void *vn, void *vm,
 void HELPER(gvec_bfmmla)(void *vd, void *vn, void *vm, void *va,
                          void *envp, uint32_t desc)
 {
+    CPUARMState *env = envp;
     intptr_t s, opr_sz = simd_oprsz(desc);
     float32 *d = vd, *a = va;
     uint32_t *n = vn, *m = vm;
+    float_status fpst, fpst_odd;
 
-    for (s = 0; s < opr_sz / 4; s += 4) {
-        float32 sum00, sum01, sum10, sum11;
+    if (is_ebf(env, &fpst, &fpst_odd)) {
+        for (s = 0; s < opr_sz / 4; s += 4) {
+            float32 sum00, sum01, sum10, sum11;
 
-        /*
-         * Process the entire segment at once, writing back the
-         * results only after we've consumed all of the inputs.
-         *
-         * Key to indices by column:
-         *               i   j           i   k             j   k
-         */
-        sum00 = a[s + H4(0 + 0)];
-        sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)]);
-        sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)]);
+            /*
+             * Process the entire segment at once, writing back the
+             * results only after we've consumed all of the inputs.
+             *
+             * Key to indices by column:
+             *               i   j               i   k             j   k
+             */
+            sum00 = a[s + H4(0 + 0)];
+            sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd);
+            sum00 = bfdotadd_ebf(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd);
 
-        sum01 = a[s + H4(0 + 1)];
-        sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)]);
-        sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)]);
+            sum01 = a[s + H4(0 + 1)];
+            sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd);
+            sum01 = bfdotadd_ebf(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd);
 
-        sum10 = a[s + H4(2 + 0)];
-        sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)]);
-        sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)]);
+            sum10 = a[s + H4(2 + 0)];
+            sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst, &fpst_odd);
+            sum10 = bfdotadd_ebf(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst, &fpst_odd);
 
-        sum11 = a[s + H4(2 + 1)];
-        sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)]);
-        sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)]);
+            sum11 = a[s + H4(2 + 1)];
+            sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst, &fpst_odd);
+            sum11 = bfdotadd_ebf(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst, &fpst_odd);
 
-        d[s + H4(0 + 0)] = sum00;
-        d[s + H4(0 + 1)] = sum01;
-        d[s + H4(2 + 0)] = sum10;
-        d[s + H4(2 + 1)] = sum11;
+            d[s + H4(0 + 0)] = sum00;
+            d[s + H4(0 + 1)] = sum01;
+            d[s + H4(2 + 0)] = sum10;
+            d[s + H4(2 + 1)] = sum11;
+        }
+    } else {
+        for (s = 0; s < opr_sz / 4; s += 4) {
+            float32 sum00, sum01, sum10, sum11;
+
+            /*
+             * Process the entire segment at once, writing back the
+             * results only after we've consumed all of the inputs.
+             *
+             * Key to indices by column:
+             *               i   j           i   k             j   k
+             */
+            sum00 = a[s + H4(0 + 0)];
+            sum00 = bfdotadd(sum00, n[s + H4(0 + 0)], m[s + H4(0 + 0)], &fpst);
+            sum00 = bfdotadd(sum00, n[s + H4(0 + 1)], m[s + H4(0 + 1)], &fpst);
+
+            sum01 = a[s + H4(0 + 1)];
+            sum01 = bfdotadd(sum01, n[s + H4(0 + 0)], m[s + H4(2 + 0)], &fpst);
+            sum01 = bfdotadd(sum01, n[s + H4(0 + 1)], m[s + H4(2 + 1)], &fpst);
+
+            sum10 = a[s + H4(2 + 0)];
+            sum10 = bfdotadd(sum10, n[s + H4(2 + 0)], m[s + H4(0 + 0)], &fpst);
+            sum10 = bfdotadd(sum10, n[s + H4(2 + 1)], m[s + H4(0 + 1)], &fpst);
+
+            sum11 = a[s + H4(2 + 1)];
+            sum11 = bfdotadd(sum11, n[s + H4(2 + 0)], m[s + H4(2 + 0)], &fpst);
+            sum11 = bfdotadd(sum11, n[s + H4(2 + 1)], m[s + H4(2 + 1)], &fpst);
+
+            d[s + H4(0 + 0)] = sum00;
+            d[s + H4(0 + 1)] = sum01;
+            d[s + H4(2 + 0)] = sum10;
+            d[s + H4(2 + 1)] = sum11;
+        }
     }
     clear_tail(d, opr_sz, simd_maxsz(desc));
 }