cambridge/libs/bigint/bigint.lua

567 lines
17 KiB
Lua

#!/usr/bin/env lua
-- If this variable is true, then strict type checking is performed for all
-- operations. This may result in slower code, but it will allow you to catch
-- errors and bugs earlier.
local strict = true
--------------------------------------------------------------------------------
local bigint = {}
setmetatable(bigint, {__call = function(_, arg) return bigint.new(arg) end})
local mt = {
__add = function(lhs, rhs)
return bigint.add(lhs, rhs)
end,
__unm = function(arg)
return bigint.negate(arg)
end,
__sub = function(lhs, rhs)
return bigint.subtract(lhs, rhs)
end,
__mul = function(lhs, rhs)
return bigint.multiply(lhs, rhs)
end,
__div = function(lhs, rhs)
return bigint.divide(lhs, rhs)
end,
__mod = function(lhs, rhs)
return bigint.modulus(lhs, rhs)
end,
__pow = function(lhs, rhs)
return bigint.exponentiate(lhs, rhs)
end,
__tostring = function(arg)
return bigint.unserialize(arg, "s")
end,
__eq = function(lhs, rhs)
return bigint.compare(lhs, rhs, "==")
end,
__lt = function(lhs, rhs)
return bigint.compare(lhs, rhs, "<")
end,
__le = function(lhs, rhs)
return bigint.compare(lhs, rhs, "<=")
end
}
local named_powers = require("libs.bigint.named-powers-of-ten")
-- Create a new bigint or convert a number or string into a big
-- Returns an empty, positive bigint if no number or string is given
function bigint.new(num)
local self = {
sign = "+",
digits = {}
}
-- Return a new bigint with the same sign and digits
function self:clone()
local newint = bigint.new()
newint.sign = self.sign
for _, digit in pairs(self.digits) do
newint.digits[#newint.digits + 1] = digit
end
return newint
end
setmetatable(self, mt)
if (num) then
local num_string = tostring(num)
for digit in string.gmatch(num_string, "[0-9]") do
table.insert(self.digits, tonumber(digit))
end
if string.sub(num_string, 1, 1) == "-" then
self.sign = "-"
end
end
return bigint.strip(self)
end
-- Check the type of a big
-- Normally only runs when global variable "strict" == true, but checking can be
-- forced by supplying "true" as the second argument.
function bigint.check(big, force)
if (strict or force) then
assert(getmetatable(big) == mt, "at least one arg is not a bigint")
assert(#big.digits > 0, "bigint is empty")
assert(big.sign == "+" or big.sign == "-", "bigint is unsigned")
for _, digit in pairs(big.digits) do
assert(type(digit) == "number", "at least one digit is invalid")
assert(digit <= 9 and digit >= 0, digit .. " is not between 0 and 9")
assert(math.floor(digit) == digit, digit .. " is not an integer")
end
end
return true
end
-- Strip leading zeroes from a big, but don't remove the last zero
function bigint.strip(big)
while (#big.digits > 1) and (big.digits[1] == 0) do
table.remove(big.digits, 1)
end
return big
end
-- Return a new big with the same digits but with a positive sign (absolute
-- value)
function bigint.abs(big)
bigint.check(big)
local result = big:clone()
result.sign = "+"
return result
end
-- Return a new big with the same digits but the opposite sign (negation)
function bigint.negate(big)
bigint.check(big)
local result = big:clone()
if (result.sign == "+") then
result.sign = "-"
else
result.sign = "+"
end
return result
end
-- Return the number of digits in the big
function bigint.digits(big)
bigint.check(big)
return #big.digits
end
-- Convert a big to a number or string
function bigint.unserialize(big, output_type, precision)
bigint.check(big)
local num = ""
if big.sign == "-" then
num = "-"
end
if ((output_type == nil)
or (output_type == "number")
or (output_type == "n")
or (output_type == "string")
or (output_type == "s")) then
-- Unserialization to a string or number requires reconstructing the
-- entire number
for _, digit in pairs(big.digits) do
num = num .. math.floor(digit) -- lazy way of getting rid of .0$
end
if ((output_type == nil)
or (output_type == "number")
or (output_type == "n")) then
return tonumber(num)
else
return num
end
else
-- Unserialization to human-readable form or scientific notation only
-- requires reading the first few digits
if (precision == nil) then
precision = math.min(#big.digits, 3)
else
assert(precision > 0, "Precision cannot be less than 1")
assert(math.floor(precision) == precision,
"Precision must be a positive integer")
end
-- num is the first (precision + 1) digits, the first being separated by
-- a decimal point from the others
num = num .. math.floor(big.digits[1])
if (precision > 1) then
num = num .. "."
for i = 1, (precision - 1) do
num = num .. math.floor(big.digits[i + 1])
end
end
if ((output_type == "human-readable")
or (output_type == "human")
or (output_type == "h"))
and (#big.digits >= 3 and #big.digits <= 10002) then
-- Human-readable output contributed by 123eee555
local name
local walkback = 0 -- Used to enumerate "ten", "hundred", etc
-- Walk backwards in the index of named_powers starting at the
-- number of digits of the input until the first value is found
for i = (#big.digits - 1), (#big.digits - 4), -1 do
name = named_powers[i]
if (name) then
if (walkback == 1) then
name = "ten " .. name
elseif (walkback == 2) then
name = "hundred " .. name
end
break
else
walkback = walkback + 1
end
end
return num .. " " .. name
else
return num .. "*10^" .. (#big.digits - 1)
end
end
end
-- Basic comparisons
-- Accepts symbols (<, >=, ~=) and Unix shell-like options (lt, ge, ne)
function bigint.compare(big1, big2, comparison)
bigint.check(big1)
bigint.check(big2)
local greater = false -- If big1.digits > big2.digits
local equal = false
if (big1.sign == "-") and (big2.sign == "+") then
greater = false
elseif (#big1.digits > #big2.digits)
or ((big1.sign == "+") and (big2.sign == "-")) then
greater = true
elseif (#big1.digits == #big2.digits) then
-- Walk left to right, comparing digits
for digit = 1, #big1.digits do
if (big1.digits[digit] > big2.digits[digit]) then
greater = true
break
elseif (big2.digits[digit] > big1.digits[digit]) then
break
elseif (digit == #big1.digits)
and (big1.digits[digit] == big2.digits[digit]) then
equal = true
end
end
end
-- If both numbers are negative, then the requirements for greater are
-- reversed
if (not equal) and (big1.sign == "-") and (big2.sign == "-") then
greater = not greater
end
return (((comparison == "<") or (comparison == "lt"))
and ((not greater) and (not equal)) and true)
or (((comparison == ">") or (comparison == "gt"))
and ((greater) and (not equal)) and true)
or (((comparison == "==") or (comparison == "eq"))
and (equal) and true)
or (((comparison == ">=") or (comparison == "ge"))
and (equal or greater) and true)
or (((comparison == "<=") or (comparison == "le"))
and (equal or not greater) and true)
or (((comparison == "~=") or (comparison == "!=") or (comparison == "ne"))
and (not equal) and true)
or false
end
-- BACKEND: Add big1 and big2, ignoring signs
function bigint.add_raw(big1, big2)
bigint.check(big1)
bigint.check(big2)
local result = bigint.new()
local max_digits = 0
local carry = 0
if (#big1.digits >= #big2.digits) then
max_digits = #big1.digits
else
max_digits = #big2.digits
end
-- Walk backwards right to left, like in long addition
for digit = 0, max_digits - 1 do
local sum = (big1.digits[#big1.digits - digit] or 0)
+ (big2.digits[#big2.digits - digit] or 0)
+ carry
if (sum >= 10) then
carry = 1
sum = sum - 10
else
carry = 0
end
result.digits[max_digits - digit] = sum
end
-- Leftover carry in cases when #big1.digits == #big2.digits and sum > 10, ex. 7 + 9
if (carry == 1) then
table.insert(result.digits, 1, 1)
end
return result
end
-- BACKEND: Subtract big2 from big1, ignoring signs
function bigint.subtract_raw(big1, big2)
-- Type checking is done by bigint.compare
assert(bigint.compare(bigint.abs(big1), bigint.abs(big2), ">="),
"Size of " .. bigint.unserialize(big1, "string") .. " is less than "
.. bigint.unserialize(big2, "string"))
local result = big1:clone()
local max_digits = #big1.digits
local borrow = 0
-- Logic mostly copied from bigint.add_raw ---------------------------------
-- Walk backwards right to left, like in long subtraction
for digit = 0, max_digits - 1 do
local diff = (big1.digits[#big1.digits - digit] or 0)
- (big2.digits[#big2.digits - digit] or 0)
- borrow
if (diff < 0) then
borrow = 1
diff = diff + 10
else
borrow = 0
end
result.digits[max_digits - digit] = diff
end
----------------------------------------------------------------------------
return bigint.strip(result)
end
-- FRONTEND: Addition and subtraction operations, accounting for signs
function bigint.add(big1, big2)
-- Type checking is done by bigint.compare
local result
-- If adding numbers of different sign, subtract the smaller sized one from
-- the bigger sized one and take the sign of the bigger sized one
if (big1.sign ~= big2.sign) then
if (bigint.compare(bigint.abs(big1), bigint.abs(big2), ">")) then
result = bigint.subtract_raw(big1, big2)
result.sign = big1.sign
else
result = bigint.subtract_raw(big2, big1)
result.sign = big2.sign
end
elseif (big1.sign == "+") and (big2.sign == "+") then
result = bigint.add_raw(big1, big2)
elseif (big1.sign == "-") and (big2.sign == "-") then
result = bigint.add_raw(big1, big2)
result.sign = "-"
end
return result
end
function bigint.subtract(big1, big2)
-- Type checking is done by bigint.compare in bigint.add
-- Subtracting is like adding a negative
local big2_local = big2:clone()
if (big2.sign == "+") then
big2_local.sign = "-"
else
big2_local.sign = "+"
end
return bigint.add(big1, big2_local)
end
-- BACKEND: Multiply a big by a single digit big, ignoring signs
function bigint.multiply_single(big1, big2)
bigint.check(big1)
bigint.check(big2)
assert(#big2.digits == 1, bigint.unserialize(big2, "string")
.. " has more than one digit")
local result = bigint.new()
local carry = 0
-- Logic mostly copied from bigint.add_raw ---------------------------------
-- Walk backwards right to left, like in long multiplication
for digit = 0, #big1.digits - 1 do
local this_digit = big1.digits[#big1.digits - digit]
* big2.digits[1]
+ carry
if (this_digit >= 10) then
carry = math.floor(this_digit / 10)
this_digit = this_digit - (carry * 10)
else
carry = 0
end
result.digits[#big1.digits - digit] = this_digit
end
-- Leftover carry in cases when big1.digits[1] * big2.digits[1] > 0
if (carry > 0) then
table.insert(result.digits, 1, carry)
end
----------------------------------------------------------------------------
return result
end
-- FRONTEND: Multiply two bigs, accounting for signs
function bigint.multiply(big1, big2)
-- Type checking done by bigint.multiply_single
local result = bigint.new(0)
local larger, smaller -- Larger and smaller in terms of digits, not size
if (bigint.unserialize(big1) == 0) or (bigint.unserialize(big2) == 0) then
return result
end
if (#big1.digits >= #big2.digits) then
larger = big1
smaller = big2
else
larger = big2
smaller = big1
end
-- Walk backwards right to left, like in long multiplication
for digit = 0, #smaller.digits - 1 do
-- Sorry for going over column 80! There's lots of big names here
local this_digit_product = bigint.multiply_single(larger,
bigint.new(smaller.digits[#smaller.digits - digit]))
-- "Placeholding zeroes"
if (digit > 0) then
for placeholder = 1, digit do
table.insert(this_digit_product.digits, 0)
end
end
result = bigint.add(result, this_digit_product)
end
if (larger.sign == smaller.sign) then
result.sign = "+"
else
result.sign = "-"
end
return result
end
-- Raise a big to a positive integer or big power (TODO: negative integer power)
function bigint.exponentiate(big, power)
-- Type checking for big done by bigint.multiply
assert(bigint.compare(power, bigint.new(0), ">="),
"negative powers are not supported")
local exp = power:clone()
if (bigint.compare(exp, bigint.new(0), "==")) then
return bigint.new(1)
elseif (bigint.compare(exp, bigint.new(1), "==")) then
return big:clone()
else
local result = bigint.new(1)
local base = big:clone()
while (true) do
if (bigint.compare(
bigint.modulus(exp, bigint.new(2)), bigint.new(1), "=="
)) then
result = bigint.multiply(result, base)
end
if (bigint.compare(exp, bigint.new(1), "==")) then
break
else
exp = bigint.divide(exp, bigint.new(2))
base = bigint.multiply(base, base)
end
end
return result
end
end
-- BACKEND: Divide two bigs (decimals not supported), returning big result and
-- big remainder
-- WARNING: Only supports positive integers
function bigint.divide_raw(big1, big2)
-- Type checking done by bigint.compare
if (bigint.compare(big1, big2, "==")) then
return bigint.new(1), bigint.new(0)
elseif (bigint.compare(big1, big2, "<")) then
return bigint.new(0), big1:clone()
else
assert(bigint.compare(big2, bigint.new(0), "!="), "error: divide by zero")
assert(big1.sign == "+", "error: big1 is not positive")
assert(big2.sign == "+", "error: big2 is not positive")
local result = bigint.new()
local dividend = bigint.new() -- Dividend of a single operation
local neg_zero = bigint.new(0)
neg_zero.sign = "-"
for i = 1, #big1.digits do
-- Fixes a negative zero bug
if (#dividend.digits ~= 0) and (bigint.compare(dividend, neg_zero, "==")) then
dividend = bigint.new()
end
table.insert(dividend.digits, big1.digits[i])
local factor = bigint.new(0)
while bigint.compare(dividend, big2, ">=") do
dividend = bigint.subtract(dividend, big2)
factor = bigint.add(factor, bigint.new(1))
end
for i = 0, #factor.digits - 1 do
result.digits[#result.digits + 1 - i] = factor.digits[i + 1]
end
end
return bigint.strip(result), dividend
end
end
-- FRONTEND: Divide two bigs (decimals not supported), returning big result and
-- big remainder, accounting for signs
function bigint.divide(big1, big2)
local result, remainder = bigint.divide_raw(bigint.abs(big1),
bigint.abs(big2))
if (big1.sign == big2.sign) then
result.sign = "+"
else
result.sign = "-"
end
return result, remainder
end
-- FRONTEND: Return only the remainder from bigint.divide
function bigint.modulus(big1, big2)
local result, remainder = bigint.divide(big1, big2)
-- Remainder will always have the same sign as the dividend per C standard
-- https://en.wikipedia.org/wiki/Modulo_operation#Remainder_calculation_for_the_modulo_operation
remainder.sign = big1.sign
return remainder
end
return bigint