This commit is contained in:
2024-10-29 13:21:34 -05:00
parent 4dda811a1c
commit 16eaad818e
4 changed files with 97 additions and 39 deletions

View File

@@ -1,10 +1,16 @@
internal import PFFFTLib
public final class FFTDoubleImpl: FFTImplProtocol {
public final class FFTDouble: PFFFTProtocol {
let ptr: OpaquePointer
let n: Int
let type: FFTType
static let sharedCache = SetupCache<Element>()
public static func setup(for n: Int, type: FFTType) throws -> FFTDouble {
try sharedCache.get(for: n, type: type)
}
public init(n: Int, type: FFTType) throws {
guard let ptr = pffftd_new_setup(Int32(n), pffft_transform_t(type)) else { throw FFTError.invalidSize }
self.ptr = ptr
@@ -12,37 +18,56 @@ public final class FFTDoubleImpl: FFTImplProtocol {
self.type = type
}
public func fft(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>, work: borrowing Buffer<Double>?, sign: FFTSign) {
func transform(_ input: borrowing Buffer<Element>, _ output: borrowing Buffer<Element>, _ work: borrowing Buffer<Element>?, _ dir: pffft_direction_t) {
checkFftBufferCounts(n: n, type: type, input: input, output: output, work: work)
let workAddress: UnsafeMutablePointer<Element>! = switch work {
case let .some(b): b.baseAddress
case .none: nil
}
pffftd_transform_ordered(ptr, input.baseAddress, output.baseAddress, workAddress, dir)
}
func transformUnordered(_ input: borrowing Buffer<Double>, _ output: borrowing Buffer<Double>, _ work: borrowing Buffer<Double>?, _ dir: pffft_direction_t) {
checkFftBufferCounts(n: n, type: type, input: input, output: output, work: work)
let workAddress: UnsafeMutablePointer<Double>! = switch work {
case let .some(b): b.baseAddress
case .none: nil
}
pffftd_transform_ordered(ptr, input.baseAddress, output.baseAddress, workAddress, pffft_direction_t(sign))
pffftd_transform(ptr, input.baseAddress, output.baseAddress, workAddress, dir)
}
public func fftUnordered(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>, work: borrowing Buffer<Double>?, sign: FFTSign) {
checkFftBufferCounts(n: n, type: type, input: input, output: output, work: work)
let workAddress: UnsafeMutablePointer<Double>! = switch work {
case let .some(b): b.baseAddress
case .none: nil
}
pffftd_transform(ptr, input.baseAddress, output.baseAddress, workAddress, pffft_direction_t(sign))
public func forward(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>, work: borrowing Buffer<Double>?) {
transform(input, output, work, PFFFT_FORWARD)
}
public func zReorder(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>, sign: FFTSign) {
public func inverse(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>, work: borrowing Buffer<Double>?) {
transform(input, output, work, PFFFT_BACKWARD)
}
public func forwardUnordered(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>, work: borrowing Buffer<Double>?) {
transformUnordered(input, output, work, PFFFT_FORWARD)
}
public func inverseUnordered(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>, work: borrowing Buffer<Double>?) {
transformUnordered(input, output, work, PFFFT_BACKWARD)
}
public func reorderSpectrum(input: borrowing Buffer<Double>, output: borrowing Buffer<Double>) {
checkFftBufferCounts(n: n, type: type, input: input, output: output, work: nil)
pffftd_zreorder(ptr, input.baseAddress, output.baseAddress, pffft_direction_t(sign))
pffftd_zreorder(ptr, input.baseAddress, output.baseAddress, PFFFT_FORWARD)
}
public func zConvolveAccumulate(dftA: borrowing Buffer<Double>, dftB: borrowing Buffer<Double>, dftAB: borrowing Buffer<Double>, scaling: Double) {
public func convolveAccumulate(dftA: borrowing Buffer<Double>, dftB: borrowing Buffer<Double>, dftAB: borrowing Buffer<Double>, scaling: Double) {
checkConvolveBufferCounts(n: n, type: type, a: dftA, b: dftB, ab: dftAB)
pffftd_zconvolve_accumulate(ptr, dftA.baseAddress, dftB.baseAddress, dftAB.baseAddress, scaling)
}
public func zConvolve(dftA: borrowing Buffer<Double>, dftB: borrowing Buffer<Double>, dftAB: borrowing Buffer<Double>, scaling: Double) {
public func convolve(dftA: borrowing Buffer<Double>, dftB: borrowing Buffer<Double>, dftAB: borrowing Buffer<Double>, scaling: Double) {
checkConvolveBufferCounts(n: n, type: type, a: dftA, b: dftB, ab: dftAB)
pffftd_zconvolve_no_accu(ptr, dftA.baseAddress, dftB.baseAddress, dftAB.baseAddress, scaling)
}
@@ -65,5 +90,5 @@ public final class FFTDoubleImpl: FFTImplProtocol {
}
extension Double: FFTElement {
public typealias FFTImpl = FFTDoubleImpl
public typealias FFTImpl = FFTDouble
}

View File

@@ -1,10 +1,16 @@
internal import PFFFTLib
public final class FFTFloatImpl: FFTImplProtocol {
public final class FFTFloat: PFFFTProtocol {
let ptr: OpaquePointer
let n: Int
let type: FFTType
static let sharedCache = SetupCache<Element>()
public static func setup(for n: Int, type: FFTType) throws -> FFTFloat {
try sharedCache.get(for: n, type: type)
}
public init(n: Int, type: FFTType) throws {
guard let ptr = pffft_new_setup(Int32(n), pffft_transform_t(type)) else { throw FFTError.invalidSize }
self.ptr = ptr
@@ -12,7 +18,7 @@ public final class FFTFloatImpl: FFTImplProtocol {
self.type = type
}
public func fft(input: borrowing Buffer<Float>, output: borrowing Buffer<Float>, work: borrowing Buffer<Float>?, sign: FFTSign) {
public func forward(input: borrowing Buffer<Float>, output: borrowing Buffer<Float>, work: borrowing Buffer<Float>?, sign: FFTSign) {
checkFftBufferCounts(n: n, type: type, input: input, output: output, work: work)
let workAddress: UnsafeMutablePointer<Float>! = switch work {
@@ -32,17 +38,17 @@ public final class FFTFloatImpl: FFTImplProtocol {
pffft_transform(ptr, input.baseAddress, output.baseAddress, workAddress, pffft_direction_t(sign))
}
public func zReorder(input: borrowing Buffer<Float>, output: borrowing Buffer<Float>, sign: FFTSign) {
public func reorderSpectrum(input: borrowing Buffer<Float>, output: borrowing Buffer<Float>) {
checkFftBufferCounts(n: n, type: type, input: input, output: output, work: nil)
pffft_zreorder(ptr, input.baseAddress, output.baseAddress, pffft_direction_t(sign))
pffft_zreorder(ptr, input.baseAddress, output.baseAddress, PFFFT_FORWARD)
}
public func zConvolveAccumulate(dftA: borrowing Buffer<Float>, dftB: borrowing Buffer<Float>, dftAB: borrowing Buffer<Float>, scaling: Float) {
public func convolveAccumulate(dftA: borrowing Buffer<Float>, dftB: borrowing Buffer<Float>, dftAB: borrowing Buffer<Float>, scaling: Float) {
checkConvolveBufferCounts(n: n, type: type, a: dftA, b: dftB, ab: dftAB)
pffft_zconvolve_accumulate(ptr, dftA.baseAddress, dftB.baseAddress, dftAB.baseAddress, scaling)
}
public func zConvolve(dftA: borrowing Buffer<Float>, dftB: borrowing Buffer<Float>, dftAB: borrowing Buffer<Float>, scaling: Float) {
public func convolve(dftA: borrowing Buffer<Float>, dftB: borrowing Buffer<Float>, dftAB: borrowing Buffer<Float>, scaling: Float) {
checkConvolveBufferCounts(n: n, type: type, a: dftA, b: dftB, ab: dftAB)
pffft_zconvolve_no_accu(ptr, dftA.baseAddress, dftB.baseAddress, dftAB.baseAddress, scaling)
}
@@ -65,5 +71,5 @@ public final class FFTFloatImpl: FFTImplProtocol {
}
extension Float: FFTElement {
public typealias FFTImpl = FFTFloatImpl
public typealias FFTImpl = FFTFloat
}

View File

@@ -17,12 +17,22 @@ public enum FFTError: Error {
}
public protocol FFTElement {
associatedtype FFTImpl: FFTImplProtocol
associatedtype FFTImpl: PFFFTProtocol
}
public protocol FFTImplProtocol<Element>: ~Copyable {
public protocol PFFFTProtocol<Element> {
associatedtype Element
/// Get an FFT interface for the given size and FFT Type.
///
/// This call is backed by a global cache to avoid repeated setup costs.
/// - Parameters:
/// - n: The size of the FFT.
/// - type: The type of FFT.
/// - Returns: An FFT interface.
/// - Throws: `FFTError.invalidSize` if the size is invalid.
static func setup(for n: Int, type: FFTType) throws -> Self
/// Initialize the FFT implementation with the given size and type.
/// - Parameters:
/// - n: The size of the FFT.
@@ -30,7 +40,7 @@ public protocol FFTImplProtocol<Element>: ~Copyable {
/// - Throws: `FFTError.invalidSize` if the size is invalid.
init(n: Int, type: FFTType) throws
/// Perform a forward or backward FFT on the input buffer.
/// Perform a forward FFT on the input buffer.
///
/// The input and output buffers may be the same.
/// The data is stores in order as expected (interleaved complex components ordered by frequency).
@@ -50,9 +60,11 @@ public protocol FFTImplProtocol<Element>: ~Copyable {
/// - output: The output buffer.
/// - work: An optional work buffer. Must have capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// - sign: The direction of the FFT.
func fft(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, work: borrowing Buffer<Element>?, sign: FFTSign)
func forward(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, work: borrowing Buffer<Element>?)
/// Perform a forward or backward FFT on the input buffer, with implementation defined order.
func inverse(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, work: borrowing Buffer<Element>?)
/// Perform a forward FFT on the input buffer, with implementation defined order.
///
/// This function behaves similarly to `fft` however the z-domain data is stored in most efficient ordering,
/// which is suitable for transforming back with this function, or for convolution.
@@ -61,9 +73,11 @@ public protocol FFTImplProtocol<Element>: ~Copyable {
/// - output: The output buffer.
/// - work: An optional work buffer. Must have capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// - sign: The direction of the FFT.
func fftUnordered(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, work: borrowing Buffer<Element>?, sign: FFTSign)
func forwardUnordered(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, work: borrowing Buffer<Element>?)
func zReorder(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, sign: FFTSign)
func inverseUnordered(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, work: borrowing Buffer<Element>?)
func reorder(input: borrowing Buffer<Element>, output: borrowing Buffer<Element>, sign: FFTSign)
/// Perform a convolution of two complex signals in the frequency domain.
///
@@ -74,7 +88,7 @@ public protocol FFTImplProtocol<Element>: ~Copyable {
/// - dftB: The second input buffer of frequency domain data.
/// - dftAB: The output buffer of frequency domain data.
/// - scaling: The scaling factor to apply to the result.
func zConvolveAccumulate(dftA: borrowing Buffer<Element>, dftB: borrowing Buffer<Element>, dftAB: borrowing Buffer<Element>, scaling: Element)
func convolveAccumulate(dftA: borrowing Buffer<Element>, dftB: borrowing Buffer<Element>, dftAB: borrowing Buffer<Element>, scaling: Element)
/// Perform a convolution of two complex signals in the frequency domain.
///
@@ -85,7 +99,7 @@ public protocol FFTImplProtocol<Element>: ~Copyable {
/// - dftB: The second input buffer of frequency domain data.
/// - dftAB: The output buffer of frequency domain data.
/// - scaling: The scaling factor to apply to the result.
func zConvolve(dftA: borrowing Buffer<Element>, dftB: borrowing Buffer<Element>, dftAB: borrowing Buffer<Element>, scaling: Element)
func convolve(dftA: borrowing Buffer<Element>, dftB: borrowing Buffer<Element>, dftAB: borrowing Buffer<Element>, scaling: Element)
/// Returns the minimum FFT size for the given type.
///