Source code for neodroidvision.data.mixed.nvlabs_ffhq

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

__author__ = "heider"
__doc__ = r"""

           Created on 7/5/22
           """

__all__ = []

from torch.utils.data import Dataset

"""Download Flickr-Faces-HQ (FFHQ) dataset to current working directory."""

import argparse
import glob
import hashlib
import html
import itertools
import json
import os
import queue
import shutil
import sys
import threading
import time
import uuid
from collections import OrderedDict, defaultdict

import PIL.Image
import PIL.ImageFile
import numpy
import requests
import scipy.ndimage

PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True  # avoid "Decompressed Data Too Large" error

# ----------------------------------------------------------------------------

json_spec = dict(
    file_url="https://drive.google.com/uc?id=16N0RV4fHI6joBuKbQAoG34V_cQk7vxSA",
    file_path="ffhq-dataset-v2.json",
    file_size=267793842,
    file_md5="425ae20f06a4da1d4dc0f46d40ba5fd6",
)

tfrecords_specs = [
    dict(
        file_url="https://drive.google.com/uc?id=1LnhoytWihRRJ7CfhLQ76F8YxwxRDlZN3",
        file_path="tfrecords/ffhq/ffhq-r02.tfrecords",
        file_size=6860000,
        file_md5="63e062160f1ef9079d4f51206a95ba39",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1LWeKZGZ_x2rNlTenqsaTk8s7Cpadzjbh",
        file_path="tfrecords/ffhq/ffhq-r03.tfrecords",
        file_size=17290000,
        file_md5="54fb32a11ebaf1b86807cc0446dd4ec5",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1Lr7Tiufr1Za85HQ18yg3XnJXstiI2BAC",
        file_path="tfrecords/ffhq/ffhq-r04.tfrecords",
        file_size=57610000,
        file_md5="7164cc5531f6828bf9c578bdc3320e49",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1LnyiayZ-XJFtatxGFgYePcs9bdxuIJO_",
        file_path="tfrecords/ffhq/ffhq-r05.tfrecords",
        file_size=218890000,
        file_md5="050cc7e5fd07a1508eaa2558dafbd9ed",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1Lt6UP201zHnpH8zLNcKyCIkbC-aMb5V_",
        file_path="tfrecords/ffhq/ffhq-r06.tfrecords",
        file_size=864010000,
        file_md5="90bedc9cc07007cd66615b2b1255aab8",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1LwOP25fJ4xN56YpNCKJZM-3mSMauTxeb",
        file_path="tfrecords/ffhq/ffhq-r07.tfrecords",
        file_size=3444980000,
        file_md5="bff839e0dda771732495541b1aff7047",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1LxxgVBHWgyN8jzf8bQssgVOrTLE8Gv2v",
        file_path="tfrecords/ffhq/ffhq-r08.tfrecords",
        file_size=13766900000,
        file_md5="74de4f07dc7bfb07c0ad4471fdac5e67",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1M-ulhD5h-J7sqSy5Y1njUY_80LPcrv3V",
        file_path="tfrecords/ffhq/ffhq-r09.tfrecords",
        file_size=55054580000,
        file_md5="05355aa457a4bd72709f74a81841b46d",
    ),
    dict(
        file_url="https://drive.google.com/uc?id=1M11BIdIpFCiapUqV658biPlaXsTRvYfM",
        file_path="tfrecords/ffhq/ffhq-r10.tfrecords",
        file_size=220205650000,
        file_md5="bf43cab9609ab2a27892fb6c2415c11b",
    ),
]

license_specs = {
    "json": dict(
        file_url="https://drive.google.com/uc?id=1SHafCugkpMZzYhbgOz0zCuYiy-hb9lYX",
        file_path="LICENSE.txt",
        file_size=1610,
        file_md5="724f3831aaecd61a84fe98500079abc2",
    ),
    "images": dict(
        file_url="https://drive.google.com/uc?id=1sP2qz8TzLkzG2gjwAa4chtdB31THska4",
        file_path="images1024x1024/LICENSE.txt",
        file_size=1610,
        file_md5="724f3831aaecd61a84fe98500079abc2",
    ),
    "thumbs": dict(
        file_url="https://drive.google.com/uc?id=1iaL1S381LS10VVtqu-b2WfF9TiY75Kmj",
        file_path="thumbnails128x128/LICENSE.txt",
        file_size=1610,
        file_md5="724f3831aaecd61a84fe98500079abc2",
    ),
    "wilds": dict(
        file_url="https://drive.google.com/uc?id=1rsfFOEQvkd6_Z547qhpq5LhDl2McJEzw",
        file_path="in-the-wild-images/LICENSE.txt",
        file_size=1610,
        file_md5="724f3831aaecd61a84fe98500079abc2",
    ),
    "tfrecords": dict(
        file_url="https://drive.google.com/uc?id=1SYUmqKdLoTYq-kqsnPsniLScMhspvl5v",
        file_path="tfrecords/ffhq/LICENSE.txt",
        file_size=1610,
        file_md5="724f3831aaecd61a84fe98500079abc2",
    ),
}


# ----------------------------------------------------------------------------


[docs]def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10, **kwargs): """ :param session: :type session: :param file_spec: :type file_spec: :param stats: :type stats: :param chunk_size: :type chunk_size: :param num_attempts: :type num_attempts: :param kwargs: :type kwargs: """ file_path = file_spec["file_path"] file_url = file_spec["file_url"] file_dir = os.path.dirname(file_path) tmp_path = file_path + ".tmp." + uuid.uuid4().hex if file_dir: os.makedirs(file_dir, exist_ok=True) for attempts_left in reversed(range(num_attempts)): data_size = 0 try: # Download. data_md5 = hashlib.md5() with session.get(file_url, stream=True) as res: res.raise_for_status() with open(tmp_path, "wb") as f: for chunk in res.iter_content(chunk_size=chunk_size << 10): f.write(chunk) data_size += len(chunk) data_md5.update(chunk) with stats["lock"]: stats["bytes_done"] += len(chunk) # Validate. if "file_size" in file_spec and data_size != file_spec["file_size"]: raise IOError("Incorrect file size", file_path) if ( "file_md5" in file_spec and data_md5.hexdigest() != file_spec["file_md5"] ): raise IOError("Incorrect file MD5", file_path) if "pixel_size" in file_spec or "pixel_md5" in file_spec: with PIL.Image.open(tmp_path) as image: if ( "pixel_size" in file_spec and list(image.size) != file_spec["pixel_size"] ): raise IOError("Incorrect pixel size", file_path) if ( "pixel_md5" in file_spec and hashlib.md5(numpy.array(image)).hexdigest() != file_spec["pixel_md5"] ): raise IOError("Incorrect pixel MD5", file_path) break except: with stats["lock"]: stats["bytes_done"] -= data_size # Handle known failure cases. if data_size > 0 and data_size < 8192: with open(tmp_path, "rb") as f: data = f.read() data_str = data.decode("utf-8") # Google Drive virus checker nag. links = [ html.unescape(link) for link in data_str.split('"') if "export=download" in link ] if len(links) == 1: if attempts_left: file_url = requests.compat.urljoin(file_url, links[0]) continue # Google Drive quota exceeded. if "Google Drive - Quota exceeded" in data_str: if not attempts_left: raise IOError( "Google Drive download quota exceeded -- please try again later" ) # Last attempt => raise error. if not attempts_left: raise # Rename temp file to the correct name. os.replace(tmp_path, file_path) # atomic with stats["lock"]: stats["files_done"] += 1 # Attempt to clean up any leftover temps. for filename in glob.glob(file_path + ".tmp.*"): try: os.remove(filename) except: pass
# ----------------------------------------------------------------------------
[docs]def choose_bytes_unit(num_bytes): """ :param num_bytes: :type num_bytes: :return: :rtype: """ b = int(numpy.rint(num_bytes)) if b < (100 << 0): return "B", (1 << 0) if b < (100 << 10): return "kB", (1 << 10) if b < (100 << 20): return "MB", (1 << 20) if b < (100 << 30): return "GB", (1 << 30) return "TB", (1 << 40)
# ----------------------------------------------------------------------------
[docs]def format_time(seconds): """ :param seconds: :type seconds: :return: :rtype: """ s = int(numpy.rint(seconds)) if s < 60: return "%ds" % s if s < 60 * 60: return "%dm %02ds" % (s // 60, s % 60) if s < 24 * 60 * 60: return "%dh %02dm" % (s // (60 * 60), (s // 60) % 60) if s < 100 * 24 * 60 * 60: return "%dd %02dh" % (s // (24 * 60 * 60), (s // (60 * 60)) % 24) return ">100d"
# ----------------------------------------------------------------------------
[docs]def download_files( file_specs, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs ): """ :param file_specs: :type file_specs: :param num_threads: :type num_threads: :param status_delay: :type status_delay: :param timing_window: :type timing_window: :param download_kwargs: :type download_kwargs: :return: :rtype: """ # Determine which files to download. done_specs = { spec["file_path"]: spec for spec in file_specs if os.path.isfile(spec["file_path"]) } missing_specs = [spec for spec in file_specs if spec["file_path"] not in done_specs] files_total = len(file_specs) bytes_total = sum(spec["file_size"] for spec in file_specs) stats = dict( files_done=len(done_specs), bytes_done=sum(spec["file_size"] for spec in done_specs.values()), lock=threading.Lock(), ) if len(done_specs) == files_total: print("All files already downloaded -- skipping.") return # Launch worker threads. spec_queue = queue.Queue() exception_queue = queue.Queue() for spec in missing_specs: spec_queue.put(spec) thread_kwargs = dict( spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, download_kwargs=download_kwargs, ) for _thread_idx in range(min(num_threads, len(missing_specs))): threading.Thread( target=_download_thread, kwargs=thread_kwargs, daemon=True ).start() # Monitor status until done. bytes_unit, bytes_div = choose_bytes_unit(bytes_total) spinner = "/-\\|" timing = [] while True: with stats["lock"]: files_done = stats["files_done"] bytes_done = stats["bytes_done"] spinner = spinner[1:] + spinner[:1] timing = timing[max(len(timing) - timing_window + 1, 0) :] + [ (time.time(), bytes_done) ] bandwidth = max( (timing[-1][1] - timing[0][1]) / max(timing[-1][0] - timing[0][0], 1e-8), 0 ) bandwidth_unit, bandwidth_div = choose_bytes_unit(bandwidth) eta = format_time((bytes_total - bytes_done) / max(bandwidth, 1)) print( "\r%s %6.2f%% done %d/%d files %-13s %-10s ETA: %-7s " % ( spinner[0], bytes_done / bytes_total * 100, files_done, files_total, "%.2f/%.2f %s" % (bytes_done / bytes_div, bytes_total / bytes_div, bytes_unit), "%.2f %s/s" % (bandwidth / bandwidth_div, bandwidth_unit), "done" if bytes_total == bytes_done else "..." if len(timing) < timing_window or bandwidth == 0 else eta, ), end="", flush=True, ) if files_done == files_total: print() break try: exc_info = exception_queue.get(timeout=status_delay) raise exc_info[1].with_traceback(exc_info[2]) except queue.Empty: pass
def _download_thread(spec_queue, exception_queue, stats, download_kwargs): with requests.Session() as session: while not spec_queue.empty(): spec = spec_queue.get() try: download_file(session, spec, stats, **download_kwargs) except Exception as exc: exception_queue.put(sys.exc_info()) # ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
[docs]def recreate_aligned_images( json_data, source_dir, dst_dir="realign1024x1024", output_size=1024, transform_size=4096, enable_padding=True, rotate_level=True, random_shift=0.0, retry_crops=False, ): """ :param json_data: :type json_data: :param source_dir: :type source_dir: :param dst_dir: :type dst_dir: :param output_size: :type output_size: :param transform_size: :type transform_size: :param enable_padding: :type enable_padding: :param rotate_level: :type rotate_level: :param random_shift: :type random_shift: :param retry_crops: :type retry_crops: :return: :rtype: """ print("Recreating aligned images...") # Fix random seed for reproducibility numpy.random.seed(12345) # The following random numbers are unused in present implementation, but we consume them for reproducibility _ = numpy.random.normal(0, 1, (len(json_data.values()), 2)) if dst_dir: os.makedirs(dst_dir, exist_ok=True) shutil.copyfile("LICENSE.txt", os.path.join(dst_dir, "LICENSE.txt")) for item_idx, item in enumerate(json_data.values()): print("\r%d / %d ... " % (item_idx, len(json_data)), end="", flush=True) # Parse landmarks. # pylint: disable=unused-variable lm = numpy.array(item["in_the_wild"]["face_landmarks"]) lm_chin = lm[0:17] # left-right lm_eyebrow_left = lm[17:22] # left-right lm_eyebrow_right = lm[22:27] # left-right lm_nose = lm[27:31] # top-down lm_nostrils = lm[31:36] # top-down lm_eye_left = lm[36:42] # left-clockwise lm_eye_right = lm[42:48] # left-clockwise lm_mouth_outer = lm[48:60] # left-clockwise lm_mouth_inner = lm[60:68] # left-clockwise # Calculate auxiliary vectors. eye_left = numpy.mean(lm_eye_left, axis=0) eye_right = numpy.mean(lm_eye_right, axis=0) eye_avg = (eye_left + eye_right) * 0.5 eye_to_eye = eye_right - eye_left mouth_left = lm_mouth_outer[0] mouth_right = lm_mouth_outer[6] mouth_avg = (mouth_left + mouth_right) * 0.5 eye_to_mouth = mouth_avg - eye_avg # Choose oriented crop rectangle. if rotate_level: x = eye_to_eye - numpy.flipud(eye_to_mouth) * [-1, 1] x /= numpy.hypot(*x) x *= max(numpy.hypot(*eye_to_eye) * 2.0, numpy.hypot(*eye_to_mouth) * 1.8) y = numpy.flipud(x) * [-1, 1] c0 = eye_avg + eye_to_mouth * 0.1 else: x = numpy.array([1, 0], dtype=numpy.float64) x *= max(numpy.hypot(*eye_to_eye) * 2.0, numpy.hypot(*eye_to_mouth) * 1.8) y = numpy.flipud(x) * [-1, 1] c0 = eye_avg + eye_to_mouth * 0.1 # Load in-the-wild image. src_file = os.path.join(source_dir, item["in_the_wild"]["file_path"]) if not os.path.isfile(src_file): print('\nCannot find source image. Please run "--wilds" before "--align".') return img = PIL.Image.open(src_file) quad = numpy.stack([c0 - x - y, c0 - x + y, c0 + x + y, c0 + x - y]) qsize = numpy.hypot(*x) * 2 # Keep drawing new random crop offsets until we find one that is contained in the image # and does not require padding if random_shift != 0: for _ in range(1000): # Offset the crop rectange center by a random shift proportional to image dimension # and the requested standard deviation c = c0 + numpy.hypot(*x) * 2 * random_shift * numpy.random.normal( 0, 1, c0.shape ) quad = numpy.stack([c - x - y, c - x + y, c + x + y, c + x - y]) crop = ( int(numpy.floor(min(quad[:, 0]))), int(numpy.floor(min(quad[:, 1]))), int(numpy.ceil(max(quad[:, 0]))), int(numpy.ceil(max(quad[:, 1]))), ) if not retry_crops or not ( crop[0] < 0 or crop[1] < 0 or crop[2] >= img.width or crop[3] >= img.height ): # We're happy with this crop (either it fits within the image, or retries are disabled) break else: # rejected N times, give up and move to next image # (does not happen in practice with the FFHQ data) print("rejected image") return # Shrink. shrink = int(numpy.floor(qsize / output_size * 0.5)) if shrink > 1: rsize = ( int(numpy.rint(float(img.size[0]) / shrink)), int(numpy.rint(float(img.size[1]) / shrink)), ) img = img.resize(rsize, PIL.Image.ANTIALIAS) quad /= shrink qsize /= shrink # Crop. border = max(int(numpy.rint(qsize * 0.1)), 3) crop = ( int(numpy.floor(min(quad[:, 0]))), int(numpy.floor(min(quad[:, 1]))), int(numpy.ceil(max(quad[:, 0]))), int(numpy.ceil(max(quad[:, 1]))), ) crop = ( max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]), ) if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: img = img.crop(crop) quad -= crop[0:2] # Pad. pad = ( int(numpy.floor(min(quad[:, 0]))), int(numpy.floor(min(quad[:, 1]))), int(numpy.ceil(max(quad[:, 0]))), int(numpy.ceil(max(quad[:, 1]))), ) pad = ( max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0), max(pad[3] - img.size[1] + border, 0), ) if enable_padding and max(pad) > border - 4: pad = numpy.maximum(pad, int(numpy.rint(qsize * 0.3))) img = numpy.pad( numpy.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), "reflect", ) h, w, _ = img.shape y, x, _ = numpy.ogrid[:h, :w, :1] mask = numpy.maximum( 1.0 - numpy.minimum( numpy.float32(x) / pad[0], numpy.float32(w - 1 - x) / pad[2] ), 1.0 - numpy.minimum( numpy.float32(y) / pad[1], numpy.float32(h - 1 - y) / pad[3] ), ) blur = qsize * 0.02 img += ( scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img ) * numpy.clip(mask * 3.0 + 1.0, 0.0, 1.0) img += (numpy.median(img, axis=(0, 1)) - img) * numpy.clip(mask, 0.0, 1.0) img = PIL.Image.fromarray( numpy.uint8(numpy.clip(numpy.rint(img), 0, 255)), "RGB" ) quad += pad[:2] # Transform. img = img.transform( (transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR, ) if output_size < transform_size: img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) # Save aligned image. dst_subdir = os.path.join(dst_dir, "%05d" % (item_idx - item_idx % 1000)) os.makedirs(dst_subdir, exist_ok=True) img.save(os.path.join(dst_subdir, "%05d.png" % item_idx)) # All done. print("\r%d / %d ... done" % (len(json_data), len(json_data)))
# ----------------------------------------------------------------------------
[docs]def run(tasks, **download_kwargs): """ :param tasks: :type tasks: :param download_kwargs: :type download_kwargs: """ if not os.path.isfile(json_spec["file_path"]) or not os.path.isfile("LICENSE.txt"): print("Downloading JSON metadata...") download_files([json_spec, license_specs["json"]], **download_kwargs) print("Parsing JSON metadata...") with open(json_spec["file_path"], "rb") as f: json_data = json.load(f, object_pairs_hook=OrderedDict) if "stats" in tasks: print_statistics(json_data) specs = [] if "images" in tasks: specs += [item["image"] for item in json_data.values()] + [ license_specs["images"] ] if "thumbs" in tasks: specs += [item["thumbnail"] for item in json_data.values()] + [ license_specs["thumbs"] ] if "wilds" in tasks: specs += [item["in_the_wild"] for item in json_data.values()] + [ license_specs["wilds"] ] if "tfrecords" in tasks: specs += tfrecords_specs + [license_specs["tfrecords"]] if len(specs): print("Downloading %d files..." % len(specs)) numpy.random.shuffle(specs) # to make the workload more homogeneous download_files(specs, **download_kwargs) if "align" in tasks: recreate_aligned_images( json_data, source_dir=download_kwargs["source_dir"], rotate_level=not download_kwargs["no_rotation"], random_shift=download_kwargs["random_shift"], enable_padding=not download_kwargs["no_padding"], retry_crops=download_kwargs["retry_crops"], )
# ----------------------------------------------------------------------------
[docs]def run_cmdline(argv): """ :param argv: :type argv: """ parser = argparse.ArgumentParser( prog=argv[0], description="Download Flickr-Face-HQ (FFHQ) dataset to current working directory.", ) parser.add_argument( "-j", "--json", help="download metadata as JSON (254 MB)", dest="tasks", action="append_const", const="json", ) parser.add_argument( "-s", "--stats", help="print statistics about the dataset", dest="tasks", action="append_const", const="stats", ) parser.add_argument( "-i", "--images", help="download 1024x1024 images as PNG (89.1 GB)", dest="tasks", action="append_const", const="images", ) parser.add_argument( "-t", "--thumbs", help="download 128x128 thumbnails as PNG (1.95 GB)", dest="tasks", action="append_const", const="thumbs", ) parser.add_argument( "-w", "--wilds", help="download in-the-wild images as PNG (955 GB)", dest="tasks", action="append_const", const="wilds", ) parser.add_argument( "-r", "--tfrecords", help="download multi-resolution TFRecords (273 GB)", dest="tasks", action="append_const", const="tfrecords", ) parser.add_argument( "-a", "--align", help="recreate 1024x1024 images from in-the-wild images", dest="tasks", action="append_const", const="align", ) parser.add_argument( "--num_threads", help="number of concurrent download threads (default: 32)", type=int, default=32, metavar="NUM", ) parser.add_argument( "--status_delay", help="time between download status prints (default: 0.2)", type=float, default=0.2, metavar="SEC", ) parser.add_argument( "--timing_window", help="samples for estimating download eta (default: 50)", type=int, default=50, metavar="LEN", ) parser.add_argument( "--chunk_size", help="chunk size for each download thread (default: 128)", type=int, default=128, metavar="KB", ) parser.add_argument( "--num_attempts", help="number of download attempts per file (default: 10)", type=int, default=10, metavar="NUM", ) parser.add_argument( "--random-shift", help="standard deviation of random crop rectangle jitter", type=float, default=0.0, metavar="SHIFT", ) parser.add_argument( "--retry-crops", help="retry random shift if crop rectangle falls outside image (up to 1000 times)", dest="retry_crops", default=False, action="store_true", ) parser.add_argument( "--no-rotation", help="keep the original orientation of images", dest="no_rotation", default=False, action="store_true", ) parser.add_argument( "--no-padding", help="do not apply blur-padding outside and near the image borders", dest="no_padding", default=False, action="store_true", ) parser.add_argument( "--source-dir", help="where to find already downloaded FFHQ source data", default="", metavar="DIR", ) args = parser.parse_args() if not args.tasks: print('No tasks specified. Please see "-h" for help.') exit(1) run(**vars(args))
# ---------------------------------------------------------------------------- if __name__ == "__main__": run_cmdline(sys.argv) # ----------------------------------------------------------------------------
[docs]class FFHQDataset(Dataset): """ :param json_data: :type json_data: :param source_dir: :type source_dir: :param transform: :type transform: """
[docs] def __init__(self, json_data, source_dir, transform=None): """ :param json_data: :type json_data: :param source_dir: :type source_dir: :param transform: :type transform: """ self.json_data = json_data self.source_dir = source_dir self.transform = transform
def __len__(self): """ :return: :rtype: """ return len(self.json_data) def __getitem__(self, idx): """ :param idx: :type idx: :return: :rtype: """ item = self.json_data[idx] image = self.load_image(item) if self.transform: image = self.transform(image) return image, item["id"]
[docs] def load_image(self, item): """ :param item: :type item: :return: :rtype: """ image = self.load_image_from_disk(item) if image is None: image = self.load_image_from_url(item) return image
[docs] def load_image_from_disk(self, item): """ :param item: :type item: :return: :rtype: """ image = None if item["image"] is not None: image = self.load_image_from_disk_path(item["image"]) return image
[docs] def load_image_from_disk_path(self, path): """ :param path: :type path: :return: :rtype: """ image = None if path is not None: image = PIL.Image.open(path) return image