File size: 4,925 Bytes
838a6a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
#!/usr/bin/env python3
# /// script
# dependencies = [
# "datasets>=2.14.0",
# ]
# ///
"""
Validate dataset format for TRL training.
Usage:
python validate_dataset.py <dataset_name> <method>
Examples:
python validate_dataset.py trl-lib/Capybara sft
python validate_dataset.py Anthropic/hh-rlhf dpo
"""
import sys
from datasets import load_dataset
def validate_sft_dataset(dataset):
"""Validate SFT dataset format."""
print("π Validating SFT dataset...")
# Check for common fields
columns = dataset.column_names
print(f"π Columns: {columns}")
has_messages = "messages" in columns
has_text = "text" in columns
if not (has_messages or has_text):
print("β Dataset must have 'messages' or 'text' field")
return False
# Check first example
example = dataset[0]
if has_messages:
messages = example["messages"]
if not isinstance(messages, list):
print("β 'messages' field must be a list")
return False
if len(messages) == 0:
print("β 'messages' field is empty")
return False
# Check message format
msg = messages[0]
if not isinstance(msg, dict):
print("β Messages must be dictionaries")
return False
if "role" not in msg or "content" not in msg:
print("β Messages must have 'role' and 'content' keys")
return False
print("β
Messages format valid")
print(f" First message: {msg['role']}: {msg['content'][:50]}...")
if has_text:
text = example["text"]
if not isinstance(text, str):
print("β 'text' field must be a string")
return False
if len(text) == 0:
print("β 'text' field is empty")
return False
print("β
Text format valid")
print(f" First text: {text[:100]}...")
return True
def validate_dpo_dataset(dataset):
"""Validate DPO dataset format."""
print("π Validating DPO dataset...")
columns = dataset.column_names
print(f"π Columns: {columns}")
required = ["prompt", "chosen", "rejected"]
missing = [col for col in required if col not in columns]
if missing:
print(f"β Missing required fields: {missing}")
return False
# Check first example
example = dataset[0]
for field in required:
value = example[field]
if isinstance(value, str):
if len(value) == 0:
print(f"β '{field}' field is empty")
return False
print(f"β
'{field}' format valid (string)")
elif isinstance(value, list):
if len(value) == 0:
print(f"β '{field}' field is empty")
return False
print(f"β
'{field}' format valid (list of messages)")
else:
print(f"β '{field}' must be string or list")
return False
return True
def validate_kto_dataset(dataset):
"""Validate KTO dataset format."""
print("π Validating KTO dataset...")
columns = dataset.column_names
print(f"π Columns: {columns}")
required = ["prompt", "completion", "label"]
missing = [col for col in required if col not in columns]
if missing:
print(f"β Missing required fields: {missing}")
return False
# Check first example
example = dataset[0]
if not isinstance(example["label"], bool):
print("β 'label' field must be boolean")
return False
print("β
KTO format valid")
return True
def main():
if len(sys.argv) != 3:
print("Usage: python validate_dataset.py <dataset_name> <method>")
print("Methods: sft, dpo, kto")
sys.exit(1)
dataset_name = sys.argv[1]
method = sys.argv[2].lower()
print(f"π¦ Loading dataset: {dataset_name}")
try:
dataset = load_dataset(dataset_name, split="train")
print(f"β
Dataset loaded: {len(dataset)} examples")
except Exception as e:
print(f"β Failed to load dataset: {e}")
sys.exit(1)
validators = {
"sft": validate_sft_dataset,
"dpo": validate_dpo_dataset,
"kto": validate_kto_dataset,
}
if method not in validators:
print(f"β Unknown method: {method}")
print(f"Supported methods: {list(validators.keys())}")
sys.exit(1)
validator = validators[method]
valid = validator(dataset)
if valid:
print(f"\nβ
Dataset is valid for {method.upper()} training")
sys.exit(0)
else:
print(f"\nβ Dataset is NOT valid for {method.upper()} training")
sys.exit(1)
if __name__ == "__main__":
main()
|