From c7d640b17e05999b64745255c6cf3fe4402d56b1 Mon Sep 17 00:00:00 2001 From: Rerumu Date: Tue, 17 May 2022 21:06:50 -0400 Subject: [PATCH] Make Luau integers use free functions --- codegen-luau/runtime/numeric.lua | 407 +++++++++++++++---------------- 1 file changed, 190 insertions(+), 217 deletions(-) diff --git a/codegen-luau/runtime/numeric.lua b/codegen-luau/runtime/numeric.lua index f192d74..429ba17 100644 --- a/codegen-luau/runtime/numeric.lua +++ b/codegen-luau/runtime/numeric.lua @@ -1,114 +1,109 @@ local Numeric = {} -Numeric.__index = Numeric +local BIT_SET_31 = 0x80000000 +local BIT_SET_32 = 0x100000000 -local bit_band = bit32.band -local bit_bnot = bit32.bnot -local bit_bor = bit32.bor -local bit_xor = bit32.bxor +local K_ZERO, K_ONE, K_BIT_SET_24 local bit_lshift = bit32.lshift local bit_rshift = bit32.rshift local bit_arshift = bit32.arshift +local bit_and = bit32.band +local bit_or = bit32.bor +local bit_xor = bit32.bxor +local bit_not = bit32.bnot + +local math_ceil = math.ceil local math_floor = math.floor +local math_log = math.log +local math_max = math.max +local math_pow = math.pow -local N_2_TO_31 = 0x80000000 -local N_2_TO_32 = 0x100000000 - -local VAL_ZERO -local VAL_ONE -local VAL_2_TO_24 - -local op_is_equal -local op_is_greater_unsigned -local op_is_less_unsigned -local op_is_negative -local op_is_zero - -local op_bnot -local op_negate +local from_u32, into_u32, from_u64, into_u64 +local num_add, num_subtract, num_multiply, num_divide_unsigned, num_negate +local num_is_negative, num_is_zero, num_is_equal, num_is_less_unsigned, num_is_greater_unsigned -- TODO: Eventually support Vector3 -local function from_u32(low, high) - return setmetatable({ low, high }, Numeric) +function Numeric.from_u32(data_1, data_2) + return { data_1, data_2 } end -local function to_u32(value) - return value[1], value[2] +function Numeric.into_u32(data) + return data[1], data[2] end -local function from_f64(value) +function Numeric.from_u64(value) if value < 0 then - return op_negate(from_f64(-value)) + return num_negate(from_u64(-value)) else - return from_u32(value % N_2_TO_32, math_floor(value / N_2_TO_32)) + return from_u32(value % BIT_SET_32, math_floor(value / BIT_SET_32)) end end -local function to_f64(value) - local low, high = to_u32(value) +function Numeric.into_u64(value) + local data_1, data_2 = into_u32(value) - return low + high * N_2_TO_32 + return data_1 + data_2 * BIT_SET_32 end -local function op_add(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.add(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - local low = low_a + low_b - local high = high_a + high_b + local data_1 = data_l_1 + data_r_1 + local data_2 = data_l_2 + data_r_2 - if low >= N_2_TO_32 then - low = low - N_2_TO_32 - high = high + 1 + if data_1 >= BIT_SET_32 then + data_1 = data_1 - BIT_SET_32 + data_2 = data_2 + 1 end - if high >= N_2_TO_32 then - high = high - N_2_TO_32 + if data_2 >= BIT_SET_32 then + data_2 = data_2 - BIT_SET_32 end - return from_u32(low, high) + return from_u32(data_1, data_2) end -local function op_subtract(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.subtract(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - local low = low_a - low_b - local high = high_a - high_b + local data_1 = data_l_1 - data_r_1 + local data_2 = data_l_2 - data_r_2 - if low < 0 then - low = low + N_2_TO_32 - high = high - 1 + if data_1 < 0 then + data_1 = data_1 + BIT_SET_32 + data_2 = data_2 - 1 end - if high < 0 then - high = high + N_2_TO_32 + if data_2 < 0 then + data_2 = data_2 + BIT_SET_32 end - return from_u32(low, high) + return from_u32(data_1, data_2) end local function set_absolute(lhs, rhs) local has_negative = false - if op_is_negative(lhs) then - lhs = op_negate(lhs) + if num_is_negative(lhs) then + lhs = num_negate(lhs) has_negative = not has_negative end - if op_is_negative(rhs) then - rhs = op_negate(rhs) + if num_is_negative(rhs) then + rhs = num_negate(rhs) has_negative = not has_negative end return has_negative, lhs, rhs end -local function op_multiply(lhs, rhs) - if op_is_zero(lhs) or op_is_zero(rhs) then - return VAL_ZERO +function Numeric.multiply(lhs, rhs) + if num_is_zero(lhs) or num_is_zero(rhs) then + return K_ZERO end local has_negative @@ -116,13 +111,13 @@ local function op_multiply(lhs, rhs) has_negative, lhs, rhs = set_absolute(lhs, rhs) -- If both longs are small, use float multiplication - if op_is_less_unsigned(lhs, VAL_2_TO_24) and op_is_less_unsigned(rhs, VAL_2_TO_24) then - local low_a = to_u32(lhs) - local low_b = to_u32(rhs) - local result = from_f64(low_a * low_b) + if num_is_less_unsigned(lhs, K_BIT_SET_24) and num_is_less_unsigned(rhs, K_BIT_SET_24) then + local data_l_1 = into_u32(lhs) + local data_r_1 = into_u32(rhs) + local result = from_u64(data_l_1 * data_r_1) if has_negative then - result = op_negate(result) + result = num_negate(result) end return result @@ -130,58 +125,53 @@ local function op_multiply(lhs, rhs) -- Divide each long into 4 chunks of 16 bits, and then add up 4x4 products. -- We can skip products that would overflow. - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - local a48 = bit_rshift(high_a, 16) - local a32 = bit_band(high_a, 0xFFFF) - local a16 = bit_rshift(low_a, 16) - local a00 = bit_band(low_a, 0xFFFF) + local a48 = bit_rshift(data_l_2, 16) + local a32 = bit_and(data_l_2, 0xFFFF) + local a16 = bit_rshift(data_l_1, 16) + local a00 = bit_and(data_l_1, 0xFFFF) - local b48 = bit_rshift(high_b, 16) - local b32 = bit_band(high_b, 0xFFFF) - local b16 = bit_rshift(low_b, 16) - local b00 = bit_band(low_b, 0xFFFF) + local b48 = bit_rshift(data_r_2, 16) + local b32 = bit_and(data_r_2, 0xFFFF) + local b16 = bit_rshift(data_r_1, 16) + local b00 = bit_and(data_r_1, 0xFFFF) local c48, c32, c16, c00 = 0, 0, 0, 0 c00 = c00 + a00 * b00 c16 = c16 + bit_rshift(c00, 16) - c00 = bit_band(c00, 0xFFFF) + c00 = bit_and(c00, 0xFFFF) c16 = c16 + a16 * b00 c32 = c32 + bit_rshift(c16, 16) - c16 = bit_band(c16, 0xFFFF) + c16 = bit_and(c16, 0xFFFF) c16 = c16 + a00 * b16 c32 = c32 + bit_rshift(c16, 16) - c16 = bit_band(c16, 0xFFFF) + c16 = bit_and(c16, 0xFFFF) c32 = c32 + a32 * b00 c48 = c48 + bit_rshift(c32, 16) - c32 = bit_band(c32, 0xFFFF) + c32 = bit_and(c32, 0xFFFF) c32 = c32 + a16 * b16 c48 = c48 + bit_rshift(c32, 16) - c32 = bit_band(c32, 0xFFFF) + c32 = bit_and(c32, 0xFFFF) c32 = c32 + a00 * b32 c48 = c48 + bit_rshift(c32, 16) - c32 = bit_band(c32, 0xFFFF) + c32 = bit_and(c32, 0xFFFF) c48 = c48 + a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48 - c48 = bit_band(c48, 0xFFFF) + c48 = bit_and(c48, 0xFFFF) - local low_v = bit_bor(bit_lshift(c16, 16), c00) - local high_v = bit_bor(bit_lshift(c48, 16), c32) - local result = from_u32(low_v, high_v) + local data_1 = bit_or(bit_lshift(c16, 16), c00) + local data_2 = bit_or(bit_lshift(c48, 16), c32) + local result = from_u32(data_1, data_2) if has_negative then - result = op_negate(result) + result = num_negate(result) end return result end -local math_ceil = math.ceil -local math_log = math.log -local math_max = math.max -local math_pow = math.pow - local function get_approx_delta(rem, rhs) local approx = math_max(1, math_floor(rem / rhs)) local log = math_ceil(math_log(approx, 2)) @@ -190,242 +180,225 @@ local function get_approx_delta(rem, rhs) return approx, delta end -local function op_divide_unsigned(lhs, rhs) - if op_is_zero(rhs) then +function Numeric.divide_unsigned(lhs, rhs) + if num_is_zero(rhs) then error("division by zero") - elseif op_is_zero(lhs) then + elseif num_is_zero(lhs) then return 0 end - local rhs_number = to_f64(rhs) + local rhs_number = into_u64(rhs) local rem = lhs - local res = VAL_ZERO + local res = K_ZERO - while op_is_greater_unsigned(rem, rhs) or op_is_equal(rem, rhs) do - local res_approx, delta = get_approx_delta(to_f64(rem), rhs_number) - local res_temp = from_f64(res_approx) - local rem_temp = op_multiply(res_temp, rhs) + while num_is_greater_unsigned(rem, rhs) or num_is_equal(rem, rhs) do + local res_approx, delta = get_approx_delta(into_u64(rem), rhs_number) + local res_temp = from_u64(res_approx) + local rem_temp = num_multiply(res_temp, rhs) - while op_is_negative(rem_temp) or op_is_greater_unsigned(rem_temp, rem) do + while num_is_negative(rem_temp) or num_is_greater_unsigned(rem_temp, rem) do res_approx = res_approx - delta - res_temp = from_f64(res_approx) - rem_temp = op_multiply(res_temp, rhs) + res_temp = from_u64(res_approx) + rem_temp = num_multiply(res_temp, rhs) end - if op_is_zero(res_temp) then - res_temp = VAL_ONE + if num_is_zero(res_temp) then + res_temp = K_ONE end - res = op_add(res, res_temp) - rem = op_subtract(rem, rem_temp) + res = num_add(res, res_temp) + rem = num_subtract(rem, rem_temp) end return res end -local function op_divide_signed(lhs, rhs) +function Numeric.divide_signed(lhs, rhs) local has_negative has_negative, lhs, rhs = set_absolute(lhs, rhs) - local result = op_divide_unsigned(lhs, rhs) + local result = num_divide_unsigned(lhs, rhs) if has_negative then - result = op_negate(result) + result = num_negate(result) end return result end -function op_negate(value) - return op_add(op_bnot(value), VAL_ONE) +function Numeric.negate(value) + return num_add(bit_not(value), K_ONE) end -local function op_band(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.bit_and(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - return from_u32(bit_band(low_a, low_b), bit_band(high_a, high_b)) + return from_u32(bit_and(data_l_1, data_r_1), bit_and(data_l_2, data_r_2)) end -function op_bnot(value) - local low, high = to_u32(value) +function Numeric.bit_not(value) + local data_1, data_2 = into_u32(value) - return from_u32(bit_bnot(low), bit_bnot(high)) + return from_u32(bit_not(data_1), bit_not(data_2)) end -local function op_bor(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.bit_or(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - return from_u32(bit_bor(low_a, low_b), bit_bor(high_a, high_b)) + return from_u32(bit_or(data_l_1, data_r_1), bit_or(data_l_2, data_r_2)) end -local function op_bxor(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.bit_xor(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - return from_u32(bit_xor(low_a, low_b), bit_xor(high_a, high_b)) + return from_u32(bit_xor(data_l_1, data_r_1), bit_xor(data_l_2, data_r_2)) end -local function op_shift_left(lhs, rhs) - local count = to_f64(rhs) +function Numeric.shift_left(lhs, rhs) + local count = into_u64(rhs) if count < 32 then - local low_a, high_a = to_u32(lhs) + local data_l_1, data_l_2 = into_u32(lhs) - local low_v = bit_lshift(low_a, count) - local high_v = bit_bor(bit_lshift(high_a, count), bit_rshift(low_a, 32 - count)) + local data_1 = bit_lshift(data_l_1, count) + local data_2 = bit_or(bit_lshift(data_l_2, count), bit_rshift(data_l_1, 32 - count)) - return from_u32(low_v, high_v) + return from_u32(data_1, data_2) else - local _, high_a = to_u32(lhs) + local _, data_l_2 = into_u32(lhs) - local high_v = bit_lshift(high_a, count - 32) + local data_2 = bit_lshift(data_l_2, count - 32) - return from_u32(0, high_v) + return from_u32(0, data_2) end end -local function op_shift_right_unsigned(lhs, rhs) - local count = to_f64(rhs) +function Numeric.shift_right_unsigned(lhs, rhs) + local count = into_u64(rhs) if count < 32 then - local low_a, high_a = to_u32(lhs) + local data_l_1, data_l_2 = into_u32(lhs) - local low_v = bit_bor(bit_rshift(low_a, count), bit_lshift(high_a, 32 - count)) - local high_v = bit_rshift(high_a, count) + local data_1 = bit_or(bit_rshift(data_l_1, count), bit_lshift(data_l_2, 32 - count)) + local data_2 = bit_rshift(data_l_2, count) - return from_u32(low_v, high_v) + return from_u32(data_1, data_2) elseif count == 32 then - local _, high_a = to_u32(lhs) + local _, data_l_2 = into_u32(lhs) - return from_u32(high_a, 0) + return from_u32(data_l_2, 0) else - local _, high_a = to_u32(lhs) + local _, data_l_2 = into_u32(lhs) - return from_u32(bit_rshift(high_a, count - 32), 0) + return from_u32(bit_rshift(data_l_2, count - 32), 0) end end -local function op_shift_right_signed(lhs, rhs) - local count = to_f64(rhs) +function Numeric.shift_right_signed(lhs, rhs) + local count = into_u64(rhs) if count < 32 then - local low_a, high_a = to_u32(lhs) + local data_l_1, data_l_2 = into_u32(lhs) - local low_v = bit_bor(bit_rshift(low_a, count), bit_lshift(high_a, 32 - count)) - local high_v = bit_arshift(high_a, count) + local data_1 = bit_or(bit_rshift(data_l_1, count), bit_lshift(data_l_2, 32 - count)) + local data_2 = bit_arshift(data_l_2, count) - return from_u32(low_v, high_v) + return from_u32(data_1, data_2) else - local _, high_a = to_u32(lhs) + local _, data_l_2 = into_u32(lhs) - local low_v = bit_arshift(high_a, count - 32) - local high_v = high_a > N_2_TO_31 and N_2_TO_32 - 1 or 0 + local data_1 = bit_arshift(data_l_2, count - 32) + local data_2 = data_l_2 > BIT_SET_31 and BIT_SET_32 - 1 or 0 - return from_u32(low_v, high_v) + return from_u32(data_1, data_2) end end -function op_is_negative(value) - local _, high = to_u32(value) +function Numeric.is_negative(value) + local _, data_2 = into_u32(value) - return high > N_2_TO_31 + return data_2 > BIT_SET_31 end -function op_is_zero(value) - local low, high = to_u32(value) +function Numeric.is_zero(value) + local data_1, data_2 = into_u32(value) - return low == 0 and high == 0 + return data_1 == 0 and data_2 == 0 end -function op_is_equal(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.is_equal(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - return low_a == low_b and high_a == high_b + return data_l_1 == data_r_1 and data_l_2 == data_r_2 end -function op_is_less_unsigned(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.is_less_unsigned(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - return high_a < high_b or (high_a == high_b and low_a < low_b) + return data_l_2 < data_r_2 or (data_l_2 == data_r_2 and data_l_1 < data_r_1) end -function op_is_greater_unsigned(lhs, rhs) - local low_a, high_a = to_u32(lhs) - local low_b, high_b = to_u32(rhs) +function Numeric.is_greater_unsigned(lhs, rhs) + local data_l_1, data_l_2 = into_u32(lhs) + local data_r_1, data_r_2 = into_u32(rhs) - return high_a > high_b or (high_a == high_b and low_a > low_b) + return data_l_2 > data_r_2 or (data_l_2 == data_r_2 and data_l_1 > data_r_1) end -local function op_is_less_signed(lhs, rhs) - local neg_a = op_is_negative(lhs) - local neg_b = op_is_negative(rhs) +function Numeric.is_less_signed(lhs, rhs) + local neg_a = num_is_negative(lhs) + local neg_b = num_is_negative(rhs) if neg_a and not neg_b then return true elseif not neg_a and neg_b then return false else - return op_is_negative(op_subtract(lhs, rhs)) + return num_is_negative(num_subtract(lhs, rhs)) end end -local function op_is_greater_signed(lhs, rhs) - local neg_a = op_is_negative(lhs) - local neg_b = op_is_negative(rhs) +function Numeric.is_greater_signed(lhs, rhs) + local neg_a = num_is_negative(lhs) + local neg_b = num_is_negative(rhs) if neg_a and not neg_b then return false elseif not neg_a and neg_b then return true else - return op_is_negative(op_subtract(rhs, lhs)) + return num_is_negative(num_subtract(rhs, lhs)) end end -VAL_ZERO = from_f64(0) -VAL_ONE = from_f64(1) -VAL_2_TO_24 = from_f64(0x1000000) +from_u32 = Numeric.from_u32 +into_u32 = Numeric.into_u32 +from_u64 = Numeric.from_u64 +into_u64 = Numeric.into_u64 -Numeric.from_f64 = from_f64 -Numeric.from_u32 = from_u32 -Numeric.to_f64 = to_f64 +num_add = Numeric.add +num_subtract = Numeric.subtract +num_multiply = Numeric.multiply +num_divide_unsigned = Numeric.divide_unsigned +num_negate = Numeric.negate -Numeric.divide_unsigned = op_divide_unsigned +num_is_negative = Numeric.is_negative +num_is_zero = Numeric.is_zero +num_is_equal = Numeric.is_equal +num_is_less_unsigned = Numeric.is_less_unsigned +num_is_greater_unsigned = Numeric.is_greater_unsigned -Numeric.bit_and = op_band -Numeric.bit_not = op_bnot -Numeric.bit_or = op_bor -Numeric.bit_xor = op_bxor +K_ZERO = from_u64(0) +K_ONE = from_u64(1) +K_BIT_SET_24 = from_u64(0x1000000) -Numeric.shift_left = op_shift_left -Numeric.shift_right_signed = op_shift_right_signed -Numeric.shift_right_unsigned = op_shift_right_unsigned - -Numeric.is_greater_signed = op_is_greater_signed -Numeric.is_less_unsigned = op_is_less_unsigned -Numeric.is_greater_unsigned = op_is_greater_unsigned - -Numeric.__add = op_add -Numeric.__sub = op_subtract -Numeric.__mul = op_multiply -Numeric.__div = op_divide_signed - -Numeric.__unm = op_negate - -Numeric.__eq = op_is_equal -Numeric.__lt = op_is_less_signed - -function Numeric.__le(lhs, rhs) - return op_is_less_signed(lhs, rhs) or op_is_equal(lhs, rhs) -end - -function Numeric.__tostring(value) - return tostring(to_f64(value)) -end +Numeric.K_ZERO = K_ZERO +Numeric.K_ONE = K_ONE return Numeric