Update README.md
Browse files
README.md
CHANGED
|
@@ -35,16 +35,17 @@ is not validated, it was primarily used a pre-training task.
|
|
| 35 |
|
| 36 |
### How to use
|
| 37 |
|
| 38 |
-
|
| 39 |
|
| 40 |
```python
|
| 41 |
-
|
| 42 |
-
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
```
|
| 49 |
|
| 50 |
|
|
@@ -54,9 +55,6 @@ AtomFormer is trained on an aggregated S2EF dataset from multiple sources such a
|
|
| 54 |
with structures and energies/forces for pre-training. The pre-training data includes total energies and formation
|
| 55 |
energies but trains using formation energy (which isn't included for OC22, indicated by "has_formation_energy" column).
|
| 56 |
|
| 57 |
-
## Training procedure
|
| 58 |
-
|
| 59 |
-
|
| 60 |
|
| 61 |
### Preprocessing
|
| 62 |
|
|
|
|
| 35 |
|
| 36 |
### How to use
|
| 37 |
|
| 38 |
+
Here is how to use the model to extract features from the pre-trained backbone:
|
| 39 |
|
| 40 |
```python
|
| 41 |
+
import torch
|
| 42 |
+
from transformers import AutoModel
|
| 43 |
+
model = AutoModel.from_pretrained("vector-institute/atomformer-base", trust_remote_code=True)
|
| 44 |
|
| 45 |
+
input_ids, coords, attn_mask = torch.randint(0, 100, (1, 10)), torch.randn(1, 10, 3), torch.ones(1, 10)
|
| 46 |
|
| 47 |
+
output = model(input_ids, coords=coords, attention_mask=attention_mask)
|
| 48 |
+
output[0].shape # (torch.Size([1, 10, 768])
|
| 49 |
```
|
| 50 |
|
| 51 |
|
|
|
|
| 55 |
with structures and energies/forces for pre-training. The pre-training data includes total energies and formation
|
| 56 |
energies but trains using formation energy (which isn't included for OC22, indicated by "has_formation_energy" column).
|
| 57 |
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
### Preprocessing
|
| 60 |
|