import torch from safetensors.torch import load_file def load_model(path='model.safetensors'): return load_file(path) def s134(x3, x2, x1, x0, weights): """Symmetric function S_{1,3}^4: output 1 iff exactly 1 or 3 of 4 inputs are 1.""" inp = torch.tensor([float(x3), float(x2), float(x1), float(x0)]) al1 = int((inp @ weights['at_least_1.weight'].T + weights['at_least_1.bias'] >= 0).item()) am1 = int((inp @ weights['at_most_1.weight'].T + weights['at_most_1.bias'] >= 0).item()) al3 = int((inp @ weights['at_least_3.weight'].T + weights['at_least_3.bias'] >= 0).item()) am3 = int((inp @ weights['at_most_3.weight'].T + weights['at_most_3.bias'] >= 0).item()) l1 = torch.tensor([float(al1), float(am1), float(al3), float(am3)]) e1 = int((l1 @ weights['exactly_1.weight'].T + weights['exactly_1.bias'] >= 0).item()) e3 = int((l1 @ weights['exactly_3.weight'].T + weights['exactly_3.bias'] >= 0).item()) l2 = torch.tensor([float(e1), float(e3)]) y = int((l2 @ weights['y.weight'].T + weights['y.bias'] >= 0).item()) return y if __name__ == '__main__': w = load_model() print('Symmetric S_{1,3}^4 (exactly 1 or 3 of 4):') for i in range(16): x3, x2, x1, x0 = (i >> 3) & 1, (i >> 2) & 1, (i >> 1) & 1, i & 1 y = s134(x3, x2, x1, x0, w) s = x3 + x2 + x1 + x0 print(f' {x3}{x2}{x1}{x0} (sum={s}) -> {y}')