Message ID | 87zi10neb6.fsf@linaro.org |
---|---|
State | New |
Headers | show |
Series | Support fused multiply-adds in fully-masked reductions | expand |
On Wed, May 16, 2018 at 11:26 AM Richard Sandiford < richard.sandiford@linaro.org> wrote: > This patch adds support for fusing a conditional add or subtract > with a multiplication, so that we can use fused multiply-add and > multiply-subtract operations for fully-masked reductions. E.g. > for SVE we vectorise: > double res = 0.0; > for (int i = 0; i < n; ++i) > res += x[i] * y[i]; > using a fully-masked loop in which the loop body has the form: > res_1 = PHI<0(preheader), res_2(latch)>; > avec = IFN_MASK_LOAD (loop_mask, a) > bvec = IFN_MASK_LOAD (loop_mask, b) > prod = avec * bvec; > res_2 = IFN_COND_ADD (loop_mask, res_1, prod); > where the last statement does the equivalent of: > res_2 = loop_mask ? res_1 + prod : res_1; > (operating elementwise). The point of the patch is to convert the last > two statements into a single internal function that is the equivalent of: > res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1; > (again operating elementwise). > All current conditional X operations have the form "do X or don't do X > to the first operand" (add/don't add to first operand, etc.). However, > the FMA optabs and functions are ordered so that the accumulator comes > last. There were two obvious ways of resolving this: break the > convention for conditional operators and have "add/don't add to the > final operand" or break the convention for FMA and put the accumulator > first. The patch goes for the latter, but adds _REV to make it obvious > that the operands are in a different order. Eh. I guess you'll do the same to SAD/DOT_PROD/WIDEN_SUM? That said, I don't really see the "do or not do to the first operand", it's "do or not do the operation on operands 1 to 2 (or 3)". None of the current ops modify operand 1, they all produce a new value, no? > Tested on aarch64-linux-gnu (with and without SVE), aarch64_be-elf > and x86_64-linux-gnu. OK to install? OK, but as said I don't see a reason for the operand order to differ in the first place. Richard. > Richard > 2018-05-16 Richard Sandiford <richard.sandiford@linaro.org> > Alan Hayward <alan.hayward@arm.com> > David Sherwood <david.sherwood@arm.com> > gcc/ > * doc/md.texi (cond_fma_rev, cond_fnma_rev): Document. > * optabs.def (cond_fma_rev, cond_fnma_rev): New optabs. > * internal-fn.def (COND_FMA_REV, COND_FNMA_REV): New internal > functions. > * internal-fn.h (can_interpret_as_conditional_op_p): Declare. > * internal-fn.c (cond_ternary_direct): New macro. > (expand_cond_ternary_optab_fn): Likewise. > (direct_cond_ternary_optab_supported_p): Likewise. > (FOR_EACH_CODE_MAPPING): Likewise. > (get_conditional_internal_fn): Use FOR_EACH_CODE_MAPPING. > (conditional_internal_fn_code): New function. > (can_interpret_as_conditional_op_p): Likewise. > * tree-ssa-math-opts.c (fused_cond_internal_fn): New function. > (convert_mult_to_fma_1): Transform calls to IFN_COND_ADD to > IFN_COND_FMA_REV and calls to IFN_COND_SUB to IFN_COND_FNMA_REV. > (convert_mult_to_fma): Handle calls to IFN_COND_ADD and IFN_COND_SUB. > * genmatch.c (commutative_op): Handle CFN_COND_FMA_REV and > CFN_COND_FNMA_REV. > * config/aarch64/iterators.md (UNSPEC_COND_FMLA): New unspec. > (UNSPEC_COND_FMLS): Likewise. > (optab, sve_fp_op): Handle them. > (SVE_COND_INT_OP): Rename to... > (SVE_COND_INT2_OP): ...this. > (SVE_COND_FP_OP): Rename to... > (SVE_COND_FP2_OP): ...this. > (SVE_COND_FP3_OP): New iterator. > * config/aarch64/aarch64-sve.md (cond_<optab><mode>): Update > for new iterator names. Add a pattern for SVE_COND_FP3_OP. > gcc/testsuite/ > * gcc.target/aarch64/sve/reduc_4.c: New test. > * gcc.target/aarch64/sve/reduc_6.c: Likewise. > * gcc.target/aarch64/sve/reduc_7.c: Likewise. > Index: gcc/doc/md.texi > =================================================================== > --- gcc/doc/md.texi 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/doc/md.texi 2018-05-16 10:23:03.886838736 +0100 > @@ -6367,6 +6367,32 @@ be in a normal C @samp{?:} condition. > Operands 0, 2 and 3 all have mode @var{m}, while operand 1 has the mode > returned by @code{TARGET_VECTORIZE_GET_MASK_MODE}. > +@cindex @code{cond_fma_rev@var{mode}} instruction pattern > +@item @samp{cond_fma_rev@var{mode}} > +Similar to @samp{cond_add@var{m}}, but compute: > +@smallexample > +op0 = op1 ? fma (op3, op4, op2) : op2; > +@end smallexample > +for scalars and: > +@smallexample > +op0[I] = op1[I] ? fma (op3[I], op4[I], op2[I]) : op2[I]; > +@end smallexample > +for vectors. The @samp{_rev} indicates that the addend (operand 2) > +comes first. > + > +@cindex @code{cond_fnma_rev@var{mode}} instruction pattern > +@item @samp{cond_fnma_rev@var{mode}} > +Similar to @samp{cond_fma_rev@var{m}}, but negate operand 3 before > +multiplying it. That is, the instruction performs: > +@smallexample > +op0 = op1 ? fma (-op3, op4, op2) : op2; > +@end smallexample > +for scalars and: > +@smallexample > +op0[I] = op1[I] ? fma (-op3[I], op4[I], op2[I]) : op2[I]; > +@end smallexample > +for vectors. > + > @cindex @code{neg@var{mode}cc} instruction pattern > @item @samp{neg@var{mode}cc} > Similar to @samp{mov@var{mode}cc} but for conditional negation. Conditionally > Index: gcc/optabs.def > =================================================================== > --- gcc/optabs.def 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/optabs.def 2018-05-16 10:23:03.887838686 +0100 > @@ -222,6 +222,8 @@ OPTAB_D (notcc_optab, "not$acc") > OPTAB_D (movcc_optab, "mov$acc") > OPTAB_D (cond_add_optab, "cond_add$a") > OPTAB_D (cond_sub_optab, "cond_sub$a") > +OPTAB_D (cond_fma_rev_optab, "cond_fma_rev$a") > +OPTAB_D (cond_fnma_rev_optab, "cond_fnma_rev$a") > OPTAB_D (cond_and_optab, "cond_and$a") > OPTAB_D (cond_ior_optab, "cond_ior$a") > OPTAB_D (cond_xor_optab, "cond_xor$a") > Index: gcc/internal-fn.def > =================================================================== > --- gcc/internal-fn.def 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/internal-fn.def 2018-05-16 10:23:03.887838686 +0100 > @@ -59,7 +59,8 @@ along with GCC; see the file COPYING3. > - binary: a normal binary optab, such as vec_interleave_lo_<mode> > - ternary: a normal ternary optab, such as fma<mode>4 > - - cond_binary: a conditional binary optab, such as add<mode>cc > + - cond_binary: a conditional binary optab, such as cond_add<mode> > + - cond_ternary: a conditional ternary optab, such as cond_fma_rev<mode> > - fold_left: for scalar = FN (scalar, vector), keyed off the vector mode > @@ -143,6 +144,9 @@ DEF_INTERNAL_OPTAB_FN (FMS, ECF_CONST, f > DEF_INTERNAL_OPTAB_FN (FNMA, ECF_CONST, fnma, ternary) > DEF_INTERNAL_OPTAB_FN (FNMS, ECF_CONST, fnms, ternary) > +DEF_INTERNAL_OPTAB_FN (COND_FMA_REV, ECF_CONST, cond_fma_rev, cond_ternary) > +DEF_INTERNAL_OPTAB_FN (COND_FNMA_REV, ECF_CONST, cond_fnma_rev, cond_ternary) > + > DEF_INTERNAL_OPTAB_FN (COND_ADD, ECF_CONST, cond_add, cond_binary) > DEF_INTERNAL_OPTAB_FN (COND_SUB, ECF_CONST, cond_sub, cond_binary) > DEF_INTERNAL_SIGNED_OPTAB_FN (COND_MIN, ECF_CONST, first, > Index: gcc/internal-fn.h > =================================================================== > --- gcc/internal-fn.h 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/internal-fn.h 2018-05-16 10:23:03.887838686 +0100 > @@ -191,6 +191,8 @@ direct_internal_fn_supported_p (internal > extern bool set_edom_supported_p (void); > extern internal_fn get_conditional_internal_fn (tree_code); > +extern bool can_interpret_as_conditional_op_p (gimple *, tree_code *, > + tree *, tree (&)[3]); > extern bool internal_load_fn_p (internal_fn); > extern bool internal_store_fn_p (internal_fn); > Index: gcc/internal-fn.c > =================================================================== > --- gcc/internal-fn.c 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/internal-fn.c 2018-05-16 10:23:03.887838686 +0100 > @@ -93,6 +93,7 @@ #define binary_direct { 0, 0, true } > #define ternary_direct { 0, 0, true } > #define cond_unary_direct { 1, 1, true } > #define cond_binary_direct { 1, 1, true } > +#define cond_ternary_direct { 1, 1, true } > #define while_direct { 0, 2, false } > #define fold_extract_direct { 2, 2, false } > #define fold_left_direct { 1, 1, false } > @@ -2972,6 +2973,9 @@ #define expand_cond_unary_optab_fn(FN, S > #define expand_cond_binary_optab_fn(FN, STMT, OPTAB) \ > expand_direct_optab_fn (FN, STMT, OPTAB, 3) > +#define expand_cond_ternary_optab_fn(FN, STMT, OPTAB) \ > + expand_direct_optab_fn (FN, STMT, OPTAB, 4) > + > #define expand_fold_extract_optab_fn(FN, STMT, OPTAB) \ > expand_direct_optab_fn (FN, STMT, OPTAB, 3) > @@ -3054,6 +3058,7 @@ #define direct_binary_optab_supported_p > #define direct_ternary_optab_supported_p direct_optab_supported_p > #define direct_cond_unary_optab_supported_p direct_optab_supported_p > #define direct_cond_binary_optab_supported_p direct_optab_supported_p > +#define direct_cond_ternary_optab_supported_p direct_optab_supported_p > #define direct_mask_load_optab_supported_p direct_optab_supported_p > #define direct_load_lanes_optab_supported_p multi_vector_optab_supported_p > #define direct_mask_load_lanes_optab_supported_p multi_vector_optab_supported_p > @@ -3198,6 +3203,17 @@ #define DEF_INTERNAL_FN(CODE, FLAGS, FNS > 0 > }; > +/* Invoke T(CODE, IFN) for each conditional function IFN that maps to a > + tree code CODE. */ > +#define FOR_EACH_CODE_MAPPING(T) \ > + T (PLUS_EXPR, IFN_COND_ADD) \ > + T (MINUS_EXPR, IFN_COND_SUB) \ > + T (MIN_EXPR, IFN_COND_MIN) \ > + T (MAX_EXPR, IFN_COND_MAX) \ > + T (BIT_AND_EXPR, IFN_COND_AND) \ > + T (BIT_IOR_EXPR, IFN_COND_IOR) \ > + T (BIT_XOR_EXPR, IFN_COND_XOR) > + > /* Return a function that performs the conditional form of CODE, i.e.: > LHS = RHS1 ? RHS2 CODE RHS3 : RHS2 > @@ -3210,25 +3226,78 @@ get_conditional_internal_fn (tree_code c > { > switch (code) > { > - case PLUS_EXPR: > - return IFN_COND_ADD; > - case MINUS_EXPR: > - return IFN_COND_SUB; > - case MIN_EXPR: > - return IFN_COND_MIN; > - case MAX_EXPR: > - return IFN_COND_MAX; > - case BIT_AND_EXPR: > - return IFN_COND_AND; > - case BIT_IOR_EXPR: > - return IFN_COND_IOR; > - case BIT_XOR_EXPR: > - return IFN_COND_XOR; > +#define CASE(CODE, IFN) case CODE: return IFN; > + FOR_EACH_CODE_MAPPING(CASE) > +#undef CASE > default: > return IFN_LAST; > } > } > +/* If IFN implements the conditional form of a tree code, return that > + tree code, otherwise return ERROR_MARK. */ > + > +static tree_code > +conditional_internal_fn_code (internal_fn ifn) > +{ > + switch (ifn) > + { > +#define CASE(CODE, IFN) case IFN: return CODE; > + FOR_EACH_CODE_MAPPING(CASE) > +#undef CASE > + default: > + return ERROR_MARK; > + } > +} > + > +/* Return true if STMT can be interpreted as a conditional tree code > + operation of the form: > + > + LHS = COND ? OP (RHS1, ...) : RHS1; > + > + operating elementwise if the operands are vectors. This includes > + the case of an all-true COND, so that the operation always happens. > + > + When returning true, set: > + > + - *CODE_OUT to the tree code > + - *COND_OUT to the condition COND, or to NULL_TREE if the condition > + is known to be all-true > + - OPS[I] to operand I of *CODE_OUT. */ > + > +bool > +can_interpret_as_conditional_op_p (gimple *stmt, tree_code *code_out, > + tree *cond_out, tree (&ops)[3]) > +{ > + if (gassign *assign = dyn_cast <gassign *> (stmt)) > + { > + *code_out = gimple_assign_rhs_code (assign); > + *cond_out = NULL_TREE; > + ops[0] = gimple_assign_rhs1 (assign); > + ops[1] = gimple_assign_rhs2 (assign); > + ops[2] = gimple_assign_rhs3 (assign); > + return true; > + } > + if (gcall *call = dyn_cast <gcall *> (stmt)) > + if (gimple_call_internal_p (call)) > + { > + internal_fn ifn = gimple_call_internal_fn (call); > + tree_code code = conditional_internal_fn_code (ifn); > + if (code != ERROR_MARK) > + { > + *code_out = code; > + *cond_out = gimple_call_arg (call, 0); > + if (integer_truep (*cond_out)) > + *cond_out = NULL_TREE; > + unsigned int nargs = gimple_call_num_args (call) - 1; > + for (unsigned int i = 0; i < 3; ++i) > + ops[i] = i < nargs ? gimple_call_arg (call, i + 1) : NULL_TREE; > + return true; > + } > + } > + return false; > +} > + > /* Return true if IFN is some form of load from memory. */ > bool > Index: gcc/tree-ssa-math-opts.c > =================================================================== > --- gcc/tree-ssa-math-opts.c 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/tree-ssa-math-opts.c 2018-05-16 10:23:03.889838586 +0100 > @@ -2640,6 +2640,24 @@ convert_plusminus_to_widen (gimple_stmt_ > return true; > } > +/* Return the internal function that implements: > + > + LHS = COND ? A CODE B * C : A. */ > + > +static internal_fn > +fused_cond_internal_fn (tree_code code) > +{ > + switch (code) > + { > + case PLUS_EXPR: > + return IFN_COND_FMA_REV; > + case MINUS_EXPR: > + return IFN_COND_FNMA_REV; > + default: > + gcc_unreachable (); > + } > +} > + > /* gimple_fold callback that "valueizes" everything. */ > static tree > @@ -2663,7 +2681,6 @@ convert_mult_to_fma_1 (tree mul_result, > FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result) > { > gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt); > - enum tree_code use_code; > tree addop, mulop1 = op1, result = mul_result; > bool negate_p = false; > gimple_seq seq = NULL; > @@ -2671,8 +2688,8 @@ convert_mult_to_fma_1 (tree mul_result, > if (is_gimple_debug (use_stmt)) > continue; > - use_code = gimple_assign_rhs_code (use_stmt); > - if (use_code == NEGATE_EXPR) > + if (is_gimple_assign (use_stmt) > + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) > { > result = gimple_assign_lhs (use_stmt); > use_operand_p use_p; > @@ -2683,23 +2700,30 @@ convert_mult_to_fma_1 (tree mul_result, > use_stmt = neguse_stmt; > gsi = gsi_for_stmt (use_stmt); > - use_code = gimple_assign_rhs_code (use_stmt); > negate_p = true; > } > - if (gimple_assign_rhs1 (use_stmt) == result) > - { > - addop = gimple_assign_rhs2 (use_stmt); > - /* a * b - c -> a * b + (-c) */ > - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) > - addop = gimple_build (&seq, NEGATE_EXPR, type, addop); > - } > + tree cond, ops[3]; > + tree_code code; > + if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond, ops)) > + gcc_unreachable (); > + addop = ops[0] == result ? ops[1] : ops[0]; > + > + internal_fn ifn; > + if (cond) > + ifn = fused_cond_internal_fn (code); > else > { > - addop = gimple_assign_rhs1 (use_stmt); > - /* a - b * c -> (-b) * c + a */ > - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) > - negate_p = !negate_p; > + ifn = IFN_FMA; > + if (code == MINUS_EXPR) > + { > + if (ops[0] == result) > + /* a * b - c -> a * b + (-c) */ > + addop = gimple_build (&seq, NEGATE_EXPR, type, addop); > + else > + /* a - b * c -> (-b) * c + a */ > + negate_p = !negate_p; > + } > } > if (negate_p) > @@ -2707,8 +2731,13 @@ convert_mult_to_fma_1 (tree mul_result, > if (seq) > gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT); > - fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); > - gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt)); > + > + if (ifn == IFN_FMA) > + fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); > + else > + fma_stmt = gimple_build_call_internal (ifn, 4, cond, addop, > + mulop1, op2); > + gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt)); > gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt)); > gsi_replace (&gsi, fma_stmt, true); > /* Valueize aggressively so that we generate FMS, FNMA and FNMS > @@ -2891,7 +2920,6 @@ convert_mult_to_fma (gimple *mul_stmt, t > as an addition. */ > FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result) > { > - enum tree_code use_code; > tree result = mul_result; > bool negate_p = false; > @@ -2912,13 +2940,9 @@ convert_mult_to_fma (gimple *mul_stmt, t > if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) > return false; > - if (!is_gimple_assign (use_stmt)) > - return false; > - > - use_code = gimple_assign_rhs_code (use_stmt); > - > /* A negate on the multiplication leads to FNMA. */ > - if (use_code == NEGATE_EXPR) > + if (is_gimple_assign (use_stmt) > + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) > { > ssa_op_iter iter; > use_operand_p usep; > @@ -2940,17 +2964,19 @@ convert_mult_to_fma (gimple *mul_stmt, t > use_stmt = neguse_stmt; > if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) > return false; > - if (!is_gimple_assign (use_stmt)) > - return false; > - use_code = gimple_assign_rhs_code (use_stmt); > negate_p = true; > } > - switch (use_code) > + tree cond, ops[3]; > + tree_code code; > + if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond, ops)) > + return false; > + > + switch (code) > { > case MINUS_EXPR: > - if (gimple_assign_rhs2 (use_stmt) == result) > + if (ops[1] == result) > negate_p = !negate_p; > break; > case PLUS_EXPR: > @@ -2960,47 +2986,52 @@ convert_mult_to_fma (gimple *mul_stmt, t > return false; > } > - /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed > - by a MULT_EXPR that we'll visit later, we might be able to > - get a more profitable match with fnma. > + if (cond) > + { > + /* The multiplication must be the second operand. */ > + if (cond == result || ops[0] == result) > + return false; > + internal_fn ifn = fused_cond_internal_fn (code); > + if (!direct_internal_fn_supported_p (ifn, type, opt_type)) > + return false; > + } > + > + /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that > + we'll visit later, we might be able to get a more profitable > + match with fnma. > OTOH, if we don't, a negate / fma pair has likely lower latency > that a mult / subtract pair. */ > - if (use_code == MINUS_EXPR && !negate_p > - && gimple_assign_rhs1 (use_stmt) == result > + if (code == MINUS_EXPR > + && !negate_p > + && ops[0] == result > && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type) > - && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)) > + && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type) > + && TREE_CODE (ops[1]) == SSA_NAME > + && has_single_use (ops[1])) > { > - tree rhs2 = gimple_assign_rhs2 (use_stmt); > - > - if (TREE_CODE (rhs2) == SSA_NAME) > - { > - gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2); > - if (has_single_use (rhs2) > - && is_gimple_assign (stmt2) > - && gimple_assign_rhs_code (stmt2) == MULT_EXPR) > - return false; > - } > + gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]); > + if (is_gimple_assign (stmt2) > + && gimple_assign_rhs_code (stmt2) == MULT_EXPR) > + return false; > } > - tree use_rhs1 = gimple_assign_rhs1 (use_stmt); > - tree use_rhs2 = gimple_assign_rhs2 (use_stmt); > /* We can't handle a * b + a * b. */ > - if (use_rhs1 == use_rhs2) > + if (ops[0] == ops[1]) > return false; > /* If deferring, make sure we are not looking at an instruction that > wouldn't have existed if we were not. */ > if (state->m_deferring_p > - && (state->m_mul_result_set.contains (use_rhs1) > - || state->m_mul_result_set.contains (use_rhs2))) > + && (state->m_mul_result_set.contains (ops[0]) > + || state->m_mul_result_set.contains (ops[1]))) > return false; > if (check_defer) > { > - tree use_lhs = gimple_assign_lhs (use_stmt); > + tree use_lhs = gimple_get_lhs (use_stmt); > if (state->m_last_result) > { > - if (use_rhs2 == state->m_last_result > - || use_rhs1 == state->m_last_result) > + if (ops[1] == state->m_last_result > + || ops[0] == state->m_last_result) > defer = true; > else > defer = false; > @@ -3009,12 +3040,12 @@ convert_mult_to_fma (gimple *mul_stmt, t > { > gcc_checking_assert (!state->m_initial_phi); > gphi *phi; > - if (use_rhs1 == result) > - phi = result_of_phi (use_rhs2); > + if (ops[0] == result) > + phi = result_of_phi (ops[1]); > else > { > - gcc_assert (use_rhs2 == result); > - phi = result_of_phi (use_rhs1); > + gcc_assert (ops[1] == result); > + phi = result_of_phi (ops[0]); > } > if (phi) > Index: gcc/genmatch.c > =================================================================== > --- gcc/genmatch.c 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/genmatch.c 2018-05-16 10:23:03.887838686 +0100 > @@ -485,6 +485,10 @@ commutative_op (id_base *id) > case CFN_FNMS: > return 0; > + case CFN_COND_FMA_REV: > + case CFN_COND_FNMA_REV: > + return 2; > + > default: > return -1; > } > Index: gcc/config/aarch64/iterators.md > =================================================================== > --- gcc/config/aarch64/iterators.md 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/config/aarch64/iterators.md 2018-05-16 10:23:03.886838736 +0100 > @@ -449,6 +449,8 @@ (define_c_enum "unspec" > UNSPEC_COND_AND ; Used in aarch64-sve.md. > UNSPEC_COND_ORR ; Used in aarch64-sve.md. > UNSPEC_COND_EOR ; Used in aarch64-sve.md. > + UNSPEC_COND_FMLA ; Used in aarch64-sve.md. > + UNSPEC_COND_FMLS ; Used in aarch64-sve.md. > UNSPEC_COND_LT ; Used in aarch64-sve.md. > UNSPEC_COND_LE ; Used in aarch64-sve.md. > UNSPEC_COND_EQ ; Used in aarch64-sve.md. > @@ -1499,14 +1501,16 @@ (define_int_iterator UNPACK_UNSIGNED [UN > (define_int_iterator MUL_HIGHPART [UNSPEC_SMUL_HIGHPART UNSPEC_UMUL_HIGHPART]) > -(define_int_iterator SVE_COND_INT_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB > - UNSPEC_COND_SMAX UNSPEC_COND_UMAX > - UNSPEC_COND_SMIN UNSPEC_COND_UMIN > - UNSPEC_COND_AND > - UNSPEC_COND_ORR > - UNSPEC_COND_EOR]) > +(define_int_iterator SVE_COND_INT2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB > + UNSPEC_COND_SMAX UNSPEC_COND_UMAX > + UNSPEC_COND_SMIN UNSPEC_COND_UMIN > + UNSPEC_COND_AND > + UNSPEC_COND_ORR > + UNSPEC_COND_EOR]) > -(define_int_iterator SVE_COND_FP_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB]) > +(define_int_iterator SVE_COND_FP2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB]) > + > +(define_int_iterator SVE_COND_FP3_OP [UNSPEC_COND_FMLA UNSPEC_COND_FMLS]) > (define_int_iterator SVE_COND_FP_CMP [UNSPEC_COND_LT UNSPEC_COND_LE > UNSPEC_COND_EQ UNSPEC_COND_NE > @@ -1543,7 +1547,9 @@ (define_int_attr optab [(UNSPEC_ANDF "an > (UNSPEC_COND_UMIN "umin") > (UNSPEC_COND_AND "and") > (UNSPEC_COND_ORR "ior") > - (UNSPEC_COND_EOR "xor")]) > + (UNSPEC_COND_EOR "xor") > + (UNSPEC_COND_FMLA "fma_rev") > + (UNSPEC_COND_FMLS "fnma_rev")]) > (define_int_attr maxmin_uns [(UNSPEC_UMAXV "umax") > (UNSPEC_UMINV "umin") > @@ -1762,4 +1768,6 @@ (define_int_attr sve_int_op [(UNSPEC_CON > (UNSPEC_COND_EOR "eor")]) > (define_int_attr sve_fp_op [(UNSPEC_COND_ADD "fadd") > - (UNSPEC_COND_SUB "fsub")]) > + (UNSPEC_COND_SUB "fsub") > + (UNSPEC_COND_FMLA "fmla") > + (UNSPEC_COND_FMLS "fmls")]) > Index: gcc/config/aarch64/aarch64-sve.md > =================================================================== > --- gcc/config/aarch64/aarch64-sve.md 2018-05-16 10:23:03.590853492 +0100 > +++ gcc/config/aarch64/aarch64-sve.md 2018-05-16 10:23:03.883838885 +0100 > @@ -1764,7 +1764,7 @@ (define_insn "cond_<optab><mode>" > [(match_operand:<VPRED> 1 "register_operand" "Upl") > (match_operand:SVE_I 2 "register_operand" "0") > (match_operand:SVE_I 3 "register_operand" "w")] > - SVE_COND_INT_OP))] > + SVE_COND_INT2_OP))] > "TARGET_SVE" > "<sve_int_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>" > ) > @@ -2543,11 +2543,23 @@ (define_insn "cond_<optab><mode>" > [(match_operand:<VPRED> 1 "register_operand" "Upl") > (match_operand:SVE_F 2 "register_operand" "0") > (match_operand:SVE_F 3 "register_operand" "w")] > - SVE_COND_FP_OP))] > + SVE_COND_FP2_OP))] > "TARGET_SVE" > "<sve_fp_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>" > ) > +(define_insn "cond_<optab><mode>" > + [(set (match_operand:SVE_F 0 "register_operand" "=w") > + (unspec:SVE_F > + [(match_operand:<VPRED> 1 "register_operand" "Upl") > + (match_operand:SVE_F 2 "register_operand" "0") > + (match_operand:SVE_F 3 "register_operand" "w") > + (match_operand:SVE_F 4 "register_operand" "w")] > + SVE_COND_FP3_OP))] > + "TARGET_SVE" > + "<sve_fp_op>\t%0.<Vetype>, %1/m, %3.<Vetype>, %4.<Vetype>" > +) > + > ;; Shift an SVE vector left and insert a scalar into element 0. > (define_insn "vec_shl_insert_<mode>" > [(set (match_operand:SVE_ALL 0 "register_operand" "=w, w") > Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c > =================================================================== > --- /dev/null 2018-04-20 16:19:46.369131350 +0100 > +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c 2018-05-16 10:23:03.888838636 +0100 > @@ -0,0 +1,18 @@ > +/* { dg-do compile } */ > +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ > + > +double > +f (double *restrict a, double *restrict b, int *lookup) > +{ > + double res = 0.0; > + for (int i = 0; i < 512; ++i) > + res += a[lookup[i]] * b[i]; > + return res; > +} > + > +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */ > +/* Check that the vector instructions are the only instructions. */ > +/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */ > +/* { dg-final { scan-assembler-not {\tfadd\t} } } */ > +/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */ > +/* { dg-final { scan-assembler-not {\tsel\t} } } */ > Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c > =================================================================== > --- /dev/null 2018-04-20 16:19:46.369131350 +0100 > +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c 2018-05-16 10:23:03.888838636 +0100 > @@ -0,0 +1,17 @@ > +/* { dg-do compile } */ > +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ > + > +#define REDUC(TYPE) \ > + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ > + { \ > + TYPE sum = 0; \ > + for (int i = 0; i < count; ++i) \ > + sum += x[i] * y[i]; \ > + return sum; \ > + } > + > +REDUC (float) > +REDUC (double) > + > +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */ > +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */ > Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c > =================================================================== > --- /dev/null 2018-04-20 16:19:46.369131350 +0100 > +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c 2018-05-16 10:23:03.889838586 +0100 > @@ -0,0 +1,17 @@ > +/* { dg-do compile } */ > +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ > + > +#define REDUC(TYPE) \ > + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ > + { \ > + TYPE sum = 0; \ > + for (int i = 0; i < count; ++i) \ > + sum -= x[i] * y[i]; \ > + return sum; \ > + } > + > +REDUC (float) > +REDUC (double) > + > +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */ > +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */
Richard Biener <richard.guenther@gmail.com> writes: > On Wed, May 16, 2018 at 11:26 AM Richard Sandiford < > richard.sandiford@linaro.org> wrote: > >> This patch adds support for fusing a conditional add or subtract >> with a multiplication, so that we can use fused multiply-add and >> multiply-subtract operations for fully-masked reductions. E.g. >> for SVE we vectorise: > >> double res = 0.0; >> for (int i = 0; i < n; ++i) >> res += x[i] * y[i]; > >> using a fully-masked loop in which the loop body has the form: > >> res_1 = PHI<0(preheader), res_2(latch)>; >> avec = IFN_MASK_LOAD (loop_mask, a) >> bvec = IFN_MASK_LOAD (loop_mask, b) >> prod = avec * bvec; >> res_2 = IFN_COND_ADD (loop_mask, res_1, prod); > >> where the last statement does the equivalent of: > >> res_2 = loop_mask ? res_1 + prod : res_1; > >> (operating elementwise). The point of the patch is to convert the last >> two statements into a single internal function that is the equivalent of: > >> res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1; > >> (again operating elementwise). > >> All current conditional X operations have the form "do X or don't do X >> to the first operand" (add/don't add to first operand, etc.). However, >> the FMA optabs and functions are ordered so that the accumulator comes >> last. There were two obvious ways of resolving this: break the >> convention for conditional operators and have "add/don't add to the >> final operand" or break the convention for FMA and put the accumulator >> first. The patch goes for the latter, but adds _REV to make it obvious >> that the operands are in a different order. > > Eh. I guess you'll do the same to SAD/DOT_PROD/WIDEN_SUM? > > That said, I don't really see the "do or not do to the first operand", it's > "do or not do the operation on operands 1 to 2 (or 3)". None of the > current ops modify operand 1, they all produce a new value, no? Yeah, neither the current functions nor these ones actually changed operand 1. It was all about deciding what the "else" value should be. The _REV thing was a "fix" for the fact that we wanted the else value to be the final operand of fma. Of course, the real fix was to make all the IFN_COND_* functions take an explicit else value, as you suggested in the review of the other patch in the series. So all this _REV stuff is redundant now. Here's an updated version based on top of the IFN_COND_FMA patch that I just posted. Tested in the same way. Thanks, Richard 2018-05-24 Richard Sandiford <richard.sandiford@linaro.org> Alan Hayward <alan.hayward@arm.com> David Sherwood <david.sherwood@arm.com> gcc/ * internal-fn.h (can_interpret_as_conditional_op_p): Declare. * internal-fn.c (can_interpret_as_conditional_op_p): New function. * tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional plus and minus and convert them into IFN_COND_FMA-based sequences. (convert_mult_to_fma): Handle conditional plus and minus. gcc/testsuite/ * gcc.dg/vect/vect-fma-2.c: New test. * gcc.target/aarch64/sve/reduc_4.c: Likewise. * gcc.target/aarch64/sve/reduc_6.c: Likewise. * gcc.target/aarch64/sve/reduc_7.c: Likewise. Index: gcc/internal-fn.h =================================================================== --- gcc/internal-fn.h 2018-05-24 13:05:46.049605128 +0100 +++ gcc/internal-fn.h 2018-05-24 13:08:24.643987582 +0100 @@ -196,6 +196,9 @@ extern internal_fn get_conditional_inter extern internal_fn get_conditional_internal_fn (internal_fn); extern tree_code conditional_internal_fn_code (internal_fn); extern internal_fn get_unconditional_internal_fn (internal_fn); +extern bool can_interpret_as_conditional_op_p (gimple *, tree *, + tree_code *, tree (&)[3], + tree *); extern bool internal_load_fn_p (internal_fn); extern bool internal_store_fn_p (internal_fn); Index: gcc/internal-fn.c =================================================================== --- gcc/internal-fn.c 2018-05-24 13:05:46.048606357 +0100 +++ gcc/internal-fn.c 2018-05-24 13:08:24.643987582 +0100 @@ -3333,6 +3333,62 @@ #define CASE(NAME) case IFN_COND_##NAME: } } +/* Return true if STMT can be interpreted as a conditional tree code + operation of the form: + + LHS = COND ? OP (RHS1, ...) : ELSE; + + operating elementwise if the operands are vectors. This includes + the case of an all-true COND, so that the operation always happens. + + When returning true, set: + + - *COND_OUT to the condition COND, or to NULL_TREE if the condition + is known to be all-true + - *CODE_OUT to the tree code + - OPS[I] to operand I of *CODE_OUT + - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the + condition is known to be all true. */ + +bool +can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out, + tree_code *code_out, + tree (&ops)[3], tree *else_out) +{ + if (gassign *assign = dyn_cast <gassign *> (stmt)) + { + *cond_out = NULL_TREE; + *code_out = gimple_assign_rhs_code (assign); + ops[0] = gimple_assign_rhs1 (assign); + ops[1] = gimple_assign_rhs2 (assign); + ops[2] = gimple_assign_rhs3 (assign); + *else_out = NULL_TREE; + return true; + } + if (gcall *call = dyn_cast <gcall *> (stmt)) + if (gimple_call_internal_p (call)) + { + internal_fn ifn = gimple_call_internal_fn (call); + tree_code code = conditional_internal_fn_code (ifn); + if (code != ERROR_MARK) + { + *cond_out = gimple_call_arg (call, 0); + *code_out = code; + unsigned int nops = gimple_call_num_args (call) - 2; + for (unsigned int i = 0; i < 3; ++i) + ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE; + *else_out = gimple_call_arg (call, nops + 1); + if (integer_truep (*cond_out)) + { + *cond_out = NULL_TREE; + *else_out = NULL_TREE; + } + return true; + } + } + return false; +} + /* Return true if IFN is some form of load from memory. */ bool Index: gcc/tree-ssa-math-opts.c =================================================================== --- gcc/tree-ssa-math-opts.c 2018-05-18 09:26:37.749713749 +0100 +++ gcc/tree-ssa-math-opts.c 2018-05-24 13:08:24.644961583 +0100 @@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result, FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result) { gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt); - enum tree_code use_code; tree addop, mulop1 = op1, result = mul_result; bool negate_p = false; gimple_seq seq = NULL; @@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result, if (is_gimple_debug (use_stmt)) continue; - use_code = gimple_assign_rhs_code (use_stmt); - if (use_code == NEGATE_EXPR) + if (is_gimple_assign (use_stmt) + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) { result = gimple_assign_lhs (use_stmt); use_operand_p use_p; @@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result, use_stmt = neguse_stmt; gsi = gsi_for_stmt (use_stmt); - use_code = gimple_assign_rhs_code (use_stmt); negate_p = true; } - if (gimple_assign_rhs1 (use_stmt) == result) + tree cond, else_value, ops[3]; + tree_code code; + if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, + ops, &else_value)) + gcc_unreachable (); + addop = ops[0] == result ? ops[1] : ops[0]; + + if (code == MINUS_EXPR) { - addop = gimple_assign_rhs2 (use_stmt); - /* a * b - c -> a * b + (-c) */ - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) + if (ops[0] == result) + /* a * b - c -> a * b + (-c) */ addop = gimple_build (&seq, NEGATE_EXPR, type, addop); - } - else - { - addop = gimple_assign_rhs1 (use_stmt); - /* a - b * c -> (-b) * c + a */ - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) + else + /* a - b * c -> (-b) * c + a */ negate_p = !negate_p; } @@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result, if (seq) gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT); - fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); - gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt)); + + if (cond) + fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1, + op2, addop, else_value); + else + fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); + gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt)); gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt)); gsi_replace (&gsi, fma_stmt, true); /* Follow all SSA edges so that we generate FMS, FNMA and FNMS @@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, t as an addition. */ FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result) { - enum tree_code use_code; tree result = mul_result; bool negate_p = false; @@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, t if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) return false; - if (!is_gimple_assign (use_stmt)) - return false; - - use_code = gimple_assign_rhs_code (use_stmt); - /* A negate on the multiplication leads to FNMA. */ - if (use_code == NEGATE_EXPR) + if (is_gimple_assign (use_stmt) + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) { ssa_op_iter iter; use_operand_p usep; @@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, t use_stmt = neguse_stmt; if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) return false; - if (!is_gimple_assign (use_stmt)) - return false; - use_code = gimple_assign_rhs_code (use_stmt); negate_p = true; } - switch (use_code) + tree cond, else_value, ops[3]; + tree_code code; + if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops, + &else_value)) + return false; + + switch (code) { case MINUS_EXPR: - if (gimple_assign_rhs2 (use_stmt) == result) + if (ops[1] == result) negate_p = !negate_p; break; case PLUS_EXPR: @@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, t return false; } - /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed - by a MULT_EXPR that we'll visit later, we might be able to - get a more profitable match with fnma. + if (cond) + { + if (cond == result || else_value == result) + return false; + if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type)) + return false; + } + + /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that + we'll visit later, we might be able to get a more profitable + match with fnma. OTOH, if we don't, a negate / fma pair has likely lower latency that a mult / subtract pair. */ - if (use_code == MINUS_EXPR && !negate_p - && gimple_assign_rhs1 (use_stmt) == result + if (code == MINUS_EXPR + && !negate_p + && ops[0] == result && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type) - && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)) + && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type) + && TREE_CODE (ops[1]) == SSA_NAME + && has_single_use (ops[1])) { - tree rhs2 = gimple_assign_rhs2 (use_stmt); - - if (TREE_CODE (rhs2) == SSA_NAME) - { - gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2); - if (has_single_use (rhs2) - && is_gimple_assign (stmt2) - && gimple_assign_rhs_code (stmt2) == MULT_EXPR) - return false; - } + gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]); + if (is_gimple_assign (stmt2) + && gimple_assign_rhs_code (stmt2) == MULT_EXPR) + return false; } - tree use_rhs1 = gimple_assign_rhs1 (use_stmt); - tree use_rhs2 = gimple_assign_rhs2 (use_stmt); /* We can't handle a * b + a * b. */ - if (use_rhs1 == use_rhs2) + if (ops[0] == ops[1]) return false; /* If deferring, make sure we are not looking at an instruction that wouldn't have existed if we were not. */ if (state->m_deferring_p - && (state->m_mul_result_set.contains (use_rhs1) - || state->m_mul_result_set.contains (use_rhs2))) + && (state->m_mul_result_set.contains (ops[0]) + || state->m_mul_result_set.contains (ops[1]))) return false; if (check_defer) { - tree use_lhs = gimple_assign_lhs (use_stmt); + tree use_lhs = gimple_get_lhs (use_stmt); if (state->m_last_result) { - if (use_rhs2 == state->m_last_result - || use_rhs1 == state->m_last_result) + if (ops[1] == state->m_last_result + || ops[0] == state->m_last_result) defer = true; else defer = false; @@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, t { gcc_checking_assert (!state->m_initial_phi); gphi *phi; - if (use_rhs1 == result) - phi = result_of_phi (use_rhs2); + if (ops[0] == result) + phi = result_of_phi (ops[1]); else { - gcc_assert (use_rhs2 == result); - phi = result_of_phi (use_rhs1); + gcc_assert (ops[1] == result); + phi = result_of_phi (ops[0]); } if (phi) Index: gcc/testsuite/gcc.dg/vect/vect-fma-2.c =================================================================== --- /dev/null 2018-04-20 16:19:46.369131350 +0100 +++ gcc/testsuite/gcc.dg/vect/vect-fma-2.c 2018-05-24 13:08:24.643987582 +0100 @@ -0,0 +1,17 @@ +/* { dg-do compile } */ +/* { dg-additional-options "-fdump-tree-optimized -fassociative-math -fno-trapping-math -fno-signed-zeros" } */ + +#include "tree-vect.h" + +#define N (VECTOR_BITS * 11 / 64 + 3) + +double +dot_prod (double *x, double *y) +{ + double sum = 0; + for (int i = 0; i < N; ++i) + sum += x[i] * y[i]; + return sum; +} + +/* { dg-final { scan-tree-dump { = \.COND_FMA } "optimized" { target { vect_double && { vect_fully_masked && scalar_all_fma } } } } } */ Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c =================================================================== --- /dev/null 2018-04-20 16:19:46.369131350 +0100 +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c 2018-05-24 13:08:24.643987582 +0100 @@ -0,0 +1,18 @@ +/* { dg-do compile } */ +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ + +double +f (double *restrict a, double *restrict b, int *lookup) +{ + double res = 0.0; + for (int i = 0; i < 512; ++i) + res += a[lookup[i]] * b[i]; + return res; +} + +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */ +/* Check that the vector instructions are the only instructions. */ +/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */ +/* { dg-final { scan-assembler-not {\tfadd\t} } } */ +/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */ +/* { dg-final { scan-assembler-not {\tsel\t} } } */ Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c =================================================================== --- /dev/null 2018-04-20 16:19:46.369131350 +0100 +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c 2018-05-24 13:08:24.643987582 +0100 @@ -0,0 +1,17 @@ +/* { dg-do compile } */ +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ + +#define REDUC(TYPE) \ + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ + { \ + TYPE sum = 0; \ + for (int i = 0; i < count; ++i) \ + sum += x[i] * y[i]; \ + return sum; \ + } + +REDUC (float) +REDUC (double) + +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */ +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */ Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c =================================================================== --- /dev/null 2018-04-20 16:19:46.369131350 +0100 +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c 2018-05-24 13:08:24.643987582 +0100 @@ -0,0 +1,17 @@ +/* { dg-do compile } */ +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ + +#define REDUC(TYPE) \ + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ + { \ + TYPE sum = 0; \ + for (int i = 0; i < count; ++i) \ + sum -= x[i] * y[i]; \ + return sum; \ + } + +REDUC (float) +REDUC (double) + +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */ +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */
On Thu, May 24, 2018 at 2:17 PM Richard Sandiford < richard.sandiford@linaro.org> wrote: > Richard Biener <richard.guenther@gmail.com> writes: > > On Wed, May 16, 2018 at 11:26 AM Richard Sandiford < > > richard.sandiford@linaro.org> wrote: > > > >> This patch adds support for fusing a conditional add or subtract > >> with a multiplication, so that we can use fused multiply-add and > >> multiply-subtract operations for fully-masked reductions. E.g. > >> for SVE we vectorise: > > > >> double res = 0.0; > >> for (int i = 0; i < n; ++i) > >> res += x[i] * y[i]; > > > >> using a fully-masked loop in which the loop body has the form: > > > >> res_1 = PHI<0(preheader), res_2(latch)>; > >> avec = IFN_MASK_LOAD (loop_mask, a) > >> bvec = IFN_MASK_LOAD (loop_mask, b) > >> prod = avec * bvec; > >> res_2 = IFN_COND_ADD (loop_mask, res_1, prod); > > > >> where the last statement does the equivalent of: > > > >> res_2 = loop_mask ? res_1 + prod : res_1; > > > >> (operating elementwise). The point of the patch is to convert the last > >> two statements into a single internal function that is the equivalent of: > > > >> res_2 = loop_mask ? fma (avec, bvec, res_1) : res_1; > > > >> (again operating elementwise). > > > >> All current conditional X operations have the form "do X or don't do X > >> to the first operand" (add/don't add to first operand, etc.). However, > >> the FMA optabs and functions are ordered so that the accumulator comes > >> last. There were two obvious ways of resolving this: break the > >> convention for conditional operators and have "add/don't add to the > >> final operand" or break the convention for FMA and put the accumulator > >> first. The patch goes for the latter, but adds _REV to make it obvious > >> that the operands are in a different order. > > > > Eh. I guess you'll do the same to SAD/DOT_PROD/WIDEN_SUM? > > > > That said, I don't really see the "do or not do to the first operand", it's > > "do or not do the operation on operands 1 to 2 (or 3)". None of the > > current ops modify operand 1, they all produce a new value, no? > Yeah, neither the current functions nor these ones actually changed > operand 1. It was all about deciding what the "else" value should be. > The _REV thing was a "fix" for the fact that we wanted the else value > to be the final operand of fma. > Of course, the real fix was to make all the IFN_COND_* functions take an > explicit else value, as you suggested in the review of the other patch > in the series. So all this _REV stuff is redundant now. > Here's an updated version based on top of the IFN_COND_FMA patch > that I just posted. Tested in the same way. OK. Thanks, Richard. > Thanks, > Richard > 2018-05-24 Richard Sandiford <richard.sandiford@linaro.org> > Alan Hayward <alan.hayward@arm.com> > David Sherwood <david.sherwood@arm.com> > gcc/ > * internal-fn.h (can_interpret_as_conditional_op_p): Declare. > * internal-fn.c (can_interpret_as_conditional_op_p): New function. > * tree-ssa-math-opts.c (convert_mult_to_fma_1): Handle conditional > plus and minus and convert them into IFN_COND_FMA-based sequences. > (convert_mult_to_fma): Handle conditional plus and minus. > gcc/testsuite/ > * gcc.dg/vect/vect-fma-2.c: New test. > * gcc.target/aarch64/sve/reduc_4.c: Likewise. > * gcc.target/aarch64/sve/reduc_6.c: Likewise. > * gcc.target/aarch64/sve/reduc_7.c: Likewise. > Index: gcc/internal-fn.h > =================================================================== > --- gcc/internal-fn.h 2018-05-24 13:05:46.049605128 +0100 > +++ gcc/internal-fn.h 2018-05-24 13:08:24.643987582 +0100 > @@ -196,6 +196,9 @@ extern internal_fn get_conditional_inter > extern internal_fn get_conditional_internal_fn (internal_fn); > extern tree_code conditional_internal_fn_code (internal_fn); > extern internal_fn get_unconditional_internal_fn (internal_fn); > +extern bool can_interpret_as_conditional_op_p (gimple *, tree *, > + tree_code *, tree (&)[3], > + tree *); > extern bool internal_load_fn_p (internal_fn); > extern bool internal_store_fn_p (internal_fn); > Index: gcc/internal-fn.c > =================================================================== > --- gcc/internal-fn.c 2018-05-24 13:05:46.048606357 +0100 > +++ gcc/internal-fn.c 2018-05-24 13:08:24.643987582 +0100 > @@ -3333,6 +3333,62 @@ #define CASE(NAME) case IFN_COND_##NAME: > } > } > +/* Return true if STMT can be interpreted as a conditional tree code > + operation of the form: > + > + LHS = COND ? OP (RHS1, ...) : ELSE; > + > + operating elementwise if the operands are vectors. This includes > + the case of an all-true COND, so that the operation always happens. > + > + When returning true, set: > + > + - *COND_OUT to the condition COND, or to NULL_TREE if the condition > + is known to be all-true > + - *CODE_OUT to the tree code > + - OPS[I] to operand I of *CODE_OUT > + - *ELSE_OUT to the fallback value ELSE, or to NULL_TREE if the > + condition is known to be all true. */ > + > +bool > +can_interpret_as_conditional_op_p (gimple *stmt, tree *cond_out, > + tree_code *code_out, > + tree (&ops)[3], tree *else_out) > +{ > + if (gassign *assign = dyn_cast <gassign *> (stmt)) > + { > + *cond_out = NULL_TREE; > + *code_out = gimple_assign_rhs_code (assign); > + ops[0] = gimple_assign_rhs1 (assign); > + ops[1] = gimple_assign_rhs2 (assign); > + ops[2] = gimple_assign_rhs3 (assign); > + *else_out = NULL_TREE; > + return true; > + } > + if (gcall *call = dyn_cast <gcall *> (stmt)) > + if (gimple_call_internal_p (call)) > + { > + internal_fn ifn = gimple_call_internal_fn (call); > + tree_code code = conditional_internal_fn_code (ifn); > + if (code != ERROR_MARK) > + { > + *cond_out = gimple_call_arg (call, 0); > + *code_out = code; > + unsigned int nops = gimple_call_num_args (call) - 2; > + for (unsigned int i = 0; i < 3; ++i) > + ops[i] = i < nops ? gimple_call_arg (call, i + 1) : NULL_TREE; > + *else_out = gimple_call_arg (call, nops + 1); > + if (integer_truep (*cond_out)) > + { > + *cond_out = NULL_TREE; > + *else_out = NULL_TREE; > + } > + return true; > + } > + } > + return false; > +} > + > /* Return true if IFN is some form of load from memory. */ > bool > Index: gcc/tree-ssa-math-opts.c > =================================================================== > --- gcc/tree-ssa-math-opts.c 2018-05-18 09:26:37.749713749 +0100 > +++ gcc/tree-ssa-math-opts.c 2018-05-24 13:08:24.644961583 +0100 > @@ -2655,7 +2655,6 @@ convert_mult_to_fma_1 (tree mul_result, > FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result) > { > gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt); > - enum tree_code use_code; > tree addop, mulop1 = op1, result = mul_result; > bool negate_p = false; > gimple_seq seq = NULL; > @@ -2663,8 +2662,8 @@ convert_mult_to_fma_1 (tree mul_result, > if (is_gimple_debug (use_stmt)) > continue; > - use_code = gimple_assign_rhs_code (use_stmt); > - if (use_code == NEGATE_EXPR) > + if (is_gimple_assign (use_stmt) > + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) > { > result = gimple_assign_lhs (use_stmt); > use_operand_p use_p; > @@ -2675,22 +2674,23 @@ convert_mult_to_fma_1 (tree mul_result, > use_stmt = neguse_stmt; > gsi = gsi_for_stmt (use_stmt); > - use_code = gimple_assign_rhs_code (use_stmt); > negate_p = true; > } > - if (gimple_assign_rhs1 (use_stmt) == result) > + tree cond, else_value, ops[3]; > + tree_code code; > + if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, > + ops, &else_value)) > + gcc_unreachable (); > + addop = ops[0] == result ? ops[1] : ops[0]; > + > + if (code == MINUS_EXPR) > { > - addop = gimple_assign_rhs2 (use_stmt); > - /* a * b - c -> a * b + (-c) */ > - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) > + if (ops[0] == result) > + /* a * b - c -> a * b + (-c) */ > addop = gimple_build (&seq, NEGATE_EXPR, type, addop); > - } > - else > - { > - addop = gimple_assign_rhs1 (use_stmt); > - /* a - b * c -> (-b) * c + a */ > - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) > + else > + /* a - b * c -> (-b) * c + a */ > negate_p = !negate_p; > } > @@ -2699,8 +2699,13 @@ convert_mult_to_fma_1 (tree mul_result, > if (seq) > gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT); > - fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); > - gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt)); > + > + if (cond) > + fma_stmt = gimple_build_call_internal (IFN_COND_FMA, 5, cond, mulop1, > + op2, addop, else_value); > + else > + fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); > + gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt)); > gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt)); > gsi_replace (&gsi, fma_stmt, true); > /* Follow all SSA edges so that we generate FMS, FNMA and FNMS > @@ -2883,7 +2888,6 @@ convert_mult_to_fma (gimple *mul_stmt, t > as an addition. */ > FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result) > { > - enum tree_code use_code; > tree result = mul_result; > bool negate_p = false; > @@ -2904,13 +2908,9 @@ convert_mult_to_fma (gimple *mul_stmt, t > if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) > return false; > - if (!is_gimple_assign (use_stmt)) > - return false; > - > - use_code = gimple_assign_rhs_code (use_stmt); > - > /* A negate on the multiplication leads to FNMA. */ > - if (use_code == NEGATE_EXPR) > + if (is_gimple_assign (use_stmt) > + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) > { > ssa_op_iter iter; > use_operand_p usep; > @@ -2932,17 +2932,20 @@ convert_mult_to_fma (gimple *mul_stmt, t > use_stmt = neguse_stmt; > if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) > return false; > - if (!is_gimple_assign (use_stmt)) > - return false; > - use_code = gimple_assign_rhs_code (use_stmt); > negate_p = true; > } > - switch (use_code) > + tree cond, else_value, ops[3]; > + tree_code code; > + if (!can_interpret_as_conditional_op_p (use_stmt, &cond, &code, ops, > + &else_value)) > + return false; > + > + switch (code) > { > case MINUS_EXPR: > - if (gimple_assign_rhs2 (use_stmt) == result) > + if (ops[1] == result) > negate_p = !negate_p; > break; > case PLUS_EXPR: > @@ -2952,47 +2955,50 @@ convert_mult_to_fma (gimple *mul_stmt, t > return false; > } > - /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed > - by a MULT_EXPR that we'll visit later, we might be able to > - get a more profitable match with fnma. > + if (cond) > + { > + if (cond == result || else_value == result) > + return false; > + if (!direct_internal_fn_supported_p (IFN_COND_FMA, type, opt_type)) > + return false; > + } > + > + /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that > + we'll visit later, we might be able to get a more profitable > + match with fnma. > OTOH, if we don't, a negate / fma pair has likely lower latency > that a mult / subtract pair. */ > - if (use_code == MINUS_EXPR && !negate_p > - && gimple_assign_rhs1 (use_stmt) == result > + if (code == MINUS_EXPR > + && !negate_p > + && ops[0] == result > && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type) > - && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)) > + && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type) > + && TREE_CODE (ops[1]) == SSA_NAME > + && has_single_use (ops[1])) > { > - tree rhs2 = gimple_assign_rhs2 (use_stmt); > - > - if (TREE_CODE (rhs2) == SSA_NAME) > - { > - gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2); > - if (has_single_use (rhs2) > - && is_gimple_assign (stmt2) > - && gimple_assign_rhs_code (stmt2) == MULT_EXPR) > - return false; > - } > + gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]); > + if (is_gimple_assign (stmt2) > + && gimple_assign_rhs_code (stmt2) == MULT_EXPR) > + return false; > } > - tree use_rhs1 = gimple_assign_rhs1 (use_stmt); > - tree use_rhs2 = gimple_assign_rhs2 (use_stmt); > /* We can't handle a * b + a * b. */ > - if (use_rhs1 == use_rhs2) > + if (ops[0] == ops[1]) > return false; > /* If deferring, make sure we are not looking at an instruction that > wouldn't have existed if we were not. */ > if (state->m_deferring_p > - && (state->m_mul_result_set.contains (use_rhs1) > - || state->m_mul_result_set.contains (use_rhs2))) > + && (state->m_mul_result_set.contains (ops[0]) > + || state->m_mul_result_set.contains (ops[1]))) > return false; > if (check_defer) > { > - tree use_lhs = gimple_assign_lhs (use_stmt); > + tree use_lhs = gimple_get_lhs (use_stmt); > if (state->m_last_result) > { > - if (use_rhs2 == state->m_last_result > - || use_rhs1 == state->m_last_result) > + if (ops[1] == state->m_last_result > + || ops[0] == state->m_last_result) > defer = true; > else > defer = false; > @@ -3001,12 +3007,12 @@ convert_mult_to_fma (gimple *mul_stmt, t > { > gcc_checking_assert (!state->m_initial_phi); > gphi *phi; > - if (use_rhs1 == result) > - phi = result_of_phi (use_rhs2); > + if (ops[0] == result) > + phi = result_of_phi (ops[1]); > else > { > - gcc_assert (use_rhs2 == result); > - phi = result_of_phi (use_rhs1); > + gcc_assert (ops[1] == result); > + phi = result_of_phi (ops[0]); > } > if (phi) > Index: gcc/testsuite/gcc.dg/vect/vect-fma-2.c > =================================================================== > --- /dev/null 2018-04-20 16:19:46.369131350 +0100 > +++ gcc/testsuite/gcc.dg/vect/vect-fma-2.c 2018-05-24 13:08:24.643987582 +0100 > @@ -0,0 +1,17 @@ > +/* { dg-do compile } */ > +/* { dg-additional-options "-fdump-tree-optimized -fassociative-math -fno-trapping-math -fno-signed-zeros" } */ > + > +#include "tree-vect.h" > + > +#define N (VECTOR_BITS * 11 / 64 + 3) > + > +double > +dot_prod (double *x, double *y) > +{ > + double sum = 0; > + for (int i = 0; i < N; ++i) > + sum += x[i] * y[i]; > + return sum; > +} > + > +/* { dg-final { scan-tree-dump { = \.COND_FMA } "optimized" { target { vect_double && { vect_fully_masked && scalar_all_fma } } } } } */ > Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c > =================================================================== > --- /dev/null 2018-04-20 16:19:46.369131350 +0100 > +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c 2018-05-24 13:08:24.643987582 +0100 > @@ -0,0 +1,18 @@ > +/* { dg-do compile } */ > +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ > + > +double > +f (double *restrict a, double *restrict b, int *lookup) > +{ > + double res = 0.0; > + for (int i = 0; i < 512; ++i) > + res += a[lookup[i]] * b[i]; > + return res; > +} > + > +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */ > +/* Check that the vector instructions are the only instructions. */ > +/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */ > +/* { dg-final { scan-assembler-not {\tfadd\t} } } */ > +/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */ > +/* { dg-final { scan-assembler-not {\tsel\t} } } */ > Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c > =================================================================== > --- /dev/null 2018-04-20 16:19:46.369131350 +0100 > +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c 2018-05-24 13:08:24.643987582 +0100 > @@ -0,0 +1,17 @@ > +/* { dg-do compile } */ > +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ > + > +#define REDUC(TYPE) \ > + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ > + { \ > + TYPE sum = 0; \ > + for (int i = 0; i < count; ++i) \ > + sum += x[i] * y[i]; \ > + return sum; \ > + } > + > +REDUC (float) > +REDUC (double) > + > +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */ > +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */ > Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c > =================================================================== > --- /dev/null 2018-04-20 16:19:46.369131350 +0100 > +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c 2018-05-24 13:08:24.643987582 +0100 > @@ -0,0 +1,17 @@ > +/* { dg-do compile } */ > +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ > + > +#define REDUC(TYPE) \ > + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ > + { \ > + TYPE sum = 0; \ > + for (int i = 0; i < count; ++i) \ > + sum -= x[i] * y[i]; \ > + return sum; \ > + } > + > +REDUC (float) > +REDUC (double) > + > +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */ > +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */
Index: gcc/doc/md.texi =================================================================== --- gcc/doc/md.texi 2018-05-16 10:23:03.590853492 +0100 +++ gcc/doc/md.texi 2018-05-16 10:23:03.886838736 +0100 @@ -6367,6 +6367,32 @@ be in a normal C @samp{?:} condition. Operands 0, 2 and 3 all have mode @var{m}, while operand 1 has the mode returned by @code{TARGET_VECTORIZE_GET_MASK_MODE}. +@cindex @code{cond_fma_rev@var{mode}} instruction pattern +@item @samp{cond_fma_rev@var{mode}} +Similar to @samp{cond_add@var{m}}, but compute: +@smallexample +op0 = op1 ? fma (op3, op4, op2) : op2; +@end smallexample +for scalars and: +@smallexample +op0[I] = op1[I] ? fma (op3[I], op4[I], op2[I]) : op2[I]; +@end smallexample +for vectors. The @samp{_rev} indicates that the addend (operand 2) +comes first. + +@cindex @code{cond_fnma_rev@var{mode}} instruction pattern +@item @samp{cond_fnma_rev@var{mode}} +Similar to @samp{cond_fma_rev@var{m}}, but negate operand 3 before +multiplying it. That is, the instruction performs: +@smallexample +op0 = op1 ? fma (-op3, op4, op2) : op2; +@end smallexample +for scalars and: +@smallexample +op0[I] = op1[I] ? fma (-op3[I], op4[I], op2[I]) : op2[I]; +@end smallexample +for vectors. + @cindex @code{neg@var{mode}cc} instruction pattern @item @samp{neg@var{mode}cc} Similar to @samp{mov@var{mode}cc} but for conditional negation. Conditionally Index: gcc/optabs.def =================================================================== --- gcc/optabs.def 2018-05-16 10:23:03.590853492 +0100 +++ gcc/optabs.def 2018-05-16 10:23:03.887838686 +0100 @@ -222,6 +222,8 @@ OPTAB_D (notcc_optab, "not$acc") OPTAB_D (movcc_optab, "mov$acc") OPTAB_D (cond_add_optab, "cond_add$a") OPTAB_D (cond_sub_optab, "cond_sub$a") +OPTAB_D (cond_fma_rev_optab, "cond_fma_rev$a") +OPTAB_D (cond_fnma_rev_optab, "cond_fnma_rev$a") OPTAB_D (cond_and_optab, "cond_and$a") OPTAB_D (cond_ior_optab, "cond_ior$a") OPTAB_D (cond_xor_optab, "cond_xor$a") Index: gcc/internal-fn.def =================================================================== --- gcc/internal-fn.def 2018-05-16 10:23:03.590853492 +0100 +++ gcc/internal-fn.def 2018-05-16 10:23:03.887838686 +0100 @@ -59,7 +59,8 @@ along with GCC; see the file COPYING3. - binary: a normal binary optab, such as vec_interleave_lo_<mode> - ternary: a normal ternary optab, such as fma<mode>4 - - cond_binary: a conditional binary optab, such as add<mode>cc + - cond_binary: a conditional binary optab, such as cond_add<mode> + - cond_ternary: a conditional ternary optab, such as cond_fma_rev<mode> - fold_left: for scalar = FN (scalar, vector), keyed off the vector mode @@ -143,6 +144,9 @@ DEF_INTERNAL_OPTAB_FN (FMS, ECF_CONST, f DEF_INTERNAL_OPTAB_FN (FNMA, ECF_CONST, fnma, ternary) DEF_INTERNAL_OPTAB_FN (FNMS, ECF_CONST, fnms, ternary) +DEF_INTERNAL_OPTAB_FN (COND_FMA_REV, ECF_CONST, cond_fma_rev, cond_ternary) +DEF_INTERNAL_OPTAB_FN (COND_FNMA_REV, ECF_CONST, cond_fnma_rev, cond_ternary) + DEF_INTERNAL_OPTAB_FN (COND_ADD, ECF_CONST, cond_add, cond_binary) DEF_INTERNAL_OPTAB_FN (COND_SUB, ECF_CONST, cond_sub, cond_binary) DEF_INTERNAL_SIGNED_OPTAB_FN (COND_MIN, ECF_CONST, first, Index: gcc/internal-fn.h =================================================================== --- gcc/internal-fn.h 2018-05-16 10:23:03.590853492 +0100 +++ gcc/internal-fn.h 2018-05-16 10:23:03.887838686 +0100 @@ -191,6 +191,8 @@ direct_internal_fn_supported_p (internal extern bool set_edom_supported_p (void); extern internal_fn get_conditional_internal_fn (tree_code); +extern bool can_interpret_as_conditional_op_p (gimple *, tree_code *, + tree *, tree (&)[3]); extern bool internal_load_fn_p (internal_fn); extern bool internal_store_fn_p (internal_fn); Index: gcc/internal-fn.c =================================================================== --- gcc/internal-fn.c 2018-05-16 10:23:03.590853492 +0100 +++ gcc/internal-fn.c 2018-05-16 10:23:03.887838686 +0100 @@ -93,6 +93,7 @@ #define binary_direct { 0, 0, true } #define ternary_direct { 0, 0, true } #define cond_unary_direct { 1, 1, true } #define cond_binary_direct { 1, 1, true } +#define cond_ternary_direct { 1, 1, true } #define while_direct { 0, 2, false } #define fold_extract_direct { 2, 2, false } #define fold_left_direct { 1, 1, false } @@ -2972,6 +2973,9 @@ #define expand_cond_unary_optab_fn(FN, S #define expand_cond_binary_optab_fn(FN, STMT, OPTAB) \ expand_direct_optab_fn (FN, STMT, OPTAB, 3) +#define expand_cond_ternary_optab_fn(FN, STMT, OPTAB) \ + expand_direct_optab_fn (FN, STMT, OPTAB, 4) + #define expand_fold_extract_optab_fn(FN, STMT, OPTAB) \ expand_direct_optab_fn (FN, STMT, OPTAB, 3) @@ -3054,6 +3058,7 @@ #define direct_binary_optab_supported_p #define direct_ternary_optab_supported_p direct_optab_supported_p #define direct_cond_unary_optab_supported_p direct_optab_supported_p #define direct_cond_binary_optab_supported_p direct_optab_supported_p +#define direct_cond_ternary_optab_supported_p direct_optab_supported_p #define direct_mask_load_optab_supported_p direct_optab_supported_p #define direct_load_lanes_optab_supported_p multi_vector_optab_supported_p #define direct_mask_load_lanes_optab_supported_p multi_vector_optab_supported_p @@ -3198,6 +3203,17 @@ #define DEF_INTERNAL_FN(CODE, FLAGS, FNS 0 }; +/* Invoke T(CODE, IFN) for each conditional function IFN that maps to a + tree code CODE. */ +#define FOR_EACH_CODE_MAPPING(T) \ + T (PLUS_EXPR, IFN_COND_ADD) \ + T (MINUS_EXPR, IFN_COND_SUB) \ + T (MIN_EXPR, IFN_COND_MIN) \ + T (MAX_EXPR, IFN_COND_MAX) \ + T (BIT_AND_EXPR, IFN_COND_AND) \ + T (BIT_IOR_EXPR, IFN_COND_IOR) \ + T (BIT_XOR_EXPR, IFN_COND_XOR) + /* Return a function that performs the conditional form of CODE, i.e.: LHS = RHS1 ? RHS2 CODE RHS3 : RHS2 @@ -3210,25 +3226,78 @@ get_conditional_internal_fn (tree_code c { switch (code) { - case PLUS_EXPR: - return IFN_COND_ADD; - case MINUS_EXPR: - return IFN_COND_SUB; - case MIN_EXPR: - return IFN_COND_MIN; - case MAX_EXPR: - return IFN_COND_MAX; - case BIT_AND_EXPR: - return IFN_COND_AND; - case BIT_IOR_EXPR: - return IFN_COND_IOR; - case BIT_XOR_EXPR: - return IFN_COND_XOR; +#define CASE(CODE, IFN) case CODE: return IFN; + FOR_EACH_CODE_MAPPING(CASE) +#undef CASE default: return IFN_LAST; } } +/* If IFN implements the conditional form of a tree code, return that + tree code, otherwise return ERROR_MARK. */ + +static tree_code +conditional_internal_fn_code (internal_fn ifn) +{ + switch (ifn) + { +#define CASE(CODE, IFN) case IFN: return CODE; + FOR_EACH_CODE_MAPPING(CASE) +#undef CASE + default: + return ERROR_MARK; + } +} + +/* Return true if STMT can be interpreted as a conditional tree code + operation of the form: + + LHS = COND ? OP (RHS1, ...) : RHS1; + + operating elementwise if the operands are vectors. This includes + the case of an all-true COND, so that the operation always happens. + + When returning true, set: + + - *CODE_OUT to the tree code + - *COND_OUT to the condition COND, or to NULL_TREE if the condition + is known to be all-true + - OPS[I] to operand I of *CODE_OUT. */ + +bool +can_interpret_as_conditional_op_p (gimple *stmt, tree_code *code_out, + tree *cond_out, tree (&ops)[3]) +{ + if (gassign *assign = dyn_cast <gassign *> (stmt)) + { + *code_out = gimple_assign_rhs_code (assign); + *cond_out = NULL_TREE; + ops[0] = gimple_assign_rhs1 (assign); + ops[1] = gimple_assign_rhs2 (assign); + ops[2] = gimple_assign_rhs3 (assign); + return true; + } + if (gcall *call = dyn_cast <gcall *> (stmt)) + if (gimple_call_internal_p (call)) + { + internal_fn ifn = gimple_call_internal_fn (call); + tree_code code = conditional_internal_fn_code (ifn); + if (code != ERROR_MARK) + { + *code_out = code; + *cond_out = gimple_call_arg (call, 0); + if (integer_truep (*cond_out)) + *cond_out = NULL_TREE; + unsigned int nargs = gimple_call_num_args (call) - 1; + for (unsigned int i = 0; i < 3; ++i) + ops[i] = i < nargs ? gimple_call_arg (call, i + 1) : NULL_TREE; + return true; + } + } + return false; +} + /* Return true if IFN is some form of load from memory. */ bool Index: gcc/tree-ssa-math-opts.c =================================================================== --- gcc/tree-ssa-math-opts.c 2018-05-16 10:23:03.590853492 +0100 +++ gcc/tree-ssa-math-opts.c 2018-05-16 10:23:03.889838586 +0100 @@ -2640,6 +2640,24 @@ convert_plusminus_to_widen (gimple_stmt_ return true; } +/* Return the internal function that implements: + + LHS = COND ? A CODE B * C : A. */ + +static internal_fn +fused_cond_internal_fn (tree_code code) +{ + switch (code) + { + case PLUS_EXPR: + return IFN_COND_FMA_REV; + case MINUS_EXPR: + return IFN_COND_FNMA_REV; + default: + gcc_unreachable (); + } +} + /* gimple_fold callback that "valueizes" everything. */ static tree @@ -2663,7 +2681,6 @@ convert_mult_to_fma_1 (tree mul_result, FOR_EACH_IMM_USE_STMT (use_stmt, imm_iter, mul_result) { gimple_stmt_iterator gsi = gsi_for_stmt (use_stmt); - enum tree_code use_code; tree addop, mulop1 = op1, result = mul_result; bool negate_p = false; gimple_seq seq = NULL; @@ -2671,8 +2688,8 @@ convert_mult_to_fma_1 (tree mul_result, if (is_gimple_debug (use_stmt)) continue; - use_code = gimple_assign_rhs_code (use_stmt); - if (use_code == NEGATE_EXPR) + if (is_gimple_assign (use_stmt) + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) { result = gimple_assign_lhs (use_stmt); use_operand_p use_p; @@ -2683,23 +2700,30 @@ convert_mult_to_fma_1 (tree mul_result, use_stmt = neguse_stmt; gsi = gsi_for_stmt (use_stmt); - use_code = gimple_assign_rhs_code (use_stmt); negate_p = true; } - if (gimple_assign_rhs1 (use_stmt) == result) - { - addop = gimple_assign_rhs2 (use_stmt); - /* a * b - c -> a * b + (-c) */ - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) - addop = gimple_build (&seq, NEGATE_EXPR, type, addop); - } + tree cond, ops[3]; + tree_code code; + if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond, ops)) + gcc_unreachable (); + addop = ops[0] == result ? ops[1] : ops[0]; + + internal_fn ifn; + if (cond) + ifn = fused_cond_internal_fn (code); else { - addop = gimple_assign_rhs1 (use_stmt); - /* a - b * c -> (-b) * c + a */ - if (gimple_assign_rhs_code (use_stmt) == MINUS_EXPR) - negate_p = !negate_p; + ifn = IFN_FMA; + if (code == MINUS_EXPR) + { + if (ops[0] == result) + /* a * b - c -> a * b + (-c) */ + addop = gimple_build (&seq, NEGATE_EXPR, type, addop); + else + /* a - b * c -> (-b) * c + a */ + negate_p = !negate_p; + } } if (negate_p) @@ -2707,8 +2731,13 @@ convert_mult_to_fma_1 (tree mul_result, if (seq) gsi_insert_seq_before (&gsi, seq, GSI_SAME_STMT); - fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); - gimple_call_set_lhs (fma_stmt, gimple_assign_lhs (use_stmt)); + + if (ifn == IFN_FMA) + fma_stmt = gimple_build_call_internal (IFN_FMA, 3, mulop1, op2, addop); + else + fma_stmt = gimple_build_call_internal (ifn, 4, cond, addop, + mulop1, op2); + gimple_set_lhs (fma_stmt, gimple_get_lhs (use_stmt)); gimple_call_set_nothrow (fma_stmt, !stmt_can_throw_internal (use_stmt)); gsi_replace (&gsi, fma_stmt, true); /* Valueize aggressively so that we generate FMS, FNMA and FNMS @@ -2891,7 +2920,6 @@ convert_mult_to_fma (gimple *mul_stmt, t as an addition. */ FOR_EACH_IMM_USE_FAST (use_p, imm_iter, mul_result) { - enum tree_code use_code; tree result = mul_result; bool negate_p = false; @@ -2912,13 +2940,9 @@ convert_mult_to_fma (gimple *mul_stmt, t if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) return false; - if (!is_gimple_assign (use_stmt)) - return false; - - use_code = gimple_assign_rhs_code (use_stmt); - /* A negate on the multiplication leads to FNMA. */ - if (use_code == NEGATE_EXPR) + if (is_gimple_assign (use_stmt) + && gimple_assign_rhs_code (use_stmt) == NEGATE_EXPR) { ssa_op_iter iter; use_operand_p usep; @@ -2940,17 +2964,19 @@ convert_mult_to_fma (gimple *mul_stmt, t use_stmt = neguse_stmt; if (gimple_bb (use_stmt) != gimple_bb (mul_stmt)) return false; - if (!is_gimple_assign (use_stmt)) - return false; - use_code = gimple_assign_rhs_code (use_stmt); negate_p = true; } - switch (use_code) + tree cond, ops[3]; + tree_code code; + if (!can_interpret_as_conditional_op_p (use_stmt, &code, &cond, ops)) + return false; + + switch (code) { case MINUS_EXPR: - if (gimple_assign_rhs2 (use_stmt) == result) + if (ops[1] == result) negate_p = !negate_p; break; case PLUS_EXPR: @@ -2960,47 +2986,52 @@ convert_mult_to_fma (gimple *mul_stmt, t return false; } - /* If the subtrahend (gimple_assign_rhs2 (use_stmt)) is computed - by a MULT_EXPR that we'll visit later, we might be able to - get a more profitable match with fnma. + if (cond) + { + /* The multiplication must be the second operand. */ + if (cond == result || ops[0] == result) + return false; + internal_fn ifn = fused_cond_internal_fn (code); + if (!direct_internal_fn_supported_p (ifn, type, opt_type)) + return false; + } + + /* If the subtrahend (OPS[1]) is computed by a MULT_EXPR that + we'll visit later, we might be able to get a more profitable + match with fnma. OTOH, if we don't, a negate / fma pair has likely lower latency that a mult / subtract pair. */ - if (use_code == MINUS_EXPR && !negate_p - && gimple_assign_rhs1 (use_stmt) == result + if (code == MINUS_EXPR + && !negate_p + && ops[0] == result && !direct_internal_fn_supported_p (IFN_FMS, type, opt_type) - && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type)) + && direct_internal_fn_supported_p (IFN_FNMA, type, opt_type) + && TREE_CODE (ops[1]) == SSA_NAME + && has_single_use (ops[1])) { - tree rhs2 = gimple_assign_rhs2 (use_stmt); - - if (TREE_CODE (rhs2) == SSA_NAME) - { - gimple *stmt2 = SSA_NAME_DEF_STMT (rhs2); - if (has_single_use (rhs2) - && is_gimple_assign (stmt2) - && gimple_assign_rhs_code (stmt2) == MULT_EXPR) - return false; - } + gimple *stmt2 = SSA_NAME_DEF_STMT (ops[1]); + if (is_gimple_assign (stmt2) + && gimple_assign_rhs_code (stmt2) == MULT_EXPR) + return false; } - tree use_rhs1 = gimple_assign_rhs1 (use_stmt); - tree use_rhs2 = gimple_assign_rhs2 (use_stmt); /* We can't handle a * b + a * b. */ - if (use_rhs1 == use_rhs2) + if (ops[0] == ops[1]) return false; /* If deferring, make sure we are not looking at an instruction that wouldn't have existed if we were not. */ if (state->m_deferring_p - && (state->m_mul_result_set.contains (use_rhs1) - || state->m_mul_result_set.contains (use_rhs2))) + && (state->m_mul_result_set.contains (ops[0]) + || state->m_mul_result_set.contains (ops[1]))) return false; if (check_defer) { - tree use_lhs = gimple_assign_lhs (use_stmt); + tree use_lhs = gimple_get_lhs (use_stmt); if (state->m_last_result) { - if (use_rhs2 == state->m_last_result - || use_rhs1 == state->m_last_result) + if (ops[1] == state->m_last_result + || ops[0] == state->m_last_result) defer = true; else defer = false; @@ -3009,12 +3040,12 @@ convert_mult_to_fma (gimple *mul_stmt, t { gcc_checking_assert (!state->m_initial_phi); gphi *phi; - if (use_rhs1 == result) - phi = result_of_phi (use_rhs2); + if (ops[0] == result) + phi = result_of_phi (ops[1]); else { - gcc_assert (use_rhs2 == result); - phi = result_of_phi (use_rhs1); + gcc_assert (ops[1] == result); + phi = result_of_phi (ops[0]); } if (phi) Index: gcc/genmatch.c =================================================================== --- gcc/genmatch.c 2018-05-16 10:23:03.590853492 +0100 +++ gcc/genmatch.c 2018-05-16 10:23:03.887838686 +0100 @@ -485,6 +485,10 @@ commutative_op (id_base *id) case CFN_FNMS: return 0; + case CFN_COND_FMA_REV: + case CFN_COND_FNMA_REV: + return 2; + default: return -1; } Index: gcc/config/aarch64/iterators.md =================================================================== --- gcc/config/aarch64/iterators.md 2018-05-16 10:23:03.590853492 +0100 +++ gcc/config/aarch64/iterators.md 2018-05-16 10:23:03.886838736 +0100 @@ -449,6 +449,8 @@ (define_c_enum "unspec" UNSPEC_COND_AND ; Used in aarch64-sve.md. UNSPEC_COND_ORR ; Used in aarch64-sve.md. UNSPEC_COND_EOR ; Used in aarch64-sve.md. + UNSPEC_COND_FMLA ; Used in aarch64-sve.md. + UNSPEC_COND_FMLS ; Used in aarch64-sve.md. UNSPEC_COND_LT ; Used in aarch64-sve.md. UNSPEC_COND_LE ; Used in aarch64-sve.md. UNSPEC_COND_EQ ; Used in aarch64-sve.md. @@ -1499,14 +1501,16 @@ (define_int_iterator UNPACK_UNSIGNED [UN (define_int_iterator MUL_HIGHPART [UNSPEC_SMUL_HIGHPART UNSPEC_UMUL_HIGHPART]) -(define_int_iterator SVE_COND_INT_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB - UNSPEC_COND_SMAX UNSPEC_COND_UMAX - UNSPEC_COND_SMIN UNSPEC_COND_UMIN - UNSPEC_COND_AND - UNSPEC_COND_ORR - UNSPEC_COND_EOR]) +(define_int_iterator SVE_COND_INT2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB + UNSPEC_COND_SMAX UNSPEC_COND_UMAX + UNSPEC_COND_SMIN UNSPEC_COND_UMIN + UNSPEC_COND_AND + UNSPEC_COND_ORR + UNSPEC_COND_EOR]) -(define_int_iterator SVE_COND_FP_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB]) +(define_int_iterator SVE_COND_FP2_OP [UNSPEC_COND_ADD UNSPEC_COND_SUB]) + +(define_int_iterator SVE_COND_FP3_OP [UNSPEC_COND_FMLA UNSPEC_COND_FMLS]) (define_int_iterator SVE_COND_FP_CMP [UNSPEC_COND_LT UNSPEC_COND_LE UNSPEC_COND_EQ UNSPEC_COND_NE @@ -1543,7 +1547,9 @@ (define_int_attr optab [(UNSPEC_ANDF "an (UNSPEC_COND_UMIN "umin") (UNSPEC_COND_AND "and") (UNSPEC_COND_ORR "ior") - (UNSPEC_COND_EOR "xor")]) + (UNSPEC_COND_EOR "xor") + (UNSPEC_COND_FMLA "fma_rev") + (UNSPEC_COND_FMLS "fnma_rev")]) (define_int_attr maxmin_uns [(UNSPEC_UMAXV "umax") (UNSPEC_UMINV "umin") @@ -1762,4 +1768,6 @@ (define_int_attr sve_int_op [(UNSPEC_CON (UNSPEC_COND_EOR "eor")]) (define_int_attr sve_fp_op [(UNSPEC_COND_ADD "fadd") - (UNSPEC_COND_SUB "fsub")]) + (UNSPEC_COND_SUB "fsub") + (UNSPEC_COND_FMLA "fmla") + (UNSPEC_COND_FMLS "fmls")]) Index: gcc/config/aarch64/aarch64-sve.md =================================================================== --- gcc/config/aarch64/aarch64-sve.md 2018-05-16 10:23:03.590853492 +0100 +++ gcc/config/aarch64/aarch64-sve.md 2018-05-16 10:23:03.883838885 +0100 @@ -1764,7 +1764,7 @@ (define_insn "cond_<optab><mode>" [(match_operand:<VPRED> 1 "register_operand" "Upl") (match_operand:SVE_I 2 "register_operand" "0") (match_operand:SVE_I 3 "register_operand" "w")] - SVE_COND_INT_OP))] + SVE_COND_INT2_OP))] "TARGET_SVE" "<sve_int_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>" ) @@ -2543,11 +2543,23 @@ (define_insn "cond_<optab><mode>" [(match_operand:<VPRED> 1 "register_operand" "Upl") (match_operand:SVE_F 2 "register_operand" "0") (match_operand:SVE_F 3 "register_operand" "w")] - SVE_COND_FP_OP))] + SVE_COND_FP2_OP))] "TARGET_SVE" "<sve_fp_op>\t%0.<Vetype>, %1/m, %0.<Vetype>, %3.<Vetype>" ) +(define_insn "cond_<optab><mode>" + [(set (match_operand:SVE_F 0 "register_operand" "=w") + (unspec:SVE_F + [(match_operand:<VPRED> 1 "register_operand" "Upl") + (match_operand:SVE_F 2 "register_operand" "0") + (match_operand:SVE_F 3 "register_operand" "w") + (match_operand:SVE_F 4 "register_operand" "w")] + SVE_COND_FP3_OP))] + "TARGET_SVE" + "<sve_fp_op>\t%0.<Vetype>, %1/m, %3.<Vetype>, %4.<Vetype>" +) + ;; Shift an SVE vector left and insert a scalar into element 0. (define_insn "vec_shl_insert_<mode>" [(set (match_operand:SVE_ALL 0 "register_operand" "=w, w") Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c =================================================================== --- /dev/null 2018-04-20 16:19:46.369131350 +0100 +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_4.c 2018-05-16 10:23:03.888838636 +0100 @@ -0,0 +1,18 @@ +/* { dg-do compile } */ +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ + +double +f (double *restrict a, double *restrict b, int *lookup) +{ + double res = 0.0; + for (int i = 0; i < 512; ++i) + res += a[lookup[i]] * b[i]; + return res; +} + +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+.d, p[0-7]/m, } 2 } } */ +/* Check that the vector instructions are the only instructions. */ +/* { dg-final { scan-assembler-times {\tfmla\t} 2 } } */ +/* { dg-final { scan-assembler-not {\tfadd\t} } } */ +/* { dg-final { scan-assembler-times {\tfaddv\td0,} 1 } } */ +/* { dg-final { scan-assembler-not {\tsel\t} } } */ Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c =================================================================== --- /dev/null 2018-04-20 16:19:46.369131350 +0100 +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_6.c 2018-05-16 10:23:03.888838636 +0100 @@ -0,0 +1,17 @@ +/* { dg-do compile } */ +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ + +#define REDUC(TYPE) \ + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ + { \ + TYPE sum = 0; \ + for (int i = 0; i < count; ++i) \ + sum += x[i] * y[i]; \ + return sum; \ + } + +REDUC (float) +REDUC (double) + +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.s, p[0-7]/m} 1 } } */ +/* { dg-final { scan-assembler-times {\tfmla\tz[0-9]+\.d, p[0-7]/m} 1 } } */ Index: gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c =================================================================== --- /dev/null 2018-04-20 16:19:46.369131350 +0100 +++ gcc/testsuite/gcc.target/aarch64/sve/reduc_7.c 2018-05-16 10:23:03.889838586 +0100 @@ -0,0 +1,17 @@ +/* { dg-do compile } */ +/* { dg-options "-O2 -ftree-vectorize -ffast-math" } */ + +#define REDUC(TYPE) \ + TYPE reduc_##TYPE (TYPE *x, TYPE *y, int count) \ + { \ + TYPE sum = 0; \ + for (int i = 0; i < count; ++i) \ + sum -= x[i] * y[i]; \ + return sum; \ + } + +REDUC (float) +REDUC (double) + +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.s, p[0-7]/m} 1 } } */ +/* { dg-final { scan-assembler-times {\tfmls\tz[0-9]+\.d, p[0-7]/m} 1 } } */