| import torch
|
| from safetensors.torch import load_file
|
|
|
| def load_model(path='model.safetensors'):
|
| return load_file(path)
|
|
|
| def segment_decode(b3, b2, b1, b0, weights):
|
| """BCD to 7-segment decoder. Returns dict with segments a-g."""
|
| inp = torch.tensor([float(b3), float(b2), float(b1), float(b0)])
|
|
|
|
|
| digits = []
|
| for d in range(10):
|
| val = int((inp @ weights[f'd{d}.weight'].T + weights[f'd{d}.bias'] >= 0).item())
|
| digits.append(val)
|
| digit_vec = torch.tensor([float(d) for d in digits])
|
|
|
|
|
| result = {}
|
| for seg in ['a', 'b', 'c', 'd', 'e', 'f', 'g']:
|
| val = int((digit_vec @ weights[f'{seg}.weight'].T + weights[f'{seg}.bias'] >= 0).item())
|
| result[seg] = val
|
|
|
| return result
|
|
|
| def display_digit(segs):
|
| """ASCII art display of 7-segment pattern."""
|
| a = ' ' + ('_' * 3 if segs['a'] else ' ' * 3) + ' '
|
| b = ('|' if segs['f'] else ' ') + ' ' * 3 + ('|' if segs['b'] else ' ')
|
| g = ' ' + ('_' * 3 if segs['g'] else ' ' * 3) + ' '
|
| c = ('|' if segs['e'] else ' ') + ' ' * 3 + ('|' if segs['c'] else ' ')
|
| d = ' ' + ('_' * 3 if segs['d'] else ' ' * 3) + ' '
|
| return '\n'.join([a, b, g, c, d])
|
|
|
| if __name__ == '__main__':
|
| w = load_model()
|
| print('7-Segment Display Decoder:')
|
| for digit in range(10):
|
| b3, b2, b1, b0 = (digit >> 3) & 1, (digit >> 2) & 1, (digit >> 1) & 1, digit & 1
|
| result = segment_decode(b3, b2, b1, b0, w)
|
| pattern = ''.join([str(result[s]) for s in 'abcdefg'])
|
| print(f'\nDigit {digit} ({b3}{b2}{b1}{b0}) -> {pattern}')
|
| print(display_digit(result))
|
|
|