diff mbox series

[v3,64/81] target/arm: Implement SVE2 complex integer multiply-add (indexed)

Message ID 20200918183751.2787647-65-richard.henderson@linaro.org
State Superseded
Headers show
Series target/arm: Implement SVE2 | expand

Commit Message

Richard Henderson Sept. 18, 2020, 6:37 p.m. UTC
Signed-off-by: Richard Henderson <richard.henderson@linaro.org>
---
 target/arm/helper-sve.h    |   9 +++
 target/arm/sve.decode      |  12 ++++
 target/arm/sve_helper.c    | 142 +++++++++++++++++++++++++++++++------
 target/arm/translate-sve.c |  38 +++++++---
 4 files changed, 169 insertions(+), 32 deletions(-)
diff mbox series

Patch

diff --git a/target/arm/helper-sve.h b/target/arm/helper-sve.h
index 90872fa0e6..671b3a8804 100644
--- a/target/arm/helper-sve.h
+++ b/target/arm/helper-sve.h
@@ -2715,3 +2715,12 @@  DEF_HELPER_FLAGS_5(sve2_umlsl_idx_s, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, i32)
 DEF_HELPER_FLAGS_5(sve2_umlsl_idx_d, TCG_CALL_NO_RWG,
                    void, ptr, ptr, ptr, ptr, i32)
+
+DEF_HELPER_FLAGS_5(sve2_cmla_idx_h, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, i32)
+DEF_HELPER_FLAGS_5(sve2_cmla_idx_s, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, i32)
+DEF_HELPER_FLAGS_5(sve2_sqrdcmlah_idx_h, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, i32)
+DEF_HELPER_FLAGS_5(sve2_sqrdcmlah_idx_s, TCG_CALL_NO_RWG,
+                   void, ptr, ptr, ptr, ptr, i32)
diff --git a/target/arm/sve.decode b/target/arm/sve.decode
index da77ad689f..e8011fe91b 100644
--- a/target/arm/sve.decode
+++ b/target/arm/sve.decode
@@ -825,6 +825,18 @@  SQDMLSLB_zzxw_d 01000100 .. 1 ..... 0011.0 ..... .....          @rrxw_d
 SQDMLSLT_zzxw_s 01000100 .. 1 ..... 0011.1 ..... .....          @rrxw_s
 SQDMLSLT_zzxw_d 01000100 .. 1 ..... 0011.1 ..... .....          @rrxw_d
 
+# SVE2 complex integer multiply-add (indexed)
+CMLA_zzxz_h     01000100 10 1 index:2 rm:3 0110 rot:2 rn:5 rd:5 \
+                ra=%reg_movprfx
+CMLA_zzxz_s     01000100 11 1 index:1 rm:4 0110 rot:2 rn:5 rd:5 \
+                ra=%reg_movprfx
+
+# SVE2 complex saturating integer multiply-add (indexed)
+SQRDCMLAH_zzxz_h  01000100 10 1 index:2 rm:3 0111 rot:2 rn:5 rd:5 \
+                  ra=%reg_movprfx
+SQRDCMLAH_zzxz_s  01000100 11 1 index:1 rm:4 0111 rot:2 rn:5 rd:5 \
+                  ra=%reg_movprfx
+
 # SVE2 multiply-add long (indexed)
 SMLALB_zzxw_s   01000100 .. 1 ..... 1000.0 ..... .....          @rrxw_s
 SMLALB_zzxw_d   01000100 .. 1 ..... 1000.0 ..... .....          @rrxw_d
diff --git a/target/arm/sve_helper.c b/target/arm/sve_helper.c
index f71df89cb7..b2f24d5cfd 100644
--- a/target/arm/sve_helper.c
+++ b/target/arm/sve_helper.c
@@ -1466,34 +1466,132 @@  void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
     }                                                           \
 }
 
-#define do_cmla(N, M, A, S) (A + (N * M) * (S ? -1 : 1))
+static int8_t do_cmla_b(int8_t n, int8_t m, int8_t a, bool sub)
+{
+    return n * m * (sub ? -1 : 1) + a;
+}
 
-DO_CMLA(sve2_cmla_zzzz_b, uint8_t, H1, do_cmla)
-DO_CMLA(sve2_cmla_zzzz_h, uint16_t, H2, do_cmla)
-DO_CMLA(sve2_cmla_zzzz_s, uint32_t, H4, do_cmla)
-DO_CMLA(sve2_cmla_zzzz_d, uint64_t,   , do_cmla)
+static int16_t do_cmla_h(int16_t n, int16_t m, int16_t a, bool sub)
+{
+    return n * m * (sub ? -1 : 1) + a;
+}
 
-#define DO_SQRDMLAH_B(N, M, A, S) \
-    do_sqrdmlah_b(N, M, A, S, true)
-#define DO_SQRDMLAH_H(N, M, A, S) \
-    ({ uint32_t discard; do_sqrdmlah_h(N, M, A, S, true, &discard); })
-#define DO_SQRDMLAH_S(N, M, A, S) \
-    ({ uint32_t discard; do_sqrdmlah_s(N, M, A, S, true, &discard); })
-#define DO_SQRDMLAH_D(N, M, A, S) \
-    do_sqrdmlah_d(N, M, A, S, true)
+static int32_t do_cmla_s(int32_t n, int32_t m, int32_t a, bool sub)
+{
+    return n * m * (sub ? -1 : 1) + a;
+}
 
-DO_CMLA(sve2_sqrdcmlah_zzzz_b, int8_t, H1, DO_SQRDMLAH_B)
-DO_CMLA(sve2_sqrdcmlah_zzzz_h, int16_t, H2, DO_SQRDMLAH_H)
-DO_CMLA(sve2_sqrdcmlah_zzzz_s, int32_t, H4, DO_SQRDMLAH_S)
-DO_CMLA(sve2_sqrdcmlah_zzzz_d, int64_t,   , DO_SQRDMLAH_D)
+static int64_t do_cmla_d(int64_t n, int64_t m, int64_t a, bool sub)
+{
+    return n * m * (sub ? -1 : 1) + a;
+}
+
+DO_CMLA(sve2_cmla_zzzz_b, uint8_t, H1, do_cmla_b)
+DO_CMLA(sve2_cmla_zzzz_h, uint16_t, H2, do_cmla_h)
+DO_CMLA(sve2_cmla_zzzz_s, uint32_t, H4, do_cmla_s)
+DO_CMLA(sve2_cmla_zzzz_d, uint64_t,   , do_cmla_d)
+
+static int8_t do_sqrdcmlah_b(int8_t n, int8_t m, int8_t a, bool sub)
+{
+    return do_sqrdmlah_b(n, m, a, sub, true);
+}
+
+static int16_t do_sqrdcmlah_h(int16_t n, int16_t m, int16_t a, bool sub)
+{
+    uint32_t discard;
+    return do_sqrdmlah_h(n, m, a, sub, true, &discard);
+}
+
+static int32_t do_sqrdcmlah_s(int32_t n, int32_t m, int32_t a, bool sub)
+{
+    uint32_t discard;
+    return do_sqrdmlah_s(n, m, a, sub, true, &discard);
+}
+
+static int64_t do_sqrdcmlah_d(int64_t n, int64_t m, int64_t a, bool sub)
+{
+    return do_sqrdmlah_d(n, m, a, sub, true);
+}
+
+DO_CMLA(sve2_sqrdcmlah_zzzz_b, int8_t, H1, do_sqrdcmlah_b)
+DO_CMLA(sve2_sqrdcmlah_zzzz_h, int16_t, H2, do_sqrdcmlah_h)
+DO_CMLA(sve2_sqrdcmlah_zzzz_s, int32_t, H4, do_sqrdcmlah_s)
+DO_CMLA(sve2_sqrdcmlah_zzzz_d, int64_t,   , do_sqrdcmlah_d)
 
-#undef DO_SQRDMLAH_B
-#undef DO_SQRDMLAH_H
-#undef DO_SQRDMLAH_S
-#undef DO_SQRDMLAH_D
-#undef do_cmla
 #undef DO_CMLA
 
+static void do_cmla_idx_h(int16_t *d, int16_t *n, int16_t *m,
+                          int16_t *a, uint32_t desc,
+                          int16_t (*fn)(int16_t, int16_t, int16_t, bool))
+{
+    intptr_t i, j, oprsz = simd_oprsz(desc);
+    int rot = extract32(desc, SIMD_DATA_SHIFT, 2);
+    int idx = extract32(desc, SIMD_DATA_SHIFT + 2, 2) * 2;
+    int sel_a = rot & 1, sel_b = sel_a ^ 1;
+    bool sub_r = rot == 1 || rot == 2;
+    bool sub_i = rot >= 2;
+
+    for (i = 0; i < oprsz / 2; i += 16 / 2) {
+        int16_t elt2_a = m[H2(i + idx + sel_a)];
+        int16_t elt2_b = m[H2(i + idx + sel_b)];
+
+        for (j = 0; j < 16 / 2; j += 2) {
+            int16_t elt1_a = n[H2(i + j + sel_a)];
+
+            d[H2(i + j)] = fn(elt1_a, elt2_a, a[H2(i + j)], sub_r);
+            d[H2(i + j + 1)] = fn(elt1_a, elt2_b, a[H2(i + j + 1)], sub_i);
+        }
+    }
+}
+
+void HELPER(sve2_cmla_idx_h)(void *vd, void *vn, void *vm,
+                             void *va, uint32_t desc)
+{
+    do_cmla_idx_h(vd, vn, vm, va, desc, do_cmla_h);
+}
+
+void HELPER(sve2_sqrdcmlah_idx_h)(void *vd, void *vn, void *vm,
+                                  void *va, uint32_t desc)
+{
+    do_cmla_idx_h(vd, vn, vm, va, desc, do_sqrdcmlah_h);
+}
+
+static void do_cmla_idx_s(int32_t *d, int32_t *n, int32_t *m,
+                          int32_t *a, uint32_t desc,
+                          int32_t (*fn)(int32_t, int32_t, int32_t, bool))
+{
+    intptr_t i, j, oprsz = simd_oprsz(desc);
+    int rot = extract32(desc, SIMD_DATA_SHIFT, 2);
+    int idx = extract32(desc, SIMD_DATA_SHIFT + 2, 2) * 2;
+    int sel_a = rot & 1, sel_b = sel_a ^ 1;
+    bool sub_r = rot == 1 || rot == 2;
+    bool sub_i = rot >= 2;
+
+    for (i = 0; i < oprsz / 4; i += 16 / 4) {
+        int32_t elt2_a = m[H4(i + idx + sel_a)];
+        int32_t elt2_b = m[H4(i + idx + sel_b)];
+
+        for (j = 0; j < 16 / 4; j += 2) {
+            int32_t elt1_a = n[H4(i + j + sel_a)];
+
+            d[H4(i + j)] = fn(elt1_a, elt2_a, a[H4(i + j)], sub_r);
+            d[H4(i + j + 1)] = fn(elt1_a, elt2_b, a[H4(i + j + 1)], sub_i);
+        }
+    }
+}
+
+void HELPER(sve2_cmla_idx_s)(void *vd, void *vn, void *vm,
+                             void *va, uint32_t desc)
+{
+    do_cmla_idx_s(vd, vn, vm, va, desc, do_cmla_s);
+}
+
+void HELPER(sve2_sqrdcmlah_idx_s)(void *vd, void *vn, void *vm,
+                                  void *va, uint32_t desc)
+{
+    do_cmla_idx_s(vd, vn, vm, va, desc, do_sqrdcmlah_s);
+}
+
 #define DO_ZZXZ(NAME, TYPE, H, OP) \
 void HELPER(NAME)(void *vd, void *vn, void *vm, void *va, uint32_t desc) \
 {                                                                       \
diff --git a/target/arm/translate-sve.c b/target/arm/translate-sve.c
index 248fe4de42..4628198b76 100644
--- a/target/arm/translate-sve.c
+++ b/target/arm/translate-sve.c
@@ -3821,21 +3821,21 @@  static bool trans_DOT_zzzz(DisasContext *s, arg_DOT_zzzz *a)
  * SVE Multiply - Indexed
  */
 
-static bool do_zzxz_data(DisasContext *s, arg_rrxr_esz *a,
+static bool do_zzxz_data(DisasContext *s, int rd, int rn, int rm, int ra,
                          gen_helper_gvec_4 *fn, int data)
 {
     if (fn == NULL) {
         return false;
     }
     if (sve_access_check(s)) {
-        gen_gvec_ool_zzzz(s, fn, a->rd, a->rn, a->rm, a->ra, data);
+        gen_gvec_ool_zzzz(s, fn, rd, rn, rm, ra, data);
     }
     return true;
 }
 
 #define DO_RRXR(NAME, FUNC) \
     static bool NAME(DisasContext *s, arg_rrxr_esz *a)  \
-    { return do_zzxz_data(s, a, FUNC, a->index); }
+    { return do_zzxz_data(s, a->rd, a->rn, a->rm, a->ra, FUNC, a->index); }
 
 DO_RRXR(trans_SDOT_zzxw_s, gen_helper_gvec_sdot_idx_b)
 DO_RRXR(trans_SDOT_zzxw_d, gen_helper_gvec_sdot_idx_h)
@@ -3899,18 +3899,18 @@  DO_SVE2_RRX_TB(trans_SQDMULLT_zzx_d, gen_helper_sve2_sqdmull_idx_d, true)
 
 #undef DO_SVE2_RRX_TB
 
-static bool do_sve2_zzxz_data(DisasContext *s, arg_rrxr_esz *a,
-                              gen_helper_gvec_4 *fn, int data)
+static bool do_sve2_zzxz_data(DisasContext *s, int rd, int rn, int rm,
+                              int ra, gen_helper_gvec_4 *fn, int data)
 {
     if (!dc_isar_feature(aa64_sve2, s)) {
         return false;
     }
-    return do_zzxz_data(s, a, fn, data);
+    return do_zzxz_data(s, rd, rn, rm, ra, fn, data);
 }
 
 #define DO_SVE2_RRXR(NAME, FUNC) \
-    static bool NAME(DisasContext *s, arg_rrxr_esz *a)  \
-    { return do_sve2_zzxz_data(s, a, FUNC, a->index); }
+static bool NAME(DisasContext *s, arg_rrxr_esz *a)  \
+{ return do_sve2_zzxz_data(s, a->rd, a->rn, a->rm, a->ra, FUNC, a->index); }
 
 DO_SVE2_RRXR(trans_MLA_zzxz_h, gen_helper_gvec_mla_idx_h)
 DO_SVE2_RRXR(trans_MLA_zzxz_s, gen_helper_gvec_mla_idx_s)
@@ -3931,8 +3931,11 @@  DO_SVE2_RRXR(trans_SQRDMLSH_zzxz_d, gen_helper_sve2_sqrdmlsh_idx_d)
 #undef DO_SVE2_RRXR
 
 #define DO_SVE2_RRXR_TB(NAME, FUNC, TOP) \
-    static bool NAME(DisasContext *s, arg_rrxr_esz *a)  \
-    { return do_sve2_zzxz_data(s, a, FUNC, (a->index << 1) | TOP); }
+static bool NAME(DisasContext *s, arg_rrxr_esz *a)           \
+{                                                            \
+    return do_sve2_zzxz_data(s, a->rd, a->rn, a->rm, a->ra,  \
+                             FUNC, (a->index << 1) | TOP);   \
+}
 
 DO_SVE2_RRXR_TB(trans_SQDMLALB_zzxw_s, gen_helper_sve2_sqdmlal_idx_s, false)
 DO_SVE2_RRXR_TB(trans_SQDMLALB_zzxw_d, gen_helper_sve2_sqdmlal_idx_d, false)
@@ -3966,6 +3969,21 @@  DO_SVE2_RRXR_TB(trans_UMLSLT_zzxw_d, gen_helper_sve2_umlsl_idx_d, true)
 
 #undef DO_SVE2_RRXR_TB
 
+#define DO_SVE2_RRXR_ROT(NAME, FUNC) \
+static bool trans_##NAME(DisasContext *s, arg_##NAME *a)       \
+{                                                              \
+    return do_sve2_zzxz_data(s, a->rd, a->rn, a->rm, a->ra,    \
+                             FUNC, (a->index << 2) | a->rot);  \
+}
+
+DO_SVE2_RRXR_ROT(CMLA_zzxz_h, gen_helper_sve2_cmla_idx_h)
+DO_SVE2_RRXR_ROT(CMLA_zzxz_s, gen_helper_sve2_cmla_idx_s)
+
+DO_SVE2_RRXR_ROT(SQRDCMLAH_zzxz_h, gen_helper_sve2_sqrdcmlah_idx_h)
+DO_SVE2_RRXR_ROT(SQRDCMLAH_zzxz_s, gen_helper_sve2_sqrdcmlah_idx_s)
+
+#undef DO_SVE2_RRXR_ROT
+
 /*
  *** SVE Floating Point Multiply-Add Indexed Group
  */