skiAI/createDataset.py
2026-01-10 11:53:07 +01:00

142 lines
4.4 KiB
Python

import json
import os
import random
import boto3
from urllib.parse import urlparse
from tqdm.auto import tqdm
# s3 bucket configuration
MINIO_CONFIG = {
'endpoint_url': 'https://minio.hgk.ch',
'access_key': 'meinAccessKey',
'secret_key': 'meinSecretKey',
'bucket': 'skiai'
}
# input specs, annotations
JSON_PATH = 'datasets/skier_pose/labelstudio_export.json'
# input specs, keypoint orde must stay consistent
KP_ORDER = [
"leftski_tip", "leftski_tail", "rightski_tip", "rightski_tail",
"leftpole_top", "leftpole_bottom", "rightpole_top", "rightpole_bottom"
]
# output specs
OUTPUT_DIR = 'datasets/skier_pose'
TRAIN_RATIO = 0.8
# create folder structure
def __setup_directories():
for split in ['train', 'val']:
os.makedirs(os.path.join(OUTPUT_DIR, split, 'images'), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, split, 'labels'), exist_ok=True)
# download image from s3
def __download_from_minio(s3_path, local_path):
parsed = urlparse(s3_path)
bucket = MINIO_CONFIG['bucket']
# removes 's3://bucketname/' if existing, otherwise slash
key = parsed.path.lstrip('/')
s3 = boto3.client('s3',
endpoint_url=MINIO_CONFIG['endpoint_url'],
aws_access_key_id=MINIO_CONFIG['access_key'],
aws_secret_access_key=MINIO_CONFIG['secret_key'])
s3.download_file(bucket, key, local_path)
# create YOLO dataset
def createYOLOdataset():
__setup_directories()
# read annotations
with open(JSON_PATH, 'r', encoding='utf-8') as f:
data = json.load(f)
random.seed(42)
random.shuffle(data)
split_idx = int(len(data) * TRAIN_RATIO)
# loop over all images
for i, entry in enumerate(tqdm(data, desc="Importing Images", unit="img")):
split = 'train' if i < split_idx else 'val'
# get image name
image_s3_path = entry['data']['image']
filename = os.path.basename(image_s3_path)
base_name = os.path.splitext(filename)[0]
img_local_path = os.path.join(OUTPUT_DIR, split, 'images', filename)
label_local_path = os.path.join(OUTPUT_DIR, split, 'labels', f"{base_name}.txt")
try:
__download_from_minio(image_s3_path, img_local_path)
except Exception as e:
tqdm.write(f"Error treating {filename}: {e}")
continue
yolo_lines = []
# check if annotations, otherwise skip
if not entry.get('annotations'):
continue
results = entry['annotations'][0].get('result', [])
# dummy vars
kp_map = {} # ID -> {label, x, y}
visibility_map = {} # ID -> v_status (1 oder 2)
bboxes = [] # Liste aller gefundenen BBoxes
for res in results:
res_id = res['id']
res_type = res['type']
val = res.get('value', {})
if res_type == 'keypointlabels':
kp_map[res_id] = {
'label': val['keypointlabels'][0],
'x': val['x'] / 100.0,
'y': val['y'] / 100.0
}
elif res_type == 'choices':
if "1" in val.get('choices', []):
visibility_map[res_id] = 1
elif res_type == 'rectanglelabels':
# BBox normalisieren
bw = val['width'] / 100.0
bh = val['height'] / 100.0
bx = (val['x'] / 100.0) + (bw / 2.0)
by = (val['y'] / 100.0) + (bh / 2.0)
bboxes.append(f"{bx:.6f} {by:.6f} {bw:.6f} {bh:.6f}")
# create yolo data
for bbox_coords in bboxes:
line = f"0 {bbox_coords}"
for kp_name in KP_ORDER:
target_id = next((id for id, d in kp_map.items() if d['label'] == kp_name), None)
if target_id:
coords = kp_map[target_id]
# visibility, 0 missing, 1 invisible, 2 visible
v = visibility_map.get(target_id, 2)
line += f" {coords['x']:.6f} {coords['y']:.6f} {v}"
else:
line += " 0.000000 0.000000 0"
yolo_lines.append(line)
with open(label_local_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(yolo_lines))
print(f"Finished! Dataset saved to: {os.path.abspath(OUTPUT_DIR)}")
if __name__ == "__main__":
createYOLOdataset()