Use global scratch space for temporary objects, better errors

This commit is contained in:
2024-02-14 21:44:46 -06:00
parent bc948d1a61
commit 8b57e0c441
4 changed files with 149 additions and 102 deletions

View File

@@ -7,5 +7,4 @@ insert_final_newline = true
[{*.lua,*.rockspec,.luacov}] [{*.lua,*.rockspec,.luacov}]
indent_style = space indent_style = space
indent_size = 3 indent_size = 3
call_parentheses = nosingletable max_line_length = 105
max_line_length = 98

View File

@@ -117,15 +117,18 @@ const KiwiErr* new_error(const KiwiErr* base, const std::exception& ex) {
static const constexpr KiwiErr kKiwiErrUnhandledCxxException { static const constexpr KiwiErr kKiwiErrUnhandledCxxException {
KiwiErrUnknown, KiwiErrUnknown,
"An unhandled C++ exception occurred."}; "An unhandled C++ exception occurred."
};
static const constexpr KiwiErr kKiwiErrNullObjectArg0 { static const constexpr KiwiErr kKiwiErrNullObjectArg0 {
KiwiErrNullObject, KiwiErrNullObject,
"null object passed as argument #0 (self)"}; "null object passed as argument #0 (self)"
};
static const constexpr KiwiErr kKiwiErrNullObjectArg1 { static const constexpr KiwiErr kKiwiErrNullObjectArg1 {
KiwiErrNullObject, KiwiErrNullObject,
"null object passed as argument #1"}; "null object passed as argument #1"
};
template<typename F> template<typename F>
inline const KiwiErr* wrap_err(F&& f) { inline const KiwiErr* wrap_err(F&& f) {
@@ -134,42 +137,49 @@ inline const KiwiErr* wrap_err(F&& f) {
} catch (const UnsatisfiableConstraint& ex) { } catch (const UnsatisfiableConstraint& ex) {
static const constexpr KiwiErr err { static const constexpr KiwiErr err {
KiwiErrUnsatisfiableConstraint, KiwiErrUnsatisfiableConstraint,
"The constraint cannot be satisfied."}; "The constraint cannot be satisfied."
};
return &err; return &err;
} catch (const UnknownConstraint& ex) { } catch (const UnknownConstraint& ex) {
static const constexpr KiwiErr err { static const constexpr KiwiErr err {
KiwiErrUnknownConstraint, KiwiErrUnknownConstraint,
"The constraint has not been added to the solver."}; "The constraint has not been added to the solver."
};
return &err; return &err;
} catch (const DuplicateConstraint& ex) { } catch (const DuplicateConstraint& ex) {
static const constexpr KiwiErr err { static const constexpr KiwiErr err {
KiwiErrDuplicateConstraint, KiwiErrDuplicateConstraint,
"The constraint has already been added to the solver."}; "The constraint has already been added to the solver."
};
return &err; return &err;
} catch (const UnknownEditVariable& ex) { } catch (const UnknownEditVariable& ex) {
static const constexpr KiwiErr err { static const constexpr KiwiErr err {
KiwiErrUnknownEditVariable, KiwiErrUnknownEditVariable,
"The edit variable has not been added to the solver."}; "The edit variable has not been added to the solver."
};
return &err; return &err;
} catch (const DuplicateEditVariable& ex) { } catch (const DuplicateEditVariable& ex) {
static const constexpr KiwiErr err { static const constexpr KiwiErr err {
KiwiErrDuplicateEditVariable, KiwiErrDuplicateEditVariable,
"The edit variable has already been added to the solver."}; "The edit variable has already been added to the solver."
};
return &err; return &err;
} catch (const BadRequiredStrength& ex) { } catch (const BadRequiredStrength& ex) {
static const constexpr KiwiErr err { static const constexpr KiwiErr err {
KiwiErrBadRequiredStrength, KiwiErrBadRequiredStrength,
"A required strength cannot be used in this context."}; "A required strength cannot be used in this context."
};
return &err; return &err;
} catch (const InternalSolverError& ex) { } catch (const InternalSolverError& ex) {
static const constexpr KiwiErr base { static const constexpr KiwiErr base {
KiwiErrInternalSolverError, KiwiErrInternalSolverError,
"An internal solver error occurred."}; "An internal solver error occurred."
};
return new_error(&base, ex); return new_error(&base, ex);
} catch (std::bad_alloc&) { } catch (std::bad_alloc&) {
static const constexpr KiwiErr err {KiwiErrAlloc, "A memory allocation failed."}; static const constexpr KiwiErr err {KiwiErrAlloc, "A memory allocation failed."};

View File

@@ -1,10 +1,10 @@
#ifndef CKIWI_H_ #ifndef KIWI_CKIWI_H_
#define CKIWI_H_ #define KIWI_CKIWI_H_
#include <stddef.h>
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#else
#include <stdbool.h>
#endif #endif
#define KIWI_REF_ISNULL(ref) ((ref).impl_ == NULL) #define KIWI_REF_ISNULL(ref) ((ref).impl_ == NULL)
@@ -98,4 +98,4 @@ char* kiwi_solver_dumps(const KiwiSolver* sp);
// Local Variables: // Local Variables:
// mode: c++ // mode: c++
// End: // End:
#endif // CKIWI_H_ #endif // KIWI_CKIWI_H_

188
kiwi.lua
View File

@@ -95,7 +95,8 @@ void free(void *);
]]) ]])
local strformat = string.format local strformat = string.format
local ffi_gc, ffi_istype, ffi_new, ffi_string = ffi.gc, ffi.istype, ffi.new, ffi.string local ffi_copy, ffi_gc, ffi_istype, ffi_new, ffi_string =
ffi.copy, ffi.gc, ffi.istype, ffi.new, ffi.string
local concat = table.concat local concat = table.concat
local has_table_new, new_tab = pcall(require, "table.new") local has_table_new, new_tab = pcall(require, "table.new")
@@ -125,13 +126,18 @@ kiwi.ErrKind = ffi.typeof("enum KiwiErrKind") --[[@as kiwi.ErrKind]]
---| '"EQ"' # == (equal) ---| '"EQ"' # == (equal)
kiwi.RelOp = ffi.typeof("enum KiwiRelOp") kiwi.RelOp = ffi.typeof("enum KiwiRelOp")
kiwi.Strength = { kiwi.strength = {
REQUIRED = 1001001000.0, REQUIRED = 1001001000.0,
STRONG = 1000000.0, STRONG = 1000000.0,
MEDIUM = 1000.0, MEDIUM = 1000.0,
WEAK = 1.0, WEAK = 1.0,
} }
do
local function clamp(n)
return math.max(0, math.min(1000, n))
end
--- Create a custom constraint strength. --- Create a custom constraint strength.
---@param a number: Scale factor 1e6 ---@param a number: Scale factor 1e6
---@param b number: Scale factor 1e3 ---@param b number: Scale factor 1e3
@@ -139,13 +145,11 @@ kiwi.Strength = {
---@param w? number: Weight ---@param w? number: Weight
---@return number ---@return number
---@nodiscard ---@nodiscard
function kiwi.Strength.create(a, b, c, w) function kiwi.strength.create(a, b, c, w)
local function clamp(n)
return math.max(0, math.min(1000, n))
end
w = w or 1.0 w = w or 1.0
return clamp(a * w) * 1000000.0 + clamp(b * w) * 1000.0 + clamp(c * w) return clamp(a * w) * 1000000.0 + clamp(b * w) * 1000.0 + clamp(c * w)
end end
end
local Var = ffi.typeof("struct KiwiVarRefType") --[[@as kiwi.Var]] local Var = ffi.typeof("struct KiwiVarRefType") --[[@as kiwi.Var]]
kiwi.Var = Var kiwi.Var = Var
@@ -183,8 +187,8 @@ end
---@param var kiwi.Var ---@param var kiwi.Var
---@param coeff number? ---@param coeff number?
---@nodiscard ---@nodiscard
local function new_expr_one_temp(constant, var, coeff) local function new_expr_one(constant, var, coeff)
local ret = ffi_new(Expression, 1) --[[@as kiwi.Expression]] local ret = ffi_gc(ffi_new(Expression, 1), ckiwi.kiwi_expression_del_vars) --[[@as kiwi.Expression]]
local dt = ret.terms_[0] local dt = ret.terms_[0]
dt.var = ckiwi.kiwi_var_clone(var) dt.var = ckiwi.kiwi_var_clone(var)
dt.coefficient = coeff or 1.0 dt.coefficient = coeff or 1.0
@@ -193,14 +197,6 @@ local function new_expr_one_temp(constant, var, coeff)
return ret return ret
end end
---@param constant number
---@param var kiwi.Var
---@param coeff number?
---@nodiscard
local function new_expr_one(constant, var, coeff)
return ffi_gc(new_expr_one_temp(constant, var, coeff), ckiwi.kiwi_expression_del_vars) --[[@as kiwi.Expression]]
end
---@param constant number ---@param constant number
---@param var1 kiwi.Var ---@param var1 kiwi.Var
---@param var2 kiwi.Var ---@param var2 kiwi.Var
@@ -220,45 +216,78 @@ local function new_expr_pair(constant, var1, var2, coeff1, coeff2)
return ret return ret
end end
local Strength = kiwi.Strength local function typename(o)
if ffi.istype(Var, o) then
return "Var"
elseif ffi.istype(Term, o) then
return "Term"
elseif ffi.istype(Expression, o) then
return "Expression"
elseif ffi.istype(Constraint, o) then
return "Constraint"
else
return type(o)
end
end
local function op_error(a, b, op)
--stylua: ignore
-- level 3 works for arithmetic without TCO (no return), and for rel with TCO forced (explicit return)
error(strformat(
"invalid operand type for '%s' %.40s('%.99s') and %.40s('%.99s')",
op, typename(a), tostring(a), typename(b), tostring(b)), 3)
end
local Strength = kiwi.strength
local REQUIRED = Strength.REQUIRED local REQUIRED = Strength.REQUIRED
local OP_NAMES = {
LE = "<=",
GE = ">=",
EQ = "==",
}
local SIZEOF_TERM = ffi.sizeof(Term) --[[@as integer]]
local tmpexpr = ffi_new(Expression, 2) --[[@as kiwi.Expression]]
local tmpexpr_r = ffi_new(Expression, 1) --[[@as kiwi.Expression]]
local function toexpr(o, temp)
if ffi_istype(Expression, o) then
return o --[[@as kiwi.Expression]]
elseif type(o) == "number" then
temp.constant = o
temp.term_count = 0
return temp
end
temp.constant = 0
temp.term_count = 1
local t = temp.terms_[0]
if ffi_istype(Var, o) then
t.var = o --[[@as kiwi.Var]]
t.coefficient = 1.0
elseif ffi_istype(Term, o) then
ffi_copy(t, o, SIZEOF_TERM)
else
return nil
end
return temp
end
---@param lhs kiwi.Expression|kiwi.Term|kiwi.Var|number ---@param lhs kiwi.Expression|kiwi.Term|kiwi.Var|number
---@param rhs kiwi.Expression|kiwi.Term|kiwi.Var|number ---@param rhs kiwi.Expression|kiwi.Term|kiwi.Var|number
---@param op kiwi.RelOp ---@param op kiwi.RelOp
---@param strength? number ---@param strength? number
---@nodiscard ---@nodiscard
local function rel(lhs, rhs, op, strength) local function rel(lhs, rhs, op, strength)
local function to_expr(o) local el = toexpr(lhs, tmpexpr)
if ffi_istype(Expression, o) then local er = toexpr(rhs, tmpexpr_r)
return o --[[@as kiwi.Expression]] if el == nil or er == nil then
elseif type(o) == "number" then op_error(lhs, rhs, OP_NAMES[op])
if o == 0 then
return nil
end
local ret = ffi_new(Expression, 0) --[[@as kiwi.Expression]]
ret.constant = o
ret.term_count = 0
return ret
end
local var
local coeff = 1.0
if ffi_istype(Var, o) then
var = o --[[@as kiwi.Var]]
elseif ffi_istype(Term, o) then
var = o.var
coeff = o.coefficient
else
error("Expected Expression|Term|Var|number, got " .. type(o) .. " instead")
end
return new_expr_one_temp(0.0, var, coeff)
end end
return ffi_gc( return ffi_gc(ckiwi.kiwi_constraint_new(el, er, op, strength or REQUIRED), ckiwi.kiwi_constraint_del) --[[@as kiwi.Constraint]]
ckiwi.kiwi_constraint_new(to_expr(lhs), to_expr(rhs), op, strength or REQUIRED),
ckiwi.kiwi_constraint_del
) --[[@as kiwi.Constraint]]
end end
--- Define a constraint with expressions as `a <= b`. --- Define a constraint with expressions as `a <= b`.
@@ -353,11 +382,13 @@ do
elseif type(b) == "number" then elseif type(b) == "number" then
return Term(a, b) return Term(a, b)
end end
error("Invalid var *") op_error(a, b, "*")
end end
function Var_mt.__div(a, b) function Var_mt.__div(a, b)
assert(type(b) == "number", "Invalid var /") if type(b) ~= "number" then
op_error(a, b, "/")
end
return Term(a, 1.0 / b) return Term(a, 1.0 / b)
end end
@@ -379,7 +410,7 @@ do
elseif type(b) == "number" then elseif type(b) == "number" then
return new_expr_one(b, a) return new_expr_one(b, a)
end end
error("Invalid var +") op_error(a, b, "+")
end end
function Var_mt.__sub(a, b) function Var_mt.__sub(a, b)
@@ -426,10 +457,12 @@ do
local Term_mt = { __index = Term_cls } local Term_mt = { __index = Term_cls }
function Term_mt.__new(T, var, coefficient) local function term_gc(term)
return ffi_gc(ffi_new(T, ckiwi.kiwi_var_clone(var), coefficient or 1.0), function(term)
ckiwi.kiwi_var_del(term.var) ckiwi.kiwi_var_del(term.var)
end) end
function Term_mt.__new(T, var, coefficient)
return ffi_gc(ffi_new(T, ckiwi.kiwi_var_clone(var), coefficient or 1.0), term_gc)
end end
function Term_mt.__mul(a, b) function Term_mt.__mul(a, b)
@@ -438,11 +471,13 @@ do
elseif type(a) == "number" then elseif type(a) == "number" then
return Term(b.var, b.coefficient * a) return Term(b.var, b.coefficient * a)
end end
error("Invalid term *") op_error(a, b, "*")
end end
function Term_mt.__div(a, b) function Term_mt.__div(a, b)
assert(type(b) == "number", "Invalid term /") if type(b) ~= "number" then
op_error(a, b, "/")
end
return Term(a.var, a.coefficient / b) return Term(a.var, a.coefficient / b)
end end
@@ -464,7 +499,7 @@ do
elseif type(b) == "number" then elseif type(b) == "number" then
return new_expr_one(b, a.var, a.coefficient) return new_expr_one(b, a.var, a.coefficient)
end end
error("Invalid term + op") op_error(a, b, "+")
end end
function Term_mt.__sub(a, b) function Term_mt.__sub(a, b)
@@ -603,11 +638,13 @@ do
elseif type(b) == "number" then elseif type(b) == "number" then
return mul_expr_constant(a, b) return mul_expr_constant(a, b)
end end
error("Invalid expr *") op_error(a, b, "*")
end end
function Expression_mt.__div(a, b) function Expression_mt.__div(a, b)
assert(type(b) == "number", "Invalid expr /") if type(b) ~= "number" then
op_error(a, b, "/")
end
return mul_expr_constant(a, 1.0 / b) return mul_expr_constant(a, 1.0 / b)
end end
@@ -629,7 +666,7 @@ do
elseif type(b) == "number" then elseif type(b) == "number" then
return new_expr_constant(a, a.constant + b) return new_expr_constant(a, a.constant + b)
end end
error("Invalid expr +") op_error(a, b, "+")
end end
function Expression_mt.__sub(a, b) function Expression_mt.__sub(a, b)
@@ -639,7 +676,8 @@ do
function Expression_mt:__tostring() function Expression_mt:__tostring()
local tab = new_tab(self.term_count + 1, 0) local tab = new_tab(self.term_count + 1, 0)
for i = 0, self.term_count - 1 do for i = 0, self.term_count - 1 do
tab[i + 1] = tostring(self.terms_[i]) local t = self.terms_[i]
tab[i + 1] = tostring(t.coefficient) .. " " .. t.var:name()
end end
tab[self.term_count + 1] = self.constant tab[self.term_count + 1] = self.constant
return concat(tab, " + ") return concat(tab, " + ")
@@ -693,17 +731,18 @@ do
) )
end end
function Constraint_mt:__tostring() local OPS = { [0] = "<=", ">=", "==" }
local ops = { [0] = "<=", ">=", "==" } local STRENGTH_NAMES = {
local strengths = {
[Strength.REQUIRED] = "required", [Strength.REQUIRED] = "required",
[Strength.STRONG] = "strong", [Strength.STRONG] = "strong",
[Strength.MEDIUM] = "medium", [Strength.MEDIUM] = "medium",
[Strength.WEAK] = "weak", [Strength.WEAK] = "weak",
} }
function Constraint_mt:__tostring()
local strength = self:strength() local strength = self:strength()
local strength_str = strengths[strength] or tostring(strength) local strength_str = STRENGTH_NAMES[strength] or tostring(strength)
local op = ops[tonumber(self:op())] local op = OPS[tonumber(self:op())]
return strformat("%s %s 0 | %s", tostring(self:expression()), op, strength_str) return strformat("%s %s 0 | %s", tostring(self:expression()), op, strength_str)
end end
@@ -726,18 +765,17 @@ do
---@nodiscard ---@nodiscard
function constraints.pair_ratio(left, coeff, right, constant, op, strength) function constraints.pair_ratio(left, coeff, right, constant, op, strength)
assert(ffi_istype(Var, left) and ffi_istype(Var, right)) assert(ffi_istype(Var, left) and ffi_istype(Var, right))
local lhs = ffi_new(Expression, 2) --[[@as kiwi.Expression]] local dt = tmpexpr.terms_[0]
local dt = lhs.terms_[0]
dt.var = left dt.var = left
dt.coefficient = 1.0 dt.coefficient = 1.0
dt = lhs.terms_[1] dt = tmpexpr.terms_[1]
dt.var = right dt.var = right
dt.coefficient = -coeff dt.coefficient = -coeff
lhs.constant = -(constant or 0.0) tmpexpr.constant = constant ~= nil and constant or 0
lhs.term_count = 2 tmpexpr.term_count = 2
return ffi_gc( return ffi_gc(
ckiwi.kiwi_constraint_new(lhs, nil, op or "EQ", strength or REQUIRED), ckiwi.kiwi_constraint_new(tmpexpr, nil, op or "EQ", strength or REQUIRED),
ckiwi.kiwi_constraint_del ckiwi.kiwi_constraint_del
) --[[@as kiwi.Constraint]] ) --[[@as kiwi.Constraint]]
end end
@@ -767,13 +805,13 @@ do
---@nodiscard ---@nodiscard
function constraints.single(var, constant, op, strength) function constraints.single(var, constant, op, strength)
assert(ffi_istype(Var, var)) assert(ffi_istype(Var, var))
tmpexpr.constant = -(constant or 0)
tmpexpr.term_count = 1
local t = tmpexpr.terms_[0]
t.var = var
t.coefficient = 1.0
return ffi_gc( return ffi_gc(
ckiwi.kiwi_constraint_new( ckiwi.kiwi_constraint_new(tmpexpr, nil, op or "EQ", strength or REQUIRED),
new_expr_one_temp(-(constant or 0.0), var, 1.0),
nil,
op or "EQ",
strength or REQUIRED
),
ckiwi.kiwi_constraint_del ckiwi.kiwi_constraint_del
) --[[@as kiwi.Constraint]] ) --[[@as kiwi.Constraint]]
end end