From d0ba0c470e4525365900be5161c84af27405cbb7 Mon Sep 17 00:00:00 2001 From: Fanis Baikas <fan.baikas@gmail.com> Date: Tue, 10 Sep 2024 16:02:48 +0100 Subject: [PATCH] Added inference statistics at the end of classification --- .../C_Cxx/fast_knn_demo/fast_knn_demo.cpp | 226 ++++++++++++++++-- 1 file changed, 201 insertions(+), 25 deletions(-) diff --git a/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp b/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp index 87a9d94..b3140f3 100644 --- a/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp +++ b/nanosoc_board/C_Cxx/fast_knn_demo/fast_knn_demo.cpp @@ -19,6 +19,7 @@ #include "pico/multicore.h" #include <stdio.h> +#include <stdarg.h> #include <cmath> extern "C"{ @@ -56,8 +57,8 @@ std::string console[CONSOLE_FIFO_DEPTH]; #define DATA_REQ_GPIO_PIN 2 #define DATA_SENT_GPIO_PIN 3 -#define NUM_UNLABELLED_IMGS 36 -#define NUM_LABELLED_IMGS 150 +#define NUM_UNLABELLED_IMGS 27 +#define NUM_LABELLED_IMGS 160 #define NUM_OF_CLASSES 10 #define NUM_LABELLED_IMAGES_PER_CLASS NUM_LABELLED_IMGS/NUM_OF_CLASSES @@ -104,6 +105,25 @@ std::string labelled_image_dirs[10] = { // Path to unlabelled images std::string unlabelled_image_dir = "0:/fmnist_data/test"; +// Counter for the number of correct predictions used for top-1 acc computation +uint16_t correct_pred_count = 0; + +// Array buffer to hold time measurements for statistics calcumation +int64_t compute_time_arr[NUM_UNLABELLED_IMGS]; +int64_t data_transfer_time_arr[NUM_UNLABELLED_IMGS]; +int64_t results_read_time_arr[NUM_UNLABELLED_IMGS]; +int64_t register_write_time_arr[NUM_UNLABELLED_IMGS]; +int64_t sorting_time_arr[NUM_UNLABELLED_IMGS]; +int64_t total_time_arr[NUM_UNLABELLED_IMGS]; + +float acc = 0; +int64_t compute_time_mean = 0; +int64_t data_transfer_time_mean = 0; +int64_t results_read_time_mean = 0; +int64_t register_write_time_mean = 0; +int64_t sorting_time_mean = 0; +int64_t total_time_mean = 0; + std::string classes[] = { "T-Shirt", "Trousers", @@ -138,6 +158,7 @@ screen_state last_screen_state=POWER; bool DEMO_RUN = false; bool KNN_done = false; bool KNN_start = false; +bool DEMO_done = false; bool nanosoc_console = false; bool nanosoc_console_draw = false; i2c_inst_t *i2c = i2c1; @@ -161,6 +182,12 @@ static sd_card_t sd_cards[] = {{ int img_x_grid[] = {3, 83, 163}; int img_y_grid[] = {30, 118, 206}; +void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data, int64_t *compute_time, int64_t *data_transfer_time); +int compare_indices(void *arr, const void *a, const void *b); +void sort_indices(uint32_t *array, int *indices, size_t size); +uint8_t predict_label(int *sorting_indices, int k); +uint8_t find_max_index(uint8_t *array, size_t size); + // Button IRQ (for controlling screen state) void BUTTON_irq(uint gpio, uint32_t events) { if(gpio==BUTTON_A){ @@ -351,7 +378,7 @@ void DEMO_ROUTINE(){ result_location.y = img_y_grid[ytext]+60; // KNN finished, plot the result - if (predicted_label == test_imgs_labels[current_unlab_img-1]) + if (predicted_label == test_imgs_labels[current_unlab_img]) graphics.set_pen(GREEN); else graphics.set_pen(RED); @@ -366,10 +393,121 @@ void DEMO_ROUTINE(){ } } printf("printed result\n"); + + current_unlab_img++; + + if(current_unlab_img >= NUM_UNLABELLED_IMGS) { + current_unlab_img = 0; + + acc = (float) correct_pred_count / NUM_UNLABELLED_IMGS; + + for (int k = 0; k < NUM_UNLABELLED_IMGS; k++) { + compute_time_mean += compute_time_arr[k]; + data_transfer_time_mean += data_transfer_time_arr[k]; + results_read_time_mean += results_read_time_arr[k]; + register_write_time_mean += register_write_time_arr[k]; + sorting_time_mean += sorting_time_arr[k]; + total_time_mean += total_time_arr[k]; + } + + compute_time_mean = compute_time_mean / NUM_UNLABELLED_IMGS; + data_transfer_time_mean = data_transfer_time_mean / NUM_UNLABELLED_IMGS; + results_read_time_mean = results_read_time_mean / NUM_UNLABELLED_IMGS; + register_write_time_mean = register_write_time_mean / NUM_UNLABELLED_IMGS; + sorting_time_mean = sorting_time_mean / NUM_UNLABELLED_IMGS; + total_time_mean = total_time_mean / NUM_UNLABELLED_IMGS; + + DEMO_done = true; + DEMO_RUN = false; + } + KNN_done=false; } + if (DEMO_done){ + // Clear text area + graphics.set_pen(BG); + Rect blank(0, 18, 240, 284); + graphics.rectangle(blank); + + // Edit footer text + graphics.set_pen(WHITE); + Rect blank_footer(0, 302, 240, 18); + graphics.rectangle(blank_footer); + graphics.set_pen(BG); + graphics.text(" Restart ", footer_location, 600); + + console_text_location.x = 5; + console_text_location.y = 30; + + graphics.set_pen(WHITE); + + printf("\n------------ Inference statistics ------------\n"); + char text[100]; + sprintf(text, "%-20s %lu/%u", "No. correct pred.", correct_pred_count, NUM_UNLABELLED_IMGS); + printf("%s\n", text); + graphics.text(text, console_text_location, 240); + console_text_location.y += 16; + + sprintf(text, "%-20s %0.2f", "Top-1 acc", acc); + printf("%s\n\n", text); + graphics.text(text, console_text_location, 600); + console_text_location.y += 32; + + graphics.text("Mean time measurements", console_text_location, 600); + printf("Mean time measurements: \n"); + console_text_location.y += 32; + + sprintf(text, "%s %8lld %s", "Data trans. ", data_transfer_time_mean, "us"); + printf("%s\n", text); + graphics.text(text, console_text_location, 600); + console_text_location.y += 16; + + sprintf(text, "%s %8lld %s", "Acc. compute ", compute_time_mean, "us"); + printf("%s\n", text); + graphics.text(text, console_text_location, 600); + console_text_location.y += 16; + + sprintf(text, "%s %8lld %s", "Results read ", results_read_time_mean, "us"); + printf("%s\n", text); + graphics.text(text, console_text_location, 600); + console_text_location.y += 16; + + sprintf(text, "%s %8lld %s", "Reg. write ", register_write_time_mean, "us"); + printf("%s\n", text); + graphics.text(text, console_text_location, 600); + console_text_location.y += 16; + + sprintf(text, "%s %8lld %s", "Dist. sorting", sorting_time_mean, "us"); + printf("%s\n", text); + graphics.text(text, console_text_location, 600); + console_text_location.y += 32; + + sprintf(text, "%s %8lld us", "Total time ", total_time_mean); + printf("%s\n", text); + graphics.text(text, console_text_location, 600); + } } if(BUTTON_X_pressed){ + if (DEMO_done) { + // Clear text area + graphics.set_pen(BG); + Rect blank(0, 18, 240, 284); + graphics.rectangle(blank); + + // Clear footer text + graphics.set_pen(WHITE); + Rect blank_footer(0, 302, 240, 18); + graphics.rectangle(blank_footer); + + correct_pred_count = 0; + x = 0; + y = 0; + xtext = 0; + ytext = 0; + + DEMO_done = false; + } + DEMO_RUN=true; BUTTON_X_pressed=false; } @@ -517,12 +655,6 @@ void core1_entry(){ } } -void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data); -int compare_indices(void *arr, const void *a, const void *b); -void sort_indices(uint32_t *array, int *indices, size_t size); -uint8_t predict_label(int *sorting_indices, int k); -uint8_t find_max_index(uint8_t *array, size_t size); - int main() { // Initialize GPIO pins gpio_init(DATA_REQ_GPIO_PIN); @@ -715,21 +847,36 @@ int main() { // Main loop for Core 0 while(1){ - if (DEMO_RUN && !KNN_done) { + if (DEMO_RUN && !KNN_done) { + int64_t compute_time = 0; + int64_t data_transfer_time = 0; + int64_t results_read_time = 0; + int64_t register_write_time = 0; + int64_t sorting_time = 0; + + int64_t total_time = 0; + // Reset accelerator nanosoc_download_buffer(pio, sm, &logic_zero[0], sw_reset_reg_addr, 4); + absolute_time_t total_time_start = get_absolute_time(); + // Send unlabelled image i - send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &unlabelled_buffer[current_unlab_img][0]); + send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &unlabelled_buffer[current_unlab_img][0], &compute_time, &data_transfer_time); KNN_start=true; for (int j = 0; j < NUM_LABELLED_IMGS; j++) { // Send labelled image j - send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &labelled_buffer[j][0]); + send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &labelled_buffer[j][0], &compute_time, &data_transfer_time); + absolute_time_t register_write_start = get_absolute_time(); // Set priming mode to 0 nanosoc_download_buffer(pio, sm, &logic_zero[0], priming_mode_reg_addr, 4); + absolute_time_t register_write_end = get_absolute_time(); + register_write_time += absolute_time_diff_us(register_write_start, register_write_end); + + absolute_time_t results_read_start = get_absolute_time(); uint32_t unlab_img_dot_prod = nanosoc_read_reg32(pio, sm, unlab_img_dot_prod_addr); uint32_t lab_img_dot_prod = nanosoc_read_reg32(pio, sm, lab_img_dot_prod_addr); @@ -744,28 +891,48 @@ int main() { // Compute distance between unlabelled image i and labelled image j dist[j] = unlab_img_dot_prod + lab_img_dot_prod - 2*comb_dot_prod; + absolute_time_t results_read_end = get_absolute_time(); + results_read_time += absolute_time_diff_us(results_read_start, results_read_end); + + register_write_start = get_absolute_time(); + // Clear lab_img_dot_prod and comb_dot_prod_reg nanosoc_download_buffer(pio, sm, &logic_zero[0], lab_img_dot_prod_addr, 4); nanosoc_download_buffer(pio, sm, &logic_zero[0], comb_dot_prod_addr, 4); + + register_write_end = get_absolute_time(); + register_write_time += absolute_time_diff_us(register_write_start, register_write_end); } - + absolute_time_t sorting_time_start = get_absolute_time(); sort_indices(dist, indices, NUM_LABELLED_IMGS); - printf("kNN labels: "); - for (size_t i = 0; i < kNN_k; i++) { - printf("%d ", labelled_imgs_labels[indices[i]]); - } - printf("\n"); + absolute_time_t sorting_time_end = get_absolute_time(); + sorting_time = absolute_time_diff_us(sorting_time_start, sorting_time_end); + + // printf("kNN labels: "); + // for (size_t i = 0; i < kNN_k; i++) { + // printf("%d ", labelled_imgs_labels[indices[i]]); + // } + // printf("\n"); predicted_label = predict_label(&indices[0], kNN_k); + + absolute_time_t total_time_end = get_absolute_time(); + total_time = absolute_time_diff_us(total_time_start, total_time_end); + printf("Unlabelled image %s, Predicted label: %s\n", unlabelled_files[i].c_str(), classes[predicted_label].c_str()); - KNN_done = true; - current_unlab_img++; + if (predicted_label == test_imgs_labels[current_unlab_img]) + correct_pred_count++; - if(current_unlab_img >= NUM_UNLABELLED_IMGS) { - current_unlab_img=0; - } + compute_time_arr[current_unlab_img] = compute_time; + data_transfer_time_arr[current_unlab_img] = data_transfer_time; + results_read_time_arr[current_unlab_img] = results_read_time; + register_write_time_arr[current_unlab_img] = register_write_time; + sorting_time_arr[current_unlab_img] = sorting_time; + total_time_arr[current_unlab_img] = total_time; + + KNN_done = true; } // Always read in and store STDIO from nanosoc @@ -794,7 +961,7 @@ int main() { return 0; } -void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data) { +void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data, int64_t *compute_time, int64_t *data_transfer_time) { // Wait for DATA_REQ signal // printf("Waiting for DATA_REQ signal...\n"); char buf[128]; @@ -806,14 +973,20 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, }; // printf("DATA_REQ signal received. Sending data...\n"); + absolute_time_t data_transfer_start = get_absolute_time(); + // Send data nanosoc_download_buffer(*pio, *sm, data, *nanosoc_img_buffer_addr, 784); + absolute_time_t data_transfer_end = get_absolute_time(); + *data_transfer_time += absolute_time_diff_us(data_transfer_start, data_transfer_end); + // Pull DATA_SENT_GPIO_PIN HIGH gpio_put(DATA_SENT_GPIO_PIN, true); // printf("DATA_SENT signal pulled HIGH.\n"); // printf("DATA_SENT_GPIO_PIN level: %d\n", gpio_get_out_level(DATA_SENT_GPIO_PIN)); + absolute_time_t compute_start = get_absolute_time(); // Wait for DATA_REQ signal to be pulled LOW // printf("Waiting for DATA_REQ signal to be pulled LOW...\n"); @@ -823,6 +996,9 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, }; // printf("DATA_REQ signal pulled LOW.\n"); + absolute_time_t compute_end = get_absolute_time(); + *compute_time += absolute_time_diff_us(compute_start, compute_end); + // Pull DATA_SENT_GPIO_PIN LOW gpio_put(DATA_SENT_GPIO_PIN, false); // printf("DATA_SENT signal pulled LOW.\n"); @@ -870,4 +1046,4 @@ uint8_t find_max_index(uint8_t *array, size_t size) { } return max_index; -} +} \ No newline at end of file -- GitLab