diff --git a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala b/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala
index bbd11bb9ebfbbddc140f6c5cfe7d50e999ccdd28..dcfc27f2c3e8ccb43c17d79543c685e436026a12 100644
--- a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala
+++ b/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaBlock.scala
@@ -5,20 +5,25 @@ package uk.ac.soton.ecs.can.core
 
 import chisel3._
 
-class ChaChaBlock extends MultiIOModule {
-  val muxIn = IO(Input(Bool()))
+class ChaChaBlock(val regBetweenRounds: Boolean = true) extends MultiIOModule {
+  val roundLoop = IO(Input(Bool()))
+  val initialState = IO(Input(Vec(16, UInt(32.W))))
   val in = IO(Input(Vec(16, UInt(32.W))))
   val out = IO(Output(Vec(16, UInt(32.W))))
 
-  val initialState = Reg(Vec(16, UInt(32.W)))
-  val doubleRound = Module(new ChaChaInnerBlock(regBetweenRounds = true))
-  val doubleRoundState = Reg(Vec(16, UInt(32.W)))
+  private val columnRound = Module(new ColumnRound)
+  private val diagonalRound = Module(new DiagonalRound)
+  private val betweenRounds =
+    if (regBetweenRounds) Reg(Vec(16, UInt(32.W)))
+    else Wire(Vec(16, UInt(32.W)))
+  private val afterRounds = Reg(Vec(16, UInt(32.W)))
 
-  initialState := in
-  doubleRound.in := Mux(muxIn, initialState, doubleRoundState)
-  doubleRoundState := doubleRound.out
+  private val muxRoundLoop = Mux(roundLoop, afterRounds, in)
+  private val sumInRound = in.zip(afterRounds).map { case (i, r) => i + r }
 
-  val addedState = doubleRoundState.zip(initialState).map(t => t._1 + t._2)
-
-  out := addedState
+  columnRound.in := muxRoundLoop
+  betweenRounds := columnRound.out
+  diagonalRound.in := betweenRounds
+  afterRounds := diagonalRound.out
+  out := sumInRound
 }
diff --git a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaInnerBlock.scala b/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaInnerBlock.scala
deleted file mode 100644
index 8c025db396b523160c37de4c74930794a0866a39..0000000000000000000000000000000000000000
--- a/src/main/scala/uk/ac/soton/ecs/can/core/ChaChaInnerBlock.scala
+++ /dev/null
@@ -1,25 +0,0 @@
-// SPDX-FileCopyrightText: 2021 Minyong Li <ml10g20@soton.ac.uk>
-// SPDX-License-Identifier: CERN-OHL-W-2.0
-
-package uk.ac.soton.ecs.can.core
-
-import chisel3._
-
-class ChaChaInnerBlock(regBetweenRounds: Boolean) extends MultiIOModule {
-  val in = IO(Input(Vec(16, UInt(32.W))))
-  val out = IO(Output(Vec(16, UInt(32.W))))
-
-  val betweenRounds =
-    if (regBetweenRounds)
-      Reg(Vec(16, UInt(32.W)))
-    else
-      Wire(Vec(16, UInt(32.W)))
-
-  val columnRound = Module(new ColumnRound)
-  val diagonalRound = Module(new DiagonalRound)
-
-  columnRound.in := in
-  betweenRounds := columnRound.out
-  diagonalRound.in := betweenRounds
-  out := diagonalRound.out
-}
diff --git a/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala b/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala
index cb7816765b670278f0c3e3082c4a5746aa2cf448..3438ecc1ef7bee5f0e111d1910d88abc36d03377 100644
--- a/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala
+++ b/src/test/scala/uk/ac/soton/ecs/can/core/ChaChaBlockTest.scala
@@ -64,40 +64,43 @@ class ChaChaBlockTest extends FlatSpec with ChiselScalatestTester {
   )
 
   private def doTest(c: ChaChaBlock, testVector: Seq[(UInt, UInt)]) {
-    c.in.zip(testVector).foreach { t =>
-      t._1.poke(t._2._1)
+    c.in.zip(testVector).foreach { case (blockIn, (vectorIn, _)) =>
+      blockIn.poke(vectorIn)
     }
 
-    // Shift inputs into the initial state register
-    c.clock.step()
+    // Select the input port as the input to the rounds
+    c.roundLoop.poke(false.B)
 
-    // Select the initial state register as the input to the 2-round circuit
-    c.muxIn.poke(true.B)
+    // Shift the state through the rounds
+    c.clock.step(if (c.regBetweenRounds) 2 else 1)
 
-    // Shift the 2-rounded state to the round register
-    c.clock.step(2)
-
-    // Select the round register as the input to the 2-round circuit
-    c.muxIn.poke(false.B)
+    // Select the round register as the input to the rounds
+    c.roundLoop.poke(true.B)
 
     // Depending on the ChaCha variant and the pipeline configuration, wait
     // for the correct time for the correct result. Note that one 2-round has
     // been processed in the above steps.
-    c.clock.step(19)
+    c.clock.step(if (c.regBetweenRounds) 19 else 9)
 
-    c.out.zip(testVector).foreach { t =>
-      t._1.expect(t._2._2)
+    c.out.zip(testVector).foreach { case (blockOut, (_, vectorOut)) =>
+      blockOut.expect(vectorOut)
     }
   }
 
   behavior of "The ChaCha Block Function"
 
-  it should "compute RFC8439 2.3.2 test vector correctly" in
-    test(new ChaChaBlock)(doTest(_, rfc8439232TestVector))
+  it should "compute RFC8439 2.3.2 test vector correctly" in {
+    test(new ChaChaBlock(true))(doTest(_, rfc8439232TestVector))
+    test(new ChaChaBlock(false))(doTest(_, rfc8439232TestVector))
+  }
 
-  it should "compute RFC8439 2.4.2 test vector (first block) correctly" in
-    test(new ChaChaBlock)(doTest(_, rfc8439242B1TestVector))
+  it should "compute RFC8439 2.4.2 test vector (first block) correctly" in {
+    test(new ChaChaBlock(true))(doTest(_, rfc8439242B1TestVector))
+    test(new ChaChaBlock(false))(doTest(_, rfc8439242B1TestVector))
+  }
 
-  it should "compute RFC8439 2.4.2 test vector (second block) correctly" in
-    test(new ChaChaBlock)(doTest(_, rfc8439242B2TestVector))
+  it should "compute RFC8439 2.4.2 test vector (second block) correctly" in {
+    test(new ChaChaBlock(true))(doTest(_, rfc8439242B2TestVector))
+    test(new ChaChaBlock(false))(doTest(_, rfc8439242B2TestVector))
+  }
 }