diff --git a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala b/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala index bbd11bb9ebfbbddc140f6c5cfe7d50e999ccdd28..dcfc27f2c3e8ccb43c17d79543c685e436026a12 100644 --- a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala +++ b/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala @@ -5,20 +5,25 @@ package uk.ac.soton.ecs.can.core import chisel3._ -class ChaChaBlock extends MultiIOModule { - val muxIn = IO(Input(Bool())) +class ChaChaBlock(val regBetweenRounds: Boolean = true) extends MultiIOModule { + val roundLoop = IO(Input(Bool())) + val initialState = IO(Input(Vec(16, UInt(32.W)))) val in = IO(Input(Vec(16, UInt(32.W)))) val out = IO(Output(Vec(16, UInt(32.W)))) - val initialState = Reg(Vec(16, UInt(32.W))) - val doubleRound = Module(new ChaChaInnerBlock(regBetweenRounds = true)) - val doubleRoundState = Reg(Vec(16, UInt(32.W))) + private val columnRound = Module(new ColumnRound) + private val diagonalRound = Module(new DiagonalRound) + private val betweenRounds = + if (regBetweenRounds) Reg(Vec(16, UInt(32.W))) + else Wire(Vec(16, UInt(32.W))) + private val afterRounds = Reg(Vec(16, UInt(32.W))) - initialState := in - doubleRound.in := Mux(muxIn, initialState, doubleRoundState) - doubleRoundState := doubleRound.out + private val muxRoundLoop = Mux(roundLoop, afterRounds, in) + private val sumInRound = in.zip(afterRounds).map { case (i, r) => i + r } - val addedState = doubleRoundState.zip(initialState).map(t => t._1 + t._2) - - out := addedState + columnRound.in := muxRoundLoop + betweenRounds := columnRound.out + diagonalRound.in := betweenRounds + afterRounds := diagonalRound.out + out := sumInRound } diff --git a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaInnerBlock.scala b/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaInnerBlock.scala deleted file mode 100644 index 8c025db396b523160c37de4c74930794a0866a39..0000000000000000000000000000000000000000 --- a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaInnerBlock.scala +++ /dev/null @@ -1,25 +0,0 @@ -// SPDX-FileCopyrightText: 2021 Minyong Li <ml10g20@soton.ac.uk> -// SPDX-License-Identifier: CERN-OHL-W-2.0 - -package uk.ac.soton.ecs.can.core - -import chisel3._ - -class ChaChaInnerBlock(regBetweenRounds: Boolean) extends MultiIOModule { - val in = IO(Input(Vec(16, UInt(32.W)))) - val out = IO(Output(Vec(16, UInt(32.W)))) - - val betweenRounds = - if (regBetweenRounds) - Reg(Vec(16, UInt(32.W))) - else - Wire(Vec(16, UInt(32.W))) - - val columnRound = Module(new ColumnRound) - val diagonalRound = Module(new DiagonalRound) - - columnRound.in := in - betweenRounds := columnRound.out - diagonalRound.in := betweenRounds - out := diagonalRound.out -} diff --git a/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala b/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala index cb7816765b670278f0c3e3082c4a5746aa2cf448..3438ecc1ef7bee5f0e111d1910d88abc36d03377 100644 --- a/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala +++ b/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala @@ -64,40 +64,43 @@ class ChaChaBlockTest extends FlatSpec with ChiselScalatestTester { ) private def doTest(c: ChaChaBlock, testVector: Seq[(UInt, UInt)]) { - c.in.zip(testVector).foreach { t => - t._1.poke(t._2._1) + c.in.zip(testVector).foreach { case (blockIn, (vectorIn, _)) => + blockIn.poke(vectorIn) } - // Shift inputs into the initial state register - c.clock.step() + // Select the input port as the input to the rounds + c.roundLoop.poke(false.B) - // Select the initial state register as the input to the 2-round circuit - c.muxIn.poke(true.B) + // Shift the state through the rounds + c.clock.step(if (c.regBetweenRounds) 2 else 1) - // Shift the 2-rounded state to the round register - c.clock.step(2) - - // Select the round register as the input to the 2-round circuit - c.muxIn.poke(false.B) + // Select the round register as the input to the rounds + c.roundLoop.poke(true.B) // Depending on the ChaCha variant and the pipeline configuration, wait // for the correct time for the correct result. Note that one 2-round has // been processed in the above steps. - c.clock.step(19) + c.clock.step(if (c.regBetweenRounds) 19 else 9) - c.out.zip(testVector).foreach { t => - t._1.expect(t._2._2) + c.out.zip(testVector).foreach { case (blockOut, (_, vectorOut)) => + blockOut.expect(vectorOut) } } behavior of "The ChaCha Block Function" - it should "compute RFC8439 2.3.2 test vector correctly" in - test(new ChaChaBlock)(doTest(_, rfc8439232TestVector)) + it should "compute RFC8439 2.3.2 test vector correctly" in { + test(new ChaChaBlock(true))(doTest(_, rfc8439232TestVector)) + test(new ChaChaBlock(false))(doTest(_, rfc8439232TestVector)) + } - it should "compute RFC8439 2.4.2 test vector (first block) correctly" in - test(new ChaChaBlock)(doTest(_, rfc8439242B1TestVector)) + it should "compute RFC8439 2.4.2 test vector (first block) correctly" in { + test(new ChaChaBlock(true))(doTest(_, rfc8439242B1TestVector)) + test(new ChaChaBlock(false))(doTest(_, rfc8439242B1TestVector)) + } - it should "compute RFC8439 2.4.2 test vector (second block) correctly" in - test(new ChaChaBlock)(doTest(_, rfc8439242B2TestVector)) + it should "compute RFC8439 2.4.2 test vector (second block) correctly" in { + test(new ChaChaBlock(true))(doTest(_, rfc8439242B2TestVector)) + test(new ChaChaBlock(false))(doTest(_, rfc8439242B2TestVector)) + } }