File 1102-Fuse-multiplication-with-addition.patch of Package erlang
From 6a02b044052c802b76ddf570b72d2a5943605df4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Bj=C3=B6rn=20Gustavsson?= <bjorn@erlang.org>
Date: Wed, 28 Jun 2023 09:28:26 +0200
Subject: [PATCH 2/7] Fuse multiplication with addition
Fuse a multiplication operator followed by an addition operator.
That will generally reduce the number of instructions compared
to having separate operators.
---
erts/emulator/beam/big.c | 30 ++
erts/emulator/beam/big.h | 1 +
erts/emulator/beam/erl_arith.c | 92 +++++
erts/emulator/beam/global.h | 1 +
.../beam/jit/arm/beam_asm_global.hpp.pl | 7 +-
erts/emulator/beam/jit/arm/instr_arith.cpp | 382 +++++++++++++-----
erts/emulator/beam/jit/arm/ops.tab | 23 +-
erts/emulator/beam/jit/beam_jit_common.cpp | 18 +
erts/emulator/beam/jit/beam_jit_common.hpp | 2 +
.../beam/jit/x86/beam_asm_global.hpp.pl | 6 +-
erts/emulator/beam/jit/x86/instr_arith.cpp | 248 ++++++++++--
erts/emulator/beam/jit/x86/ops.tab | 22 +-
erts/emulator/test/big_SUITE.erl | 26 +-
erts/emulator/test/small_SUITE.erl | 226 ++++++++++-
14 files changed, 922 insertions(+), 162 deletions(-)
diff --git a/erts/emulator/beam/big.c b/erts/emulator/beam/big.c
index 3c5a2bed4b..e41444979e 100644
--- a/erts/emulator/beam/big.c
+++ b/erts/emulator/beam/big.c
@@ -2560,6 +2560,36 @@ Eterm big_times(Eterm x, Eterm y, Eterm *r)
return big_norm(r, rsz, sign);
}
+/*
+** Fused multiplication and addition of bignums
+*/
+
+Eterm big_mul_add(Eterm x, Eterm y, Eterm z, Eterm *r)
+{
+ Eterm* xp = big_val(x);
+ Eterm* yp = big_val(y);
+ Eterm* zp = big_val(z);
+
+ short sign = BIG_SIGN(xp) != BIG_SIGN(yp);
+ dsize_t xsz = BIG_SIZE(xp);
+ dsize_t ysz = BIG_SIZE(yp);
+ dsize_t rsz;
+
+ if (ysz == 1)
+ rsz = D_mul(BIG_V(xp), xsz, BIG_DIGIT(yp, 0), BIG_V(r));
+ else if (xsz == 1)
+ rsz = D_mul(BIG_V(yp), ysz, BIG_DIGIT(xp, 0), BIG_V(r));
+ else if (xsz >= ysz) {
+ rsz = I_mul_karatsuba(BIG_V(xp), xsz, BIG_V(yp), ysz, BIG_V(r));
+ }
+ else {
+ rsz = I_mul_karatsuba(BIG_V(yp), ysz, BIG_V(xp), xsz, BIG_V(r));
+ }
+ return B_plus_minus(BIG_V(r), rsz, sign,
+ BIG_V(zp), BIG_SIZE(zp), (short) BIG_SIGN(zp),
+ r);
+}
+
/*
** Fused div_rem for bignums
*/
diff --git a/erts/emulator/beam/big.h b/erts/emulator/beam/big.h
index ceb35a84b8..b705421ca9 100644
--- a/erts/emulator/beam/big.h
+++ b/erts/emulator/beam/big.h
@@ -135,6 +135,7 @@ Eterm small_times(Sint, Sint, Eterm*);
Eterm big_plus(Wterm, Wterm, Eterm*);
Eterm big_minus(Eterm, Eterm, Eterm*);
Eterm big_times(Eterm, Eterm, Eterm*);
+Eterm big_mul_add(Eterm x, Eterm y, Eterm z, Eterm *r);
int big_div_rem(Eterm lhs, Eterm rhs,
Eterm *q_hp, Eterm *q,
diff --git a/erts/emulator/beam/erl_arith.c b/erts/emulator/beam/erl_arith.c
index 3e7f023e5a..88223778da 100644
--- a/erts/emulator/beam/erl_arith.c
+++ b/erts/emulator/beam/erl_arith.c
@@ -867,6 +867,98 @@ erts_mixed_times(Process* p, Eterm arg1, Eterm arg2)
}
}
+Eterm
+erts_mul_add(Process* p, Eterm arg1, Eterm arg2, Eterm arg3, Eterm* pp)
+{
+ Eterm tmp_big1[2];
+ Eterm tmp_big2[2];
+ Eterm tmp_big3[2];
+ Eterm hdr;
+ Eterm res;
+ Eterm big_arg1, big_arg2, big_arg3;
+ dsize_t sz1, sz2, sz3, sz;
+ int need_heap;
+ Eterm* hp;
+ Eterm product;
+
+ big_arg1 = arg1;
+ big_arg2 = arg2;
+ big_arg3 = arg3;
+ switch (big_arg1 & _TAG_PRIMARY_MASK) {
+ case TAG_PRIMARY_IMMED1:
+ if (is_not_small(big_arg1)) {
+ break;
+ }
+ big_arg1 = small_to_big(signed_val(big_arg1), tmp_big1);
+ /* Fall through */
+ case TAG_PRIMARY_BOXED:
+ hdr = *boxed_val(big_arg1);
+ switch ((hdr & _TAG_HEADER_MASK) >> _TAG_PRIMARY_SIZE) {
+ case (_TAG_HEADER_POS_BIG >> _TAG_PRIMARY_SIZE):
+ case (_TAG_HEADER_NEG_BIG >> _TAG_PRIMARY_SIZE):
+ switch (big_arg2 & _TAG_PRIMARY_MASK) {
+ case TAG_PRIMARY_IMMED1:
+ if (is_not_small(big_arg2)) {
+ break;
+ }
+ big_arg2 = small_to_big(signed_val(big_arg2), tmp_big2);
+ /* Fall through */
+ case TAG_PRIMARY_BOXED:
+ hdr = *boxed_val(big_arg2);
+ switch ((hdr & _TAG_HEADER_MASK) >> _TAG_PRIMARY_SIZE) {
+ case (_TAG_HEADER_POS_BIG >> _TAG_PRIMARY_SIZE):
+ case (_TAG_HEADER_NEG_BIG >> _TAG_PRIMARY_SIZE):
+ switch (big_arg3 & _TAG_PRIMARY_MASK) {
+ case TAG_PRIMARY_IMMED1:
+ if (is_not_small(big_arg3)) {
+ break;
+ }
+ big_arg3 = small_to_big(signed_val(big_arg3), tmp_big3);
+ /* Fall through */
+ case TAG_PRIMARY_BOXED:
+ hdr = *boxed_val(big_arg3);
+ switch ((hdr & _TAG_HEADER_MASK) >> _TAG_PRIMARY_SIZE) {
+ case (_TAG_HEADER_POS_BIG >> _TAG_PRIMARY_SIZE):
+ case (_TAG_HEADER_NEG_BIG >> _TAG_PRIMARY_SIZE):
+ sz1 = big_size(big_arg1);
+ sz2 = big_size(big_arg2);
+ sz3 = big_size(big_arg3);
+ sz = sz1 + sz2;
+ sz = MAX(sz, sz3) + 1;
+ need_heap = BIG_NEED_SIZE(sz);
+#ifdef DEBUG
+ need_heap++;
+#endif
+ hp = HeapFragOnlyAlloc(p, need_heap);
+
+#ifdef DEBUG
+ hp[need_heap-1] = ERTS_HOLE_MARKER;
+#endif
+ res = big_mul_add(big_arg1, big_arg2, big_arg3, hp);
+ ASSERT(hp[need_heap-1] == ERTS_HOLE_MARKER);
+ maybe_shrink(p, hp, res, need_heap);
+ if (is_nil(res)) {
+ p->freason = SYSTEM_LIMIT;
+ return THE_NON_VALUE;
+ }
+ return res;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /* At least one of the arguments is a float or invalid. */
+ product = erts_mixed_times(p, arg1, arg2);
+ *pp = product;
+ if (is_non_value(product)) {
+ return product;
+ } else {
+ return erts_mixed_plus(p, product, arg3);
+ }
+}
+
Eterm
erts_mixed_div(Process* p, Eterm arg1, Eterm arg2)
{
diff --git a/erts/emulator/beam/global.h b/erts/emulator/beam/global.h
index 75db8fe792..7f8e54b949 100644
--- a/erts/emulator/beam/global.h
+++ b/erts/emulator/beam/global.h
@@ -1598,6 +1598,7 @@ Eterm erts_unary_minus(Process* p, Eterm arg1);
Eterm erts_mixed_plus(Process* p, Eterm arg1, Eterm arg2);
Eterm erts_mixed_minus(Process* p, Eterm arg1, Eterm arg2);
Eterm erts_mixed_times(Process* p, Eterm arg1, Eterm arg2);
+Eterm erts_mul_add(Process* p, Eterm arg1, Eterm arg2, Eterm arg3, Eterm* pp);
Eterm erts_mixed_div(Process* p, Eterm arg1, Eterm arg2);
int erts_int_div_rem(Process* p, Eterm arg1, Eterm arg2, Eterm *q, Eterm *r);
diff --git a/erts/emulator/beam/jit/arm/beam_asm_global.hpp.pl b/erts/emulator/beam/jit/arm/beam_asm_global.hpp.pl
index 59524b32c7..93b239ddbd 100644
--- a/erts/emulator/beam/jit/arm/beam_asm_global.hpp.pl
+++ b/erts/emulator/beam/jit/arm/beam_asm_global.hpp.pl
@@ -92,11 +92,16 @@ my @beam_global_funcs = qw(
i_loop_rec_shared
i_test_yield_shared
i_bxor_body_shared
+ int128_to_big_shared
int_div_rem_body_shared
int_div_rem_guard_shared
is_in_range_shared
is_ge_lt_shared
minus_body_shared
+ mul_add_body_shared
+ mul_add_guard_shared
+ mul_body_shared
+ mul_guard_shared
new_map_shared
update_map_assoc_shared
unloaded_fun
@@ -106,8 +111,6 @@ my @beam_global_funcs = qw(
raise_exception
raise_exception_shared
store_unaligned
- times_body_shared
- times_guard_shared
unary_minus_body_shared
update_map_exact_guard_shared
update_map_exact_body_shared
diff --git a/erts/emulator/beam/jit/arm/instr_arith.cpp b/erts/emulator/beam/jit/arm/instr_arith.cpp
index 485f93956d..8ca898b675 100644
--- a/erts/emulator/beam/jit/arm/instr_arith.cpp
+++ b/erts/emulator/beam/jit/arm/instr_arith.cpp
@@ -82,9 +82,15 @@ void BeamModuleAssembler::emit_are_both_small(const ArgSource &LHS,
a.and_(TMP1, lhs_reg, rhs_reg);
emit_is_boxed(next, TMP1);
} else {
- ERTS_CT_ASSERT(_TAG_IMMED1_SMALL == _TAG_IMMED1_MASK);
- a.and_(TMP1, lhs_reg, rhs_reg);
- a.and_(TMP1, TMP1, imm(_TAG_IMMED1_MASK));
+ if (always_small(RHS)) {
+ a.and_(TMP1, lhs_reg, imm(_TAG_IMMED1_MASK));
+ } else if (always_small(LHS)) {
+ a.and_(TMP1, rhs_reg, imm(_TAG_IMMED1_MASK));
+ } else {
+ ERTS_CT_ASSERT(_TAG_IMMED1_SMALL == _TAG_IMMED1_MASK);
+ a.and_(TMP1, lhs_reg, rhs_reg);
+ a.and_(TMP1, TMP1, imm(_TAG_IMMED1_MASK));
+ }
a.cmp(TMP1, imm(_TAG_IMMED1_SMALL));
a.b_eq(next);
}
@@ -376,45 +382,35 @@ void BeamModuleAssembler::emit_i_minus(const ArgLabel &Fail,
mov_arg(Dst, ARG1);
}
-/* ARG2 = LHS
- * ARG3 = RHS
+/*
+ * Create a bignum from a the 128-bit product of two smalls shifted
+ * left _TAG_IMMED1_SIZE bits.
*
- * The result is returned in ARG1 (set to THE_NON_VALUE if
- * the call failed).
+ * ARG1 = low 64 bits
+ * TMP2 = high 64 bits
+ *
+ * The result is returned in ARG1.
*/
-void BeamGlobalAssembler::emit_times_guard_shared() {
- Label generic = a.newLabel();
-
- /* Speculatively untag and multiply. */
- a.and_(TMP1, ARG2, imm(~_TAG_IMMED1_MASK));
- a.asr(TMP2, ARG3, imm(_TAG_IMMED1_SIZE));
- a.mul(TMP3, TMP1, TMP2);
- a.smulh(TMP4, TMP1, TMP2);
+void BeamGlobalAssembler::emit_int128_to_big_shared() {
+ Label positive = a.newLabel();
- /* Check that both operands are small integers. */
- ERTS_CT_ASSERT(_TAG_IMMED1_SMALL == _TAG_IMMED1_MASK);
- a.and_(TMP1, ARG2, ARG3);
- a.and_(TMP1, TMP1, imm(_TAG_IMMED1_MASK));
- a.cmp(TMP1, imm(_TAG_IMMED1_SMALL));
- a.b_ne(generic);
+ a.extr(ARG3, TMP2, ARG1, imm(_TAG_IMMED1_SIZE));
+ a.asr(ARG4, TMP2, imm(_TAG_IMMED1_SIZE));
- /* The high 65 bits of result will all be the same if no overflow
- * occurred. Another way to say that is that the sign bit of the
- * low 64 bits repeated 64 times must be equal to the high 64 bits
- * of the product. */
- a.cmp(TMP4, TMP3, arm::asr(63));
- a.b_ne(generic);
+ a.mov(ARG1, c_p);
- a.orr(ARG1, TMP3, imm(_TAG_IMMED1_SMALL));
- a.ret(a64::x30);
+ a.cmp(ARG4, imm(0));
+ a.cset(ARG2, arm::CondCode::kMI);
- a.bind(generic);
+ a.b_pl(positive);
+ a.negs(ARG3, ARG3);
+ a.ngc(ARG4, ARG4);
+ a.bind(positive);
emit_enter_runtime_frame();
emit_enter_runtime();
- a.mov(ARG1, c_p);
- runtime_call<3>(erts_mixed_times);
+ runtime_call<4>(beam_jit_int128_to_big);
emit_leave_runtime();
emit_leave_runtime_frame();
@@ -422,111 +418,295 @@ void BeamGlobalAssembler::emit_times_guard_shared() {
a.ret(a64::x30);
}
-/* ARG2 = LHS
- * ARG3 = RHS
+/* ARG2 = Src1
+ * ARG3 = Src2
+ * ARG4 = Src4
*
* The result is returned in ARG1.
*/
-void BeamGlobalAssembler::emit_times_body_shared() {
- Label generic = a.newLabel(), error = a.newLabel();
+void BeamGlobalAssembler::emit_mul_add_body_shared() {
+ Label mul_only = a.newLabel(), error = a.newLabel(),
+ mul_error = a.newLabel(), do_error = a.newLabel();
- /* Speculatively untag and multiply. */
- a.and_(TMP1, ARG2, imm(~_TAG_IMMED1_MASK));
- a.asr(TMP2, ARG3, imm(_TAG_IMMED1_SIZE));
- a.mul(TMP3, TMP1, TMP2);
- a.smulh(TMP4, TMP1, TMP2);
+ emit_enter_runtime_frame();
+ emit_enter_runtime();
- /* Check that both operands are integers. */
- ERTS_CT_ASSERT(_TAG_IMMED1_SMALL == _TAG_IMMED1_MASK);
- a.and_(TMP1, ARG2, ARG3);
- a.and_(TMP1, TMP1, imm(_TAG_IMMED1_MASK));
- a.cmp(TMP1, imm(_TAG_IMMED1_SMALL));
- a.b_ne(generic);
+ /* Save original arguments. */
+ a.stp(ARG2, ARG3, TMP_MEM1q);
+ a.mov(ARG1, c_p);
+ a.cmp(ARG4, imm(make_small(0)));
+ a.b_eq(mul_only);
+ a.str(ARG4, TMP_MEM4q);
- /* The high 65 bits of result will all be the same if no overflow
- * occurred. Another way to say that is that the sign bit of the
- * low 64 bits repeated 64 times must be equal to the high 64 bits
- * of the product. */
- a.cmp(TMP4, TMP3, arm::asr(63));
- a.b_ne(generic);
+ lea(ARG5, TMP_MEM3q);
+ runtime_call<5>(erts_mul_add);
+
+ emit_leave_runtime();
+ emit_leave_runtime_frame();
- a.orr(ARG1, TMP3, imm(_TAG_IMMED1_SMALL));
+ emit_branch_if_not_value(ARG1, error);
a.ret(a64::x30);
- a.bind(generic);
+ a.bind(mul_only);
+ {
+ runtime_call<3>(erts_mixed_times);
- /* Save original arguments for the error path. */
- a.stp(ARG2, ARG3, TMP_MEM1q);
+ emit_leave_runtime();
+ emit_leave_runtime_frame();
+
+ emit_branch_if_not_value(ARG1, mul_error);
+ a.ret(a64::x30);
+ }
+
+ a.bind(error);
+ {
+ static const ErtsCodeMFA mul_mfa = {am_erlang, am_Times, 2};
+ static const ErtsCodeMFA add_mfa = {am_erlang, am_Plus, 2};
+
+ a.ldp(XREG0, XREG1, TMP_MEM3q);
+ mov_imm(ARG4, &add_mfa);
+ emit_branch_if_value(XREG0, do_error);
+
+ a.bind(mul_error);
+ a.ldp(XREG0, XREG1, TMP_MEM1q);
+ mov_imm(ARG4, &mul_mfa);
+
+ a.bind(do_error);
+ a.b(labels[raise_exception]);
+ }
+}
+
+/* ARG2 = Src1
+ * ARG3 = Src2
+ * ARG4 = Src4
+ *
+ * The result is returned in ARG1 (set to THE_NON_VALUE if
+ * the call failed).
+ */
+void BeamGlobalAssembler::emit_mul_add_guard_shared() {
+ Label mul_failed = a.newLabel();
+
+ a.str(ARG4, TMP_MEM1q);
emit_enter_runtime_frame();
emit_enter_runtime();
a.mov(ARG1, c_p);
runtime_call<3>(erts_mixed_times);
+ emit_branch_if_not_value(ARG1, mul_failed);
+
+ a.ldr(ARG3, TMP_MEM1q);
+ a.mov(ARG2, ARG1);
+ a.mov(ARG1, c_p);
+ runtime_call<3>(erts_mixed_plus);
+ a.bind(mul_failed);
emit_leave_runtime();
emit_leave_runtime_frame();
- emit_branch_if_not_value(ARG1, error);
-
a.ret(a64::x30);
+}
- a.bind(error);
- {
- static const ErtsCodeMFA bif_mfa = {am_erlang, am_Times, 2};
+/* ARG2 = Src1
+ * ARG3 = Src2
+ *
+ * The result is returned in ARG1.
+ */
+void BeamGlobalAssembler::emit_mul_body_shared() {
+ mov_imm(ARG4, make_small(0));
+ a.b(labels[mul_add_body_shared]);
+}
- /* Place the original arguments in x-registers. */
- a.ldp(XREG0, XREG1, TMP_MEM1q);
- mov_imm(ARG4, &bif_mfa);
- a.b(labels[raise_exception]);
- }
+/* ARG2 = Src1
+ * ARG3 = Src2
+ *
+ * The result is returned in ARG1 (set to THE_NON_VALUE if
+ * the call failed).
+ */
+void BeamGlobalAssembler::emit_mul_guard_shared() {
+ mov_imm(ARG4, make_small(0));
+ a.b(labels[mul_add_guard_shared]);
}
-void BeamModuleAssembler::emit_i_times(const ArgLabel &Fail,
- const ArgWord &Live,
- const ArgSource &LHS,
- const ArgSource &RHS,
- const ArgRegister &Dst) {
- bool is_small_result = is_product_small_if_args_are_small(LHS, RHS);
+void BeamModuleAssembler::emit_i_mul_add(const ArgLabel &Fail,
+ const ArgSource &Src1,
+ const ArgSource &Src2,
+ const ArgSource &Src3,
+ const ArgSource &Src4,
+ const ArgRegister &Dst) {
+ bool is_product_small = is_product_small_if_args_are_small(Src1, Src2);
+ bool is_sum_small = is_sum_small_if_args_are_small(Src3, Src4);
+ bool is_increment_zero =
+ Src4.isSmall() && Src4.as<ArgSmall>().getSigned() == 0;
+ Sint factor = 0;
+ int left_shift = -1;
+
+ if (is_increment_zero) {
+ comment("(adding zero)");
+ }
- if (always_small(LHS) && always_small(RHS) && is_small_result) {
+ if (Src2.isSmall()) {
+ factor = Src2.as<ArgSmall>().getSigned();
+ if (Support::isPowerOf2(factor)) {
+ left_shift = Support::ctz<Eterm>(factor);
+ }
+ }
+
+ if (always_small(Src1) && Src2.isSmall() && always_small(Src4) &&
+ is_product_small && is_sum_small) {
auto dst = init_destination(Dst, ARG1);
- comment("multiplication without overflow check");
- if (RHS.isSmall()) {
- auto lhs = load_source(LHS, ARG2);
- Sint factor = RHS.as<ArgSmall>().getSigned();
+ auto [src1, src4] = load_sources(Src1, ARG2, Src4, ARG3);
+
+ comment("multiplication and addition without overflow check");
+ a.and_(TMP1, src1.reg, imm(~_TAG_IMMED1_MASK));
+ if (left_shift > 0) {
+ comment("optimized multiplication by replacing with left "
+ "shift");
+ a.add(dst.reg, src4.reg, TMP1, arm::lsl(left_shift));
+ } else {
+ mov_imm(TMP2, factor);
+ a.madd(dst.reg, TMP1, TMP2, src4.reg);
+ }
+ flush_var(dst);
+ } else {
+ Label small = a.newLabel();
+ Label store_result = a.newLabel();
+ auto [src1, src2] = load_sources(Src1, ARG2, Src2, ARG3);
+ auto src4 = load_source(ArgXRegister(0), XREG0);
- a.and_(TMP1, lhs.reg, imm(~_TAG_IMMED1_MASK));
- if (Support::isPowerOf2(factor)) {
- int trailing_bits = Support::ctz<Eterm>(factor);
- comment("optimized multiplication by replacing with left "
- "shift");
- a.lsl(TMP1, TMP1, imm(trailing_bits));
+ if (!is_increment_zero) {
+ src4 = load_source(Src4, ARG4);
+ }
+
+ if (always_small(Src1) && always_small(Src2) && always_small(Src4)) {
+ comment("skipped test for small operands since they are always "
+ "small");
+ } else {
+ if (always_small(Src4)) {
+ emit_are_both_small(Src1, src1.reg, Src2, src2.reg, small);
+ } else if (always_small(Src2)) {
+ emit_are_both_small(Src1, src1.reg, Src4, src4.reg, small);
+ } else {
+ ASSERT(!is_increment_zero);
+ ERTS_CT_ASSERT(_TAG_IMMED1_SMALL == _TAG_IMMED1_MASK);
+ a.and_(TMP1, src1.reg, src2.reg);
+ a.and_(TMP1, TMP1, src4.reg);
+ if (always_one_of<BeamTypeId::Integer, BeamTypeId::AlwaysBoxed>(
+ Src1) &&
+ always_one_of<BeamTypeId::Integer, BeamTypeId::AlwaysBoxed>(
+ Src2) &&
+ always_one_of<BeamTypeId::Integer, BeamTypeId::AlwaysBoxed>(
+ Src4)) {
+ emit_is_boxed(small, TMP1);
+ } else {
+ a.and_(TMP1, TMP1, imm(_TAG_IMMED1_MASK));
+ a.cmp(TMP1, imm(_TAG_IMMED1_SMALL));
+ a.b_eq(small);
+ }
+ }
+
+ mov_var(ARG2, src1);
+ mov_var(ARG3, src2);
+
+ if (Fail.get() != 0) {
+ if (is_increment_zero) {
+ fragment_call(ga->get_mul_guard_shared());
+ } else {
+ mov_var(ARG4, src4);
+ fragment_call(ga->get_mul_add_guard_shared());
+ }
+ emit_branch_if_not_value(ARG1,
+ resolve_beam_label(Fail, dispUnknown));
} else {
- mov_imm(TMP2, factor);
- a.mul(TMP1, TMP1, TMP2);
+ if (is_increment_zero) {
+ fragment_call(ga->get_mul_body_shared());
+ } else {
+ mov_var(ARG4, src4);
+ fragment_call(ga->get_mul_add_body_shared());
+ }
}
+
+ a.b(store_result);
+ }
+
+ a.bind(small);
+ if (is_increment_zero) {
+ comment("multiply smalls");
} else {
- auto [lhs, rhs] = load_sources(LHS, ARG2, RHS, ARG3);
- a.and_(TMP1, lhs.reg, imm(~_TAG_IMMED1_MASK));
- a.asr(TMP2, rhs.reg, imm(_TAG_IMMED1_SIZE));
- a.mul(TMP1, TMP1, TMP2);
+ comment("multiply and add smalls");
}
- a.orr(dst.reg, TMP1, imm(_TAG_IMMED1_SMALL));
- flush_var(dst);
- } else {
- auto [lhs, rhs] = load_sources(LHS, ARG2, RHS, ARG3);
- mov_var(ARG2, lhs);
- mov_var(ARG3, rhs);
- if (Fail.get() != 0) {
- fragment_call(ga->get_times_guard_shared());
- emit_branch_if_not_value(ARG1,
- resolve_beam_label(Fail, dispUnknown));
+ if (is_product_small && is_sum_small) {
+ arm::Gp increment_reg;
+
+ a.and_(TMP3, src1.reg, imm(~_TAG_IMMED1_MASK));
+
+ if (is_increment_zero) {
+ mov_imm(TMP1, make_small(0));
+ increment_reg = TMP1;
+ } else {
+ increment_reg = src4.reg;
+ }
+
+ if (left_shift > 0) {
+ comment("optimized multiplication by replacing with left "
+ "shift");
+ a.add(ARG1, increment_reg, TMP3, arm::lsl(left_shift));
+ } else {
+ a.asr(TMP4, src2.reg, imm(_TAG_IMMED1_SIZE));
+ a.madd(ARG1, TMP3, TMP4, increment_reg);
+ }
+
+ comment("skipped test for small result");
} else {
- fragment_call(ga->get_times_body_shared());
+ auto min_increment = std::get<0>(getClampedRange(Src4));
+
+ a.and_(TMP3, src1.reg, imm(~_TAG_IMMED1_MASK));
+ if (left_shift == 0) {
+ comment("optimized multiplication by one");
+ a.mov(ARG1, TMP3);
+ a.asr(TMP2, TMP3, imm(63));
+ } else if (left_shift > 0) {
+ comment("optimized multiplication by replacing with left "
+ "shift");
+ a.lsl(ARG1, TMP3, imm(left_shift));
+ a.asr(TMP2, TMP3, imm(64 - left_shift));
+ } else {
+ ASSERT(left_shift == -1);
+ a.asr(TMP4, src2.reg, imm(_TAG_IMMED1_SIZE));
+ a.mul(ARG1, TMP3, TMP4);
+ a.smulh(TMP2, TMP3, TMP4);
+ }
+
+ if (is_increment_zero) {
+ a.add(ARG1, ARG1, imm(_TAG_IMMED1_SMALL));
+ } else {
+ arm::Gp sign_reg;
+
+ if (min_increment > 0) {
+ sign_reg = ZERO;
+ } else {
+ sign_reg = TMP3;
+ a.asr(sign_reg, src4.reg, imm(63));
+ }
+
+ a.adds(ARG1, ARG1, src4.reg);
+ a.adc(TMP2, TMP2, sign_reg);
+ }
+
+ comment("test whether the result fits in a small");
+ /* The high 65 bits of result will all be the same if no
+ * overflow occurred. Another way to say that is that the
+ * sign bit of the low 64 bits repeated 64 times must be
+ * equal to the high 64 bits of the result. */
+ a.asr(TMP3, ARG1, imm(SMALL_BITS + _TAG_IMMED1_SIZE - 1));
+ a.cmp(TMP2, TMP3);
+ a.b_eq(store_result);
+
+ fragment_call(ga->get_int128_to_big_shared());
}
+ a.bind(store_result);
mov_arg(Dst, ARG1);
}
}
diff --git a/erts/emulator/beam/jit/arm/ops.tab b/erts/emulator/beam/jit/arm/ops.tab
index ed8c51ae3a..b0c79d3e2c 100644
--- a/erts/emulator/beam/jit/arm/ops.tab
+++ b/erts/emulator/beam/jit/arm/ops.tab
@@ -1256,6 +1256,23 @@ i_get_map_element f S S S
# Arithmetic instructions.
#
+gc_bif2 Fail1 Live1 u$bif:erlang:stimes/2 S1 S2 Dst1 |
+ gc_bif2 Fail2 Live2 u$bif:erlang:splus/2 S3 S4 Dst2 |
+ equal(Dst1, S3) |
+ equal(Dst1, Dst2) |
+ equal(Fail1, Fail2) =>
+ i_mul_add Fail1 S1 S2 S3 S4 Dst1
+
+gc_bif2 Fail1 Live1 u$bif:erlang:stimes/2 S1 S2 Dst1 |
+ gc_bif2 Fail2 Live2 u$bif:erlang:splus/2 S3 S4 Dst2 |
+ equal(Dst1, S4) |
+ equal(Dst1, Dst2) |
+ equal(Fail1, Fail2) =>
+ i_mul_add Fail1 S1 S2 S4 S3 Dst1
+
+gc_bif2 Fail Live u$bif:erlang:stimes/2 S1 S2 Dst =>
+ i_mul_add Fail S1 S2 Dst i Dst
+
gc_bif2 Fail Live u$bif:erlang:splus/2 Src1 Src2 Dst =>
i_plus Fail Live Src1 Src2 Dst
@@ -1265,9 +1282,6 @@ gc_bif1 Fail Live u$bif:erlang:sminus/1 Src Dst =>
gc_bif2 Fail Live u$bif:erlang:sminus/2 Src1 Src2 Dst =>
i_minus Fail Live Src1 Src2 Dst
-gc_bif2 Fail Live u$bif:erlang:stimes/2 S1 S2 Dst =>
- i_times Fail Live S1 S2 Dst
-
gc_bif2 Fail Live u$bif:erlang:div/2 S1 S2 Dst =>
i_m_div Fail Live S1 S2 Dst
@@ -1332,10 +1346,11 @@ gc_bif2 Fail Live u$bif:erlang:bsr/2 S1 S2 Dst =>
gc_bif2 Fail Live u$bif:erlang:bsl/2 S1 S2 Dst =>
i_bsl Fail Live S1 S2 Dst
+i_mul_add j s s s s d
+
i_plus j I s s d
i_unary_minus j I s d
i_minus j I s s d
-i_times j I s s d
i_m_div j I s s d
diff --git a/erts/emulator/beam/jit/beam_jit_common.cpp b/erts/emulator/beam/jit/beam_jit_common.cpp
index 1465c3842f..8e78e2cf16 100644
--- a/erts/emulator/beam/jit/beam_jit_common.cpp
+++ b/erts/emulator/beam/jit/beam_jit_common.cpp
@@ -1087,6 +1087,24 @@ Sint beam_jit_bs_bit_size(Eterm term) {
return (Sint)-1;
}
+Eterm beam_jit_int128_to_big(Process *p, Uint sign, Uint low, Uint high) {
+ Eterm *hp;
+ Uint arity;
+
+ arity = high ? 2 : 1;
+ hp = HeapFragOnlyAlloc(p, BIG_NEED_SIZE(arity));
+ if (sign) {
+ hp[0] = make_neg_bignum_header(arity);
+ } else {
+ hp[0] = make_pos_bignum_header(arity);
+ }
+ BIG_DIGIT(hp, 0) = low;
+ if (arity == 2) {
+ BIG_DIGIT(hp, 1) = high;
+ }
+ return make_big(hp);
+}
+
ErtsMessage *beam_jit_decode_dist(Process *c_p, ErtsMessage *msgp) {
if (!erts_proc_sig_decode_dist(c_p, ERTS_PROC_LOCK_MAIN, msgp, 0)) {
/*
diff --git a/erts/emulator/beam/jit/beam_jit_common.hpp b/erts/emulator/beam/jit/beam_jit_common.hpp
index b6f7239fae..ddd9c245dc 100644
--- a/erts/emulator/beam/jit/beam_jit_common.hpp
+++ b/erts/emulator/beam/jit/beam_jit_common.hpp
@@ -628,6 +628,8 @@ void beam_jit_bs_construct_fail_info(Process *c_p,
Eterm arg1);
Sint beam_jit_bs_bit_size(Eterm term);
+Eterm beam_jit_int128_to_big(Process *p, Uint sign, Uint low, Uint high);
+
void beam_jit_take_receive_lock(Process *c_p);
void beam_jit_wait_locked(Process *c_p, ErtsCodePtr cp);
void beam_jit_wait_unlocked(Process *c_p, ErtsCodePtr cp);
diff --git a/erts/emulator/beam/jit/x86/beam_asm_global.hpp.pl b/erts/emulator/beam/jit/x86/beam_asm_global.hpp.pl
index 3c620462b3..9782bbb226 100755
--- a/erts/emulator/beam/jit/x86/beam_asm_global.hpp.pl
+++ b/erts/emulator/beam/jit/x86/beam_asm_global.hpp.pl
@@ -90,6 +90,10 @@ my @beam_global_funcs = qw(
is_ge_lt_shared
minus_body_shared
minus_guard_shared
+ mul_add_body_shared
+ mul_add_guard_shared
+ mul_body_shared
+ mul_guard_shared
new_map_shared
plus_body_shared
plus_guard_shared
@@ -98,8 +102,6 @@ my @beam_global_funcs = qw(
raise_exception
raise_exception_shared
store_unaligned
- times_body_shared
- times_guard_shared
unary_minus_body_shared
unary_minus_guard_shared
unloaded_fun
diff --git a/erts/emulator/beam/jit/x86/instr_arith.cpp b/erts/emulator/beam/jit/x86/instr_arith.cpp
index 888f3109f1..fdb021fa7c 100644
--- a/erts/emulator/beam/jit/x86/instr_arith.cpp
+++ b/erts/emulator/beam/jit/x86/instr_arith.cpp
@@ -823,16 +823,32 @@ void BeamModuleAssembler::emit_i_m_div(const ArgLabel &Fail,
mov_arg(Dst, RET);
}
-/* ARG2 = LHS, ARG3 (!) = RHS
+/* ARG2 = Src1
+ * ARG3 = Src2
+ * ARG4 = Increment
*
* Result is returned in RET, error is indicated by ZF. */
-void BeamGlobalAssembler::emit_times_guard_shared() {
+void BeamGlobalAssembler::emit_mul_add_guard_shared() {
+ Label done = a.newLabel();
+
emit_enter_frame();
emit_enter_runtime();
+ a.mov(TMP_MEM1q, ARG4);
+
a.mov(ARG1, c_p);
runtime_call<3>(erts_mixed_times);
+ emit_test_the_non_value(RET);
+ a.short_().je(done);
+
+ a.mov(ARG3, TMP_MEM1q);
+ a.mov(ARG2, RET);
+ a.mov(ARG1, c_p);
+ a.cmp(ARG3, imm(make_small(0)));
+ a.short_().je(done);
+ runtime_call<3>(erts_mixed_plus);
+ a.bind(done);
emit_leave_runtime();
emit_leave_frame();
@@ -841,13 +857,14 @@ void BeamGlobalAssembler::emit_times_guard_shared() {
a.ret();
}
-/* ARG2 = LHS, ARG3 (!) = RHS
+/* ARG2 = Src1
+ * ARG3 = Src2
+ * ARG4 = Increment
*
* Result is returned in RET. */
-void BeamGlobalAssembler::emit_times_body_shared() {
- static const ErtsCodeMFA bif_mfa = {am_erlang, am_Times, 2};
-
- Label error = a.newLabel();
+void BeamGlobalAssembler::emit_mul_add_body_shared() {
+ Label mul_only = a.newLabel(), error = a.newLabel(),
+ mul_error = a.newLabel(), do_error = a.newLabel();
emit_enter_frame();
emit_enter_runtime();
@@ -855,61 +872,166 @@ void BeamGlobalAssembler::emit_times_body_shared() {
/* Save original arguments for the error path. */
a.mov(TMP_MEM1q, ARG2);
a.mov(TMP_MEM2q, ARG3);
-
a.mov(ARG1, c_p);
- runtime_call<3>(erts_mixed_times);
+ a.cmp(ARG4, imm(make_small(0)));
+ a.short_().je(mul_only);
+ a.mov(TMP_MEM4q, ARG4);
+
+ a.lea(ARG5, TMP_MEM3q);
+ runtime_call<5>(erts_mul_add);
emit_leave_runtime();
emit_leave_frame();
emit_test_the_non_value(RET);
a.short_().je(error);
+
a.ret();
+ a.bind(mul_only);
+ {
+ runtime_call<3>(erts_mixed_times);
+
+ emit_leave_runtime();
+ emit_leave_frame();
+
+ emit_test_the_non_value(RET);
+ a.short_().je(mul_error);
+
+ a.ret();
+ }
+
a.bind(error);
{
- /* Place the original arguments in x-registers. */
+ static const ErtsCodeMFA mul_mfa = {am_erlang, am_Times, 2};
+ static const ErtsCodeMFA add_mfa = {am_erlang, am_Plus, 2};
+
+ a.mov(ARG1, TMP_MEM3q);
+ a.mov(ARG2, TMP_MEM4q);
+ mov_imm(ARG4, &add_mfa);
+ emit_test_the_non_value(ARG1);
+ a.short_().jne(do_error);
+
+ a.bind(mul_error);
a.mov(ARG1, TMP_MEM1q);
a.mov(ARG2, TMP_MEM2q);
+ mov_imm(ARG4, &mul_mfa);
+
+ a.bind(do_error);
a.mov(getXRef(0), ARG1);
a.mov(getXRef(1), ARG2);
-
- a.mov(ARG4, imm(&bif_mfa));
a.jmp(labels[raise_exception]);
}
}
-void BeamModuleAssembler::emit_i_times(const ArgLabel &Fail,
- const ArgSource &LHS,
- const ArgSource &RHS,
- const ArgRegister &Dst) {
- bool small_result = is_product_small_if_args_are_small(LHS, RHS);
+/* ARG2 = Src1
+ * ARG3 = Src2
+ *
+ * The result is returned in RET.
+ */
+void BeamGlobalAssembler::emit_mul_body_shared() {
+ mov_imm(ARG4, make_small(0));
+ a.jmp(labels[mul_add_body_shared]);
+}
- if (always_small(LHS) && always_small(RHS) && small_result) {
- comment("multiplication without overflow check");
- if (RHS.isSmall()) {
- Sint factor = RHS.as<ArgSmall>().getSigned();
+/* ARG2 = Src1
+ * ARG3 = Src2
+ *
+ * Result is returned in RET, error is indicated by ZF.
+ */
+void BeamGlobalAssembler::emit_mul_guard_shared() {
+ mov_imm(ARG4, make_small(0));
+ a.jmp(labels[mul_add_guard_shared]);
+}
+
+void BeamModuleAssembler::emit_i_mul_add(const ArgLabel &Fail,
+ const ArgSource &Src1,
+ const ArgSource &Src2,
+ const ArgSource &Src3,
+ const ArgSource &Src4,
+ const ArgRegister &Dst) {
+ bool is_product_small = is_product_small_if_args_are_small(Src1, Src2);
+ bool is_sum_small = is_sum_small_if_args_are_small(Src3, Src4);
+ bool is_increment_zero =
+ Src4.isSmall() && Src4.as<ArgSmall>().getSigned() == 0;
+ Sint factor = 0;
+ int left_shift = -1;
+
+ if (is_increment_zero) {
+ comment("(adding zero)");
+ }
+
+ if (Src2.isSmall()) {
+ factor = Src2.as<ArgSmall>().getSigned();
+ if (Support::isPowerOf2(factor)) {
+ left_shift = Support::ctz<Eterm>(factor);
+ }
+ }
- mov_arg(RET, LHS);
+ if (always_small(Src1) && Src2.isSmall() && Src4.isSmall() &&
+ is_product_small && is_sum_small) {
+ x86::Mem p;
+ Sint increment = Src4.as<ArgSmall>().get();
+ increment -= factor * _TAG_IMMED1_SMALL;
+
+ switch (factor) {
+ case 2:
+ p = ptr(RET, RET, 0, increment);
+ break;
+ case 3:
+ p = ptr(RET, RET, 1, increment);
+ break;
+ case 4:
+ p = ptr(x86::Gp(), RET, 2, increment);
+ break;
+ case 5:
+ p = ptr(RET, RET, 2, increment);
+ break;
+ case 8:
+ p = ptr(x86::Gp(), RET, 3, increment);
+ break;
+ case 9:
+ p = ptr(RET, RET, 3, increment);
+ break;
+ }
+
+ if (Support::isInt32(increment) && p.hasIndex()) {
+ comment("optimizing multiplication and addition using LEA");
+ mov_arg(RET, Src1);
+ a.lea(RET, p);
+ mov_arg(Dst, RET);
+ return;
+ }
+ }
+
+ if (always_small(Src1) && Src2.isSmall() && always_small(Src4) &&
+ is_product_small && is_sum_small) {
+ comment("multiplication and addition without overflow check");
+ if (Src2.isSmall()) {
+ mov_arg(RET, Src1);
a.and_(RET, imm(~_TAG_IMMED1_MASK));
if (Support::isPowerOf2(factor)) {
- int trailing_bits = Support::ctz<Eterm>(factor);
comment("optimized multiplication by replacing with left "
"shift");
- a.shl(RET, imm(trailing_bits));
+ a.shl(RET, imm(left_shift));
} else {
mov_imm(ARG2, factor);
a.imul(RET, ARG2);
}
} else {
- mov_arg(RET, LHS);
- mov_arg(ARG2, RHS);
+ mov_arg(RET, Src1);
+ mov_arg(ARG2, Src2);
a.and_(RET, imm(~_TAG_IMMED1_MASK));
a.sar(ARG2, imm(_TAG_IMMED1_SIZE));
a.imul(RET, ARG2);
}
- a.or_(RET, imm(_TAG_IMMED1_SMALL));
+ if (is_increment_zero) {
+ a.or_(RET, imm(_TAG_IMMED1_SMALL));
+ } else {
+ mov_arg(ARG2, Src4);
+ a.add(RET, ARG2);
+ }
mov_arg(Dst, RET);
return;
@@ -917,39 +1039,81 @@ void BeamModuleAssembler::emit_i_times(const ArgLabel &Fail,
Label next = a.newLabel(), mixed = a.newLabel();
- mov_arg(ARG2, LHS); /* Used by erts_mixed_times in this slot */
- mov_arg(ARG3, RHS); /* Used by erts_mixed_times in this slot */
+ mov_arg(ARG2, Src1);
+ mov_arg(ARG3, Src2);
+ if (!is_increment_zero) {
+ mov_arg(ARG4, Src4);
+ }
- if (RHS.isSmall()) {
- Sint val = RHS.as<ArgSmall>().getSigned();
- emit_is_small(mixed, LHS, ARG2);
+ if (Src2.isSmall()) {
+ Sint val = Src2.as<ArgSmall>().getSigned();
+ emit_are_both_small(mixed, Src1, ARG2, Src4, ARG4);
a.mov(RET, ARG2);
- a.mov(ARG4, imm(val));
+ mov_imm(ARG5, val);
} else {
- emit_are_both_small(mixed, LHS, ARG2, RHS, ARG3);
+ if (is_increment_zero) {
+ emit_are_both_small(mixed, Src1, ARG2, Src2, ARG3);
+ } else if (always_small(Src1)) {
+ emit_are_both_small(mixed, Src2, ARG3, Src4, ARG4);
+ } else {
+ a.mov(RETd, ARG2.r32());
+ a.and_(RETd, ARG3.r32());
+ a.and_(RETd, ARG4.r32());
+ if (always_one_of<BeamTypeId::Integer, BeamTypeId::AlwaysBoxed>(
+ Src1) &&
+ always_one_of<BeamTypeId::Integer, BeamTypeId::AlwaysBoxed>(
+ Src2) &&
+ always_one_of<BeamTypeId::Integer, BeamTypeId::AlwaysBoxed>(
+ Src4)) {
+ emit_is_not_boxed(mixed, RET);
+ } else {
+ a.and_(RETb, imm(_TAG_IMMED1_MASK));
+ a.cmp(RETb, imm(_TAG_IMMED1_SMALL));
+ a.short_().jne(mixed);
+ }
+ }
a.mov(RET, ARG2);
- a.mov(ARG4, ARG3);
- a.sar(ARG4, imm(_TAG_IMMED1_SIZE));
+ a.mov(ARG5, ARG3);
+ a.sar(ARG5, imm(_TAG_IMMED1_SIZE));
}
a.and_(RET, imm(~_TAG_IMMED1_MASK));
- a.imul(RET, ARG4);
- if (small_result) {
- comment("skipped overflow check because the result is always small");
+ a.imul(RET, ARG5);
+ if (is_product_small) {
+ comment("skipped overflow check because product is always small");
} else {
a.short_().jo(mixed);
}
- a.or_(RET, imm(_TAG_IMMED1_SMALL));
+
+ if (is_increment_zero) {
+ a.or_(RET, imm(_TAG_IMMED1_SMALL));
+ } else {
+ a.add(RET, ARG4);
+ if (is_sum_small) {
+ comment("skipped overflow check because sum is always small");
+ } else {
+ a.short_().jo(mixed);
+ }
+ }
+
a.short_().jmp(next);
/* Call mixed multiplication. */
a.bind(mixed);
{
if (Fail.get() != 0) {
- safe_fragment_call(ga->get_times_guard_shared());
+ if (is_increment_zero) {
+ safe_fragment_call(ga->get_mul_guard_shared());
+ } else {
+ safe_fragment_call(ga->get_mul_add_guard_shared());
+ }
a.je(resolve_beam_label(Fail));
} else {
- safe_fragment_call(ga->get_times_body_shared());
+ if (is_increment_zero) {
+ safe_fragment_call(ga->get_mul_body_shared());
+ } else {
+ safe_fragment_call(ga->get_mul_add_body_shared());
+ }
}
}
diff --git a/erts/emulator/beam/jit/x86/ops.tab b/erts/emulator/beam/jit/x86/ops.tab
index e96590b534..bbc2313118 100644
--- a/erts/emulator/beam/jit/x86/ops.tab
+++ b/erts/emulator/beam/jit/x86/ops.tab
@@ -1229,13 +1229,27 @@ gc_bif2 Fail Live u$bif:erlang:sminus/2 S1 S2 Dst =>
# Arithmetic instructions.
#
+gc_bif2 Fail1 Live1 u$bif:erlang:stimes/2 S1 S2 Dst1 |
+ gc_bif2 Fail2 Live2 u$bif:erlang:splus/2 S3 S4 Dst2 |
+ equal(Dst1, S3) |
+ equal(Dst1, Dst2) |
+ equal(Fail1, Fail2) =>
+ i_mul_add Fail1 S1 S2 S3 S4 Dst1
+
+gc_bif2 Fail1 Live1 u$bif:erlang:stimes/2 S1 S2 Dst1 |
+ gc_bif2 Fail2 Live2 u$bif:erlang:splus/2 S3 S4 Dst2 |
+ equal(Dst1, S4) |
+ equal(Dst1, Dst2) |
+ equal(Fail1, Fail2) =>
+ i_mul_add Fail1 S1 S2 S4 S3 Dst1
+
+gc_bif2 Fail Live u$bif:erlang:stimes/2 S1 S2 Dst =>
+ i_mul_add Fail S1 S2 Dst i Dst
+
gen_plus Fail Live S1 S2 Dst => i_plus S1 S2 Fail Dst
gen_minus Fail Live S1 S2 Dst => i_minus S1 S2 Fail Dst
-gc_bif2 Fail Live u$bif:erlang:stimes/2 S1 S2 Dst =>
- i_times Fail S1 S2 Dst
-
gc_bif2 Fail Live u$bif:erlang:div/2 S1 S2 Dst =>
i_m_div Fail S1 S2 Dst
@@ -1304,7 +1318,7 @@ i_minus s s j d
i_unary_minus s j d
-i_times j s s d
+i_mul_add j s s s s d
i_m_div j s s d
diff --git a/erts/emulator/test/big_SUITE.erl b/erts/emulator/test/big_SUITE.erl
index 2839a97061..635abc8800 100644
--- a/erts/emulator/test/big_SUITE.erl
+++ b/erts/emulator/test/big_SUITE.erl
@@ -177,6 +177,7 @@ eval({op,_,Op,A0}, LFH) ->
eval({op,_,Op,A0,B0}, LFH) ->
[A,B] = eval_list([A0,B0], LFH),
Res = eval_op(Op, A, B),
+ ok = eval_op_guard(Op, A, B, Res),
erlang:garbage_collect(),
Res;
eval({integer,_,I}, _) ->
@@ -207,6 +208,18 @@ eval_op('bxor', A, B) -> A bxor B;
eval_op('bsl', A, B) -> A bsl B;
eval_op('bsr', A, B) -> A bsr B.
+eval_op_guard('-', A, B, Res) when Res =:= A - B -> ok;
+eval_op_guard('+', A, B, Res) when Res =:= A + B -> ok;
+eval_op_guard('*', A, B, Res) when Res =:= A * B -> ok;
+eval_op_guard('div', A, B, Res) when Res =:= A div B -> ok;
+eval_op_guard('rem', A, B, Res) when Res =:= A rem B -> ok;
+eval_op_guard('band', A, B, Res) when Res =:= A band B -> ok;
+eval_op_guard('bor', A, B, Res) when Res =:= A bor B -> ok;
+eval_op_guard('bxor', A, B, Res) when Res =:= A bxor B -> ok;
+eval_op_guard('bsl', A, B, Res) when Res =:= A bsl B -> ok;
+eval_op_guard('bsr', A, B, Res) when Res =:= A bsr B -> ok;
+eval_op_guard(Op, A, B, Res) -> {error,{Op,A,B,Res}}.
+
test_squaring(I) ->
%% Multiplying an integer by itself is specially optimized, so we
%% should take special care to test squaring. The optimization
@@ -520,12 +533,13 @@ properties(_Config) ->
_ = [begin
A = id(rand_int()),
B = id(rand_int()),
- io:format("~.36#\n~.36#\n", [A,B]),
- test_properties(A, B)
+ C = id(rand_int()),
+ io:format("~.36#\n~.36#\n~.36#\n", [A,B,C]),
+ test_properties(A, B, C)
end || _ <- lists:seq(1, 1000)],
ok.
-test_properties(A, B) ->
+test_properties(A, B, C) ->
SquaredA = id(A * A),
SquaredB = id(B * B),
@@ -543,6 +557,11 @@ test_properties(A, B) ->
A = id(Sum - B),
B = id(Sum - A),
0 = Sum - A - B,
+ C = id(A + B + C) - Sum,
+
+ PS = id(A * B + C),
+ PS = P + C,
+ ok = test_mul_add_guard(A, B, C, PS),
NegA = id(-A),
A = -NegA,
@@ -563,6 +582,7 @@ test_properties(A, B) ->
ok.
+test_mul_add_guard(A, B, C, Res) when Res =:= A * B + C -> ok.
rand_int() ->
Sz = max(floor(rand:normal() * 512 + 256), 7),
diff --git a/erts/emulator/test/small_SUITE.erl b/erts/emulator/test/small_SUITE.erl
index 2ab944d85e..c8a1b2fbf2 100644
--- a/erts/emulator/test/small_SUITE.erl
+++ b/erts/emulator/test/small_SUITE.erl
@@ -23,10 +23,12 @@
-export([all/0, suite/0, groups/0]).
-export([edge_cases/1,
- addition/1, subtraction/1, negation/1, multiplication/1, division/1,
+ addition/1, subtraction/1, negation/1,
+ multiplication/1, mul_add/1, division/1,
test_bitwise/1, test_bsl/1,
element/1,
range_optimization/1]).
+-export([mul_add/0]).
-include_lib("common_test/include/ct.hrl").
@@ -40,7 +42,7 @@ all() ->
groups() ->
[{p, [parallel],
[edge_cases,
- addition, subtraction, negation, multiplication, division,
+ addition, subtraction, negation, multiplication, mul_add, division,
test_bitwise, test_bsl,
element,
range_optimization]}].
@@ -420,7 +422,9 @@ mul_gen_pairs() ->
_ <- lists:seq(1, 75)],
%% Generate pairs of numbers whose product is small.
- Pairs1 = [{N, MaxSmall div N} || N <- [1,2,3,5,17,63,64,1111,22222]] ++ Pairs0,
+ SmallPairs = [{N, MaxSmall div N} ||
+ N <- [1,2,3,4,5,8,16,17,32,63,64,1111,22222]],
+ Pairs1 = [{N,M-1} || {N,M} <- SmallPairs] ++ SmallPairs ++ Pairs0,
%% Add prime factors of 2^59 - 1 (MAX_SMALL for 64-bit architecture
%% at the time of writing).
@@ -460,7 +464,11 @@ gen_mul_function({Name,{A,B}}) ->
Res = Y * X;
'@Name@'(X, fixed, number) when -_@APlusOne@ < X, X < _@APlusOne@ ->
X * _@B@;
+ '@Name@'(X, fixed, any) ->
+ X * _@B@;
'@Name@'(fixed, Y, number) when -_@BPlusOne@ < Y, Y < _@BPlusOne@ ->
+ _@A@ * Y;
+ '@Name@'(fixed, Y, any) ->
_@A@ * Y. ").
test_multiplication([{Name,{A,B}}|T], Mod) ->
@@ -474,7 +482,9 @@ test_multiplication([{Name,{A,B}}|T], Mod) ->
Res0 = F(-A, -B, false),
Res0 = F(A, B, number),
Res0 = F(fixed, B, number),
+ Res0 = F(fixed, B, any),
Res0 = F(A, fixed, number),
+ Res0 = F(A, fixed, any),
Res0 = F(-A, -B, number),
Res1 = -(A * B),
@@ -483,7 +493,9 @@ test_multiplication([{Name,{A,B}}|T], Mod) ->
Res1 = F(-A, B, number),
Res1 = F(A, -B, number),
Res1 = F(-A, fixed, number),
- Res1 = F(fixed, -B, number)
+ Res1 = F(-A, fixed, any),
+ Res1 = F(fixed, -B, number),
+ Res1 = F(fixed, -B, any)
catch
C:R:Stk ->
io:format("~p failed. numbers: ~p ~p\n", [Name,A,B]),
@@ -494,6 +506,212 @@ test_multiplication([{Name,{A,B}}|T], Mod) ->
test_multiplication([], _) ->
ok.
+mul_add() ->
+ [{timetrap, {minutes, 5}}].
+mul_add(_Config) ->
+ _ = rand:uniform(), %Seed generator
+ io:format("Seed: ~p", [rand:export_seed()]),
+ Mod = list_to_atom(lists:concat([?MODULE,"_",?FUNCTION_NAME])),
+ Triples = mul_add_triples(),
+ Fs0 = gen_func_names(Triples, 0),
+ Fs = [gen_mul_add_function(F) || F <- Fs0],
+ Tree = ?Q(["-module('@Mod@').",
+ "-compile([export_all,nowarn_export_all]).",
+ "id(I) -> I."]) ++ Fs,
+ %% merl:print(Tree),
+ {ok,_Bin} = merl:compile_and_load(Tree, []),
+ test_mul_add(Fs0, Mod),
+ unload(Mod),
+
+ test_mul_add_float(),
+ test_mul_add_exceptions(),
+
+ ok.
+
+mul_add_triples() ->
+ {_, MaxSmall} = determine_small_limits(0),
+ SqrtMaxSmall = floor(math:sqrt(MaxSmall)),
+
+ Numbers0 = [1,2,3,4,5,8,9,
+ (MaxSmall div 2) band -2,
+ MaxSmall band -2,
+ MaxSmall * 2],
+ Numbers = [rand:uniform(SqrtMaxSmall) || _ <- lists:seq(1, 5)] ++ Numbers0,
+
+ %% Generate pairs of numbers whose product is small.
+ SmallPairs = [{MaxSmall div M,M} || M <- Numbers],
+ Pairs = [{N+M,M} || {N,M} <- SmallPairs] ++ SmallPairs,
+
+ Triples0 = [{A,B,rand:uniform(MaxSmall)} || {A,B} <- Pairs],
+ Triples1a = [{A,B,abs(MaxSmall - A * B)} || {A,B} <- Pairs],
+ Triples1 = [{A,B,C+Offset} ||
+ {A,B,C} <- Triples1a,
+ Offset <- [-2,-1,0,1,2],
+ C + Offset >= 0],
+ Triples2 = [{A,B,MaxSmall+1} || {A,B} <- Pairs],
+ [{3,4,5},
+ {MaxSmall div 2,2,42}, %Result is not small.
+ {MaxSmall,MaxSmall,MaxSmall}|Triples0 ++ Triples1 ++ Triples2].
+
+gen_mul_add_function({Name,{A,B,C}}) ->
+ APlusOne = A + 1,
+ BPlusOne = B + 1,
+ CPlusOne = C + 1,
+ ?Q("'@Name@'(int_vvv_plus_z, X, Y, Z)
+ when is_integer(X), is_integer(Y), is_integer(Z),
+ -_@APlusOne@ < X, X < _@APlusOne@,
+ -_@BPlusOne@ < Y, Y < _@BPlusOne@,
+ -_@CPlusOne@ < Z, Z < _@CPlusOne@ ->
+ Res = id(X * Y + Z),
+ Res = id(Y * X + Z),
+ Res = id(Z + X * Y),
+ Res = id(Z + Y * X),
+ Res;
+ '@Name@'(int_vvv_minus_z, X, Y, Z)
+ when is_integer(X), is_integer(Y), is_integer(Z),
+ -_@APlusOne@ < X, X < _@APlusOne@,
+ -_@BPlusOne@ < Y, Y < _@BPlusOne@,
+ -_@CPlusOne@ < Z, Z < _@CPlusOne@ ->
+ Res = id(X * Y - Z),
+ Res = id(Y * X - Z),
+ Res;
+ '@Name@'(pos_int_vvv_plus_z, X, Y, Z)
+ when is_integer(X), is_integer(Y), is_integer(Z),
+ 0 =< X, X < _@APlusOne@,
+ 0 =< Y, Y < _@BPlusOne@,
+ 0 =< Z, Z < _@CPlusOne@ ->
+ Res = id(X * Y + Z),
+ Res = id(Y * X + Z),
+ Res = id(Z + X * Y),
+ Res = id(Z + Y * X),
+ Res;
+ '@Name@'(neg_int_vvv_plus_z, X, Y, Z)
+ when is_integer(X), is_integer(Y), is_integer(Z),
+ -_@APlusOne@ < X, X < 0,
+ -_@BPlusOne@ < Y, Y < 0,
+ -_@CPlusOne@ < Z, Z < 0 ->
+ Res = id(X * Y + Z),
+ Res = id(Y * X + Z),
+ Res = id(Z + X * Y),
+ Res = id(Z + Y * X),
+ Res;
+ '@Name@'(any_vvv_plus_z, X, Y, Z) ->
+ Res = id(X * Y + Z),
+ Res = id(Y * X + Z),
+ Res = id(Z + X * Y),
+ Res = id(Z + Y * X),
+ Res = '@Name@'(int_vvv_plus_z, id(X), id(Y), id(Z)),
+ Res;
+ '@Name@'(any_vvv_minus_z, X, Y, Z) ->
+ Res = id(X * Y - Z),
+ Res = id(Y * X - Z),
+ Res = '@Name@'(int_vvv_minus_z, id(X), id(Y), id(Z)),
+ Res;
+ '@Name@'(any_vvi_plus_z, X, Y, _Z) ->
+ Z = _@C@,
+ Res = id(X * Y + Z),
+ Res = id(Y * X + Z),
+ Res = id(Z + X * Y),
+ Res = id(Z + Y * X),
+ Res = '@Name@'(any_vvv_plus_z, X, Y, id(Z)),
+ Res = '@Name@'(any_vvv_minus_z, X, Y, id(-Z)),
+ Res;
+ '@Name@'(any_vvi_minus_z, X, Y, _Z) ->
+ Z = _@C@,
+ Res = id(X * Y - Z),
+ Res = id(Y * X - Z),
+ Res = id(-Z + X * Y),
+ Res = id(-Z + Y * X),
+ Res = '@Name@'(any_vvv_plus_z, X, Y, id(-Z)),
+ Res = '@Name@'(any_vvv_minus_z, X, Y, id(Z)),
+ Res;
+ '@Name@'(any_vii_plus_z, X, fixed, fixed) ->
+ Y = _@B@,
+ Z = _@C@,
+ Res = id(X * Y + Z),
+ Res = id(Y * X + Z),
+ Res = id(Z + X * Y),
+ Res = id(Z + Y * X),
+ Res = '@Name@'(any_vvi_plus_z, X, id(Y), fixed),
+ Res = '@Name@'(any_vvv_minus_z, X, id(Y), id(-Z)),
+ Res;
+ '@Name@'(any_vii_minus_z, X, fixed, fixed) ->
+ Y = _@B@,
+ Z = _@C@,
+ Res = id(X * Y - Z),
+ Res = id(Y * X - Z),
+ Res = id(-Z + X * Y),
+ Res = id(-Z + Y * X),
+ Res = '@Name@'(any_vvi_minus_z, X, id(Y), fixed),
+ Res = '@Name@'(any_vvv_plus_z, X, Y, id(-Z)),
+ Res;
+ '@Name@'({guard_plus_z,Res}, X, Y, Z) when X * Y + Z =:= Res ->
+ ok;
+ '@Name@'({guard_minus_z,Res}, X, Y, Z) when X * Y - Z =:= Res ->
+ ok. ").
+
+test_mul_add([{Name,{A,B,C}}|T], Mod) ->
+ F = fun Mod:Name/4,
+ try
+ Res0 = A * B + C,
+ Res0 = F(any_vii_plus_z, A, fixed, fixed),
+ Res0 = F(pos_int_vvv_plus_z, A, B, C),
+ ok = F({guard_plus_z,Res0}, A, B, C),
+ ok = F({guard_plus_z,Res0}, -A, -B, C),
+
+ Res1 = A * B - C,
+ Res1 = F(any_vii_minus_z, A, fixed, fixed),
+ Res1 = if
+ A > 0, B > 0, C > 0 ->
+ F(neg_int_vvv_plus_z, -A, -B, -C);
+ true ->
+ Res1
+ end,
+ ok = F({guard_minus_z,Res1}, A, B, C),
+ ok = F({guard_minus_z,Res1}, -A, -B, C),
+
+ Res2 = -A * B + C,
+ Res2 = A * -B + C,
+ Res2 = F(any_vii_plus_z, -A, fixed, fixed),
+ ok = F({guard_plus_z,Res2}, -A, B, C),
+
+ Res3 = -A * B - C,
+ Res3 = A * -B - C,
+ Res3 = F(any_vii_minus_z, -A, fixed, fixed),
+ ok = F({guard_minus_z,Res3}, -A, B, C)
+ catch
+ Class:R:Stk ->
+ io:format("~p failed. numbers: ~p ~p ~p\n", [Name,A,B,C]),
+ erlang:raise(Class, R, Stk)
+ end,
+ test_mul_add(T, Mod);
+test_mul_add([], _) ->
+ ok.
+
+test_mul_add_float() ->
+ Res = madd(id(2.0), id(3.0), id(7.0)),
+ Res = madd(id(2.0), id(3.0), id(7)),
+ ok = madd(id(2.0), id(3.0), id(7), id(Res)).
+
+test_mul_add_exceptions() ->
+ error = madd(id(a), id(2), id(3), id(whatever)),
+ error = madd(id(7), id(b), id(3), id(whatever)),
+ error = madd(id(7), id(15), id(c), id(whatever)),
+
+ {'EXIT',{badarith,[{erlang,'*',[a,2],_}|_]}} = catch madd(id(a), id(2), id(0)),
+ {'EXIT',{badarith,[{erlang,'*',[a,2],_}|_]}} = catch madd(id(a), id(2), id(42)),
+ {'EXIT',{badarith,[{erlang,'*',[a,2],_}|_]}} = catch madd(id(a), id(2), id(c)),
+ {'EXIT',{badarith,[{erlang,'*',[3,b],_}|_]}} = catch madd(id(3), id(b), id(c)),
+ {'EXIT',{badarith,[{erlang,'+',[6,c],_}|_]}} = catch madd(id(2), id(3), id(c)),
+
+ ok.
+
+madd(A, B, C) -> A * B + C.
+
+madd(A, B, C, Res) when Res =:= A * B + C -> ok;
+madd(_, _, _, _) -> error.
+
+
%% Test that the JIT only omits the overflow check when it's safe.
division(_Config) ->
_ = rand:uniform(), %Seed generator
--
2.35.3