From 1a8b14a71c6ce20be9abb1704b7239ba47105fb8 Mon Sep 17 00:00:00 2001 From: Kevin Joly Date: Wed, 13 Dec 2023 14:17:56 +0100 Subject: [PATCH] Add IIR filter module A fixed point IIR filter module with pipelined accumulations. Signed-off-by: Kevin Joly --- README.md | 1 + src/main/scala/chisel/lib/iirfilter/README.md | 7 + .../chisel/lib/iirfilter/iirfilter.scala | 208 ++++++++++++ .../lib/iirfilter/RandomSignalTest.scala | 151 +++++++++ .../lib/iirfilter/SimpleIIRFilterTest.scala | 297 ++++++++++++++++++ 5 files changed, 664 insertions(+) create mode 100644 src/main/scala/chisel/lib/iirfilter/README.md create mode 100644 src/main/scala/chisel/lib/iirfilter/iirfilter.scala create mode 100644 src/test/scala/chisel/lib/iirfilter/RandomSignalTest.scala create mode 100644 src/test/scala/chisel/lib/iirfilter/SimpleIIRFilterTest.scala diff --git a/README.md b/README.md index 2edb157..b976f82 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ the community. This is the place to do it. | spi2wb | maven | Fabien Marteau | Drive a wishbone master bus with SPI | | dclib | internal | Guy Hutchison | Utility components for DecoupledIO interfaces | | ecc | internal | Guy Hutchison | Hamming Error-Correcting code modules | +| iir | internal | Kevin Joly | Infinite Impulse Response filter module | ### Using ip-contributions diff --git a/src/main/scala/chisel/lib/iirfilter/README.md b/src/main/scala/chisel/lib/iirfilter/README.md new file mode 100644 index 0000000..f0cb7d7 --- /dev/null +++ b/src/main/scala/chisel/lib/iirfilter/README.md @@ -0,0 +1,7 @@ +# IIR Filter + +Simple fixed point IIR filter with pipelined computation. + +Tests are run as follow: +```sbt "testOnly chisel.lib.iirfilter.SimpleIIRFilterTest -- -DwriteVcd=1"``` +```sbt "testOnly chisel.lib.iirfilter.RandomSignalTest -- -DwriteVcd=1"``` diff --git a/src/main/scala/chisel/lib/iirfilter/iirfilter.scala b/src/main/scala/chisel/lib/iirfilter/iirfilter.scala new file mode 100644 index 0000000..e3d6c93 --- /dev/null +++ b/src/main/scala/chisel/lib/iirfilter/iirfilter.scala @@ -0,0 +1,208 @@ +/* + * + * A fixed point IIR filter module. + * + * Author: Kevin Joly (kevin.joly@armadeus.com) + * + */ + +package chisel.lib.iirfilter + +import chisel3._ +import chisel3.experimental.ChiselEnum +import chisel3.util._ + +/* + * IIR filter module + * + * Apply filter on input samples passed by ready/valid handshake. Numerators + * and denominators are to be set prior to push any input sample. + * + * All the computations are done in fixed point. The user should manage the + * input decimal width by himself. Output width should be sufficient in order + * to not overflow (i.e. in case of overshoot). A minimum output width of + * inputWidht+coefWidth+log2Ceil(numeratorNum + denominatorNum) + 1 is + * requested. + * + */ +class IIRFilter( + inputWidth: Int, + coefWidth: Int, + coefDecimalWidth: Int, + outputWidth: Int, + numeratorNum: Int, + denominatorNum: Int) + extends Module { + val io = IO(new Bundle { + /* + * Input samples + */ + val input = Flipped(Decoupled(SInt(inputWidth.W))) + /* + * Numerator's coefficients b[0], b[1], ... + */ + val num = Input(Vec(numeratorNum, SInt(coefWidth.W))) + /* + * The first coefficient of the denominator should be omitted and should be a[0] == 1. + * a[1], a[2], ... + */ + val den = Input(Vec(denominatorNum, SInt(coefWidth.W))) + /* + * Filtered samples. Fixed point format is: + * (outputWidth-coefDecimalWidth).coefDecimalWidth + * Thus, output should be right shifted to the right of 'coefDecimalWidth' bits. + */ + val output = Decoupled(SInt(outputWidth.W)) + }) + + assert(coefWidth >= coefDecimalWidth) + + val minOutputWidth = inputWidth + coefWidth + log2Ceil(numeratorNum + denominatorNum) + 1 + assert(outputWidth >= minOutputWidth) + + val coefNum = RegInit(0.U(log2Ceil(math.max(numeratorNum, denominatorNum)).W)) + + object IIRFilterState extends ChiselEnum { + val Idle, ComputeNum, ComputeDen, Valid, StoreLast = Value + } + + val state = RegInit(IIRFilterState.Idle) + + switch(state) { + is(IIRFilterState.Idle) { + when(io.input.valid) { + state := IIRFilterState.ComputeNum + } + } + is(IIRFilterState.ComputeNum) { + when(coefNum === (numeratorNum - 1).U) { + state := IIRFilterState.ComputeDen + } + } + is(IIRFilterState.ComputeDen) { + when(coefNum === (denominatorNum - 1).U) { + state := IIRFilterState.StoreLast + } + } + is(IIRFilterState.StoreLast) { + state := IIRFilterState.Valid + } + is(IIRFilterState.Valid) { + when(io.output.ready) { + state := IIRFilterState.Idle + } + } + } + + when((state === IIRFilterState.Idle) && io.input.valid) { + coefNum := 1.U + }.elsewhen(state === IIRFilterState.ComputeNum) { + when(coefNum === (numeratorNum - 1).U) { + coefNum := 0.U + }.otherwise { + coefNum := coefNum + 1.U + } + }.elsewhen(state === IIRFilterState.ComputeDen) { + when(coefNum === (denominatorNum - 1).U) { + coefNum := 0.U + }.otherwise { + coefNum := coefNum + 1.U + } + }.otherwise { + coefNum := 0.U + } + + val inputReg = RegInit(0.S(inputWidth.W)) + val inputMem = Mem(numeratorNum - 1, SInt(inputWidth.W)) + val inputMemAddr = RegInit(0.U(math.max(log2Ceil(numeratorNum - 1), 1).W)) + val inputMemOut = Wire(SInt(inputWidth.W)) + val inputRdWr = inputMem(inputMemAddr) + + inputMemOut := DontCare + + when(state === IIRFilterState.StoreLast) { + inputRdWr := inputReg + }.elsewhen((state === IIRFilterState.Idle) && io.input.valid) { + inputReg := io.input.bits // Delayed write + inputMemOut := inputRdWr + }.otherwise { + inputMemOut := inputRdWr + } + + when((state === IIRFilterState.ComputeNum) && (coefNum < (numeratorNum - 1).U)) { + when(inputMemAddr === (numeratorNum - 2).U) { + inputMemAddr := 0.U + }.otherwise { + inputMemAddr := inputMemAddr + 1.U + } + } + + val outputMem = Mem(denominatorNum, SInt(outputWidth.W)) + val outputMemAddr = RegInit(0.U(math.max(log2Ceil(denominatorNum), 1).W)) + val outputMemOut = Wire(SInt(outputWidth.W)) + val outputRdWr = outputMem(outputMemAddr) + + outputMemOut := DontCare + + when((state === IIRFilterState.Valid) && (RegNext(state) === IIRFilterState.StoreLast)) { + outputRdWr := io.output.bits + }.otherwise { + outputMemOut := outputRdWr + } + + when((state === IIRFilterState.ComputeDen) && (coefNum < (denominatorNum - 1).U)) { + when(outputMemAddr === (denominatorNum - 1).U) { + outputMemAddr := 0.U + }.otherwise { + outputMemAddr := outputMemAddr + 1.U + } + } + + val inputSum = RegInit(0.S((inputWidth + coefWidth + log2Ceil(numeratorNum)).W)) + val outputSum = RegInit(0.S((outputWidth + coefWidth + log2Ceil(denominatorNum)).W)) + + val multOut = Wire(SInt((outputWidth + coefWidth).W)) + val multOutReg = RegInit(0.S((outputWidth + coefWidth).W)) + val multIn = Wire(SInt(outputWidth.W)) + val multCoef = Wire(SInt(coefWidth.W)) + + when((state === IIRFilterState.Idle) && io.input.valid) { + multOutReg := multOut + outputSum := 0.S + inputSum := 0.S + }.elsewhen(state === IIRFilterState.ComputeNum) { + multOutReg := multOut + inputSum := inputSum +& multOutReg + }.elsewhen(state === IIRFilterState.ComputeDen) { + multOutReg := multOut + + when (coefNum === 0.U) { + // Store numerator's last value + inputSum := inputSum +& multOutReg + }.otherwise { + outputSum := outputSum +& multOutReg + } + }.elsewhen(state === IIRFilterState.StoreLast) { + outputSum := outputSum +& multOutReg + } + + when(state === IIRFilterState.ComputeNum) { + multIn := inputMemOut + }.elsewhen(state === IIRFilterState.ComputeDen) { + multIn := outputMemOut + }.otherwise { + multIn := io.input.bits + } + + when(state === IIRFilterState.ComputeDen) { + multCoef := io.den(coefNum) + }.otherwise { + multCoef := io.num(coefNum) + } + + multOut := multIn * multCoef + + io.input.ready := state === IIRFilterState.Idle + io.output.valid := state === IIRFilterState.Valid + io.output.bits := inputSum -& (outputSum >> coefDecimalWidth) +} diff --git a/src/test/scala/chisel/lib/iirfilter/RandomSignalTest.scala b/src/test/scala/chisel/lib/iirfilter/RandomSignalTest.scala new file mode 100644 index 0000000..f338f62 --- /dev/null +++ b/src/test/scala/chisel/lib/iirfilter/RandomSignalTest.scala @@ -0,0 +1,151 @@ +/* + * Filter a random signal using IIRFilter module and compare with the expected output. + * + * See README.md for license details. + */ + +package chisel.lib.iirfilter + +import chisel3._ +import chisel3.experimental.VecLiterals._ +import chisel3.util.log2Ceil +import chiseltest._ + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import scala.util.Random + +trait IIRFilterBehavior { + + this: AnyFlatSpec with ChiselScalatestTester => + + def testFilter( + inputWidth: Int, + inputDecimalWidth: Int, + coefWidth: Int, + coefDecimalWidth: Int, + outputWidth: Int, + numerators: Seq[Int], + denominators: Seq[Int], + inputData: Seq[Int], + expectedOutput: Seq[Double], + precision: Double + ): Unit = { + + it should "work" in { + test( + new IIRFilter( + inputWidth = inputWidth, + coefWidth = coefWidth, + coefDecimalWidth = coefDecimalWidth, + outputWidth = outputWidth, + numeratorNum = numerators.length, + denominatorNum = (denominators.length - 1) + ) + ) { dut => + // Set numerators and denominators + dut.io.num.poke(Vec.Lit(numerators.map(_.S(coefWidth.W)): _*)) + dut.io.den.poke(Vec.Lit(denominators.drop(1).map(_.S(coefWidth.W)): _*)) + + dut.io.output.ready.poke(true.B) + + for ((d, e) <- (inputData.zip(expectedOutput))) { + + dut.io.input.ready.expect(true.B) + + // Push input sample + dut.io.input.bits.poke(d.S(inputWidth.W)) + dut.io.input.valid.poke(true.B) + + dut.clock.step(1) + + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (numerators.length + denominators.length - 1)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + // Check output + val outputDecimalWidth = coefDecimalWidth + inputDecimalWidth + val output = dut.io.output.bits.peek().litValue.toFloat / math.pow(2, outputDecimalWidth).toFloat + val upperBound = e + precision + val lowerBound = e - precision + + assert(output < upperBound) + assert(output > lowerBound) + + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + } + } + } + } +} + +class RandomSignalTest extends AnyFlatSpec with IIRFilterBehavior with ChiselScalatestTester with Matchers { + + def computeExpectedOutput(num: Seq[Double], den: Seq[Double], inputData: Seq[Double]): Seq[Double] = { + var outputMem = Seq.fill(den.length - 1)(0.0) + return for (i <- 0 until inputData.length) yield { + val outputSum = (for ((d, y) <- (den.drop(1).zip(outputMem))) yield { + d * y + }).reduce(_ + _) + + val inputSum = (for (j <- i until math.max(i - num.length, -1) by -1) yield { + inputData(j) * num(i - j) + }).reduce(_ + _) + + outputMem = outputMem.dropRight(1) + outputMem = (inputSum - outputSum) +: outputMem + + outputMem(0) + } + } + + behavior.of("IIRFilter") + + Random.setSeed(53297103) + + // Stable Butterworth high-pass filter + val num = Seq(0.89194287, -2.6758286, 2.6758286, -0.89194287) + val den = Seq(1.0, -2.77154144, 2.56843944, -0.79556205) + + // Setup data width + val inputWidth = 16 + val inputDecimalWidth = 12 + + val coefWidth = 32 + val coefDecimalWidth = 28 + + val outputWidth = inputWidth + coefWidth + log2Ceil(num.length + den.length) + 1 + + // Generate random input data [-1., 1.] + val inputData = Seq.fill(100)(-1.0 + Random.nextDouble * 2.0) + + // Compute expected outputs + val expectedOutput = computeExpectedOutput(num, den, inputData) + + // Floating point to fixed point data + val numInt = for (n <- num) yield { (n * math.pow(2, coefDecimalWidth)).toInt } + val denInt = for (d <- den) yield { (d * math.pow(2, coefDecimalWidth)).toInt } + val inputDataInt = for (x <- inputData) yield (x * math.pow(2, inputDecimalWidth)).toInt + + (it should behave).like( + testFilter( + inputWidth, + inputDecimalWidth, + coefWidth, + coefDecimalWidth, + outputWidth, + numInt, + denInt, + inputDataInt, + expectedOutput, + 0.001 + ) + ) +} diff --git a/src/test/scala/chisel/lib/iirfilter/SimpleIIRFilterTest.scala b/src/test/scala/chisel/lib/iirfilter/SimpleIIRFilterTest.scala new file mode 100644 index 0000000..759279c --- /dev/null +++ b/src/test/scala/chisel/lib/iirfilter/SimpleIIRFilterTest.scala @@ -0,0 +1,297 @@ +/* + * A very simple test collection for IIRFilter module. + * + * See README.md for license details. + */ + +package chisel.lib.iirfilter + +import chisel3._ +import chisel3.experimental.VecLiterals._ +import chisel3.util.log2Ceil +import chiseltest._ + +import org.scalatest.flatspec.AnyFlatSpec + +class IIRFilterNumeratorTest extends AnyFlatSpec with ChiselScalatestTester { + "IIRFilter numerator" should "work" in { + + val inputWidth = 4 + val coefWidth = 3 + val coefDecimalWidth = 0 + val num = Seq(2, 1, 0, 3) + val den = Seq(0, 0) + val outputWidth = inputWidth + coefWidth + log2Ceil(num.length + den.length) + 1 + + test( + new IIRFilter( + inputWidth = inputWidth, + coefWidth = coefWidth, + coefDecimalWidth = coefDecimalWidth, + outputWidth = outputWidth, + numeratorNum = num.length, + denominatorNum = den.length + ) + ) { dut => + dut.io.num.poke(Vec.Lit(num.map(_.S(coefWidth.W)): _*)) + dut.io.den.poke(Vec.Lit(den.map(_.S(coefWidth.W)): _*)) + + dut.io.output.ready.poke(true.B) + + // Sample 1: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + dut.io.input.ready.expect(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(2.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 2: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(3.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 3: Write 0. on input port + dut.io.input.bits.poke(0.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(1.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 4: Write 0. on input port + dut.io.input.bits.poke(0.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(3.S) + dut.io.output.valid.expect(true.B) + } + } +} + +class IIRFilterDenominatorTest extends AnyFlatSpec with ChiselScalatestTester { + "IIRFilter denominator" should "work" in { + + val inputWidth = 4 + val coefWidth = 3 + val coefDecimalWidth = 0 + val num = Seq(1, 0) + val den = Seq(2, 3, 1) + val outputWidth = inputWidth + coefWidth + log2Ceil(num.length + den.length) + 1 + + test( + new IIRFilter( + inputWidth = inputWidth, + coefWidth = coefWidth, + coefDecimalWidth = coefDecimalWidth, + outputWidth = outputWidth, + numeratorNum = num.length, + denominatorNum = den.length + ) + ) { dut => + dut.io.num.poke(Vec.Lit(num.map(_.S(coefWidth.W)): _*)) + dut.io.den.poke(Vec.Lit(den.map(_.S(coefWidth.W)): _*)) + + dut.io.output.ready.poke(true.B) + + // Sample 1: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + dut.io.input.ready.expect(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(1.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 2: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(-1.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 3: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(0.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 4: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(3.S) + dut.io.output.valid.expect(true.B) + + dut.clock.step(1) + + // Sample 5: Write 0. on input port + dut.io.input.bits.poke(0.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.bits.expect(-5.S) + dut.io.output.valid.expect(true.B) + } + } +} + +class IIRFilterReadyTest extends AnyFlatSpec with ChiselScalatestTester { + "IIRFilter" should "work" in { + + val inputWidth = 4 + val coefWidth = 3 + val coefDecimalWidth = 0 + val num = Seq(1, 2, 0) + val den = Seq(0, 0) + val outputWidth = inputWidth + coefWidth + log2Ceil(num.length + den.length) + 1 + + test( + new IIRFilter( + inputWidth = inputWidth, + coefWidth = coefWidth, + coefDecimalWidth = coefDecimalWidth, + outputWidth = outputWidth, + numeratorNum = num.length, + denominatorNum = den.length + ) + ) { dut => + dut.io.num.poke(Vec.Lit(num.map(_.S(coefWidth.W)): _*)) + dut.io.den.poke(Vec.Lit(den.map(_.S(coefWidth.W)): _*)) + + dut.io.output.ready.poke(false.B) + + // Sample 1: Write 1. on input port + dut.io.input.bits.poke(1.S) + dut.io.input.valid.poke(true.B) + dut.io.input.ready.expect(true.B) + dut.io.output.valid.expect(false.B) + dut.clock.step(1) + dut.io.input.valid.poke(false.B) + dut.io.input.ready.expect(false.B) + + for (i <- 0 until (num.length + den.length)) { + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + val extraClockCycles = 10 + for (i <- 0 until extraClockCycles) { + dut.io.output.valid.expect(true.B) + dut.io.input.ready.expect(false.B) + dut.clock.step(1) + } + + dut.io.output.ready.poke(true.B) + + dut.clock.step(1) + + dut.io.output.bits.expect(1.S) + dut.io.output.valid.expect(false.B) + dut.io.input.ready.expect(true.B) + } + } +}