Message ID | 20220620175235.60881-31-richard.henderson@linaro.org |
---|---|
State | Superseded |
Headers | show |
Series | target/arm: Scalable Matrix Extension | expand |
On Mon, 20 Jun 2022 at 19:07, Richard Henderson <richard.henderson@linaro.org> wrote: > > Signed-off-by: Richard Henderson <richard.henderson@linaro.org> > +void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, > + void *vpm, void *vst, uint32_t desc) > +{ > + intptr_t row, col, oprsz = simd_maxsz(desc); > + uint32_t neg = simd_data(desc) << 31; > + uint16_t *pn = vpn, *pm = vpm; > + > + bool save_dn = get_default_nan_mode(vst); > + set_default_nan_mode(true, vst); > + > + for (row = 0; row < oprsz; ) { > + uint16_t pa = pn[H2(row >> 4)]; > + do { > + if (pa & 1) { > + void *vza_row = vza + row * sizeof(ARMVectorReg); > + uint32_t n = *(uint32_t *)(vzn + row) ^ neg; > + > + for (col = 0; col < oprsz; ) { > + uint16_t pb = pm[H2(col >> 4)]; > + do { > + if (pb & 1) { > + uint32_t *a = vza_row + col; > + uint32_t *m = vzm + col; > + *a = float32_muladd(n, *m, *a, 0, vst); > + } > + col += 4; > + pb >>= 4; > + } while (col & 15); > + } > + } > + row += 4; > + pa >>= 4; > + } while (row & 15); > + } The code for the double version seems straightforward: row counts from 0 up to the number of rows, and we do something per row. Why is the single precision version doing something with an unrolled loop here? It's confusing that 'oprsz' in the two functions isn't the same thing -- in the double version we divide by the element size, but here we don't. > + > + set_default_nan_mode(save_dn, vst); > +} > + > +void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn, > + void *vpm, void *vst, uint32_t desc) > +{ > + intptr_t row, col, oprsz = simd_oprsz(desc) / 8; > + uint64_t neg = (uint64_t)simd_data(desc) << 63; > + uint64_t *za = vza, *zn = vzn, *zm = vzm; > + uint8_t *pn = vpn, *pm = vpm; > + > + bool save_dn = get_default_nan_mode(vst); > + set_default_nan_mode(true, vst); > + > + for (row = 0; row < oprsz; ++row) { > + if (pn[H1(row)] & 1) { > + uint64_t *za_row = &za[row * sizeof(ARMVectorReg)]; > + uint64_t n = zn[row] ^ neg; > + > + for (col = 0; col < oprsz; ++col) { > + if (pm[H1(col)] & 1) { > + uint64_t *a = &za_row[col]; > + *a = float64_muladd(n, zm[col], *a, 0, vst); > + } > + } > + } > + } > + > + set_default_nan_mode(save_dn, vst); > +} The pseudocode says that we ignore floating point exceptions (ie do not accumulate them in the FPSR) -- it passes fpexc == false to FPMulAdd(). Don't we need to do something special to arrange for that ? thanks -- PMM
On 6/24/22 05:31, Peter Maydell wrote: > On Mon, 20 Jun 2022 at 19:07, Richard Henderson > <richard.henderson@linaro.org> wrote: >> >> Signed-off-by: Richard Henderson <richard.henderson@linaro.org> > >> +void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, >> + void *vpm, void *vst, uint32_t desc) >> +{ >> + intptr_t row, col, oprsz = simd_maxsz(desc); >> + uint32_t neg = simd_data(desc) << 31; >> + uint16_t *pn = vpn, *pm = vpm; >> + >> + bool save_dn = get_default_nan_mode(vst); >> + set_default_nan_mode(true, vst); >> + >> + for (row = 0; row < oprsz; ) { >> + uint16_t pa = pn[H2(row >> 4)]; >> + do { >> + if (pa & 1) { >> + void *vza_row = vza + row * sizeof(ARMVectorReg); >> + uint32_t n = *(uint32_t *)(vzn + row) ^ neg; >> + >> + for (col = 0; col < oprsz; ) { >> + uint16_t pb = pm[H2(col >> 4)]; >> + do { >> + if (pb & 1) { >> + uint32_t *a = vza_row + col; >> + uint32_t *m = vzm + col; >> + *a = float32_muladd(n, *m, *a, 0, vst); >> + } >> + col += 4; >> + pb >>= 4; >> + } while (col & 15); >> + } >> + } >> + row += 4; >> + pa >>= 4; >> + } while (row & 15); >> + } > > The code for the double version seems straightforward: > row counts from 0 up to the number of rows, and we > do something per row. Why is the single precision version > doing something with an unrolled loop here? It's confusing > that 'oprsz' in the two functions isn't the same thing -- > in the double version we divide by the element size, but > here we don't. It's all about the predicate addressing. For doubles, the bits are spaced 8 bits apart, which makes it easy as you see. For singles, the bits are spaced 4 bits apart, which is inconvenient. Anyway, just as over in sve_helper.c, I load uint16_t at a time and shift to find each predicate bit. So it's not unrolled, exactly. There's second loop over predicates. And since this is a matrix op, we get loops nested 4 deep. > The pseudocode says that we ignore floating point exceptions > (ie do not accumulate them in the FPSR) -- it passes fpexc == false > to FPMulAdd(). Don't we need to do something special to arrange > for that ? Oops, somewhere I read that as "do not trap" not "do not accumulate". But R_TGSKG is very clear on this as accumulate. r~
diff --git a/target/arm/helper-sme.h b/target/arm/helper-sme.h index 6f0fce7e2c..727095a3eb 100644 --- a/target/arm/helper-sme.h +++ b/target/arm/helper-sme.h @@ -119,3 +119,8 @@ DEF_HELPER_FLAGS_5(sme_addha_s, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32) DEF_HELPER_FLAGS_5(sme_addva_s, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32) DEF_HELPER_FLAGS_5(sme_addha_d, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32) DEF_HELPER_FLAGS_5(sme_addva_d, TCG_CALL_NO_RWG, void, ptr, ptr, ptr, ptr, i32) + +DEF_HELPER_FLAGS_7(sme_fmopa_s, TCG_CALL_NO_RWG, + void, ptr, ptr, ptr, ptr, ptr, ptr, i32) +DEF_HELPER_FLAGS_7(sme_fmopa_d, TCG_CALL_NO_RWG, + void, ptr, ptr, ptr, ptr, ptr, ptr, i32) diff --git a/target/arm/sme.decode b/target/arm/sme.decode index 8cb6c4053c..ba4774d174 100644 --- a/target/arm/sme.decode +++ b/target/arm/sme.decode @@ -64,3 +64,12 @@ ADDHA_s 11000000 10 01000 0 ... ... ..... 000 .. @adda_32 ADDVA_s 11000000 10 01000 1 ... ... ..... 000 .. @adda_32 ADDHA_d 11000000 11 01000 0 ... ... ..... 00 ... @adda_64 ADDVA_d 11000000 11 01000 1 ... ... ..... 00 ... @adda_64 + +### SME Outer Product + +&op zad zn zm pm pn sub:bool +@op_32 ........ ... zm:5 pm:3 pn:3 zn:5 sub:1 .. zad:2 &op +@op_64 ........ ... zm:5 pm:3 pn:3 zn:5 sub:1 . zad:3 &op + +FMOPA_s 10000000 100 ..... ... ... ..... . 00 .. @op_32 +FMOPA_d 10000000 110 ..... ... ... ..... . 0 ... @op_64 diff --git a/target/arm/sme_helper.c b/target/arm/sme_helper.c index 799e44c047..62d9690cae 100644 --- a/target/arm/sme_helper.c +++ b/target/arm/sme_helper.c @@ -25,6 +25,7 @@ #include "exec/cpu_ldst.h" #include "exec/exec-all.h" #include "qemu/int128.h" +#include "fpu/softfloat.h" #include "vec_internal.h" #include "sve_ldst_internal.h" @@ -897,3 +898,69 @@ void HELPER(sme_addva_d)(void *vzda, void *vzn, void *vpn, } } } + +void HELPER(sme_fmopa_s)(void *vza, void *vzn, void *vzm, void *vpn, + void *vpm, void *vst, uint32_t desc) +{ + intptr_t row, col, oprsz = simd_maxsz(desc); + uint32_t neg = simd_data(desc) << 31; + uint16_t *pn = vpn, *pm = vpm; + + bool save_dn = get_default_nan_mode(vst); + set_default_nan_mode(true, vst); + + for (row = 0; row < oprsz; ) { + uint16_t pa = pn[H2(row >> 4)]; + do { + if (pa & 1) { + void *vza_row = vza + row * sizeof(ARMVectorReg); + uint32_t n = *(uint32_t *)(vzn + row) ^ neg; + + for (col = 0; col < oprsz; ) { + uint16_t pb = pm[H2(col >> 4)]; + do { + if (pb & 1) { + uint32_t *a = vza_row + col; + uint32_t *m = vzm + col; + *a = float32_muladd(n, *m, *a, 0, vst); + } + col += 4; + pb >>= 4; + } while (col & 15); + } + } + row += 4; + pa >>= 4; + } while (row & 15); + } + + set_default_nan_mode(save_dn, vst); +} + +void HELPER(sme_fmopa_d)(void *vza, void *vzn, void *vzm, void *vpn, + void *vpm, void *vst, uint32_t desc) +{ + intptr_t row, col, oprsz = simd_oprsz(desc) / 8; + uint64_t neg = (uint64_t)simd_data(desc) << 63; + uint64_t *za = vza, *zn = vzn, *zm = vzm; + uint8_t *pn = vpn, *pm = vpm; + + bool save_dn = get_default_nan_mode(vst); + set_default_nan_mode(true, vst); + + for (row = 0; row < oprsz; ++row) { + if (pn[H1(row)] & 1) { + uint64_t *za_row = &za[row * sizeof(ARMVectorReg)]; + uint64_t n = zn[row] ^ neg; + + for (col = 0; col < oprsz; ++col) { + if (pm[H1(col)] & 1) { + uint64_t *a = &za_row[col]; + *a = float64_muladd(n, zm[col], *a, 0, vst); + } + } + } + } + + set_default_nan_mode(save_dn, vst); +} diff --git a/target/arm/translate-sme.c b/target/arm/translate-sme.c index e9676b2415..e6e4541e76 100644 --- a/target/arm/translate-sme.c +++ b/target/arm/translate-sme.c @@ -273,3 +273,36 @@ TRANS_FEAT(ADDHA_s, aa64_sme, do_adda, a, MO_32, gen_helper_sme_addha_s) TRANS_FEAT(ADDVA_s, aa64_sme, do_adda, a, MO_32, gen_helper_sme_addva_s) TRANS_FEAT(ADDHA_d, aa64_sme_i16i64, do_adda, a, MO_64, gen_helper_sme_addha_d) TRANS_FEAT(ADDVA_d, aa64_sme_i16i64, do_adda, a, MO_64, gen_helper_sme_addva_d) + +static bool do_outprod_fpst(DisasContext *s, arg_op *a, MemOp esz, + gen_helper_gvec_5_ptr *fn) +{ + uint32_t desc = simd_desc(s->svl, s->svl, a->sub); + TCGv_ptr za, zn, zm, pn, pm, fpst; + + if (!sme_smza_enabled_check(s)) { + return true; + } + + /* Sum XZR+zad to find ZAd. */ + za = get_tile_rowcol(s, esz, 31, a->zad, false); + zn = vec_full_reg_ptr(s, a->zn); + zm = vec_full_reg_ptr(s, a->zm); + pn = pred_full_reg_ptr(s, a->pn); + pm = pred_full_reg_ptr(s, a->pm); + fpst = fpstatus_ptr(FPST_FPCR); + + fn(za, zn, zm, pn, pm, fpst, tcg_constant_i32(desc)); + + tcg_temp_free_ptr(za); + tcg_temp_free_ptr(zn); + tcg_temp_free_ptr(pn); + tcg_temp_free_ptr(pm); + tcg_temp_free_ptr(fpst); + return true; +} + +TRANS_FEAT(FMOPA_s, aa64_sme, do_outprod_fpst, + a, MO_32, gen_helper_sme_fmopa_s) +TRANS_FEAT(FMOPA_d, aa64_sme_f64f64, do_outprod_fpst, + a, MO_64, gen_helper_sme_fmopa_d)
Signed-off-by: Richard Henderson <richard.henderson@linaro.org> --- target/arm/helper-sme.h | 5 +++ target/arm/sme.decode | 9 +++++ target/arm/sme_helper.c | 67 ++++++++++++++++++++++++++++++++++++++ target/arm/translate-sme.c | 33 +++++++++++++++++++ 4 files changed, 114 insertions(+)