This commit is contained in:
2024-10-29 22:43:33 -05:00
parent 03b467bfb5
commit 84e570ddaa
4 changed files with 105 additions and 137 deletions

View File

@@ -18,11 +18,11 @@ public enum FFTError: Error {
case invalidSize
}
public protocol FFTElemental {
associatedtype ScalarType: FFTScalar
associatedtype ComplexType = Complex<ScalarType>
public protocol FFTElement {
associatedtype FFTScalarType: FFTScalar
associatedtype FFTComplexType = Complex<FFTScalarType>
static func setupPfft(_ n: Int, _ type: FFTType) throws -> OpaquePointer
static func pffftSetup(_ n: Int, _ type: FFTType) throws -> OpaquePointer
static func pffftMinFftSize(_ type: FFTType) -> Int
static func pffftIsValidSize(_ n: Int, _ type: FFTType) -> Bool
static func pffftNearestValidSize(_ n: Int, _ type: FFTType, _ higher: Bool) -> Int
@@ -80,30 +80,30 @@ extension Double: FFTScalar {
}
}
extension Complex: FFTElemental where RealType: FFTElemental & FFTScalar {
public typealias ScalarType = RealType
extension Complex: FFTElement where RealType: FFTElement & FFTScalar {
public typealias FFTScalarType = RealType
public static func setupPfft(_ n: Int, _: FFTType) throws -> OpaquePointer {
return try ScalarType.setupPfft(n, .complex)
public static func pffftSetup(_ n: Int, _: FFTType) throws -> OpaquePointer {
return try FFTScalarType.pffftSetup(n, .complex)
}
public static func pffftMinFftSize(_: FFTType) -> Int {
return ScalarType.pffftMinFftSize(.complex)
return FFTScalarType.pffftMinFftSize(.complex)
}
public static func pffftIsValidSize(_ n: Int, _: FFTType) -> Bool {
return ScalarType.pffftIsValidSize(n, .complex)
return FFTScalarType.pffftIsValidSize(n, .complex)
}
public static func pffftNearestValidSize(_ n: Int, _: FFTType, _ higher: Bool) -> Int {
return ScalarType.pffftNearestValidSize(n, .complex, higher)
return FFTScalarType.pffftNearestValidSize(n, .complex, higher)
}
}
extension Double: FFTElemental {
public typealias ScalarType = Double
extension Double: FFTElement {
public typealias FFTScalarType = Double
public static func setupPfft(_ n: Int, _ type: FFTType) throws -> OpaquePointer {
public static func pffftSetup(_ n: Int, _ type: FFTType) throws -> OpaquePointer {
guard let ptr = pffftd_new_setup(Int32(n), pffft_transform_t(type)) else { throw FFTError.invalidSize }
return ptr
}
@@ -121,10 +121,10 @@ extension Double: FFTElemental {
}
}
extension Float: FFTElemental {
public typealias ScalarType = Float
extension Float: FFTElement {
public typealias FFTScalarType = Float
public static func setupPfft(_ n: Int, _ type: FFTType) throws -> OpaquePointer {
public static func pffftSetup(_ n: Int, _ type: FFTType) throws -> OpaquePointer {
guard let ptr = pffft_new_setup(Int32(n), pffft_transform_t(type)) else { throw FFTError.invalidSize }
return ptr
}
@@ -142,9 +142,9 @@ extension Float: FFTElemental {
}
}
public final class FFT<T: FFTElemental> {
public typealias ComplexType = T.ComplexType
public typealias ScalarType = T.ScalarType
public final class FFT<T: FFTElement> {
public typealias ComplexType = T.FFTComplexType
public typealias ScalarType = T.FFTScalarType
let ptr: OpaquePointer
let n: Int
@@ -153,21 +153,34 @@ public final class FFT<T: FFTElemental> {
/// Initialize the FFT implementation with the given size and type.
/// - Parameters:
/// - n: The size of the FFT.
/// - type: The type of FFT.
/// - Throws: `FFTError.invalidSize` if the size is invalid.
public init(n: Int) throws {
ptr = try T.setupPfft(n, .real)
ptr = try T.pffftSetup(n, .real)
self.n = n
let capacity = T.self == Complex<ScalarType>.self ? 2 * n : n
let capacity = T.self == ComplexType.self ? 2 * n : n
work = n > 4096 ? Buffer<ScalarType>(capacity: capacity) : nil
}
public func makeSignalBuffer(extra _: Int = 0) -> Buffer<T> {
Buffer(capacity: n)
/// Make a buffer for the FFT (time-domain) signal.
/// - Parameters:
/// - extra: An extra number of elements to allocate.
public func makeSignalBuffer(extra: Int = 0) -> Buffer<T> {
Buffer(capacity: n + extra)
}
public func makeSpectrumBuffer(extra _: Int = 0) -> Buffer<ComplexType> {
Buffer(capacity: n)
/// Make a buffer for the FFT (frequency-domain) spectrum.
/// - Parameters:
/// - extra: An extra number of elements to allocate.
public func makeSpectrumBuffer(extra: Int = 0) -> Buffer<ComplexType> {
Buffer(capacity: T.self == ComplexType.self ? (n + extra) : n / 2 + extra)
}
/// Make a buffer for the internal layout of the FFT (frequency-domain) spectrum.
/// - Parameters:
/// - extra: An extra number of points to allocate. For complex FFTs, 2 * extra
/// additional elements will be allocated.
public func makeInternalLayoutBuffer(extra: Int = 0) -> Buffer<ScalarType> {
Buffer(capacity: (T.self == ComplexType.self ? 2 : 1) * (n + extra))
}
@inline(__always)
@@ -188,7 +201,7 @@ public final class FFT<T: FFTElemental> {
guard signal.count >= n else {
fatalError("signal buffer too small")
}
guard spectrum.count >= n else {
guard spectrum.count >= (T.self == ComplexType.self ? n : n / 2) else {
fatalError("spectrum buffer too small")
}
}
@@ -198,14 +211,14 @@ public final class FFT<T: FFTElemental> {
guard signal.count >= n else {
fatalError("signal buffer too small")
}
guard spectrum.count >= (T.self == Complex<ScalarType>.self ? 2 * n : n) else {
guard spectrum.count >= (T.self == ComplexType.self ? 2 * n : n) else {
fatalError("spectrum buffer too small")
}
}
@inline(__always)
func checkConvolveBufferCounts(a: borrowing Buffer<ScalarType>, b: borrowing Buffer<ScalarType>, ab: borrowing Buffer<ScalarType>) {
let minCount = T.self == Complex<ScalarType>.self ? 2 * n : n
let minCount = T.self == ComplexType.self ? 2 * n : n
guard a.count >= minCount else {
fatalError("a buffer too small")
@@ -218,6 +231,26 @@ public final class FFT<T: FFTElemental> {
}
}
/// Perform a forward FFT on the input buffer.
///
/// The input and output buffers may be the same.
/// The data is stores in order as expected (interleaved complex components ordered by frequency).
/// The input and output buffer must have a capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// A fatal error will occur if any buffer is too small.
///
/// For a real forward transform with real input, the output array is organized as follows:
/// index k > 2 where k is even is the real part of the k/2-th complex coefficient.
/// index k > 2 where k is odd is the imaginary part of the k/2-th complex coefficient.
/// index k = 0 is the real part of the 0 frequency (DC) coefficient.
/// index k = 1 is the real part of the Nyquist coefficient.
///
/// Transforms are not scaled. fft_backward(fft_forward(x)) == n * x.
///
/// - Parameters:
/// - input: The input buffer.
/// - output: The output buffer.
/// - work: An optional work buffer. Must have capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// - sign: The direction of the FFT.
public func forward(signal: borrowing Buffer<T>, spectrum: borrowing Buffer<ComplexType>) {
checkFftBufferCounts(signal: signal, spectrum: spectrum)
ScalarType.pffftTransformOrdered(ptr, rebind(signal), rebind(spectrum), toAddress(work), .forward)
@@ -228,6 +261,15 @@ public final class FFT<T: FFTElemental> {
ScalarType.pffftTransformOrdered(ptr, rebind(spectrum), rebind(signal), toAddress(work), .backward)
}
/// Perform a forward FFT on the input buffer, with implementation defined order.
///
/// This function behaves similarly to `fft` however the z-domain data is stored in most efficient ordering,
/// which is suitable for transforming back with this function, or for convolution.
/// - Parameters:
/// - input: The input buffer.
/// - output: The output buffer.
/// - work: An optional work buffer. Must have capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// - sign: The direction of the FFT.
public func forwardToInternalLayout(signal: borrowing Buffer<T>, spectrum: borrowing Buffer<ScalarType>) {
checkFftInternalLayoutBufferCounts(signal: signal, spectrum: spectrum)
ScalarType.pffftTransform(ptr, rebind(signal), spectrum.baseAddress, toAddress(work), .forward)
@@ -239,7 +281,7 @@ public final class FFT<T: FFTElemental> {
}
public func reorder(spectrum: borrowing Buffer<ScalarType>, output: borrowing Buffer<ComplexType>) {
guard spectrum.count >= (T.self == Complex<ScalarType>.self ? 2 * n : n) else {
guard spectrum.count >= (T.self == ComplexType.self ? 2 * n : n) else {
fatalError("signal buffer too small")
}
guard output.count >= n else {
@@ -248,24 +290,54 @@ public final class FFT<T: FFTElemental> {
ScalarType.pffftZreorder(ptr, spectrum.baseAddress, rebind(output), .forward)
}
/// Perform a convolution of two complex signals in the frequency domain.
///
/// Multiplies frequency domain components of `dftA` and `dftB` and stores the result in `dftAB`.
/// The operation performed is `dftAB = (dftA * dftB) * scaling`.
/// - Parameters:
/// - dftA: The first input buffer of frequency domain data.
/// - dftB: The second input buffer of frequency domain data.
/// - dftAB: The output buffer of frequency domain data.
/// - scaling: The scaling factor to apply to the result.
public func convolve(dftA: borrowing Buffer<ScalarType>, dftB: borrowing Buffer<ScalarType>, dftAB: borrowing Buffer<ScalarType>, scaling: ScalarType) {
checkConvolveBufferCounts(a: dftA, b: dftB, ab: dftAB)
ScalarType.pffftZconvolveNoAccu(ptr, dftA.baseAddress, dftB.baseAddress, dftAB.baseAddress, scaling)
}
/// Perform a convolution of two complex signals in the frequency domain.
///
/// Multiplies frequency domain components of `dftA` and `dftB` and accumulates the result in `dftAB`.
/// The operation performed is `dftAB += (dftA * dftB) * scaling`.
/// - Parameters:
/// - dftA: The first input buffer of frequency domain data.
/// - dftB: The second input buffer of frequency domain data.
/// - dftAB: The output buffer of frequency domain data.
/// - scaling: The scaling factor to apply to the result.
public func convolveAccumulate(dftA: borrowing Buffer<ScalarType>, dftB: borrowing Buffer<ScalarType>, dftAB: borrowing Buffer<ScalarType>, scaling: ScalarType) {
checkConvolveBufferCounts(a: dftA, b: dftB, ab: dftAB)
ScalarType.pffftZconvolveAccumulate(ptr, dftA.baseAddress, dftB.baseAddress, dftAB.baseAddress, scaling)
}
/// Returns the minimum FFT size for this type of setup.
public static func minFftSize() -> Int {
T.pffftMinFftSize(.real)
}
/// Returns whether the given size is valid for the given type.
///
/// The PFFFT library requires `n` to be factorizable to `minFftSize` with factors of 2, 3, 5.
/// - Parameters:
/// - n: The size to check.
/// - Returns: Whether the size is valid.
public static func isValidSize(_ n: Int) -> Bool {
T.pffftIsValidSize(n, .real)
}
/// Returns the nearest valid size for the given type.
/// - Parameters:
/// - n: The size to check.
/// - higher: Whether to return the next higher size if `n` is invalid.
/// - Returns: The nearest valid size.
public static func nearestValidSize(_ n: Int, higher: Bool) -> Int {
T.pffftNearestValidSize(n, .real, higher)
}

View File

@@ -1,108 +1,4 @@
internal import PFFFTLib
import ComplexModule
import RealModule
public protocol PFFFTProtocol<T> {
associatedtype T: FFTElemental
associatedtype ComplexType = T.ComplexType
associatedtype ScalarType = T.ScalarType
/// Get an FFT interface for the given size and FFT Type.
///
/// This call is backed by a global cache to avoid repeated setup costs.
/// - Parameters:
/// - n: The size of the FFT.
/// - type: The type of FFT.
/// - Returns: An FFT interface.
/// - Throws: `FFTError.invalidSize` if the size is invalid.
static func setup(for n: Int) throws -> Self
/// Initialize the FFT implementation with the given size and type.
/// - Parameters:
/// - n: The size of the FFT.
/// - type: The type of FFT.
/// - Throws: `FFTError.invalidSize` if the size is invalid.
init(n: Int) throws
/// Perform a forward FFT on the input buffer.
///
/// The input and output buffers may be the same.
/// The data is stores in order as expected (interleaved complex components ordered by frequency).
/// The input and output buffer must have a capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// A fatal error will occur if any buffer is too small.
///
/// For a real forward transform with real input, the output array is organized as follows:
/// index k > 2 where k is even is the real part of the k/2-th complex coefficient.
/// index k > 2 where k is odd is the imaginary part of the k/2-th complex coefficient.
/// index k = 0 is the real part of the 0 frequency (DC) coefficient.
/// index k = 1 is the real part of the Nyquist coefficient.
///
/// Transforms are not scaled. fft_backward(fft_forward(x)) == n * x.
///
/// - Parameters:
/// - input: The input buffer.
/// - output: The output buffer.
/// - work: An optional work buffer. Must have capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// - sign: The direction of the FFT.
func forward(input: borrowing Buffer<T>, output: borrowing Buffer<ComplexType>)
func inverse(input: borrowing Buffer<ComplexType>, output: borrowing Buffer<T>)
/// Perform a forward FFT on the input buffer, with implementation defined order.
///
/// This function behaves similarly to `fft` however the z-domain data is stored in most efficient ordering,
/// which is suitable for transforming back with this function, or for convolution.
/// - Parameters:
/// - input: The input buffer.
/// - output: The output buffer.
/// - work: An optional work buffer. Must have capacity of at least `n` for real FFTs and `2 * n` for complex FFTs.
/// - sign: The direction of the FFT.
func forwardToInternalLayout(input: borrowing Buffer<T>, output: borrowing Buffer<ScalarType>)
func inverseFromInternalLayout(input: borrowing Buffer<ScalarType>, output: borrowing Buffer<T>)
func reorder(spectrum: borrowing Buffer<ScalarType>, output: borrowing Buffer<ComplexType>)
/// Perform a convolution of two complex signals in the frequency domain.
///
/// Multiplies frequency domain components of `dftA` and `dftB` and stores the result in `dftAB`.
/// The operation performed is `dftAB = (dftA * dftB) * scaling`.
/// - Parameters:
/// - dftA: The first input buffer of frequency domain data.
/// - dftB: The second input buffer of frequency domain data.
/// - dftAB: The output buffer of frequency domain data.
/// - scaling: The scaling factor to apply to the result.
func convolve(dftA: borrowing Buffer<ScalarType>, dftB: borrowing Buffer<ScalarType>, dftAB: borrowing Buffer<ScalarType>, scaling: ScalarType)
/// Perform a convolution of two complex signals in the frequency domain.
///
/// Multiplies frequency domain components of `dftA` and `dftB` and accumulates the result in `dftAB`.
/// The operation performed is `dftAB += (dftA * dftB) * scaling`.
/// - Parameters:
/// - dftA: The first input buffer of frequency domain data.
/// - dftB: The second input buffer of frequency domain data.
/// - dftAB: The output buffer of frequency domain data.
/// - scaling: The scaling factor to apply to the result.
func convolveAccumulate(dftA: borrowing Buffer<ScalarType>, dftB: borrowing Buffer<ScalarType>, dftAB: borrowing Buffer<ScalarType>, scaling: ScalarType)
/// Returns the minimum FFT size for the given type.
///
/// - Parameter type: The type of FFT.
static func minFftSize() -> Int
/// Returns whether the given size is valid for the given type.
///
/// The PFFFT library requires `n` to be factorizable to `minFftSize` with factors of 2, 3, 5.
/// - Parameters:
/// - n: The size to check.
/// - type: The type of FFT.
static func isValidSize(_ n: Int) -> Bool
/// Returns the nearest valid size for the given type.
static func nearestValidSize(_ n: Int, higher: Bool) -> Int
}
public var simdArch: String {
String(cString: pffft_simd_arch())

View File

@@ -1,6 +1,6 @@
import Foundation
public class SetupCache<T: FFTElemental> : @unchecked Sendable {
public class SetupCache<T: FFTElement> : @unchecked Sendable {
typealias FFTType = FFT<T>
private var cache: [Int: FFT<T>] = [:]

View File

@@ -75,7 +75,7 @@ struct RRProcess: ~Copyable {
init(nrr: Int) {
self.nrr = nrr
fft = try! FFT<Double>(n: nrr)
spectrum = fft.makeSpectrumBuffer()
spectrum = fft.makeSpectrumBuffer(extra: 1)
signal = fft.makeSignalBuffer()
}