Skip to content
Snippets Groups Projects
Commit d0ba0c47 authored by Fanis Baikas's avatar Fanis Baikas
Browse files

Added inference statistics at the end of classification

parent ed4d9aeb
Branches
No related tags found
No related merge requests found
Pipeline #11600 passed
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "pico/multicore.h" #include "pico/multicore.h"
#include <stdio.h> #include <stdio.h>
#include <stdarg.h>
#include <cmath> #include <cmath>
extern "C"{ extern "C"{
...@@ -56,8 +57,8 @@ std::string console[CONSOLE_FIFO_DEPTH]; ...@@ -56,8 +57,8 @@ std::string console[CONSOLE_FIFO_DEPTH];
#define DATA_REQ_GPIO_PIN 2 #define DATA_REQ_GPIO_PIN 2
#define DATA_SENT_GPIO_PIN 3 #define DATA_SENT_GPIO_PIN 3
#define NUM_UNLABELLED_IMGS 36 #define NUM_UNLABELLED_IMGS 27
#define NUM_LABELLED_IMGS 150 #define NUM_LABELLED_IMGS 160
#define NUM_OF_CLASSES 10 #define NUM_OF_CLASSES 10
#define NUM_LABELLED_IMAGES_PER_CLASS NUM_LABELLED_IMGS/NUM_OF_CLASSES #define NUM_LABELLED_IMAGES_PER_CLASS NUM_LABELLED_IMGS/NUM_OF_CLASSES
...@@ -104,6 +105,25 @@ std::string labelled_image_dirs[10] = { ...@@ -104,6 +105,25 @@ std::string labelled_image_dirs[10] = {
// Path to unlabelled images // Path to unlabelled images
std::string unlabelled_image_dir = "0:/fmnist_data/test"; std::string unlabelled_image_dir = "0:/fmnist_data/test";
// Counter for the number of correct predictions used for top-1 acc computation
uint16_t correct_pred_count = 0;
// Array buffer to hold time measurements for statistics calcumation
int64_t compute_time_arr[NUM_UNLABELLED_IMGS];
int64_t data_transfer_time_arr[NUM_UNLABELLED_IMGS];
int64_t results_read_time_arr[NUM_UNLABELLED_IMGS];
int64_t register_write_time_arr[NUM_UNLABELLED_IMGS];
int64_t sorting_time_arr[NUM_UNLABELLED_IMGS];
int64_t total_time_arr[NUM_UNLABELLED_IMGS];
float acc = 0;
int64_t compute_time_mean = 0;
int64_t data_transfer_time_mean = 0;
int64_t results_read_time_mean = 0;
int64_t register_write_time_mean = 0;
int64_t sorting_time_mean = 0;
int64_t total_time_mean = 0;
std::string classes[] = { std::string classes[] = {
"T-Shirt", "T-Shirt",
"Trousers", "Trousers",
...@@ -138,6 +158,7 @@ screen_state last_screen_state=POWER; ...@@ -138,6 +158,7 @@ screen_state last_screen_state=POWER;
bool DEMO_RUN = false; bool DEMO_RUN = false;
bool KNN_done = false; bool KNN_done = false;
bool KNN_start = false; bool KNN_start = false;
bool DEMO_done = false;
bool nanosoc_console = false; bool nanosoc_console = false;
bool nanosoc_console_draw = false; bool nanosoc_console_draw = false;
i2c_inst_t *i2c = i2c1; i2c_inst_t *i2c = i2c1;
...@@ -161,6 +182,12 @@ static sd_card_t sd_cards[] = {{ ...@@ -161,6 +182,12 @@ static sd_card_t sd_cards[] = {{
int img_x_grid[] = {3, 83, 163}; int img_x_grid[] = {3, 83, 163};
int img_y_grid[] = {30, 118, 206}; int img_y_grid[] = {30, 118, 206};
void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data, int64_t *compute_time, int64_t *data_transfer_time);
int compare_indices(void *arr, const void *a, const void *b);
void sort_indices(uint32_t *array, int *indices, size_t size);
uint8_t predict_label(int *sorting_indices, int k);
uint8_t find_max_index(uint8_t *array, size_t size);
// Button IRQ (for controlling screen state) // Button IRQ (for controlling screen state)
void BUTTON_irq(uint gpio, uint32_t events) { void BUTTON_irq(uint gpio, uint32_t events) {
if(gpio==BUTTON_A){ if(gpio==BUTTON_A){
...@@ -351,7 +378,7 @@ void DEMO_ROUTINE(){ ...@@ -351,7 +378,7 @@ void DEMO_ROUTINE(){
result_location.y = img_y_grid[ytext]+60; result_location.y = img_y_grid[ytext]+60;
// KNN finished, plot the result // KNN finished, plot the result
if (predicted_label == test_imgs_labels[current_unlab_img-1]) if (predicted_label == test_imgs_labels[current_unlab_img])
graphics.set_pen(GREEN); graphics.set_pen(GREEN);
else else
graphics.set_pen(RED); graphics.set_pen(RED);
...@@ -366,10 +393,121 @@ void DEMO_ROUTINE(){ ...@@ -366,10 +393,121 @@ void DEMO_ROUTINE(){
} }
} }
printf("printed result\n"); printf("printed result\n");
current_unlab_img++;
if(current_unlab_img >= NUM_UNLABELLED_IMGS) {
current_unlab_img = 0;
acc = (float) correct_pred_count / NUM_UNLABELLED_IMGS;
for (int k = 0; k < NUM_UNLABELLED_IMGS; k++) {
compute_time_mean += compute_time_arr[k];
data_transfer_time_mean += data_transfer_time_arr[k];
results_read_time_mean += results_read_time_arr[k];
register_write_time_mean += register_write_time_arr[k];
sorting_time_mean += sorting_time_arr[k];
total_time_mean += total_time_arr[k];
}
compute_time_mean = compute_time_mean / NUM_UNLABELLED_IMGS;
data_transfer_time_mean = data_transfer_time_mean / NUM_UNLABELLED_IMGS;
results_read_time_mean = results_read_time_mean / NUM_UNLABELLED_IMGS;
register_write_time_mean = register_write_time_mean / NUM_UNLABELLED_IMGS;
sorting_time_mean = sorting_time_mean / NUM_UNLABELLED_IMGS;
total_time_mean = total_time_mean / NUM_UNLABELLED_IMGS;
DEMO_done = true;
DEMO_RUN = false;
}
KNN_done=false; KNN_done=false;
} }
if (DEMO_done){
// Clear text area
graphics.set_pen(BG);
Rect blank(0, 18, 240, 284);
graphics.rectangle(blank);
// Edit footer text
graphics.set_pen(WHITE);
Rect blank_footer(0, 302, 240, 18);
graphics.rectangle(blank_footer);
graphics.set_pen(BG);
graphics.text(" Restart ", footer_location, 600);
console_text_location.x = 5;
console_text_location.y = 30;
graphics.set_pen(WHITE);
printf("\n------------ Inference statistics ------------\n");
char text[100];
sprintf(text, "%-20s %lu/%u", "No. correct pred.", correct_pred_count, NUM_UNLABELLED_IMGS);
printf("%s\n", text);
graphics.text(text, console_text_location, 240);
console_text_location.y += 16;
sprintf(text, "%-20s %0.2f", "Top-1 acc", acc);
printf("%s\n\n", text);
graphics.text(text, console_text_location, 600);
console_text_location.y += 32;
graphics.text("Mean time measurements", console_text_location, 600);
printf("Mean time measurements: \n");
console_text_location.y += 32;
sprintf(text, "%s %8lld %s", "Data trans. ", data_transfer_time_mean, "us");
printf("%s\n", text);
graphics.text(text, console_text_location, 600);
console_text_location.y += 16;
sprintf(text, "%s %8lld %s", "Acc. compute ", compute_time_mean, "us");
printf("%s\n", text);
graphics.text(text, console_text_location, 600);
console_text_location.y += 16;
sprintf(text, "%s %8lld %s", "Results read ", results_read_time_mean, "us");
printf("%s\n", text);
graphics.text(text, console_text_location, 600);
console_text_location.y += 16;
sprintf(text, "%s %8lld %s", "Reg. write ", register_write_time_mean, "us");
printf("%s\n", text);
graphics.text(text, console_text_location, 600);
console_text_location.y += 16;
sprintf(text, "%s %8lld %s", "Dist. sorting", sorting_time_mean, "us");
printf("%s\n", text);
graphics.text(text, console_text_location, 600);
console_text_location.y += 32;
sprintf(text, "%s %8lld us", "Total time ", total_time_mean);
printf("%s\n", text);
graphics.text(text, console_text_location, 600);
}
} }
if(BUTTON_X_pressed){ if(BUTTON_X_pressed){
if (DEMO_done) {
// Clear text area
graphics.set_pen(BG);
Rect blank(0, 18, 240, 284);
graphics.rectangle(blank);
// Clear footer text
graphics.set_pen(WHITE);
Rect blank_footer(0, 302, 240, 18);
graphics.rectangle(blank_footer);
correct_pred_count = 0;
x = 0;
y = 0;
xtext = 0;
ytext = 0;
DEMO_done = false;
}
DEMO_RUN=true; DEMO_RUN=true;
BUTTON_X_pressed=false; BUTTON_X_pressed=false;
} }
...@@ -517,12 +655,6 @@ void core1_entry(){ ...@@ -517,12 +655,6 @@ void core1_entry(){
} }
} }
void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data);
int compare_indices(void *arr, const void *a, const void *b);
void sort_indices(uint32_t *array, int *indices, size_t size);
uint8_t predict_label(int *sorting_indices, int k);
uint8_t find_max_index(uint8_t *array, size_t size);
int main() { int main() {
// Initialize GPIO pins // Initialize GPIO pins
gpio_init(DATA_REQ_GPIO_PIN); gpio_init(DATA_REQ_GPIO_PIN);
...@@ -716,20 +848,35 @@ int main() { ...@@ -716,20 +848,35 @@ int main() {
// Main loop for Core 0 // Main loop for Core 0
while(1){ while(1){
if (DEMO_RUN && !KNN_done) { if (DEMO_RUN && !KNN_done) {
int64_t compute_time = 0;
int64_t data_transfer_time = 0;
int64_t results_read_time = 0;
int64_t register_write_time = 0;
int64_t sorting_time = 0;
int64_t total_time = 0;
// Reset accelerator // Reset accelerator
nanosoc_download_buffer(pio, sm, &logic_zero[0], sw_reset_reg_addr, 4); nanosoc_download_buffer(pio, sm, &logic_zero[0], sw_reset_reg_addr, 4);
absolute_time_t total_time_start = get_absolute_time();
// Send unlabelled image i // Send unlabelled image i
send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &unlabelled_buffer[current_unlab_img][0]); send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &unlabelled_buffer[current_unlab_img][0], &compute_time, &data_transfer_time);
KNN_start=true; KNN_start=true;
for (int j = 0; j < NUM_LABELLED_IMGS; j++) { for (int j = 0; j < NUM_LABELLED_IMGS; j++) {
// Send labelled image j // Send labelled image j
send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &labelled_buffer[j][0]); send_data_to_accelerator(&pio, &sm, &nanosoc_img_buffer_addr, &labelled_buffer[j][0], &compute_time, &data_transfer_time);
absolute_time_t register_write_start = get_absolute_time();
// Set priming mode to 0 // Set priming mode to 0
nanosoc_download_buffer(pio, sm, &logic_zero[0], priming_mode_reg_addr, 4); nanosoc_download_buffer(pio, sm, &logic_zero[0], priming_mode_reg_addr, 4);
absolute_time_t register_write_end = get_absolute_time();
register_write_time += absolute_time_diff_us(register_write_start, register_write_end);
absolute_time_t results_read_start = get_absolute_time();
uint32_t unlab_img_dot_prod = nanosoc_read_reg32(pio, sm, unlab_img_dot_prod_addr); uint32_t unlab_img_dot_prod = nanosoc_read_reg32(pio, sm, unlab_img_dot_prod_addr);
uint32_t lab_img_dot_prod = nanosoc_read_reg32(pio, sm, lab_img_dot_prod_addr); uint32_t lab_img_dot_prod = nanosoc_read_reg32(pio, sm, lab_img_dot_prod_addr);
...@@ -744,28 +891,48 @@ int main() { ...@@ -744,28 +891,48 @@ int main() {
// Compute distance between unlabelled image i and labelled image j // Compute distance between unlabelled image i and labelled image j
dist[j] = unlab_img_dot_prod + lab_img_dot_prod - 2*comb_dot_prod; dist[j] = unlab_img_dot_prod + lab_img_dot_prod - 2*comb_dot_prod;
absolute_time_t results_read_end = get_absolute_time();
results_read_time += absolute_time_diff_us(results_read_start, results_read_end);
register_write_start = get_absolute_time();
// Clear lab_img_dot_prod and comb_dot_prod_reg // Clear lab_img_dot_prod and comb_dot_prod_reg
nanosoc_download_buffer(pio, sm, &logic_zero[0], lab_img_dot_prod_addr, 4); nanosoc_download_buffer(pio, sm, &logic_zero[0], lab_img_dot_prod_addr, 4);
nanosoc_download_buffer(pio, sm, &logic_zero[0], comb_dot_prod_addr, 4); nanosoc_download_buffer(pio, sm, &logic_zero[0], comb_dot_prod_addr, 4);
}
register_write_end = get_absolute_time();
register_write_time += absolute_time_diff_us(register_write_start, register_write_end);
}
absolute_time_t sorting_time_start = get_absolute_time();
sort_indices(dist, indices, NUM_LABELLED_IMGS); sort_indices(dist, indices, NUM_LABELLED_IMGS);
printf("kNN labels: "); absolute_time_t sorting_time_end = get_absolute_time();
for (size_t i = 0; i < kNN_k; i++) { sorting_time = absolute_time_diff_us(sorting_time_start, sorting_time_end);
printf("%d ", labelled_imgs_labels[indices[i]]);
} // printf("kNN labels: ");
printf("\n"); // for (size_t i = 0; i < kNN_k; i++) {
// printf("%d ", labelled_imgs_labels[indices[i]]);
// }
// printf("\n");
predicted_label = predict_label(&indices[0], kNN_k); predicted_label = predict_label(&indices[0], kNN_k);
absolute_time_t total_time_end = get_absolute_time();
total_time = absolute_time_diff_us(total_time_start, total_time_end);
printf("Unlabelled image %s, Predicted label: %s\n", unlabelled_files[i].c_str(), classes[predicted_label].c_str()); printf("Unlabelled image %s, Predicted label: %s\n", unlabelled_files[i].c_str(), classes[predicted_label].c_str());
KNN_done = true; if (predicted_label == test_imgs_labels[current_unlab_img])
current_unlab_img++; correct_pred_count++;
if(current_unlab_img >= NUM_UNLABELLED_IMGS) { compute_time_arr[current_unlab_img] = compute_time;
current_unlab_img=0; data_transfer_time_arr[current_unlab_img] = data_transfer_time;
} results_read_time_arr[current_unlab_img] = results_read_time;
register_write_time_arr[current_unlab_img] = register_write_time;
sorting_time_arr[current_unlab_img] = sorting_time;
total_time_arr[current_unlab_img] = total_time;
KNN_done = true;
} }
// Always read in and store STDIO from nanosoc // Always read in and store STDIO from nanosoc
...@@ -794,7 +961,7 @@ int main() { ...@@ -794,7 +961,7 @@ int main() {
return 0; return 0;
} }
void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data) { void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, uint8_t *data, int64_t *compute_time, int64_t *data_transfer_time) {
// Wait for DATA_REQ signal // Wait for DATA_REQ signal
// printf("Waiting for DATA_REQ signal...\n"); // printf("Waiting for DATA_REQ signal...\n");
char buf[128]; char buf[128];
...@@ -806,14 +973,20 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, ...@@ -806,14 +973,20 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr,
}; };
// printf("DATA_REQ signal received. Sending data...\n"); // printf("DATA_REQ signal received. Sending data...\n");
absolute_time_t data_transfer_start = get_absolute_time();
// Send data // Send data
nanosoc_download_buffer(*pio, *sm, data, *nanosoc_img_buffer_addr, 784); nanosoc_download_buffer(*pio, *sm, data, *nanosoc_img_buffer_addr, 784);
absolute_time_t data_transfer_end = get_absolute_time();
*data_transfer_time += absolute_time_diff_us(data_transfer_start, data_transfer_end);
// Pull DATA_SENT_GPIO_PIN HIGH // Pull DATA_SENT_GPIO_PIN HIGH
gpio_put(DATA_SENT_GPIO_PIN, true); gpio_put(DATA_SENT_GPIO_PIN, true);
// printf("DATA_SENT signal pulled HIGH.\n"); // printf("DATA_SENT signal pulled HIGH.\n");
// printf("DATA_SENT_GPIO_PIN level: %d\n", gpio_get_out_level(DATA_SENT_GPIO_PIN)); // printf("DATA_SENT_GPIO_PIN level: %d\n", gpio_get_out_level(DATA_SENT_GPIO_PIN));
absolute_time_t compute_start = get_absolute_time();
// Wait for DATA_REQ signal to be pulled LOW // Wait for DATA_REQ signal to be pulled LOW
// printf("Waiting for DATA_REQ signal to be pulled LOW...\n"); // printf("Waiting for DATA_REQ signal to be pulled LOW...\n");
...@@ -823,6 +996,9 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr, ...@@ -823,6 +996,9 @@ void send_data_to_accelerator(PIO *pio, uint *sm, int *nanosoc_img_buffer_addr,
}; };
// printf("DATA_REQ signal pulled LOW.\n"); // printf("DATA_REQ signal pulled LOW.\n");
absolute_time_t compute_end = get_absolute_time();
*compute_time += absolute_time_diff_us(compute_start, compute_end);
// Pull DATA_SENT_GPIO_PIN LOW // Pull DATA_SENT_GPIO_PIN LOW
gpio_put(DATA_SENT_GPIO_PIN, false); gpio_put(DATA_SENT_GPIO_PIN, false);
// printf("DATA_SENT signal pulled LOW.\n"); // printf("DATA_SENT signal pulled LOW.\n");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment