Wasynth/codegen/luau/runtime/numeric_v3.lua
2024-03-04 16:01:03 +00:00

422 lines
9.5 KiB
Lua

local Numeric = {}
local NUM_ZERO = Vector3.zero
local NUM_ONE, NUM_SIX_FOUR
local NUM_BIT_26, NUM_BIT_52
local bit_lshift = bit32.lshift
local bit_rshift = bit32.rshift
local bit_arshift = bit32.arshift
local bit_and = bit32.band
local bit_or = bit32.bor
local bit_xor = bit32.bxor
local bit_not = bit32.bnot
local bit_extract = bit32.extract
local bit_replace = bit32.replace
local from_u32, from_u64, into_u64
local num_subtract, num_divide_unsigned, num_negate
local num_or, num_shift_left, num_shift_right_unsigned
local num_is_less_unsigned
-- X: a[0 ..21]
-- Y: a[22..31]
-- | b[0 ..11]
-- Z: b[12..31]
local constructor = Vector3.new
function Numeric.from_u32(data_1, data_2)
local x = bit_and(data_1, 0x3FFFFF)
local y = bit_and(data_2, 0x3FFFFF)
local z = bit_replace(bit_rshift(data_1, 22), bit_rshift(data_2, 22), 10, 10)
return constructor(x, y, z)
end
local function num_is_zero(value)
return value == NUM_ZERO
end
local function load_d1(value)
return bit_replace(bit_and(value.X, 0x3FFFFF), value.Z, 22, 10)
end
local function load_d2(value)
return bit_replace(bit_and(value.Y, 0x3FFFFF), bit_rshift(value.Z, 10), 22, 10)
end
local function into_u32(value)
return load_d1(value), load_d2(value)
end
Numeric.into_u32 = into_u32
function Numeric.from_u64(value)
return from_u32(bit_and(value % 0x100000000), bit_and(value / 0x100000000))
end
function Numeric.into_u64(value)
local value_1, value_2 = into_u32(value)
return value_1 + value_2 * 0x100000000
end
function Numeric.add(lhs, rhs)
local lhs_1, lhs_2 = into_u32(lhs)
local rhs_1, rhs_2 = into_u32(rhs)
local data_1 = lhs_1 + rhs_1
local data_2 = lhs_2 + rhs_2
if data_1 >= 0x100000000 then
data_1 = data_1 - 0x100000000
data_2 = data_2 + 1
end
if data_2 >= 0x100000000 then
data_2 = data_2 - 0x100000000
end
return from_u32(data_1, data_2)
end
function Numeric.subtract(lhs, rhs)
local lhs_1, lhs_2 = into_u32(lhs)
local rhs_1, rhs_2 = into_u32(rhs)
local data_1 = lhs_1 - rhs_1
local data_2 = lhs_2 - rhs_2
if data_1 < 0 then
data_1 = data_1 + 0x100000000
data_2 = data_2 - 1
end
if data_2 < 0 then
data_2 = data_2 + 0x100000000
end
return from_u32(data_1, data_2)
end
function Numeric.multiply(lhs, rhs)
if num_is_zero(lhs) or num_is_zero(rhs) then
return NUM_ZERO
elseif num_is_less_unsigned(lhs, NUM_BIT_26) and num_is_less_unsigned(rhs, NUM_BIT_26) then
return from_u64(load_d1(lhs) * load_d1(rhs))
end
-- Divide each long into 4 chunks of 16 bits, and then add up 4x4 products.
-- We can skip products that would overflow.
local lhs_1, lhs_2 = into_u32(lhs)
local rhs_1, rhs_2 = into_u32(rhs)
local a48 = bit_rshift(lhs_2, 16)
local a32 = bit_and(lhs_2, 0xFFFF)
local a16 = bit_rshift(lhs_1, 16)
local a00 = bit_and(lhs_1, 0xFFFF)
local b48 = bit_rshift(rhs_2, 16)
local b32 = bit_and(rhs_2, 0xFFFF)
local b16 = bit_rshift(rhs_1, 16)
local b00 = bit_and(rhs_1, 0xFFFF)
local c00 = a00 * b00
local c16 = bit_rshift(c00, 16)
c00 = bit_and(c00, 0xFFFF)
c16 = c16 + a16 * b00
local c32 = bit_rshift(c16, 16)
c16 = bit_and(c16, 0xFFFF)
c16 = c16 + a00 * b16
c32 = c32 + bit_rshift(c16, 16)
c16 = bit_and(c16, 0xFFFF)
c32 = c32 + a32 * b00
local c48 = bit_rshift(c32, 16)
c32 = bit_and(c32, 0xFFFF)
c32 = c32 + a16 * b16
c48 = c48 + bit_rshift(c32, 16)
c32 = bit_and(c32, 0xFFFF)
c32 = c32 + a00 * b32
c48 = c48 + bit_rshift(c32, 16)
c32 = bit_and(c32, 0xFFFF)
c48 = c48 + a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48
c48 = bit_and(c48, 0xFFFF)
local data_1 = bit_replace(c00, c16, 16, 16)
local data_2 = bit_replace(c32, c48, 16, 16)
return from_u32(data_1, data_2)
end
function Numeric.divide_unsigned(lhs, rhs)
if num_is_zero(rhs) then
error("division by zero")
elseif num_is_zero(lhs) then
return NUM_ZERO, NUM_ZERO
elseif num_is_less_unsigned(lhs, NUM_BIT_52) and num_is_less_unsigned(rhs, NUM_BIT_52) then
local lhs_u = into_u64(lhs)
local rhs_u = into_u64(rhs)
return from_u64(lhs_u / rhs_u), from_u64(lhs_u % rhs_u)
end
local quotient = NUM_ZERO
local remainder = NUM_ZERO
local num_1, num_2 = into_u32(lhs)
for i = 63, 0, -1 do
local rem_1, rem_2 = into_u32(num_shift_left(remainder, NUM_ONE))
if i > 31 then
rem_1 = bit_or(rem_1, bit_extract(num_2, i - 32, 1))
else
rem_1 = bit_or(rem_1, bit_extract(num_1, i, 1))
end
remainder = from_u32(rem_1, rem_2)
if not num_is_less_unsigned(remainder, rhs) then
remainder = num_subtract(remainder, rhs)
quotient = num_or(quotient, num_shift_left(NUM_ONE, from_u32(i, 0)))
end
end
return quotient, remainder
end
local function num_is_negative(value)
return value.Z >= 0x80000
end
function Numeric.divide_signed(lhs, rhs)
local left_negative = num_is_negative(lhs)
local right_negative = num_is_negative(rhs)
if left_negative then
lhs = num_negate(lhs)
end
if right_negative then
rhs = num_negate(rhs)
end
local quotient, remainder = num_divide_unsigned(lhs, rhs)
if left_negative ~= right_negative then
quotient = num_negate(quotient)
end
if left_negative then
remainder = num_negate(remainder)
end
return quotient, remainder
end
function Numeric.negate(value)
local value_1, value_2 = into_u32(value)
local data_1 = bit_not(value_1) + 1
local data_2 = bit_not(value_2)
if data_1 >= 0x100000000 then
data_1 = data_1 - 0x100000000
data_2 = data_2 + 1
end
if data_2 >= 0x100000000 then
data_2 = data_2 - 0x100000000
end
return from_u32(data_1, data_2)
end
function Numeric.bit_and(lhs, rhs)
local x = bit_and(lhs.X, rhs.X)
local y = bit_and(lhs.Y, rhs.Y)
local z = bit_and(lhs.Z, rhs.Z)
return constructor(x, y, z)
end
function Numeric.bit_not(value)
local x = bit_and(bit_not(value.X), 0xFFFFFF)
local y = bit_and(bit_not(value.Y), 0xFFFFFF)
local z = bit_and(bit_not(value.Z), 0xFFFFFF)
return constructor(x, y, z)
end
function Numeric.bit_or(lhs, rhs)
local x = bit_or(lhs.X, rhs.X)
local y = bit_or(lhs.Y, rhs.Y)
local z = bit_or(lhs.Z, rhs.Z)
return constructor(x, y, z)
end
function Numeric.bit_xor(lhs, rhs)
local x = bit_xor(lhs.X, rhs.X)
local y = bit_xor(lhs.Y, rhs.Y)
local z = bit_xor(lhs.Z, rhs.Z)
return constructor(x, y, z)
end
function Numeric.shift_left(lhs, rhs)
local count = rhs.X % 64
if count == 0 then
return lhs
elseif count < 32 then
local pad = 32 - count
local lhs_1, lhs_2 = into_u32(lhs)
local data_1 = bit_lshift(lhs_1, count)
local data_2 = bit_replace(bit_rshift(lhs_1, pad), lhs_2, count, pad)
return from_u32(data_1, data_2)
else
local lhs_1 = load_d1(lhs)
return from_u32(0, bit_lshift(lhs_1, count - 32))
end
end
function Numeric.shift_right_unsigned(lhs, rhs)
local count = rhs.X % 64
if count == 0 then
return lhs
elseif count < 32 then
local lhs_1, lhs_2 = into_u32(lhs)
local data_1 = bit_replace(bit_rshift(lhs_1, count), lhs_2, 32 - count, count)
local data_2 = bit_rshift(lhs_2, count)
return from_u32(data_1, data_2)
else
local lhs_2 = load_d2(lhs)
return from_u32(bit_rshift(lhs_2, count - 32), 0)
end
end
function Numeric.shift_right_signed(lhs, rhs)
local count = rhs.X % 64
if count == 0 then
return lhs
elseif count < 32 then
local lhs_1, lhs_2 = into_u32(lhs)
local data_1 = bit_replace(bit_rshift(lhs_1, count), lhs_2, 32 - count, count)
local data_2 = bit_arshift(lhs_2, count)
return from_u32(data_1, data_2)
else
local lhs_2 = load_d2(lhs)
local data_1 = bit_arshift(lhs_2, count - 32)
local data_2 = lhs_2 >= 0x80000000 and 0xFFFFFFFF or 0
return from_u32(data_1, data_2)
end
end
function Numeric.rotate_left(lhs, rhs)
if num_is_zero(rhs) then
return lhs
else
local data_1 = num_shift_left(lhs, rhs)
local data_2 = num_shift_right_unsigned(lhs, num_subtract(NUM_SIX_FOUR, rhs))
return num_or(data_1, data_2)
end
end
function Numeric.rotate_right(lhs, rhs)
if num_is_zero(rhs) then
return lhs
else
local data_1 = num_shift_right_unsigned(lhs, rhs)
local data_2 = num_shift_left(lhs, num_subtract(NUM_SIX_FOUR, rhs))
return num_or(data_1, data_2)
end
end
function Numeric.is_equal(lhs, rhs)
return lhs == rhs
end
function Numeric.is_less_unsigned(lhs, rhs)
local data_l_2 = load_d2(lhs)
local data_r_2 = load_d2(rhs)
return data_l_2 < data_r_2 or (data_l_2 == data_r_2 and load_d1(lhs) < load_d1(rhs))
end
function Numeric.is_greater_unsigned(lhs, rhs)
local data_l_2 = load_d2(lhs)
local data_r_2 = load_d2(rhs)
return data_l_2 > data_r_2 or (data_l_2 == data_r_2 and load_d1(lhs) > load_d1(rhs))
end
function Numeric.is_less_signed(lhs, rhs)
local neg_a = num_is_negative(lhs)
local neg_b = num_is_negative(rhs)
if neg_a and not neg_b then
return true
elseif not neg_a and neg_b then
return false
else
return num_is_negative(num_subtract(lhs, rhs))
end
end
function Numeric.is_greater_signed(lhs, rhs)
local neg_a = num_is_negative(lhs)
local neg_b = num_is_negative(rhs)
if neg_a and not neg_b then
return false
elseif not neg_a and neg_b then
return true
else
return num_is_negative(num_subtract(rhs, lhs))
end
end
from_u32 = Numeric.from_u32
from_u64 = Numeric.from_u64
into_u64 = Numeric.into_u64
num_subtract = Numeric.subtract
num_divide_unsigned = Numeric.divide_unsigned
num_negate = Numeric.negate
num_or = Numeric.bit_or
num_shift_left = Numeric.shift_left
num_shift_right_unsigned = Numeric.shift_right_unsigned
Numeric.is_negative = num_is_negative
Numeric.is_zero = num_is_zero
num_is_less_unsigned = Numeric.is_less_unsigned
NUM_ONE = from_u64(1)
NUM_SIX_FOUR = from_u64(64)
NUM_BIT_26 = from_u64(0x4000000)
NUM_BIT_52 = from_u64(0x10000000000000)
Numeric.ZERO = NUM_ZERO
Numeric.ONE = NUM_ONE
return table.freeze(Numeric)