Skip to content
Snippets Groups Projects
Commit 338dd2d9 authored by Alexander Hartelt's avatar Alexander Hartelt
Browse files

added scripts

parent 3e5ef39c
Branches
No related tags found
No related merge requests found
import multiprocessing
from ast import literal_eval
import argparse
import json
import numpy as np
import os
import tqdm
from PIL import Image
import itertools
import random
def get_image_colors(path_to_mask: np.array):
image_pil = Image.open(path_to_mask)
if image_pil.mode == 'RGBA':
image_pil = image_pil.convert('RGB')
mask = np.asarray(image_pil)
if mask.ndim == 2 or mask.shape[2] == 2:
return [(255, 255, 255), (0, 0, 0)]
tt = mask.view()
tt.shape = -1, 3
height, width, depth = mask.shape
ifl = tt[..., 0].astype(np.int) * height * width + tt[..., 1].astype(np.int) * width + tt[..., 2].astype(np.int)
colors = np.unique(ifl, return_inverse=False)
colors = [(int(color / (height * width)), int((color / width) % height), int(color % width)) for color in colors]
return colors
def compute_image_map(input_dir, output_dir, max_images=-1, processes=4):
if not os.path.exists(input_dir):
raise Exception("Cannot open {}".format(input_dir))
files = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]
if max_images > 0:
files = random.sample(files, max_images)
with multiprocessing.Pool(processes=processes) as p:
colors = [v for v in
tqdm.tqdm(p.imap(get_image_colors, files), total=len(files))
]
colors = set(itertools.chain.from_iterable(colors))
colors = sorted(colors, key=lambda element: (element[0], element[1], element[2]))[::-1]
color_dict = {str(key): (value, "label") for (value, key) in enumerate(colors)}
if output_dir:
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'image_map.json'), 'w') as fp:
json.dump(color_dict, fp)
def load_image_map_from_file(path):
if not os.path.exists(path):
raise Exception("Cannot open {}".format(path))
with open(path) as f:
data = json.load(f)
color_map = {literal_eval(k): v for k, v in data.items()}
return color_map
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, required=True,
help="Mask directory to process")
parser.add_argument("--output_dir", type=str, required=True,
help="The output dir for the color map")
parser.add_argument("--max_image", type=int, default=-1,
help="Max images to check for color. -1 to check every mask")
parser.add_argument("--processes", type=int, default=4,
help="Number of processes to run")
args = parser.parse_args()
compute_image_map(args.input_dir, args.output_dir, args.max_image, args.processes)
if __name__ == '__main__':
main()
\ No newline at end of file
import argparse
import json
from os import path
from typing import List
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
def dir_path(string):
if path.isdir(string):
return string
else:
raise NotADirectoryError(string)
def main():
from segmentation.network import TrainSettings, dirs_to_pandaframe, load_image_map_from_file, MaskSetting, MaskType, PCGTSVersion, XMLDataset, Network, compose, MaskGenerator, MaskDataset
from segmentation.settings import Architecture
parser = argparse.ArgumentParser()
parser.add_argument("-L", "--l-rate", type=float, default=1e-4,
help="set learning rate")
parser.add_argument("-O", "--output", type=str, default="./",
help="target directory for model and logs")
parser.add_argument("--load", type=str, default=None,
help="load an existing model and continue training")
parser.add_argument("-E", "--n-epoch", type=int, default=100,
help="number of epochs")
parser.add_argument("--data-augmentation", action="store_true",
help="Enable data augmentation")
parser.add_argument("--train_input", type=dir_path, nargs="+", default=[], help="Path to folder(s) containing train images")
parser.add_argument("--train_mask", type=dir_path, nargs="+", default=[], help="Path to folder(s) containing train xmls")
parser.add_argument("--test_input", type=dir_path, nargs="*", default=[], help="Path to folder(s) containing test images")
parser.add_argument("--train_mask", type=dir_path, nargs="+", default=[], help="Path to folder(s) containing test xmls")
parser.add_argument("--color-map", dest="map", type=str, default="image_map.json",
help="color map to load")
parser.add_argument('--architecture',
default=Architecture.UNET,
const=Architecture.UNET,
nargs='?',
choices=[x.value for x in list(Architecture)],
help='Network architecture to use for training')
parser.add_argument('--encoder',
default="efficientnet-b3",
const="efficientnet-b3",
help='Network architecture to use for training')
args = parser.parse_args()
train = dirs_to_pandaframe(args.train_input, args.train_mask)
test = dirs_to_pandaframe(args.test_input, args.train_mask) if len(args.test_input > 0) else train
map = load_image_map_from_file(args.map)
from segmentation.dataset import base_line_transform
settings = MaskSetting(MASK_TYPE=MaskType.BASE_LINE, PCGTS_VERSION=PCGTSVersion.PCGTS2013, LINEWIDTH=5,
BASELINELENGTH=10)
train_dataset = XMLDataset(train, map, transform=compose([base_line_transform()]),
mask_generator=MaskGenerator(settings=settings))
test_dataset = XMLDataset(test, map, transform=compose([base_line_transform()]),
mask_generator=MaskGenerator(settings=settings))
setting = TrainSettings(CLASSES=len(map), TRAIN_DATASET=train_dataset, VAL_DATASET=test_dataset,
OUTPUT_PATH=args.output,
MODEL_PATH=args.load)
trainer = Network(setting, color_map=map)
trainer.train()
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment