|
|
| # BaseTrainer |
|
|
| ## π Trained With [EasyDeL](https://github.com/erfanzar/EasyDeL) |
|
|
| EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning |
| models. With a primary focus on Jax, EasyDeL aims to provide convenient and effective solutions for |
| training Flax/Jax models on TPU/GPU, for both serving and training purposes. |
|
|
| ## π¦ Installation & Usage |
| |
| ```python |
| from easydel import AutoEasyDeLModelForCausalLM |
| from jax import numpy as jnp, lax |
| |
| model = AutoEasyDeLModelForCausalLM.from_pretrained( |
| f"REPO_ID/BaseTrainer", |
| dtype=..., |
| param_dtype=..., |
| precision=lax.Precision("fastest"), |
| auto_shard_model=True, |
| ) |
| ``` |
|
|
| ## π§ Training Configuration |
|
|
| ### Model Details |
| - **Architecture**: qwen2 |
| - **Platform**: TPU |
| - **Number of Devices**: 16 |
|
|
| ### Training Parameters |
| - **Learning Rate**: 5e-05 β 5e-06 |
| - **Optimizer**: adamw |
| - **Scheduler**: cosine |
| - **Warmup Steps**: 160 |
| - **Weight Decay**: 0.02 |
| - **Loss Config**: LossConfig( |
| ignore_index: -100 |
| label_smoothing: 0.0 |
| z_loss: 0.0 |
| loss_normalizing_factor: 'NUM_REAL_TARGET_TOKENS' |
| num_labels: None |
| problem_type: None |
| divide_weight_sum: False |
| shift_tokens: True |
| break_on_nan: True |
| reduction: None |
| num_classification_labels: None |
| classification_problem_type: None |
| ) |
| |
| ### Training Setup |
| - **Epochs**: 5 |
| - **Batch Size**: 16 |
| - **Sequence Length**: 4096 |
| - **Dtype**: <class 'jax.numpy.bfloat16'> |
| - **Params Dtype**: <class 'jax.numpy.bfloat16'> |
|
|
| ### Advanced Configuration |
| - **Gradient Checkpointing**: |
| - **Gradient Accumulation Steps**: 1 |
| - **Max Training Steps**: None |
| - **Max Evaluation Steps**: None |
| - **Training Duration**: 7H |
|
|
| ### Sharding Configuration |
| ```python |
| # Partition Rules |
| ( ('model/embed_tokens/embedding', PartitionSpec('tp', ('fsdp', 'sp'))), |
| ( 'self_attn/(q_proj|k_proj|v_proj)/kernel', |
| PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('self_attn/o_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
| ('mlp/gate_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('mlp/down_proj/kernel', PartitionSpec('tp', ('fsdp', 'sp'))), |
| ('mlp/up_proj/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('input_layernorm/kernel', PartitionSpec(None,)), |
| ('post_attention_layernorm/kernel', PartitionSpec(None,)), |
| ('model/norm/kernel', PartitionSpec(None,)), |
| ('lm_head/kernel', PartitionSpec(('fsdp', 'sp'), 'tp')), |
| ('.*', PartitionSpec(None,))) |
| ``` |
|
|
| --- |
| *Generated with EasyDeL v0.1.2* |
|
|