Skip to content

Commit

Permalink
Add IIR filter module
Browse files Browse the repository at this point in the history
A fixed point IIR filter module with pipelined accumulations.

Signed-off-by: Kevin Joly <[email protected]>
  • Loading branch information
Kevin Joly committed Feb 2, 2024
1 parent 0e112f1 commit 1f581c7
Show file tree
Hide file tree
Showing 5 changed files with 664 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ the community. This is the place to do it.
| spi2wb | maven | Fabien Marteau | Drive a wishbone master bus with SPI |
| dclib | internal | Guy Hutchison | Utility components for DecoupledIO interfaces |
| ecc | internal | Guy Hutchison | Hamming Error-Correcting code modules |
| iir | internal | Kevin Joly | Infinite Impulse Response filter module |

### Using ip-contributions

Expand Down
7 changes: 7 additions & 0 deletions src/main/scala/chisel/lib/iirfilter/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# IIR Filter

Simple fixed point IIR filter with pipelined computation.

Tests are run as follow:
```sbt "testOnly chisel.lib.iirfilter.SimpleIIRFilterTest -- -DwriteVcd=1"```
```sbt "testOnly chisel.lib.iirfilter.RandomSignalTest -- -DwriteVcd=1"```
208 changes: 208 additions & 0 deletions src/main/scala/chisel/lib/iirfilter/iirfilter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
/*
*
* A fixed point IIR filter module.
*
* Author: Kevin Joly ([email protected])
*
*/

package chisel.lib.iirfilter

import chisel3._
import chisel3.experimental.ChiselEnum
import chisel3.util._

/*
* IIR filter module
*
* Apply filter on input samples passed by ready/valid handshake. Numerators
* and denominators are to be set prior to push any input sample.
*
* All the computations are done in fixed point. The user should manage the
* input decimal width by himself. Output width should be sufficient in order
* to not overflow (i.e. in case of overshoot). A minimum output width of
* inputWidht+coefWidth+log2Ceil(numeratorNum + denominatorNum) + 1 is
* requested.
*
*/
class IIRFilter(
inputWidth: Int,
coefWidth: Int,
coefDecimalWidth: Int,
outputWidth: Int,
numeratorNum: Int,
denominatorNum: Int)
extends Module {
val io = IO(new Bundle {
/*
* Input samples
*/
val input = Flipped(Decoupled(SInt(inputWidth.W)))
/*
* Numerator's coefficients b[0], b[1], ...
*/
val num = Input(Vec(numeratorNum, SInt(coefWidth.W)))
/*
* The first coefficient of the denominator should be omitted and should be a[0] == 1.
* a[1], a[2], ...
*/
val den = Input(Vec(denominatorNum, SInt(coefWidth.W)))
/*
* Filtered samples. Fixed point format is:
* (outputWidth-coefDecimalWidth).coefDecimalWidth
* Thus, output should be right shifted to the right of 'coefDecimalWidth' bits.
*/
val output = Decoupled(SInt(outputWidth.W))
})

assert(coefWidth >= coefDecimalWidth)

val minOutputWidth = inputWidth + coefWidth + log2Ceil(numeratorNum + denominatorNum) + 1
assert(outputWidth >= minOutputWidth)

val coefNum = RegInit(0.U(log2Ceil(math.max(numeratorNum, denominatorNum)).W))

object IIRFilterState extends ChiselEnum {
val Idle, ComputeNum, ComputeDen, Valid, StoreLast = Value
}

val state = RegInit(IIRFilterState.Idle)

switch(state) {
is(IIRFilterState.Idle) {
when(io.input.valid) {
state := IIRFilterState.ComputeNum
}
}
is(IIRFilterState.ComputeNum) {
when(coefNum === (numeratorNum - 1).U) {
state := IIRFilterState.ComputeDen
}
}
is(IIRFilterState.ComputeDen) {
when(coefNum === (denominatorNum - 1).U) {
state := IIRFilterState.StoreLast
}
}
is(IIRFilterState.StoreLast) {
state := IIRFilterState.Valid
}
is(IIRFilterState.Valid) {
when(io.output.ready) {
state := IIRFilterState.Idle
}
}
}

when((state === IIRFilterState.Idle) && io.input.valid) {
coefNum := 1.U
}.elsewhen(state === IIRFilterState.ComputeNum) {
when(coefNum === (numeratorNum - 1).U) {
coefNum := 0.U
}.otherwise {
coefNum := coefNum + 1.U
}
}.elsewhen(state === IIRFilterState.ComputeDen) {
when(coefNum === (denominatorNum - 1).U) {
coefNum := 0.U
}.otherwise {
coefNum := coefNum + 1.U
}
}.otherwise {
coefNum := 0.U
}

val inputReg = RegInit(0.S(inputWidth.W))
val inputMem = Mem(numeratorNum - 1, SInt(inputWidth.W))
val inputMemAddr = RegInit(0.U(math.max(log2Ceil(numeratorNum - 1), 1).W))
val inputMemOut = Wire(SInt(inputWidth.W))
val inputRdWr = inputMem(inputMemAddr)

inputMemOut := DontCare

when(state === IIRFilterState.StoreLast) {
inputRdWr := inputReg
}.elsewhen((state === IIRFilterState.Idle) && io.input.valid) {
inputReg := io.input.bits // Delayed write
inputMemOut := inputRdWr
}.otherwise {
inputMemOut := inputRdWr
}

when((state === IIRFilterState.ComputeNum) && (coefNum < (numeratorNum - 1).U)) {
when(inputMemAddr === (numeratorNum - 2).U) {
inputMemAddr := 0.U
}.otherwise {
inputMemAddr := inputMemAddr + 1.U
}
}

val outputMem = Mem(denominatorNum, SInt(outputWidth.W))
val outputMemAddr = RegInit(0.U(math.max(log2Ceil(denominatorNum), 1).W))
val outputMemOut = Wire(SInt(outputWidth.W))
val outputRdWr = outputMem(outputMemAddr)

outputMemOut := DontCare

when((state === IIRFilterState.Valid) && (RegNext(state) === IIRFilterState.StoreLast)) {
outputRdWr := io.output.bits
}.otherwise {
outputMemOut := outputRdWr
}

when((state === IIRFilterState.ComputeDen) && (coefNum < (denominatorNum - 1).U)) {
when(outputMemAddr === (denominatorNum - 1).U) {
outputMemAddr := 0.U
}.otherwise {
outputMemAddr := outputMemAddr + 1.U
}
}

val inputSum = RegInit(0.S((inputWidth + coefWidth + log2Ceil(numeratorNum)).W))
val outputSum = RegInit(0.S((outputWidth + coefWidth + log2Ceil(denominatorNum)).W))

val multOut = Wire(SInt((outputWidth + coefWidth).W))
val multOutReg = RegInit(0.S((outputWidth + coefWidth).W))
val multIn = Wire(SInt(outputWidth.W))
val multCoef = Wire(SInt(coefWidth.W))

when((state === IIRFilterState.Idle) && io.input.valid) {
multOutReg := multOut
outputSum := 0.S
inputSum := 0.S
}.elsewhen(state === IIRFilterState.ComputeNum) {
multOutReg := multOut
inputSum := inputSum +& multOutReg
}.elsewhen(state === IIRFilterState.ComputeDen) {
multOutReg := multOut

when(coefNum === 0.U) {
// Store numerator's last value
inputSum := inputSum +& multOutReg
}.otherwise {
outputSum := outputSum +& multOutReg
}
}.elsewhen(state === IIRFilterState.StoreLast) {
outputSum := outputSum +& multOutReg
}

when(state === IIRFilterState.ComputeNum) {
multIn := inputMemOut
}.elsewhen(state === IIRFilterState.ComputeDen) {
multIn := outputMemOut
}.otherwise {
multIn := io.input.bits
}

when(state === IIRFilterState.ComputeDen) {
multCoef := io.den(coefNum)
}.otherwise {
multCoef := io.num(coefNum)
}

multOut := multIn * multCoef

io.input.ready := state === IIRFilterState.Idle
io.output.valid := state === IIRFilterState.Valid
io.output.bits := inputSum -& (outputSum >> coefDecimalWidth)
}
151 changes: 151 additions & 0 deletions src/test/scala/chisel/lib/iirfilter/RandomSignalTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Filter a random signal using IIRFilter module and compare with the expected output.
*
* See README.md for license details.
*/

package chisel.lib.iirfilter

import chisel3._
import chisel3.experimental.VecLiterals._
import chisel3.util.log2Ceil
import chiseltest._

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.util.Random

trait IIRFilterBehavior {

this: AnyFlatSpec with ChiselScalatestTester =>

def testFilter(
inputWidth: Int,
inputDecimalWidth: Int,
coefWidth: Int,
coefDecimalWidth: Int,
outputWidth: Int,
numerators: Seq[Int],
denominators: Seq[Int],
inputData: Seq[Int],
expectedOutput: Seq[Double],
precision: Double
): Unit = {

it should "work" in {
test(
new IIRFilter(
inputWidth = inputWidth,
coefWidth = coefWidth,
coefDecimalWidth = coefDecimalWidth,
outputWidth = outputWidth,
numeratorNum = numerators.length,
denominatorNum = (denominators.length - 1)
)
) { dut =>
// Set numerators and denominators
dut.io.num.poke(Vec.Lit(numerators.map(_.S(coefWidth.W)): _*))
dut.io.den.poke(Vec.Lit(denominators.drop(1).map(_.S(coefWidth.W)): _*))

dut.io.output.ready.poke(true.B)

for ((d, e) <- (inputData.zip(expectedOutput))) {

dut.io.input.ready.expect(true.B)

// Push input sample
dut.io.input.bits.poke(d.S(inputWidth.W))
dut.io.input.valid.poke(true.B)

dut.clock.step(1)

dut.io.input.valid.poke(false.B)

for (i <- 0 until (numerators.length + denominators.length - 1)) {
dut.io.output.valid.expect(false.B)
dut.io.input.ready.expect(false.B)
dut.clock.step(1)
}

// Check output
val outputDecimalWidth = coefDecimalWidth + inputDecimalWidth
val output = dut.io.output.bits.peek().litValue.toFloat / math.pow(2, outputDecimalWidth).toFloat
val upperBound = e + precision
val lowerBound = e - precision

assert(output < upperBound)
assert(output > lowerBound)

dut.io.output.valid.expect(true.B)

dut.clock.step(1)
}
}
}
}
}

class RandomSignalTest extends AnyFlatSpec with IIRFilterBehavior with ChiselScalatestTester with Matchers {

def computeExpectedOutput(num: Seq[Double], den: Seq[Double], inputData: Seq[Double]): Seq[Double] = {
var outputMem = Seq.fill(den.length - 1)(0.0)
return for (i <- 0 until inputData.length) yield {
val outputSum = (for ((d, y) <- (den.drop(1).zip(outputMem))) yield {
d * y
}).reduce(_ + _)

val inputSum = (for (j <- i until math.max(i - num.length, -1) by -1) yield {
inputData(j) * num(i - j)
}).reduce(_ + _)

outputMem = outputMem.dropRight(1)
outputMem = (inputSum - outputSum) +: outputMem

outputMem(0)
}
}

behavior.of("IIRFilter")

Random.setSeed(53297103)

// Stable Butterworth high-pass filter
val num = Seq(0.89194287, -2.6758286, 2.6758286, -0.89194287)
val den = Seq(1.0, -2.77154144, 2.56843944, -0.79556205)

// Setup data width
val inputWidth = 16
val inputDecimalWidth = 12

val coefWidth = 32
val coefDecimalWidth = 28

val outputWidth = inputWidth + coefWidth + log2Ceil(num.length + den.length) + 1

// Generate random input data [-1., 1.]
val inputData = Seq.fill(100)(-1.0 + Random.nextDouble * 2.0)

// Compute expected outputs
val expectedOutput = computeExpectedOutput(num, den, inputData)

// Floating point to fixed point data
val numInt = for (n <- num) yield { (n * math.pow(2, coefDecimalWidth)).toInt }
val denInt = for (d <- den) yield { (d * math.pow(2, coefDecimalWidth)).toInt }
val inputDataInt = for (x <- inputData) yield (x * math.pow(2, inputDecimalWidth)).toInt

(it should behave).like(
testFilter(
inputWidth,
inputDecimalWidth,
coefWidth,
coefDecimalWidth,
outputWidth,
numInt,
denInt,
inputDataInt,
expectedOutput,
0.001
)
)
}
Loading

0 comments on commit 1f581c7

Please sign in to comment.