From d0ba0c470e4525365900be5161c84af27405cbb7 Mon Sep 17 00:00:00 2001
From: Fanis Baikas <fan.baikas@gmail.com>
Date: Tue, 10 Sep 2024 16:02:48 +0100
Subject: [PATCH] Added inference statistics at the end of classification

---
 .../C_Cxx/fast_knn_demo/fast_knn_demo.cpp     | 226 ++++++++++++++++--
 1 file changed, 201 insertions(+), 25 deletions(-)

diff --git a/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp b/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp
index 87a9d94..b3140f3 100644
--- a/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp
+++ b/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp
@@ -19,6 +19,7 @@
 
 #include "pico/multicore.h"
 #include <stdio.h>
+#include <stdarg.h>
 #include <cmath>
 
 extern "C"{
@@ -56,8 +57,8 @@ std::string console[CONSOLE_FIFO_DEPTH];
 #define DATA_REQ_GPIO_PIN 2
 #define DATA_SENT_GPIO_PIN 3
 
-#define NUM_UNLABELLED_IMGS 36
-#define NUM_LABELLED_IMGS 150
+#define NUM_UNLABELLED_IMGS 27
+#define NUM_LABELLED_IMGS 160
 
 #define NUM_OF_CLASSES 10
 #define NUM_LABELLED_IMAGES_PER_CLASS  NUM_LABELLED_IMGS/NUM_OF_CLASSES
@@ -104,6 +105,25 @@ std::string labelled_image_dirs[10] = {
 // Path to unlabelled images
 std::string unlabelled_image_dir = "0:/fmnist_data/test";
 
+// Counter for the number of correct predictions used for top-1 acc computation
+uint16_t correct_pred_count = 0;
+
+// Array buffer to hold time measurements for statistics calcumation
+int64_t compute_time_arr[NUM_UNLABELLED_IMGS];
+int64_t data_transfer_time_arr[NUM_UNLABELLED_IMGS];
+int64_t results_read_time_arr[NUM_UNLABELLED_IMGS];
+int64_t register_write_time_arr[NUM_UNLABELLED_IMGS];
+int64_t sorting_time_arr[NUM_UNLABELLED_IMGS];
+int64_t total_time_arr[NUM_UNLABELLED_IMGS];
+
+float acc = 0;
+int64_t compute_time_mean = 0;
+int64_t data_transfer_time_mean = 0;
+int64_t results_read_time_mean = 0;
+int64_t register_write_time_mean = 0;
+int64_t sorting_time_mean = 0;
+int64_t total_time_mean = 0;
+
 std::string classes[] = {
     "T-Shirt",
     "Trousers",
@@ -138,6 +158,7 @@ screen_state last_screen_state=POWER;
 bool DEMO_RUN = false;
 bool KNN_done = false;
 bool KNN_start = false;
+bool DEMO_done = false;
 bool nanosoc_console = false;
 bool nanosoc_console_draw = false;
 i2c_inst_t *i2c = i2c1;
@@ -161,6 +182,12 @@ static sd_card_t sd_cards[] = {{
 int img_x_grid[] = {3, 83, 163};
 int img_y_grid[] = {30, 118, 206};
 
+void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data, int64_t *compute_time, int64_t *data_transfer_time);
+int compare_indices(void *arr, const void *a, const void *b);
+void sort_indices(uint32_t *array, int *indices, size_t size);
+uint8_t predict_label(int *sorting_indices, int k);
+uint8_t find_max_index(uint8_t *array, size_t size);
+
 // Button IRQ (for controlling screen state)
 void BUTTON_irq(uint gpio, uint32_t events) {
     if(gpio==BUTTON_A){
@@ -351,7 +378,7 @@ void DEMO_ROUTINE(){
             result_location.y = img_y_grid[ytext]+60;
 
             // KNN finished, plot the result    
-            if (predicted_label == test_imgs_labels[current_unlab_img-1])
+            if (predicted_label == test_imgs_labels[current_unlab_img])
                 graphics.set_pen(GREEN);
             else
                 graphics.set_pen(RED);
@@ -366,10 +393,121 @@ void DEMO_ROUTINE(){
                 }
             }
             printf("printed result\n");
+
+            current_unlab_img++;
+
+            if(current_unlab_img >= NUM_UNLABELLED_IMGS) {
+                current_unlab_img = 0;
+
+                acc = (float) correct_pred_count / NUM_UNLABELLED_IMGS;
+    
+                for (int k = 0; k < NUM_UNLABELLED_IMGS; k++) {
+                    compute_time_mean += compute_time_arr[k];
+                    data_transfer_time_mean += data_transfer_time_arr[k];
+                    results_read_time_mean += results_read_time_arr[k];
+                    register_write_time_mean += register_write_time_arr[k];
+                    sorting_time_mean += sorting_time_arr[k];
+                    total_time_mean += total_time_arr[k];
+                }
+
+                compute_time_mean = compute_time_mean / NUM_UNLABELLED_IMGS;
+                data_transfer_time_mean = data_transfer_time_mean / NUM_UNLABELLED_IMGS;
+                results_read_time_mean = results_read_time_mean / NUM_UNLABELLED_IMGS;
+                register_write_time_mean = register_write_time_mean / NUM_UNLABELLED_IMGS;
+                sorting_time_mean = sorting_time_mean / NUM_UNLABELLED_IMGS;
+                total_time_mean = total_time_mean / NUM_UNLABELLED_IMGS;
+
+                DEMO_done = true;
+                DEMO_RUN = false;
+            }
+
             KNN_done=false;
         }
+        if (DEMO_done){
+            // Clear text area
+            graphics.set_pen(BG);
+            Rect blank(0, 18, 240, 284);
+            graphics.rectangle(blank);
+
+            // Edit footer text
+            graphics.set_pen(WHITE);
+            Rect blank_footer(0, 302, 240, 18);
+            graphics.rectangle(blank_footer);
+            graphics.set_pen(BG);
+            graphics.text("               Restart   ", footer_location, 600);
+
+            console_text_location.x = 5;
+            console_text_location.y = 30;
+            
+            graphics.set_pen(WHITE);
+
+            printf("\n------------ Inference statistics ------------\n");
+            char text[100];
+            sprintf(text, "%-20s %lu/%u", "No. correct pred.", correct_pred_count, NUM_UNLABELLED_IMGS);
+            printf("%s\n", text);
+            graphics.text(text, console_text_location, 240);
+            console_text_location.y += 16;
+
+            sprintf(text, "%-20s %0.2f", "Top-1 acc", acc);
+            printf("%s\n\n", text);
+            graphics.text(text, console_text_location, 600);
+            console_text_location.y += 32;
+
+            graphics.text("Mean time measurements", console_text_location, 600);
+            printf("Mean time measurements: \n");
+            console_text_location.y += 32;
+
+            sprintf(text, "%s %8lld %s", "Data trans.  ", data_transfer_time_mean, "us");
+            printf("%s\n", text);
+            graphics.text(text, console_text_location, 600);
+            console_text_location.y += 16;
+
+            sprintf(text, "%s %8lld %s", "Acc. compute ", compute_time_mean, "us");
+            printf("%s\n", text);
+            graphics.text(text, console_text_location, 600);
+            console_text_location.y += 16;
+
+            sprintf(text, "%s %8lld %s", "Results read ", results_read_time_mean, "us");
+            printf("%s\n", text);
+            graphics.text(text, console_text_location, 600);
+            console_text_location.y += 16;
+
+            sprintf(text, "%s %8lld %s", "Reg. write   ", register_write_time_mean, "us");
+            printf("%s\n", text);
+            graphics.text(text, console_text_location, 600);
+            console_text_location.y += 16;
+
+            sprintf(text, "%s %8lld %s", "Dist. sorting", sorting_time_mean, "us");
+            printf("%s\n", text);
+            graphics.text(text, console_text_location, 600);
+            console_text_location.y += 32;
+        
+            sprintf(text, "%s %8lld us", "Total time   ", total_time_mean);
+            printf("%s\n", text);
+            graphics.text(text, console_text_location, 600);
+        }
     }
     if(BUTTON_X_pressed){
+        if (DEMO_done) {
+            // Clear text area
+            graphics.set_pen(BG);
+            Rect blank(0, 18, 240, 284);
+            graphics.rectangle(blank);
+
+            // Clear footer text
+            graphics.set_pen(WHITE);
+            Rect blank_footer(0, 302, 240, 18);
+            graphics.rectangle(blank_footer);
+
+            correct_pred_count = 0;
+            x = 0;
+            y = 0;
+            xtext = 0;
+            ytext = 0;
+
+            DEMO_done = false;
+        }
+
         DEMO_RUN=true;
         BUTTON_X_pressed=false;
     }
@@ -517,12 +655,6 @@ void core1_entry(){
     }
 }
 
-void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data);
-int compare_indices(void *arr, const void *a, const void *b);
-void sort_indices(uint32_t *array, int *indices, size_t size);
-uint8_t predict_label(int *sorting_indices, int k);
-uint8_t find_max_index(uint8_t *array, size_t size);
-
 int main() {
     // Initialize GPIO pins
     gpio_init(DATA_REQ_GPIO_PIN);
@@ -715,21 +847,36 @@ int main() {
 
     // Main loop for Core 0
     while(1){
-        if (DEMO_RUN && !KNN_done) {                
+        if (DEMO_RUN && !KNN_done) {
+            int64_t compute_time = 0;
+            int64_t data_transfer_time = 0;
+            int64_t results_read_time = 0;
+            int64_t register_write_time = 0;
+            int64_t sorting_time = 0;
+
+            int64_t total_time = 0;
+
             // Reset accelerator
             nanosoc_download_buffer(pio, sm, &logic_zero[0], sw_reset_reg_addr, 4);
 
+            absolute_time_t total_time_start = get_absolute_time();
+
             // Send unlabelled image i
-            send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &unlabelled_buffer[current_unlab_img][0]);
+            send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &unlabelled_buffer[current_unlab_img][0], &compute_time, &data_transfer_time);
 
             KNN_start=true;
 
             for (int j = 0; j < NUM_LABELLED_IMGS; j++) {
                 // Send labelled image j
-                send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &labelled_buffer[j][0]);
+                send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &labelled_buffer[j][0], &compute_time, &data_transfer_time);
 
+                absolute_time_t register_write_start = get_absolute_time();
                 // Set priming mode to 0
                 nanosoc_download_buffer(pio, sm, &logic_zero[0], priming_mode_reg_addr, 4);
+                absolute_time_t register_write_end = get_absolute_time();
+                register_write_time += absolute_time_diff_us(register_write_start, register_write_end);
+
+                absolute_time_t results_read_start = get_absolute_time();
 
                 uint32_t unlab_img_dot_prod = nanosoc_read_reg32(pio, sm, unlab_img_dot_prod_addr);
                 uint32_t lab_img_dot_prod = nanosoc_read_reg32(pio, sm, lab_img_dot_prod_addr);
@@ -744,28 +891,48 @@ int main() {
                 // Compute distance between unlabelled image i and labelled image j
                 dist[j] = unlab_img_dot_prod + lab_img_dot_prod - 2*comb_dot_prod;
 
+                absolute_time_t results_read_end = get_absolute_time();
+                results_read_time += absolute_time_diff_us(results_read_start, results_read_end);
+
+                register_write_start = get_absolute_time();
+
                 // Clear lab_img_dot_prod and comb_dot_prod_reg
                 nanosoc_download_buffer(pio, sm, &logic_zero[0], lab_img_dot_prod_addr, 4);
                 nanosoc_download_buffer(pio, sm, &logic_zero[0], comb_dot_prod_addr, 4);
+                
+                register_write_end = get_absolute_time();
+                register_write_time += absolute_time_diff_us(register_write_start, register_write_end);
             }
 
-
+            absolute_time_t sorting_time_start = get_absolute_time();
             sort_indices(dist, indices, NUM_LABELLED_IMGS);
-            printf("kNN labels: ");
-            for (size_t i = 0; i < kNN_k; i++) {
-                printf("%d ", labelled_imgs_labels[indices[i]]);
-            }
-            printf("\n");
+            absolute_time_t sorting_time_end = get_absolute_time();
+            sorting_time = absolute_time_diff_us(sorting_time_start, sorting_time_end);
+
+            // printf("kNN labels: ");
+            // for (size_t i = 0; i < kNN_k; i++) {
+            //     printf("%d ", labelled_imgs_labels[indices[i]]);
+            // }
+            // printf("\n");
 
             predicted_label = predict_label(&indices[0], kNN_k);
+
+            absolute_time_t total_time_end = get_absolute_time();
+            total_time = absolute_time_diff_us(total_time_start, total_time_end);
+
             printf("Unlabelled image %s, Predicted label: %s\n", unlabelled_files[i].c_str(), classes[predicted_label].c_str());
 
-            KNN_done = true;
-            current_unlab_img++;
+            if (predicted_label == test_imgs_labels[current_unlab_img])
+                correct_pred_count++;
 
-            if(current_unlab_img >= NUM_UNLABELLED_IMGS) {
-                current_unlab_img=0;
-            }
+            compute_time_arr[current_unlab_img] = compute_time;
+            data_transfer_time_arr[current_unlab_img] = data_transfer_time;
+            results_read_time_arr[current_unlab_img] = results_read_time;
+            register_write_time_arr[current_unlab_img] = register_write_time;
+            sorting_time_arr[current_unlab_img] = sorting_time;
+            total_time_arr[current_unlab_img] = total_time;
+
+            KNN_done = true;
         }
 
         // Always read in and store STDIO from nanosoc
@@ -794,7 +961,7 @@ int main() {
     return 0;
 }
 
-void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data) {
+void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data, int64_t *compute_time, int64_t *data_transfer_time) {
     // Wait for DATA_REQ signal
     // printf("Waiting for DATA_REQ signal...\n");
     char buf[128];
@@ -806,14 +973,20 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr,
     };
     // printf("DATA_REQ signal received. Sending data...\n");
 
+    absolute_time_t data_transfer_start = get_absolute_time();
+
     // Send data
     nanosoc_download_buffer(*pio, *sm, data, *nanosoc_img_buffer_addr, 784);
 
+    absolute_time_t data_transfer_end = get_absolute_time();
+    *data_transfer_time += absolute_time_diff_us(data_transfer_start, data_transfer_end);
+
     // Pull DATA_SENT_GPIO_PIN HIGH
     gpio_put(DATA_SENT_GPIO_PIN, true);
     // printf("DATA_SENT signal pulled HIGH.\n");
     // printf("DATA_SENT_GPIO_PIN level: %d\n", gpio_get_out_level(DATA_SENT_GPIO_PIN));
 
+    absolute_time_t compute_start = get_absolute_time();
 
     // Wait for DATA_REQ signal to be pulled LOW 
     // printf("Waiting for DATA_REQ signal to be pulled LOW...\n");
@@ -823,6 +996,9 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr,
     };
     // printf("DATA_REQ signal pulled LOW.\n");
 
+    absolute_time_t compute_end = get_absolute_time();
+    *compute_time += absolute_time_diff_us(compute_start, compute_end);
+
     // Pull DATA_SENT_GPIO_PIN LOW
     gpio_put(DATA_SENT_GPIO_PIN, false);
     // printf("DATA_SENT signal pulled LOW.\n");
@@ -870,4 +1046,4 @@ uint8_t find_max_index(uint8_t *array, size_t size) {
     }
 
     return max_index;
-}
+}
\ No newline at end of file
-- 
GitLab