|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Validate dataset format for TRL training. |
|
|
|
|
|
Usage: |
|
|
python validate_dataset.py <dataset_name> <method> |
|
|
|
|
|
Examples: |
|
|
python validate_dataset.py trl-lib/Capybara sft |
|
|
python validate_dataset.py Anthropic/hh-rlhf dpo |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from datasets import load_dataset |
|
|
|
|
|
def validate_sft_dataset(dataset): |
|
|
"""Validate SFT dataset format.""" |
|
|
print("π Validating SFT dataset...") |
|
|
|
|
|
|
|
|
columns = dataset.column_names |
|
|
print(f"π Columns: {columns}") |
|
|
|
|
|
has_messages = "messages" in columns |
|
|
has_text = "text" in columns |
|
|
|
|
|
if not (has_messages or has_text): |
|
|
print("β Dataset must have 'messages' or 'text' field") |
|
|
return False |
|
|
|
|
|
|
|
|
example = dataset[0] |
|
|
|
|
|
if has_messages: |
|
|
messages = example["messages"] |
|
|
if not isinstance(messages, list): |
|
|
print("β 'messages' field must be a list") |
|
|
return False |
|
|
|
|
|
if len(messages) == 0: |
|
|
print("β 'messages' field is empty") |
|
|
return False |
|
|
|
|
|
|
|
|
msg = messages[0] |
|
|
if not isinstance(msg, dict): |
|
|
print("β Messages must be dictionaries") |
|
|
return False |
|
|
|
|
|
if "role" not in msg or "content" not in msg: |
|
|
print("β Messages must have 'role' and 'content' keys") |
|
|
return False |
|
|
|
|
|
print("β
Messages format valid") |
|
|
print(f" First message: {msg['role']}: {msg['content'][:50]}...") |
|
|
|
|
|
if has_text: |
|
|
text = example["text"] |
|
|
if not isinstance(text, str): |
|
|
print("β 'text' field must be a string") |
|
|
return False |
|
|
|
|
|
if len(text) == 0: |
|
|
print("β 'text' field is empty") |
|
|
return False |
|
|
|
|
|
print("β
Text format valid") |
|
|
print(f" First text: {text[:100]}...") |
|
|
|
|
|
return True |
|
|
|
|
|
def validate_dpo_dataset(dataset): |
|
|
"""Validate DPO dataset format.""" |
|
|
print("π Validating DPO dataset...") |
|
|
|
|
|
columns = dataset.column_names |
|
|
print(f"π Columns: {columns}") |
|
|
|
|
|
required = ["prompt", "chosen", "rejected"] |
|
|
missing = [col for col in required if col not in columns] |
|
|
|
|
|
if missing: |
|
|
print(f"β Missing required fields: {missing}") |
|
|
return False |
|
|
|
|
|
|
|
|
example = dataset[0] |
|
|
|
|
|
for field in required: |
|
|
value = example[field] |
|
|
if isinstance(value, str): |
|
|
if len(value) == 0: |
|
|
print(f"β '{field}' field is empty") |
|
|
return False |
|
|
print(f"β
'{field}' format valid (string)") |
|
|
elif isinstance(value, list): |
|
|
if len(value) == 0: |
|
|
print(f"β '{field}' field is empty") |
|
|
return False |
|
|
print(f"β
'{field}' format valid (list of messages)") |
|
|
else: |
|
|
print(f"β '{field}' must be string or list") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def validate_kto_dataset(dataset): |
|
|
"""Validate KTO dataset format.""" |
|
|
print("π Validating KTO dataset...") |
|
|
|
|
|
columns = dataset.column_names |
|
|
print(f"π Columns: {columns}") |
|
|
|
|
|
required = ["prompt", "completion", "label"] |
|
|
missing = [col for col in required if col not in columns] |
|
|
|
|
|
if missing: |
|
|
print(f"β Missing required fields: {missing}") |
|
|
return False |
|
|
|
|
|
|
|
|
example = dataset[0] |
|
|
|
|
|
if not isinstance(example["label"], bool): |
|
|
print("β 'label' field must be boolean") |
|
|
return False |
|
|
|
|
|
print("β
KTO format valid") |
|
|
return True |
|
|
|
|
|
def main(): |
|
|
if len(sys.argv) != 3: |
|
|
print("Usage: python validate_dataset.py <dataset_name> <method>") |
|
|
print("Methods: sft, dpo, kto") |
|
|
sys.exit(1) |
|
|
|
|
|
dataset_name = sys.argv[1] |
|
|
method = sys.argv[2].lower() |
|
|
|
|
|
print(f"π¦ Loading dataset: {dataset_name}") |
|
|
try: |
|
|
dataset = load_dataset(dataset_name, split="train") |
|
|
print(f"β
Dataset loaded: {len(dataset)} examples") |
|
|
except Exception as e: |
|
|
print(f"β Failed to load dataset: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
validators = { |
|
|
"sft": validate_sft_dataset, |
|
|
"dpo": validate_dpo_dataset, |
|
|
"kto": validate_kto_dataset, |
|
|
} |
|
|
|
|
|
if method not in validators: |
|
|
print(f"β Unknown method: {method}") |
|
|
print(f"Supported methods: {list(validators.keys())}") |
|
|
sys.exit(1) |
|
|
|
|
|
validator = validators[method] |
|
|
valid = validator(dataset) |
|
|
|
|
|
if valid: |
|
|
print(f"\nβ
Dataset is valid for {method.upper()} training") |
|
|
sys.exit(0) |
|
|
else: |
|
|
print(f"\nβ Dataset is NOT valid for {method.upper()} training") |
|
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|