Skip to content

Commit

Permalink
As recommended in the docs, mark custom functions SQLITE_DIRECTONLY b…
Browse files Browse the repository at this point in the history
…y default (security hardening). Provide an initializer flag to override it if needed.
  • Loading branch information
gwynne committed May 4, 2024
1 parent ccc22b5 commit 0e5ca1a
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions Sources/SQLiteNIO/SQLiteCustomFunction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ public final class SQLiteCustomFunction: Hashable {

/// The name of the SQL function
public var name: String { identity.name }

private let identity: Identity
private let pure: Bool
private let indirect: Bool
private let kind: Kind
private var eTextRep: Int32 { (SQLITE_UTF8 | (pure ? SQLITE_DETERMINISTIC : 0)) }
private var eTextRep: Int32 { (SQLITE_UTF8 | (pure ? SQLITE_DETERMINISTIC : 0) | (indirect ? 0 : SQLITE_DIRECTONLY)) }

public struct SQLiteCustomFunctionArgumentError: Error {
public let count: Int
Expand All @@ -38,19 +40,19 @@ public final class SQLiteCustomFunction: Hashable {
_ name: String,
argumentCount: Int32? = nil,
pure: Bool = false,
indirect: Bool = false,
function: @Sendable @escaping ([SQLiteData]) throws -> (any SQLiteDataConvertible)?)
{
self.identity = Identity(name: name, nArg: argumentCount ?? -1)
self.pure = pure
self.indirect = indirect
self.kind = .function { (argc, argv) in
let count = Int(argc)
let arguments = try (0 ..< count).map { index -> SQLiteData in
try function((0 ..< Int(argc)).map { index -> SQLiteData in
guard let value = argv?[index] else {
throw SQLiteCustomFunctionArgumentError(count: count, index: index)
throw SQLiteCustomFunctionArgumentError(count: Int(argc), index: index)
}
return try SQLiteData(sqliteValue: value)
}
return try function(arguments)
})
}
}

Expand Down Expand Up @@ -95,33 +97,30 @@ public final class SQLiteCustomFunction: Hashable {
_ name: String,
argumentCount: Int32? = nil,
pure: Bool = false,
indirect: Bool = false,
aggregate: Aggregate.Type
) {
self.identity = Identity(name: name, nArg: argumentCount ?? -1)
self.pure = pure
self.indirect = indirect
self.kind = .aggregate { Aggregate() }
}

/// Invokes `sqlite3_create_function_v2()` to install a custom function.
/// See https://sqlite.org/c3ref/create_function.html
func install(in connection: SQLiteConnection) throws {
// Retain the function definition
let definition = kind.definition
let definitionP = Unmanaged.passRetained(definition).toOpaque()

let code = sqlite_nio_sqlite3_create_function_v2(
connection.handle.raw,
identity.name,
identity.nArg,
eTextRep,
definitionP,
kind.xFunc,
kind.xStep,
kind.xFinal,
{ definitionP in
// Release the function definition
Unmanaged<AnyObject>.fromOpaque(definitionP!).release()
})
self.identity.name,
self.identity.nArg,
self.eTextRep,
Unmanaged.passRetained(self.kind.definition).toOpaque(),
self.kind.xFunc,
self.kind.xStep,
self.kind.xFinal,
{ Unmanaged<AnyObject>.fromOpaque($0!).release() } // Release the function definition
)

guard code == SQLITE_OK else {
throw SQLiteError(statusCode: code, connection: connection)
Expand All @@ -133,10 +132,11 @@ public final class SQLiteCustomFunction: Hashable {
func uninstall(in connection: SQLiteConnection) throws {
let code = sqlite_nio_sqlite3_create_function_v2(
connection.handle.raw,
identity.name,
identity.nArg,
eTextRep,
nil, nil, nil, nil, nil)
self.identity.name,
self.identity.nArg,
self.eTextRep,
nil, nil, nil, nil, nil
)

guard code == SQLITE_OK else {
throw SQLiteError(statusCode: code, connection: connection)
Expand Down

0 comments on commit 0e5ca1a

Please sign in to comment.