Spaces:
Sleeping
Sleeping
| import random | |
| import pytest | |
| import torch | |
| from ding.torch_utils.metric import levenshtein_distance, hamming_distance | |
| class TestMetric(): | |
| def test_levenshtein_distance(self): | |
| r''' | |
| Overview: | |
| Test the Levenshtein Distance | |
| ''' | |
| pred = torch.LongTensor([1, 4, 6, 4, 1]) | |
| target1 = torch.LongTensor([1, 6, 4, 4, 1]) | |
| distance = levenshtein_distance(pred, target1) | |
| assert (distance.item() == 2) | |
| target2 = torch.LongTensor([]) | |
| distance = levenshtein_distance(pred, target2) | |
| assert (distance.item() == 5) | |
| target3 = torch.LongTensor([6, 4, 1]) | |
| distance = levenshtein_distance(pred, target3) | |
| assert (distance.item() == 2) | |
| target3 = torch.LongTensor([6, 4, 1]) | |
| distance = levenshtein_distance(pred, target3, pred, target3, extra_fn=lambda x, y: x + y) | |
| assert distance.item() == 13 | |
| target4 = torch.LongTensor([1, 4, 1]) | |
| distance = levenshtein_distance(pred, target4, pred, target4, extra_fn=lambda x, y: x + y) | |
| assert distance.item() == 14 | |
| def test_hamming_distance(self): | |
| r''' | |
| Overview: | |
| Test the Hamming Distance | |
| ''' | |
| base = torch.zeros(8).long() | |
| index = [i for i in range(8)] | |
| for i in range(2): | |
| pred_idx = random.sample(index, 4) | |
| target_idx = random.sample(index, 4) | |
| pred = base.clone() | |
| pred[pred_idx] = 1 | |
| target = base.clone() | |
| target[target_idx] = 1 | |
| pred = pred.unsqueeze(0) | |
| target = target.unsqueeze(0) | |
| distance = hamming_distance(pred, target) | |
| diff = len(set(pred_idx).union(set(target_idx)) - set(pred_idx).intersection(set(target_idx))) | |
| assert (distance.item() == diff) | |