#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "heider"
__doc__ = r"""
Created on 7/5/22
"""
__all__ = []
from torch.utils.data import Dataset
"""Download Flickr-Faces-HQ (FFHQ) dataset to current working directory."""
import argparse
import glob
import hashlib
import html
import itertools
import json
import os
import queue
import shutil
import sys
import threading
import time
import uuid
from collections import OrderedDict, defaultdict
import PIL.Image
import PIL.ImageFile
import numpy
import requests
import scipy.ndimage
PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error
# ----------------------------------------------------------------------------
json_spec = dict(
file_url="https://drive.google.com/uc?id=16N0RV4fHI6joBuKbQAoG34V_cQk7vxSA",
file_path="ffhq-dataset-v2.json",
file_size=267793842,
file_md5="425ae20f06a4da1d4dc0f46d40ba5fd6",
)
tfrecords_specs = [
dict(
file_url="https://drive.google.com/uc?id=1LnhoytWihRRJ7CfhLQ76F8YxwxRDlZN3",
file_path="tfrecords/ffhq/ffhq-r02.tfrecords",
file_size=6860000,
file_md5="63e062160f1ef9079d4f51206a95ba39",
),
dict(
file_url="https://drive.google.com/uc?id=1LWeKZGZ_x2rNlTenqsaTk8s7Cpadzjbh",
file_path="tfrecords/ffhq/ffhq-r03.tfrecords",
file_size=17290000,
file_md5="54fb32a11ebaf1b86807cc0446dd4ec5",
),
dict(
file_url="https://drive.google.com/uc?id=1Lr7Tiufr1Za85HQ18yg3XnJXstiI2BAC",
file_path="tfrecords/ffhq/ffhq-r04.tfrecords",
file_size=57610000,
file_md5="7164cc5531f6828bf9c578bdc3320e49",
),
dict(
file_url="https://drive.google.com/uc?id=1LnyiayZ-XJFtatxGFgYePcs9bdxuIJO_",
file_path="tfrecords/ffhq/ffhq-r05.tfrecords",
file_size=218890000,
file_md5="050cc7e5fd07a1508eaa2558dafbd9ed",
),
dict(
file_url="https://drive.google.com/uc?id=1Lt6UP201zHnpH8zLNcKyCIkbC-aMb5V_",
file_path="tfrecords/ffhq/ffhq-r06.tfrecords",
file_size=864010000,
file_md5="90bedc9cc07007cd66615b2b1255aab8",
),
dict(
file_url="https://drive.google.com/uc?id=1LwOP25fJ4xN56YpNCKJZM-3mSMauTxeb",
file_path="tfrecords/ffhq/ffhq-r07.tfrecords",
file_size=3444980000,
file_md5="bff839e0dda771732495541b1aff7047",
),
dict(
file_url="https://drive.google.com/uc?id=1LxxgVBHWgyN8jzf8bQssgVOrTLE8Gv2v",
file_path="tfrecords/ffhq/ffhq-r08.tfrecords",
file_size=13766900000,
file_md5="74de4f07dc7bfb07c0ad4471fdac5e67",
),
dict(
file_url="https://drive.google.com/uc?id=1M-ulhD5h-J7sqSy5Y1njUY_80LPcrv3V",
file_path="tfrecords/ffhq/ffhq-r09.tfrecords",
file_size=55054580000,
file_md5="05355aa457a4bd72709f74a81841b46d",
),
dict(
file_url="https://drive.google.com/uc?id=1M11BIdIpFCiapUqV658biPlaXsTRvYfM",
file_path="tfrecords/ffhq/ffhq-r10.tfrecords",
file_size=220205650000,
file_md5="bf43cab9609ab2a27892fb6c2415c11b",
),
]
license_specs = {
"json": dict(
file_url="https://drive.google.com/uc?id=1SHafCugkpMZzYhbgOz0zCuYiy-hb9lYX",
file_path="LICENSE.txt",
file_size=1610,
file_md5="724f3831aaecd61a84fe98500079abc2",
),
"images": dict(
file_url="https://drive.google.com/uc?id=1sP2qz8TzLkzG2gjwAa4chtdB31THska4",
file_path="images1024x1024/LICENSE.txt",
file_size=1610,
file_md5="724f3831aaecd61a84fe98500079abc2",
),
"thumbs": dict(
file_url="https://drive.google.com/uc?id=1iaL1S381LS10VVtqu-b2WfF9TiY75Kmj",
file_path="thumbnails128x128/LICENSE.txt",
file_size=1610,
file_md5="724f3831aaecd61a84fe98500079abc2",
),
"wilds": dict(
file_url="https://drive.google.com/uc?id=1rsfFOEQvkd6_Z547qhpq5LhDl2McJEzw",
file_path="in-the-wild-images/LICENSE.txt",
file_size=1610,
file_md5="724f3831aaecd61a84fe98500079abc2",
),
"tfrecords": dict(
file_url="https://drive.google.com/uc?id=1SYUmqKdLoTYq-kqsnPsniLScMhspvl5v",
file_path="tfrecords/ffhq/LICENSE.txt",
file_size=1610,
file_md5="724f3831aaecd61a84fe98500079abc2",
),
}
# ----------------------------------------------------------------------------
[docs]def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10, **kwargs):
"""
:param session:
:type session:
:param file_spec:
:type file_spec:
:param stats:
:type stats:
:param chunk_size:
:type chunk_size:
:param num_attempts:
:type num_attempts:
:param kwargs:
:type kwargs:
"""
file_path = file_spec["file_path"]
file_url = file_spec["file_url"]
file_dir = os.path.dirname(file_path)
tmp_path = file_path + ".tmp." + uuid.uuid4().hex
if file_dir:
os.makedirs(file_dir, exist_ok=True)
for attempts_left in reversed(range(num_attempts)):
data_size = 0
try:
# Download.
data_md5 = hashlib.md5()
with session.get(file_url, stream=True) as res:
res.raise_for_status()
with open(tmp_path, "wb") as f:
for chunk in res.iter_content(chunk_size=chunk_size << 10):
f.write(chunk)
data_size += len(chunk)
data_md5.update(chunk)
with stats["lock"]:
stats["bytes_done"] += len(chunk)
# Validate.
if "file_size" in file_spec and data_size != file_spec["file_size"]:
raise IOError("Incorrect file size", file_path)
if (
"file_md5" in file_spec
and data_md5.hexdigest() != file_spec["file_md5"]
):
raise IOError("Incorrect file MD5", file_path)
if "pixel_size" in file_spec or "pixel_md5" in file_spec:
with PIL.Image.open(tmp_path) as image:
if (
"pixel_size" in file_spec
and list(image.size) != file_spec["pixel_size"]
):
raise IOError("Incorrect pixel size", file_path)
if (
"pixel_md5" in file_spec
and hashlib.md5(numpy.array(image)).hexdigest()
!= file_spec["pixel_md5"]
):
raise IOError("Incorrect pixel MD5", file_path)
break
except:
with stats["lock"]:
stats["bytes_done"] -= data_size
# Handle known failure cases.
if data_size > 0 and data_size < 8192:
with open(tmp_path, "rb") as f:
data = f.read()
data_str = data.decode("utf-8")
# Google Drive virus checker nag.
links = [
html.unescape(link)
for link in data_str.split('"')
if "export=download" in link
]
if len(links) == 1:
if attempts_left:
file_url = requests.compat.urljoin(file_url, links[0])
continue
# Google Drive quota exceeded.
if "Google Drive - Quota exceeded" in data_str:
if not attempts_left:
raise IOError(
"Google Drive download quota exceeded -- please try again later"
)
# Last attempt => raise error.
if not attempts_left:
raise
# Rename temp file to the correct name.
os.replace(tmp_path, file_path) # atomic
with stats["lock"]:
stats["files_done"] += 1
# Attempt to clean up any leftover temps.
for filename in glob.glob(file_path + ".tmp.*"):
try:
os.remove(filename)
except:
pass
# ----------------------------------------------------------------------------
[docs]def choose_bytes_unit(num_bytes):
"""
:param num_bytes:
:type num_bytes:
:return:
:rtype:
"""
b = int(numpy.rint(num_bytes))
if b < (100 << 0):
return "B", (1 << 0)
if b < (100 << 10):
return "kB", (1 << 10)
if b < (100 << 20):
return "MB", (1 << 20)
if b < (100 << 30):
return "GB", (1 << 30)
return "TB", (1 << 40)
# ----------------------------------------------------------------------------
# ----------------------------------------------------------------------------
[docs]def download_files(
file_specs, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs
):
"""
:param file_specs:
:type file_specs:
:param num_threads:
:type num_threads:
:param status_delay:
:type status_delay:
:param timing_window:
:type timing_window:
:param download_kwargs:
:type download_kwargs:
:return:
:rtype:
"""
# Determine which files to download.
done_specs = {
spec["file_path"]: spec
for spec in file_specs
if os.path.isfile(spec["file_path"])
}
missing_specs = [spec for spec in file_specs if spec["file_path"] not in done_specs]
files_total = len(file_specs)
bytes_total = sum(spec["file_size"] for spec in file_specs)
stats = dict(
files_done=len(done_specs),
bytes_done=sum(spec["file_size"] for spec in done_specs.values()),
lock=threading.Lock(),
)
if len(done_specs) == files_total:
print("All files already downloaded -- skipping.")
return
# Launch worker threads.
spec_queue = queue.Queue()
exception_queue = queue.Queue()
for spec in missing_specs:
spec_queue.put(spec)
thread_kwargs = dict(
spec_queue=spec_queue,
exception_queue=exception_queue,
stats=stats,
download_kwargs=download_kwargs,
)
for _thread_idx in range(min(num_threads, len(missing_specs))):
threading.Thread(
target=_download_thread, kwargs=thread_kwargs, daemon=True
).start()
# Monitor status until done.
bytes_unit, bytes_div = choose_bytes_unit(bytes_total)
spinner = "/-\\|"
timing = []
while True:
with stats["lock"]:
files_done = stats["files_done"]
bytes_done = stats["bytes_done"]
spinner = spinner[1:] + spinner[:1]
timing = timing[max(len(timing) - timing_window + 1, 0) :] + [
(time.time(), bytes_done)
]
bandwidth = max(
(timing[-1][1] - timing[0][1]) / max(timing[-1][0] - timing[0][0], 1e-8), 0
)
bandwidth_unit, bandwidth_div = choose_bytes_unit(bandwidth)
eta = format_time((bytes_total - bytes_done) / max(bandwidth, 1))
print(
"\r%s %6.2f%% done %d/%d files %-13s %-10s ETA: %-7s "
% (
spinner[0],
bytes_done / bytes_total * 100,
files_done,
files_total,
"%.2f/%.2f %s"
% (bytes_done / bytes_div, bytes_total / bytes_div, bytes_unit),
"%.2f %s/s" % (bandwidth / bandwidth_div, bandwidth_unit),
"done"
if bytes_total == bytes_done
else "..."
if len(timing) < timing_window or bandwidth == 0
else eta,
),
end="",
flush=True,
)
if files_done == files_total:
print()
break
try:
exc_info = exception_queue.get(timeout=status_delay)
raise exc_info[1].with_traceback(exc_info[2])
except queue.Empty:
pass
def _download_thread(spec_queue, exception_queue, stats, download_kwargs):
with requests.Session() as session:
while not spec_queue.empty():
spec = spec_queue.get()
try:
download_file(session, spec, stats, **download_kwargs)
except Exception as exc:
exception_queue.put(sys.exc_info())
# ----------------------------------------------------------------------------
[docs]def print_statistics(json_data):
"""
:param json_data:
:type json_data:
"""
categories = defaultdict(int)
licenses = defaultdict(int)
countries = defaultdict(int)
for item in json_data.values():
categories[item["category"]] += 1
licenses[item["metadata"]["license"]] += 1
country = item["metadata"]["country"]
countries[country if country else "<Unknown>"] += 1
for name in [
name for name, num in countries.items() if num / len(json_data) < 1e-3
]:
countries["<Other>"] += countries.pop(name)
rows = [[]] * 2
rows += [["Category", "Images", "% of all"]]
rows += [["---"] * 3]
for name, num in sorted(categories.items(), key=lambda x: -x[1]):
rows += [[name, "%d" % num, "%.2f" % (100.0 * num / len(json_data))]]
rows += [[]] * 2
rows += [["License", "Images", "% of all"]]
rows += [["---"] * 3]
for name, num in sorted(licenses.items(), key=lambda x: -x[1]):
rows += [[name, "%d" % num, "%.2f" % (100.0 * num / len(json_data))]]
rows += [[]] * 2
rows += [["Country", "Images", "% of all", "% of known"]]
rows += [["---"] * 4]
for name, num in sorted(
countries.items(), key=lambda x: -x[1] if x[0] != "<Other>" else 0
):
rows += [
[
name,
"%d" % num,
"%.2f" % (100.0 * num / len(json_data)),
"%.2f"
% (
0
if name == "<Unknown>"
else 100.0 * num / (len(json_data) - countries["<Unknown>"])
),
]
]
rows += [[]] * 2
widths = [
max(len(cell) for cell in column if cell is not None)
for column in itertools.zip_longest(*rows)
]
for row in rows:
print(
" ".join(
cell + " " * (width - len(cell)) for cell, width in zip(row, widths)
)
)
# ----------------------------------------------------------------------------
[docs]def recreate_aligned_images(
json_data,
source_dir,
dst_dir="realign1024x1024",
output_size=1024,
transform_size=4096,
enable_padding=True,
rotate_level=True,
random_shift=0.0,
retry_crops=False,
):
"""
:param json_data:
:type json_data:
:param source_dir:
:type source_dir:
:param dst_dir:
:type dst_dir:
:param output_size:
:type output_size:
:param transform_size:
:type transform_size:
:param enable_padding:
:type enable_padding:
:param rotate_level:
:type rotate_level:
:param random_shift:
:type random_shift:
:param retry_crops:
:type retry_crops:
:return:
:rtype:
"""
print("Recreating aligned images...")
# Fix random seed for reproducibility
numpy.random.seed(12345)
# The following random numbers are unused in present implementation, but we consume them for reproducibility
_ = numpy.random.normal(0, 1, (len(json_data.values()), 2))
if dst_dir:
os.makedirs(dst_dir, exist_ok=True)
shutil.copyfile("LICENSE.txt", os.path.join(dst_dir, "LICENSE.txt"))
for item_idx, item in enumerate(json_data.values()):
print("\r%d / %d ... " % (item_idx, len(json_data)), end="", flush=True)
# Parse landmarks.
# pylint: disable=unused-variable
lm = numpy.array(item["in_the_wild"]["face_landmarks"])
lm_chin = lm[0:17] # left-right
lm_eyebrow_left = lm[17:22] # left-right
lm_eyebrow_right = lm[22:27] # left-right
lm_nose = lm[27:31] # top-down
lm_nostrils = lm[31:36] # top-down
lm_eye_left = lm[36:42] # left-clockwise
lm_eye_right = lm[42:48] # left-clockwise
lm_mouth_outer = lm[48:60] # left-clockwise
lm_mouth_inner = lm[60:68] # left-clockwise
# Calculate auxiliary vectors.
eye_left = numpy.mean(lm_eye_left, axis=0)
eye_right = numpy.mean(lm_eye_right, axis=0)
eye_avg = (eye_left + eye_right) * 0.5
eye_to_eye = eye_right - eye_left
mouth_left = lm_mouth_outer[0]
mouth_right = lm_mouth_outer[6]
mouth_avg = (mouth_left + mouth_right) * 0.5
eye_to_mouth = mouth_avg - eye_avg
# Choose oriented crop rectangle.
if rotate_level:
x = eye_to_eye - numpy.flipud(eye_to_mouth) * [-1, 1]
x /= numpy.hypot(*x)
x *= max(numpy.hypot(*eye_to_eye) * 2.0, numpy.hypot(*eye_to_mouth) * 1.8)
y = numpy.flipud(x) * [-1, 1]
c0 = eye_avg + eye_to_mouth * 0.1
else:
x = numpy.array([1, 0], dtype=numpy.float64)
x *= max(numpy.hypot(*eye_to_eye) * 2.0, numpy.hypot(*eye_to_mouth) * 1.8)
y = numpy.flipud(x) * [-1, 1]
c0 = eye_avg + eye_to_mouth * 0.1
# Load in-the-wild image.
src_file = os.path.join(source_dir, item["in_the_wild"]["file_path"])
if not os.path.isfile(src_file):
print('\nCannot find source image. Please run "--wilds" before "--align".')
return
img = PIL.Image.open(src_file)
quad = numpy.stack([c0 - x - y, c0 - x + y, c0 + x + y, c0 + x - y])
qsize = numpy.hypot(*x) * 2
# Keep drawing new random crop offsets until we find one that is contained in the image
# and does not require padding
if random_shift != 0:
for _ in range(1000):
# Offset the crop rectange center by a random shift proportional to image dimension
# and the requested standard deviation
c = c0 + numpy.hypot(*x) * 2 * random_shift * numpy.random.normal(
0, 1, c0.shape
)
quad = numpy.stack([c - x - y, c - x + y, c + x + y, c + x - y])
crop = (
int(numpy.floor(min(quad[:, 0]))),
int(numpy.floor(min(quad[:, 1]))),
int(numpy.ceil(max(quad[:, 0]))),
int(numpy.ceil(max(quad[:, 1]))),
)
if not retry_crops or not (
crop[0] < 0
or crop[1] < 0
or crop[2] >= img.width
or crop[3] >= img.height
):
# We're happy with this crop (either it fits within the image, or retries are disabled)
break
else:
# rejected N times, give up and move to next image
# (does not happen in practice with the FFHQ data)
print("rejected image")
return
# Shrink.
shrink = int(numpy.floor(qsize / output_size * 0.5))
if shrink > 1:
rsize = (
int(numpy.rint(float(img.size[0]) / shrink)),
int(numpy.rint(float(img.size[1]) / shrink)),
)
img = img.resize(rsize, PIL.Image.ANTIALIAS)
quad /= shrink
qsize /= shrink
# Crop.
border = max(int(numpy.rint(qsize * 0.1)), 3)
crop = (
int(numpy.floor(min(quad[:, 0]))),
int(numpy.floor(min(quad[:, 1]))),
int(numpy.ceil(max(quad[:, 0]))),
int(numpy.ceil(max(quad[:, 1]))),
)
crop = (
max(crop[0] - border, 0),
max(crop[1] - border, 0),
min(crop[2] + border, img.size[0]),
min(crop[3] + border, img.size[1]),
)
if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
img = img.crop(crop)
quad -= crop[0:2]
# Pad.
pad = (
int(numpy.floor(min(quad[:, 0]))),
int(numpy.floor(min(quad[:, 1]))),
int(numpy.ceil(max(quad[:, 0]))),
int(numpy.ceil(max(quad[:, 1]))),
)
pad = (
max(-pad[0] + border, 0),
max(-pad[1] + border, 0),
max(pad[2] - img.size[0] + border, 0),
max(pad[3] - img.size[1] + border, 0),
)
if enable_padding and max(pad) > border - 4:
pad = numpy.maximum(pad, int(numpy.rint(qsize * 0.3)))
img = numpy.pad(
numpy.float32(img),
((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
"reflect",
)
h, w, _ = img.shape
y, x, _ = numpy.ogrid[:h, :w, :1]
mask = numpy.maximum(
1.0
- numpy.minimum(
numpy.float32(x) / pad[0], numpy.float32(w - 1 - x) / pad[2]
),
1.0
- numpy.minimum(
numpy.float32(y) / pad[1], numpy.float32(h - 1 - y) / pad[3]
),
)
blur = qsize * 0.02
img += (
scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img
) * numpy.clip(mask * 3.0 + 1.0, 0.0, 1.0)
img += (numpy.median(img, axis=(0, 1)) - img) * numpy.clip(mask, 0.0, 1.0)
img = PIL.Image.fromarray(
numpy.uint8(numpy.clip(numpy.rint(img), 0, 255)), "RGB"
)
quad += pad[:2]
# Transform.
img = img.transform(
(transform_size, transform_size),
PIL.Image.QUAD,
(quad + 0.5).flatten(),
PIL.Image.BILINEAR,
)
if output_size < transform_size:
img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
# Save aligned image.
dst_subdir = os.path.join(dst_dir, "%05d" % (item_idx - item_idx % 1000))
os.makedirs(dst_subdir, exist_ok=True)
img.save(os.path.join(dst_subdir, "%05d.png" % item_idx))
# All done.
print("\r%d / %d ... done" % (len(json_data), len(json_data)))
# ----------------------------------------------------------------------------
[docs]def run(tasks, **download_kwargs):
"""
:param tasks:
:type tasks:
:param download_kwargs:
:type download_kwargs:
"""
if not os.path.isfile(json_spec["file_path"]) or not os.path.isfile("LICENSE.txt"):
print("Downloading JSON metadata...")
download_files([json_spec, license_specs["json"]], **download_kwargs)
print("Parsing JSON metadata...")
with open(json_spec["file_path"], "rb") as f:
json_data = json.load(f, object_pairs_hook=OrderedDict)
if "stats" in tasks:
print_statistics(json_data)
specs = []
if "images" in tasks:
specs += [item["image"] for item in json_data.values()] + [
license_specs["images"]
]
if "thumbs" in tasks:
specs += [item["thumbnail"] for item in json_data.values()] + [
license_specs["thumbs"]
]
if "wilds" in tasks:
specs += [item["in_the_wild"] for item in json_data.values()] + [
license_specs["wilds"]
]
if "tfrecords" in tasks:
specs += tfrecords_specs + [license_specs["tfrecords"]]
if len(specs):
print("Downloading %d files..." % len(specs))
numpy.random.shuffle(specs) # to make the workload more homogeneous
download_files(specs, **download_kwargs)
if "align" in tasks:
recreate_aligned_images(
json_data,
source_dir=download_kwargs["source_dir"],
rotate_level=not download_kwargs["no_rotation"],
random_shift=download_kwargs["random_shift"],
enable_padding=not download_kwargs["no_padding"],
retry_crops=download_kwargs["retry_crops"],
)
# ----------------------------------------------------------------------------
[docs]def run_cmdline(argv):
"""
:param argv:
:type argv:
"""
parser = argparse.ArgumentParser(
prog=argv[0],
description="Download Flickr-Face-HQ (FFHQ) dataset to current working directory.",
)
parser.add_argument(
"-j",
"--json",
help="download metadata as JSON (254 MB)",
dest="tasks",
action="append_const",
const="json",
)
parser.add_argument(
"-s",
"--stats",
help="print statistics about the dataset",
dest="tasks",
action="append_const",
const="stats",
)
parser.add_argument(
"-i",
"--images",
help="download 1024x1024 images as PNG (89.1 GB)",
dest="tasks",
action="append_const",
const="images",
)
parser.add_argument(
"-t",
"--thumbs",
help="download 128x128 thumbnails as PNG (1.95 GB)",
dest="tasks",
action="append_const",
const="thumbs",
)
parser.add_argument(
"-w",
"--wilds",
help="download in-the-wild images as PNG (955 GB)",
dest="tasks",
action="append_const",
const="wilds",
)
parser.add_argument(
"-r",
"--tfrecords",
help="download multi-resolution TFRecords (273 GB)",
dest="tasks",
action="append_const",
const="tfrecords",
)
parser.add_argument(
"-a",
"--align",
help="recreate 1024x1024 images from in-the-wild images",
dest="tasks",
action="append_const",
const="align",
)
parser.add_argument(
"--num_threads",
help="number of concurrent download threads (default: 32)",
type=int,
default=32,
metavar="NUM",
)
parser.add_argument(
"--status_delay",
help="time between download status prints (default: 0.2)",
type=float,
default=0.2,
metavar="SEC",
)
parser.add_argument(
"--timing_window",
help="samples for estimating download eta (default: 50)",
type=int,
default=50,
metavar="LEN",
)
parser.add_argument(
"--chunk_size",
help="chunk size for each download thread (default: 128)",
type=int,
default=128,
metavar="KB",
)
parser.add_argument(
"--num_attempts",
help="number of download attempts per file (default: 10)",
type=int,
default=10,
metavar="NUM",
)
parser.add_argument(
"--random-shift",
help="standard deviation of random crop rectangle jitter",
type=float,
default=0.0,
metavar="SHIFT",
)
parser.add_argument(
"--retry-crops",
help="retry random shift if crop rectangle falls outside image (up to 1000 times)",
dest="retry_crops",
default=False,
action="store_true",
)
parser.add_argument(
"--no-rotation",
help="keep the original orientation of images",
dest="no_rotation",
default=False,
action="store_true",
)
parser.add_argument(
"--no-padding",
help="do not apply blur-padding outside and near the image borders",
dest="no_padding",
default=False,
action="store_true",
)
parser.add_argument(
"--source-dir",
help="where to find already downloaded FFHQ source data",
default="",
metavar="DIR",
)
args = parser.parse_args()
if not args.tasks:
print('No tasks specified. Please see "-h" for help.')
exit(1)
run(**vars(args))
# ----------------------------------------------------------------------------
if __name__ == "__main__":
run_cmdline(sys.argv)
# ----------------------------------------------------------------------------
[docs]class FFHQDataset(Dataset):
"""
:param json_data:
:type json_data:
:param source_dir:
:type source_dir:
:param transform:
:type transform:
"""
[docs] def __init__(self, json_data, source_dir, transform=None):
"""
:param json_data:
:type json_data:
:param source_dir:
:type source_dir:
:param transform:
:type transform:
"""
self.json_data = json_data
self.source_dir = source_dir
self.transform = transform
def __len__(self):
"""
:return:
:rtype:
"""
return len(self.json_data)
def __getitem__(self, idx):
"""
:param idx:
:type idx:
:return:
:rtype:
"""
item = self.json_data[idx]
image = self.load_image(item)
if self.transform:
image = self.transform(image)
return image, item["id"]
[docs] def load_image(self, item):
"""
:param item:
:type item:
:return:
:rtype:
"""
image = self.load_image_from_disk(item)
if image is None:
image = self.load_image_from_url(item)
return image
[docs] def load_image_from_disk(self, item):
"""
:param item:
:type item:
:return:
:rtype:
"""
image = None
if item["image"] is not None:
image = self.load_image_from_disk_path(item["image"])
return image
[docs] def load_image_from_disk_path(self, path):
"""
:param path:
:type path:
:return:
:rtype:
"""
image = None
if path is not None:
image = PIL.Image.open(path)
return image