From 3e7e278d6f8d4635fbb2239445df0fd48951d605 Mon Sep 17 00:00:00 2001
From: Patrick Labatut <60359573+patricklabatut@users.noreply.github.com>
Date: Wed, 26 Apr 2023 01:08:35 +0200
Subject: [PATCH] Improve and fix ImageNet-1k dataset preparation (#60)

Document and fix implementation of extra metadata generation for ImageNet-1k.
---
 README.md                         |  42 ++++++--
 dinov2/data/datasets/image_net.py | 172 ++++++++++++++++++------------
 2 files changed, 138 insertions(+), 76 deletions(-)

diff --git a/README.md b/README.md
index f1d74b5..79e8a92 100644
--- a/README.md
+++ b/README.md
@@ -102,16 +102,38 @@ pip install -r requirements.txt
 
 The root directory of the dataset should hold the following contents:
 
-- `<root>/test/ILSVRC2012_test_00000001.JPEG`
-- `<root>/test/[..]`
-- `<root>/test/ILSVRC2012_test_00100000.JPEG`
-- `<root>/train/n01440764/n01440764_10026.JPEG`
-- `<root>/train/[...]`
-- `<root>/train/n15075141/n15075141_9993.JPEG`
-- `<root>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
-- `<root>/val/[...]`
-- `<root>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
-- `<root>/labels.txt`
+- `<ROOT>/test/ILSVRC2012_test_00000001.JPEG`
+- `<ROOT>/test/[..]`
+- `<ROOT>/test/ILSVRC2012_test_00100000.JPEG`
+- `<ROOT>/train/n01440764/n01440764_10026.JPEG`
+- `<ROOT>/train/[...]`
+- `<ROOT>/train/n15075141/n15075141_9993.JPEG`
+- `<ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
+- `<ROOT>/val/[...]`
+- `<ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
+- `<ROOT>/labels.txt`
+
+The provided dataset implementation expects a few additional metadata files to be present under the extra directory:
+
+- `<EXTRA>/class-ids-TRAIN.npy`
+- `<EXTRA>/class-ids-VAL.npy`
+- `<EXTRA>/class-names-TRAIN.npy`
+- `<EXTRA>/class-names-VAL.npy`
+- `<EXTRA>/entries-TEST.npy`
+- `<EXTRA>/entries-TRAIN.npy`
+- `<EXTRA>/entries-VAL.npy`
+
+These metadata files can be generated (once) with the following lines of Python code:
+
+```python
+from dinov2.data.datasets import ImageNet
+
+for split in ImageNet.Split:
+    dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
+    dataset.dump_extra()
+```
+
+Note that the root and extra directories do not have to be distinct directories.
 
 ### ImageNet-22k
 
diff --git a/dinov2/data/datasets/image_net.py b/dinov2/data/datasets/image_net.py
index a72407b..1e1c384 100644
--- a/dinov2/data/datasets/image_net.py
+++ b/dinov2/data/datasets/image_net.py
@@ -6,6 +6,7 @@
 
 import csv
 from enum import Enum
+import logging
 import os
 from typing import Callable, List, Optional, Tuple, Union
 
@@ -14,7 +15,8 @@ import numpy as np
 from .extended import ExtendedVisionDataset
 
 
-_Labels = int
+logger = logging.getLogger("dinov2")
+_Target = int
 
 
 class _Split(Enum):
@@ -52,7 +54,7 @@ class _Split(Enum):
 
 
 class ImageNet(ExtendedVisionDataset):
-    Labels = Union[_Labels]
+    Target = Union[_Target]
     Split = Union[_Split]
 
     def __init__(
@@ -67,112 +69,136 @@ class ImageNet(ExtendedVisionDataset):
     ) -> None:
         super().__init__(root, transforms, transform, target_transform)
         self._extra_root = extra
-
         self._split = split
 
-        entries_path = self._get_entries_path(split, root)
-        self._entries = self._load_extra(entries_path)
-
+        self._entries = None
         self._class_ids = None
         self._class_names = None
 
-        if split == _Split.TEST:
-            return
-
-        class_ids_path = self._get_class_ids_path(split, root)
-        self._class_ids = self._load_extra(class_ids_path)
-
-        class_names_path = self._get_class_names_path(split, root)
-        self._class_names = self._load_extra(class_names_path)
-
     @property
     def split(self) -> "ImageNet.Split":
         return self._split
 
+    def _get_extra_full_path(self, extra_path: str) -> str:
+        return os.path.join(self._extra_root, extra_path)
+
     def _load_extra(self, extra_path: str) -> np.ndarray:
-        extra_root = self._extra_root
-        extra_full_path = os.path.join(extra_root, extra_path)
+        extra_full_path = self._get_extra_full_path(extra_path)
         return np.load(extra_full_path, mmap_mode="r")
 
     def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
-        extra_root = self._extra_root
-        extra_full_path = os.path.join(extra_root, extra_path)
-        os.makedirs(extra_root, exist_ok=True)
+        extra_full_path = self._get_extra_full_path(extra_path)
+        os.makedirs(self._extra_root, exist_ok=True)
         np.save(extra_full_path, extra_array)
 
-    def _get_entries_path(self, split: "ImageNet.Split", root: Optional[str] = None) -> str:
-        return f"entries-{split.value.upper()}.npy"
+    @property
+    def _entries_path(self) -> str:
+        return f"entries-{self._split.value.upper()}.npy"
+
+    @property
+    def _class_ids_path(self) -> str:
+        return f"class-ids-{self._split.value.upper()}.npy"
 
-    def _get_class_ids_path(self, split: "ImageNet.Split", root: Optional[str] = None) -> str:
-        return f"class-ids-{split.value.upper()}.npy"
+    @property
+    def _class_names_path(self) -> str:
+        return f"class-names-{self._split.value.upper()}.npy"
+
+    def _get_entries(self) -> np.ndarray:
+        if self._entries is None:
+            self._entries = self._load_extra(self._entries_path)
+        assert self._entries is not None
+        return self._entries
+
+    def _get_class_ids(self) -> np.ndarray:
+        if self._split == _Split.TEST:
+            assert False, "Class IDs are not available in TEST split"
+        if self._class_ids is None:
+            self._class_ids = self._load_extra(self._class_ids_path)
+        assert self._class_ids is not None
+        return self._class_ids
 
-    def _get_class_names_path(self, split: "ImageNet.Split", root: Optional[str] = None) -> str:
-        return f"class-names-{split.value.upper()}.npy"
+    def _get_class_names(self) -> np.ndarray:
+        if self._split == _Split.TEST:
+            assert False, "Class names are not available in TEST split"
+        if self._class_names is None:
+            self._class_names = self._load_extra(self._class_names_path)
+        assert self._class_names is not None
+        return self._class_names
 
     def find_class_id(self, class_index: int) -> str:
-        assert self._class_ids is not None
-        return str(self._class_ids[class_index])
+        class_ids = self._get_class_ids()
+        return str(class_ids[class_index])
 
     def find_class_name(self, class_index: int) -> str:
-        assert self._class_names is not None
-        return str(self._class_names[class_index])
+        class_names = self._get_class_names()
+        return str(class_names[class_index])
 
     def get_image_data(self, index: int) -> bytes:
-        actual_index = self._entries[index]["actual_index"]
+        entries = self._get_entries()
+        actual_index = entries[index]["actual_index"]
+
         class_id = self.get_class_id(index)
+
         image_relpath = self.split.get_image_relpath(actual_index, class_id)
         image_full_path = os.path.join(self.root, image_relpath)
         with open(image_full_path, mode="rb") as f:
             image_data = f.read()
         return image_data
 
-    def get_target(self, index: int) -> Optional[_Labels]:
-        class_index = self._entries[index]["class_index"]
+    def get_target(self, index: int) -> Optional[Target]:
+        entries = self._get_entries()
+        class_index = entries[index]["class_index"]
         return None if self.split == _Split.TEST else int(class_index)
 
     def get_targets(self) -> Optional[np.ndarray]:
-        return None if self.split == _Split.TEST else self._entries["class_index"]
+        entries = self._get_entries()
+        return None if self.split == _Split.TEST else entries["class_index"]
 
     def get_class_id(self, index: int) -> Optional[str]:
-        class_id = self._entries[index]["class_id"]
+        entries = self._get_entries()
+        class_id = entries[index]["class_id"]
         return None if self.split == _Split.TEST else str(class_id)
 
     def get_class_name(self, index: int) -> Optional[str]:
-        class_name = self._entries[index]["class_name"]
+        entries = self._get_entries()
+        class_name = entries[index]["class_name"]
         return None if self.split == _Split.TEST else str(class_name)
 
     def __len__(self) -> int:
-        assert len(self._entries) == self.split.length
-        return len(self._entries)
+        entries = self._get_entries()
+        assert len(entries) == self.split.length
+        return len(entries)
 
-    def _load_labels(self, root: str) -> List[Tuple[str, str]]:
-        path = os.path.join(root, "labels.txt")
+    def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
+        labels_full_path = os.path.join(self.root, labels_path)
         labels = []
 
         try:
-            with open(path, "r") as f:
+            with open(labels_full_path, "r") as f:
                 reader = csv.reader(f)
                 for row in reader:
                     class_id, class_name = row
                     labels.append((class_id, class_name))
         except OSError as e:
-            raise RuntimeError(f'can not read labels file "{path}"') from e
+            raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
 
         return labels
 
-    def _dump_entries(self, split: "ImageNet.Split", root: Optional[str] = None) -> None:
-        # NOTE: Using torchvision ImageFolder for consistency
-        from torchvision.datasets import ImageFolder
-
-        root = self.root
-        labels = self._load_labels(root)
-
+    def _dump_entries(self) -> None:
+        split = self.split
         if split == ImageNet.Split.TEST:
             dataset = None
             sample_count = split.length
             max_class_id_length, max_class_name_length = 0, 0
         else:
-            dataset_root = os.path.join(root, split.get_dirname())
+            labels_path = "labels.txt"
+            logger.info(f'loading labels from "{labels_path}"')
+            labels = self._load_labels(labels_path)
+
+            # NOTE: Using torchvision ImageFolder for consistency
+            from torchvision.datasets import ImageFolder
+
+            dataset_root = os.path.join(self.root, split.get_dirname())
             dataset = ImageFolder(dataset_root)
             sample_count = len(dataset)
             max_class_id_length, max_class_name_length = -1, -1
@@ -193,29 +219,43 @@ class ImageNet(ExtendedVisionDataset):
         entries_array = np.empty(sample_count, dtype=dtype)
 
         if split == ImageNet.Split.TEST:
+            old_percent = -1
             for index in range(sample_count):
-                entries_array[index] = (index + 1, np.uint32(-1), "", "")
+                percent = 100 * (index + 1) // sample_count
+                if percent > old_percent:
+                    logger.info(f"creating entries: {percent}%")
+                    old_percent = percent
+
+                actual_index = index + 1
+                class_index = np.uint32(-1)
+                class_id, class_name = "", ""
+                entries_array[index] = (actual_index, class_index, class_id, class_name)
         else:
             class_names = {class_id: class_name for class_id, class_name in labels}
 
             assert dataset
-            for index, _ in enumerate(dataset):
+            old_percent = -1
+            for index in range(sample_count):
+                percent = 100 * (index + 1) // sample_count
+                if percent > old_percent:
+                    logger.info(f"creating entries: {percent}%")
+                    old_percent = percent
+
                 image_full_path, class_index = dataset.samples[index]
-                image_relpath = os.path.relpath(image_full_path, root)
+                image_relpath = os.path.relpath(image_full_path, self.root)
                 class_id, actual_index = split.parse_image_relpath(image_relpath)
                 class_name = class_names[class_id]
                 entries_array[index] = (actual_index, class_index, class_id, class_name)
 
-        entries_path = self._get_entries_path(split, root)
-        self._save_extra(entries_array, entries_path)
+        logger.info(f'saving entries to "{self._entries_path}"')
+        self._save_extra(entries_array, self._entries_path)
 
-    def _dump_class_ids_and_names(self, split: "ImageNet.Split", root: Optional[str] = None) -> None:
+    def _dump_class_ids_and_names(self) -> None:
+        split = self.split
         if split == ImageNet.Split.TEST:
             return
 
-        root = self.get_root(root)
-        entries_path = self._get_entries_path(split, root)
-        entries_array = self._load_extra(entries_path)
+        entries_array = self._load_extra(self._entries_path)
 
         max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
         for entry in entries_array:
@@ -240,12 +280,12 @@ class ImageNet(ExtendedVisionDataset):
             class_ids_array[class_index] = class_id
             class_names_array[class_index] = class_name
 
-        class_ids_path = self._get_class_ids_path(split, root)
-        self._save_extra(class_ids_array, class_ids_path)
+        logger.info(f'saving class IDs to "{self._class_ids_path}"')
+        self._save_extra(class_ids_array, self._class_ids_path)
 
-        class_names_path = self._get_class_names_path(split, root)
-        self._save_extra(class_names_array, class_names_path)
+        logger.info(f'saving class names to "{self._class_names_path}"')
+        self._save_extra(class_names_array, self._class_names_path)
 
-    def dump_extra(self, split: "ImageNet.Split", root: Optional[str] = None) -> None:
-        self._dump_entries(split, root)
-        self._dump_class_ids_and_names(split, root)
+    def dump_extra(self) -> None:
+        self._dump_entries()
+        self._dump_class_ids_and_names()
-- 
GitLab