diff --git a/hyper_parallel/auto_parallel/fast-tuner/README.md b/hyper_parallel/auto_parallel/fast-tuner/README.md
index 32a1e9d4e562deb1287fe967b88eae0b95ba368b..0ab18663df7ef3f5138aebdc22836e6f6db5210d 100644
--- a/hyper_parallel/auto_parallel/fast-tuner/README.md
+++ b/hyper_parallel/auto_parallel/fast-tuner/README.md
@@ -64,4 +64,4 @@ pip install -e .
[基于Mindformers](./docs/mindformers.md) 使用文档\
[基于Torchtitan](./docs/torchtitan.md) 使用文档 \
-[基于Mindspeed](./docs/mindspeed.md) 使用文档
\ No newline at end of file
+[基于Mindspeed](./docs/mindspeed.md) 使用文档
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/env_config/env_config.json b/hyper_parallel/auto_parallel/fast-tuner/config/env_config/env_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..06bcdb67382ec0cc8a49b6e1a267aec30d53d6e2
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/env_config/env_config.json
@@ -0,0 +1,10 @@
+{
+ "MS_SIMULATION_LEVEL": "1",
+ "GLOG_v": "3",
+ "MS_DEV_RUNTIME_CONF": "memory_statistics:True",
+ "MS_MEMORY_STATISTIC": "1",
+ "ENABLE_LAZY_INLINE_NO_PIPELINE": "1",
+ "MS_DEV_DUMP_IR_PASSES": "validate,graph_build,pipeline_split,recompute",
+ "MS_ALLOC_CONF": "memory_tracker:True",
+ "ENABLE_LESS_MEM_VPP": "0"
+}
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh b/hyper_parallel/auto_parallel/fast-tuner/config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh
new file mode 100644
index 0000000000000000000000000000000000000000..738be6ed879a38a6023900c6e98647d7309c92ff
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh
@@ -0,0 +1,194 @@
+#!/bin/bash
+
+# 需要切换MindSpeed版本
+# git checkout 9648d729e4866f8037d0bd76630029410a60e6a6 # checkout commit from MindSpeed core_r0.8.0 in 2025.06.14
+
+export CUDA_DEVICE_MAX_CONNECTIONS=1
+export CPU_AFFINITY_CONF=1
+export TASK_QUEUE_ENABLE=2
+export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True"
+export HCCL_CONNECT_TIMEOUT=3600
+export STREAMS_PER_DEVICE=32
+
+NPUS_PER_NODE=16
+MASTER_ADDR=localhost #主节点IP
+MASTER_PORT=6000
+NNODES=32
+NODE_RANK=0
+
+CKPT_SAVE_DIR="your model save ckpt path"
+DATA_PATH="your data path"
+TOKENIZER_PATH="your tokenizer path"
+CKPT_LOAD_DIR="your model ckpt path"
+
+TP=2
+PP=8
+EP=16
+CP=1
+CP_TYPE='ulysses_cp_algo'
+NUM_LAYERS=64
+SEQ_LEN=4096
+MBS=1
+GBS=3840
+
+DISTRIBUTED_ARGS="
+ --nproc_per_node $NPUS_PER_NODE \
+ --nnodes $NNODES \
+ --node_rank $NODE_RANK \
+ --master_addr $MASTER_ADDR \
+ --master_port $MASTER_PORT
+"
+
+MLA_ARGS="
+ --multi-head-latent-attention \
+ --qk-rope-head-dim 64 \
+ --qk-nope-head-dim 128 \
+ --q-lora-rank 1536 \
+ --kv-lora-rank 512 \
+ --v-head-dim 128 \
+ --qk-layernorm \
+ --mla-mm-split \
+ --mla-fa-without-pad \
+"
+
+MOE_ARGS="
+ --moe-grouped-gemm \
+ --moe-permutation-async-comm \
+ --use-fused-moe-token-permute-and-unpermute \
+ --moe-token-dispatcher-type alltoall \
+ --first-k-dense-replace 3 \
+ --moe-layer-freq 1 \
+ --n-shared-experts 1 \
+ --num-experts 256 \
+ --moe-router-topk 8 \
+ --moe-intermediate-size 2048 \
+ --moe-router-load-balancing-type noaux_tc \
+ --n-group 8 \
+ --topk-group 4 \
+ --routed-scaling-factor 2.5 \
+ --moe-aux-loss-coeff 0.0001 \
+ --seq-aux \
+ --norm-topk-prob \
+ --moe-router-score-function sigmoid \
+ --moe-router-enable-expert-bias \
+ --moe-tp-extend-ep \
+ --router-gating-in-fp32 \
+"
+
+MTP_ARGS="
+ --mtp-num-layers 1 \
+ --mtp-loss-scaling-factor 0.3 \
+"
+
+DUALPIPE_ARGS="
+ --moe-fb-overlap \
+ --schedules-method dualpipev \
+"
+
+MEM_ARGS="
+ --mtp-mem-efficient-logits \
+ --swap-optimizer \
+ --recompute-granularity full \
+ --recompute-method block \
+ --recompute-num-layers 8 \
+"
+
+ROPE_ARGS="
+ --rope-scaling-beta-fast 32 \
+ --rope-scaling-beta-slow 1 \
+ --rope-scaling-factor 40 \
+ --rope-scaling-mscale 1.0 \
+ --rope-scaling-mscale-all-dim 1.0 \
+ --rope-scaling-original-max-position-embeddings 4096 \
+ --rope-scaling-type yarn
+"
+
+GPT_ARGS="
+ --spec mindspeed_llm.tasks.models.spec.deepseek_spec layer_spec \
+ --reset-position-ids \
+ --gemm-gradient-accumulation-fusion \
+ --noop-layers 61,62,63 \
+ --manual-gc \
+ --manual-gc-interval 50 \
+ --no-shared-storage \
+ --use-distributed-optimizer \
+ --use-flash-attn \
+ --use-mcore-models \
+ --tensor-model-parallel-size ${TP} \
+ --pipeline-model-parallel-size ${PP} \
+ --expert-model-parallel-size ${EP} \
+ --sequence-parallel \
+ --context-parallel-size ${CP} \
+ --context-parallel-algo ${CP_TYPE} \
+ --num-layers ${NUM_LAYERS} \
+ --hidden-size 7168 \
+ --ffn-hidden-size 18432 \
+ --num-attention-heads 128 \
+ --tokenizer-type PretrainedFromHF \
+ --tokenizer-name-or-path ${TOKENIZER_PATH} \
+ --seq-length ${SEQ_LEN} \
+ --max-position-embeddings 163840 \
+ --micro-batch-size ${MBS} \
+ --global-batch-size ${GBS} \
+ --make-vocab-size-divisible-by 1 \
+ --lr 1.0e-5 \
+ --train-iters 2000 \
+ --lr-decay-style cosine \
+ --untie-embeddings-and-output-weights \
+ --disable-bias-linear \
+ --attention-dropout 0.0 \
+ --init-method-std 0.02 \
+ --hidden-dropout 0.0 \
+ --position-embedding-type rope \
+ --normalization RMSNorm \
+ --use-fused-rotary-pos-emb \
+ --use-rotary-position-embeddings \
+ --use-fused-swiglu \
+ --use-fused-rmsnorm \
+ --swiglu \
+ --no-masked-softmax-fusion \
+ --attention-softmax-in-fp32 \
+ --min-lr 1.0e-7 \
+ --weight-decay 1e-2 \
+ --lr-warmup-iters 500 \
+ --clip-grad 1.0 \
+ --adam-beta1 0.9 \
+ --adam-beta2 0.999 \
+ --initial-loss-scale 65536 \
+ --vocab-size 129280 \
+ --padded-vocab-size 129280 \
+ --rotary-base 10000 \
+ --norm-epsilon 1e-6 \
+ --no-load-optim \
+ --no-load-rng \
+ --bf16 \
+ --distributed-timeout-minutes 120 \
+"
+
+DATA_ARGS="
+ --data-path $DATA_PATH \
+ --split 100,0,0
+"
+
+OUTPUT_ARGS="
+ --log-interval 1 \
+ --save-interval 2000 \
+ --eval-interval 2000 \
+ --eval-iters 0 \
+ --no-save-optim \
+ --no-save-rng
+"
+
+python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \
+ $GPT_ARGS \
+ $DATA_ARGS \
+ $OUTPUT_ARGS \
+ $MLA_ARGS \
+ $DUALPIPE_ARGS \
+ $MEM_ARGS \
+ $ROPE_ARGS \
+ $MOE_ARGS \
+ $MTP_ARGS \
+ --save $CKPT_SAVE_DIR \
+ --load $CKPT_LOAD_DIR \
+ --distributed-backend nccl | tee logs/pretrain_deepseek3_671b_4k_A3_ptd.log
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/example/mindspore/pretrain_deepseek3_671b.yaml b/hyper_parallel/auto_parallel/fast-tuner/config/example/mindspore/pretrain_deepseek3_671b.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1604b865c609fe518f7f066f3b87458c5109fc81
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/example/mindspore/pretrain_deepseek3_671b.yaml
@@ -0,0 +1,225 @@
+seed: 0
+output_dir: './output' # path to save checkpoint/strategy
+load_checkpoint: ''
+src_strategy_path_or_dir: ''
+auto_trans_ckpt: False # If true, auto transform load_checkpoint to load in distributed pipeline
+only_save_strategy: False
+resume_training: False
+use_parallel: True
+run_mode: 'train'
+
+# trainer config
+trainer:
+ type: CausalLanguageModelingTrainer
+ model_name: 'deepseekV3'
+
+# runner config
+runner_config:
+ epochs: 2
+ batch_size: 1
+ sink_mode: True
+ sink_size: 1
+
+# optimizer
+optimizer:
+ type: AdamW
+ betas: [0.9, 0.95]
+ eps: 1.e-8
+
+# lr schedule
+lr_schedule:
+ type: ConstantWarmUpLR
+ learning_rate: 2.2e-4
+ warmup_ratio: 0.02
+ total_steps: -1 # -1 means it will load the total steps of the dataset
+
+# dataset
+train_dataset: &train_dataset
+ data_loader:
+ type: BlendedMegatronDatasetDataLoader
+ datasets_type: "GPTDataset"
+ sizes:
+ - 5000 # train dataset size
+ - 0
+ - 0
+ config:
+ random_seed: 1234
+ seq_length: 4096
+ split: "1, 0, 0"
+ reset_position_ids: False
+ reset_attention_mask: False
+ eod_mask_loss: False
+ num_dataset_builder_threads: 1
+ create_attention_mask: False
+ data_path:
+ - '1'
+ - "./dataset"
+ shuffle: False
+ input_columns: ["input_ids", "labels", "loss_mask", "position_ids"]
+ construct_args_key: ["input_ids", "labels"]
+ num_parallel_workers: 8
+ python_multiprocessing: False
+ drop_remainder: True
+ repeat: 1
+ numa_enable: False
+ prefetch_size: 1
+train_dataset_task:
+ type: CausalLanguageModelDataset
+ dataset_config: *train_dataset
+
+# mindspore context init config
+context:
+ mode: 0 #0--Graph Mode; 1--Pynative Mode
+ device_target: "Ascend"
+ max_call_depth: 10000
+ max_device_memory: "55GB"
+ save_graphs: False
+ save_graphs_path: "./graph"
+ jit_config:
+ jit_level: "O1"
+ ascend_config:
+ parallel_speed_up_json_path: "./parallel_speed_up.json"
+
+# parallel config for device num = 1024
+parallel_config:
+ data_parallel: &dp 16
+ model_parallel: 4
+ pipeline_stage: 16
+ expert_parallel: 8
+ micro_batch_num: µ_batch_num 32
+ vocab_emb_dp: True
+ use_seq_parallel: True
+ gradient_aggregation_group: 4
+# when pipeline parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process.
+micro_batch_interleave_num: 1
+
+# parallel context config
+parallel:
+ parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel
+ gradients_mean: False
+ enable_alltoall: True
+ full_batch: False
+ dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1]]
+ search_mode: "sharding_propagation"
+ enable_parallel_optimizer: True
+ strategy_ckpt_config:
+ save_file: "./ckpt_strategy.ckpt"
+ only_trainable_params: False
+ parallel_optimizer_config:
+ gradient_accumulation_shard: False
+ parallel_optimizer_threshold: 64
+
+# recompute config
+recompute_config:
+ recompute: [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2, 0]
+ select_recompute: False
+ parallel_optimizer_comm_recompute: True
+ mp_comm_recompute: True
+ recompute_slice_activation: True
+
+# pipeline config
+model:
+ model_config:
+ type: DeepseekV3Config
+ auto_register: deepseek3_config.DeepseekV3Config
+ batch_size: 1 # add for increase predict
+ seq_length: 4096
+ hidden_size: 7168
+ num_layers: &num_layers 61
+ num_heads: 128
+ max_position_embeddings: 4096
+ intermediate_size: 18432
+ kv_lora_rank: 512
+ n_kv_heads: 128
+ q_lora_rank: 1536
+ qk_rope_head_dim: 64
+ v_head_dim: 128
+ qk_nope_head_dim: 128
+ vocab_size: 129280
+ multiple_of: 256
+ rms_norm_eps: 1.0e-6
+ bos_token_id: 100000
+ eos_token_id: 100001
+ pad_token_id: 100001
+ ignore_token_id: -100
+ compute_dtype: "bfloat16"
+ layernorm_compute_type: "float32"
+ softmax_compute_type: "float32"
+ rotary_dtype: "float32"
+ router_dense_type: "float32"
+ param_init_type: "float32"
+ use_past: False
+ extend_method: "None"
+ use_flash_attention: True
+ input_sliced_sig: True
+ offset: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1]
+ checkpoint_name_or_path: ""
+ theta: 10000.0
+ return_extra_loss: True
+ mtp_depth: &mtp_depth 1
+ mtp_loss_factor: 0.3
+ arch:
+ type: DeepseekV3ForCausalLM
+ auto_register: deepseek3.DeepseekV3ForCausalLM
+
+#moe
+moe_config:
+ expert_num: &expert_num 256
+ expert_group_size: 8
+ capacity_factor: 1.5
+ aux_loss_factor: 0.05
+ num_experts_chosen: 8
+ routing_policy: "TopkRouterV2"
+ enable_sdrop: False
+ balance_via_topk_bias: &balance_via_topk_bias True
+ topk_bias_update_rate: &topk_bias_update_rate 0.001
+ use_fused_ops_topkrouter: True
+ group_wise_a2a: False
+ shared_expert_num: 1
+ routed_scaling_factor: 2.5
+ norm_topk_prob: False
+ first_k_dense_replace: 3
+ moe_intermediate_size: 2048
+ # greedy_group_limited strategy, select topk_group from n_group
+ topk_group: 4
+ n_group: 8
+ aux_loss_factors: [0.0001, 0., 0.]
+ aux_loss_types: ["expert", "device", "comm"]
+ z_loss_factor: 0.0
+ expert_model_parallel: 1
+ use_gating_sigmoid: True
+ callback_moe_droprate: False
+
+
+# callbacks
+callbacks:
+ - type: MFLossMonitor
+ per_print_times: 1
+ # balance topk bias with callback
+ - type: TopkBiasBalanceCallback
+ balance_via_topk_bias: *balance_via_topk_bias
+ topk_bias_update_rate: *topk_bias_update_rate
+ num_layers: *num_layers
+ mtp_depth: *mtp_depth
+ expert_num: *expert_num
+ micro_batch_num: *micro_batch_num
+ - type: CheckpointMonitor
+ prefix: "deepseekv3"
+ save_checkpoint_steps: 1000
+ keep_checkpoint_max: 5
+ integrated_save: False
+ async_save: False
+ checkpoint_format: "safetensors"
+
+# wrapper cell config
+runner_wrapper:
+ type: MFTrainOneStepCell
+ scale_sense: 1.0
+ use_clip_grad: True
+
+profile: False
+profile_start_step: 1
+profile_stop_step: 10
+init_start_profile: False
+profile_communication: False
+profile_memory: True
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/example/torchtitan/debug_model.toml b/hyper_parallel/auto_parallel/fast-tuner/config/example/torchtitan/debug_model.toml
new file mode 100644
index 0000000000000000000000000000000000000000..1951cc43503574138e9208442c626c01007834bc
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/example/torchtitan/debug_model.toml
@@ -0,0 +1,79 @@
+[job]
+dump_folder = "./outputs"
+description = "DeepSeek-V3 debug training"
+print_config = false
+
+[profiling]
+enable_profiling = false
+save_traces_folder = "profile_trace"
+profile_freq = 10
+enable_memory_snapshot = false
+save_memory_snapshot_folder = "memory_snapshot"
+
+[metrics]
+log_freq = 1
+disable_color_printing = false
+enable_tensorboard = false
+save_tb_folder = "tb"
+enable_wandb = false
+
+[model]
+name = "deepseek_v3"
+flavor = "debugmodel"
+# test tokenizer, for debug purpose only
+hf_assets_path = "./tests/assets/tokenizer"
+# converters = ["float8"]
+
+[optimizer]
+name = "AdamW"
+lr = 8e-4
+eps = 1e-8
+
+[lr_scheduler]
+warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
+decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
+decay_type = "linear"
+min_lr_factor = 0.0
+
+[training]
+local_batch_size = 8
+seq_len = 2048
+max_norm = 1.0 # grad norm clipping
+steps = 10
+dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
+
+[parallelism]
+data_parallel_replicate_degree = 1
+data_parallel_shard_degree = -1
+fsdp_reshard_after_forward = "default" # default / never / always
+tensor_parallel_degree = 1
+enable_async_tensor_parallel = false
+pipeline_parallel_degree = 1
+pipeline_parallel_schedule = "1F1B"
+context_parallel_degree = 1
+expert_parallel_degree = 1
+expert_tensor_parallel_degree = 1
+
+[checkpoint]
+enable = false
+folder = "checkpoint"
+interval = 10
+last_save_model_only = false
+export_dtype = "float32"
+async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
+
+[activation_checkpoint]
+mode = "selective" # ["none", "selective", "full"]
+selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
+
+[compile]
+enable=false
+components = ["model", "loss"]
+
+[quantize.linear.float8]
+enable_fsdp_float8_all_gather = false
+precompute_float8_dynamic_scale_for_fsdp = false
+filter_fqns = ["output", "router.gate"]
+
+[quantize.grouped_mm.float8]
+fqns = ["experts"]
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_dryrun.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_dryrun.json
new file mode 100644
index 0000000000000000000000000000000000000000..1319893154730045d294d7802251ab9e1a3b05bd
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_dryrun.json
@@ -0,0 +1,12 @@
+{
+ "train_yaml": "",
+ "mindformers_dir": "",
+ "output_path": "dryrun_output",
+ "offset": [0,1,1,1,1,1,1,0],
+ "is_recompute": true,
+ "recompute_layers": [1,1,1,1,1,1,1,1],
+ "is_select_recompute": false,
+ "select_recompute_layers": null,
+ "env_json": "./config/boss_env_config.json",
+ "dryrun_lim": 16
+}
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool.json
new file mode 100644
index 0000000000000000000000000000000000000000..6c866e16580787afbe3e5c748fc56dceb4f83378
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool.json
@@ -0,0 +1,9 @@
+{
+ "yaml_path" : "./config/pretrain_deepseek3_671b.yaml",
+ "mindformers_dir" : "./config/pretrain_gpt.py",
+ "rank_num" : 64,
+ "dataset" : "./config/dataset/wiki103-4k.mindrecord",
+ "output_path": "./output/nd_output/",
+ "env_json" : "./config/env_config.json",
+ "gbs": 1024
+}
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_mcore.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_mcore.json
new file mode 100644
index 0000000000000000000000000000000000000000..2f0e19131809d03eda6f6d216102b5a1c9af9036
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_mcore.json
@@ -0,0 +1,10 @@
+{
+ "shell_path" : "./config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh",
+ "mindspeed_path": "./config/pretrain_gpt.py",
+ "rank_num" : 512,
+ "dataset" : "./config/dataset/wiki103-4k.mindrecord",
+ "max_expert_parallel": 64,
+ "output_path": "./output/nd_output/",
+ "env_json" : "./config/env_config.json",
+ "gbs": 3840
+}
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_titan.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_titan.json
new file mode 100644
index 0000000000000000000000000000000000000000..21d2c2aaf8e3041dbe7b40b939d63f9de4ee4682
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_titan.json
@@ -0,0 +1,16 @@
+{
+ "toml_path" : "./config/example/torchtitan/debug_model.toml",
+ "torchtitan_path": "./torchtitan/run_train.sh",
+ "rank_num" : 8,
+ "output_path": "./output/",
+ "npus_per_node": 8,
+ "nnodes": 1,
+ "strategy": {
+ "DP": true,
+ "TP": true,
+ "EP": true,
+ "FSDP": true,
+ "PP": false,
+ "CP": false
+ }
+}
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_parallel.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_parallel.json
new file mode 100644
index 0000000000000000000000000000000000000000..bc0862ac2805a7d6a4c4eda6b9f98aa2f7eeb58d
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_parallel.json
@@ -0,0 +1,16 @@
+{
+ "llm_class": "0",
+ "train_yaml": "",
+ "mindformers_dir": "",
+ "layer_ratio": 0.33,
+ "backward_ratio": 2.0,
+ "head_loss": 1.5,
+ "recompute_ratio": 1,
+ "time_limit": 9999999999,
+ "dryrun": true,
+ "check": true,
+ "fit_level": 0,
+ "extract": false,
+ "env_json": "./config/boss_env_config.json",
+ "dryrun_lim": 16
+}
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_tool.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_tool.json
new file mode 100644
index 0000000000000000000000000000000000000000..e4a794ebd311b8747b984403fabfb8eed610e2eb
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_tool.json
@@ -0,0 +1,10 @@
+{
+ "yaml_dir": "./output/dryrun_yaml/",
+ "mindformers_dir": "./mindformers/run_mindformer.py",
+ "profile_data_dir": "./profile_data/",
+ "parser_result": null,
+ "env_json": "./config/boss_env_config.json",
+ "dryrun_lim": 16,
+ "topn": 0,
+ "max_vpp": 3
+}
\ No newline at end of file
diff --git a/hyper_parallel/auto_parallel/fast-tuner/docs/mindformers.md b/hyper_parallel/auto_parallel/fast-tuner/docs/mindformers.md
new file mode 100644
index 0000000000000000000000000000000000000000..fefc77209c5ddb05d24f5f03b5639e4029e26983
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/docs/mindformers.md
@@ -0,0 +1,115 @@
+# 基于Mindformers的自动搜索
+
+## 1 快速上手
+
+在安装了fast-tuner工具之后,使用如下命令:
+
+```bash
+fast-tuner-parallel --config ./config/setup_config/args_for_parallel_tool.json
+```
+
+其中 json 文件示例如下:
+
+```json
+{
+ // 基础配置
+ "rank_num" : 8,
+ "gbs": 32,
+ "output_path": "./output/nd_output/",
+
+ // Mindspore相关配置
+ "yaml_path" : "./config/pretrain.yaml",
+ "mindformers_dir" : "./mindformers",
+ "dataset" : "./config/dataset/wiki103-4k.mindrecord"
+}
+```
+
+参考文件:[配置文件 args_for_parallel_tool.json](../config/setup_config/args_for_parallel_tool.json)
+
+输出推荐并行配置:
+
+```js
+dp, tp, pp, ep, step_time(μs)
+ 2, 1, 4, 1, 14383
+ 1, 1, 8, 1, 15811
+ 2, 2, 2, 1, 22738
+ 4, 1, 2, 1, N/A
+```
+
+## 2 参数说明
+
+**基础参数**
+
+| 参数 | 含义 | 默认值 |
+|---------------|--------------------------------------------|------------------------------------------------------------------------|
+| rank_num | 工具可用于做profile的卡数 | 8 |
+| strategy | 工具搜索的并行策略,可选策略有 [DP, TP, PP, CP, EP, FSDP] | {"DP":true, "TP":true, "EP":true, "FSDP":true, "CP":false, "PP":false} |
+| gbs | global batch size | 32 |
+| output_path | 日志输出路径 | ./output/nd_output/ |
+
+**加速库参数**
+
+
+
+ | 加速库 |
+ 参数 |
+ 含义 |
+ 示例 |
+
+
+ | mindformers |
+ mindformers_dir |
+ run_mindformers.py路径,基于原生Mindspore搜索时需要填写该参数 |
+ ./mindformers |
+
+
+ | yaml_path |
+ 模型配置文件 |
+ ./config/pretrain.yaml |
+
+
+ | dataset |
+ 数据集路径 |
+ ./config/dataset/wiki103-4k.mindrecord |
+
+
+
+通常用户只需要配置好基础参数,也就是模型规模,支持的并行范式等信息,以及所用的大模型训练套件信息,即可开始配置搜索。对于熟悉本工具的用户,也提供一些更高自由度的选项。
+
+**高级参数**
+
+| 参数 | 含义 | 默认值 |
+|---------------------|--------------------------------------------|----------------------------------------|
+| select_recompute | 是否搜索自定义选重 | True |
+| dryrun_data_dir | 已有dryrun数据路径 | None |
+| profile_data_dir | 已有profile数据路径 | None |
+| parallel_num | dryrun并行数 | 2 |
+| max_expert_parallel | 搜索范围的最大EP数 | 64 |
+| parser_result | profile解析结果csv文件,不需要解析profile时填写此参数即可 | None |
+| dryrun | 是否使用dryrun计算内存信息 | True |
+| check | 是否使用double_check进行内存拟合 | True |
+| alg_phase | 选择使用搜索算法阶段;0--全流程搜索, 1--ND搜索, 2--流水线负载均衡搜索 | 1 |
+
+## 3 参数配置
+
+对于想要修改的参数,直接在json配置文件中修改即可。比如说在支持DP,TP,EP和PP,不支持CP,FSDP的场景下,可以按照如下方式修改:
+
+```json
+{
+ #其他配置
+ "xxx": x,
+ ...
+
+ #自定义需要调优的并行配置
+ "strategy": {
+ "DP": true,
+ "TP": true,
+ "PP": true,
+ "EP": true,
+ "CP": false,
+ "FSDP": false,
+ }
+}
+```
+
+[返回原文档](../README.md)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/docs/mindspeed.md b/hyper_parallel/auto_parallel/fast-tuner/docs/mindspeed.md
new file mode 100644
index 0000000000000000000000000000000000000000..ab57c258ba385ebd97567fcfa9da309895dc886c
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/docs/mindspeed.md
@@ -0,0 +1,105 @@
+# 基于Mindspeed的自动搜索
+
+## 1 快速上手
+
+在安装了fast-tuner工具之后,使用如下命令:
+
+```bash
+fast-tuner-parallel --config ./config/setup_config/args_for_parallel_tool_mcore.json
+```
+
+其中 json 文件示例如下:
+
+```json
+{
+ // 基础配置
+ "rank_num" : 8,
+ "gbs": 32,
+ "output_path": "./output/nd_output/",
+
+ // Mindspeed相关配置
+ "shell_path" : "./config/example/mcore/pretrain.sh",
+ "mindformers_dir" : "./config/pretrain_gpt.py"
+}
+```
+
+参考文件:[配置文件 args_for_parallel_tool_mcore.json](../config/setup_config/args_for_parallel_tool_mcore.json)
+
+输出推荐并行配置:
+
+```js
+dp, tp, pp, ep, step_time(μs)
+ 2, 1, 4, 1, 14383
+ 1, 1, 8, 1, 15811
+ 2, 2, 2, 1, 22738
+ 4, 1, 2, 1, N/A
+```
+
+## 2 参数说明
+
+**基础参数**
+
+| 参数 | 含义 | 默认值 |
+|---------------|----------------------------------|------------------------------------------------------------------------|
+| rank_num | 工具可用于做profile的卡数 | 8 |
+| strategy | 工具搜索的并行策略,可选策略有 [DP, TP, PP, CP, EP, FSDP] | {"DP":true, "TP":true, "EP":true, "FSDP":true, "CP":false, "PP":false} |
+| gbs | global batch size | 32 |
+| output_path | 日志输出路径 | ./output/nd_output/ |
+
+**加速库参数**
+
+
+
+ | 加速库 |
+ 参数 |
+ 含义 |
+ 示例 |
+
+
+ | Mindspeed |
+ mindspeed_path |
+ MindSpeed训练脚本路径 |
+ ./home/pretrain_gpt.py |
+
+
+ | shell_path |
+ 模型配置文件 |
+ ./config/pretrain.sh |
+
+
+
+通常用户只需要配置好基础参数,也就是模型规模,支持的并行范式等信息,以及所用的大模型训练套件信息,即可开始配置搜索。对于熟悉本工具的用户,也提供一些更高自由度的选项。
+
+**高级参数**
+
+| 参数 | 含义 | 默认值 |
+|---------------------|--------------------------------------------|----------------------------------------|
+| select_recompute | 是否搜索自定义选重 | True |
+| profile_data_dir | 已有profile数据路径 | None |
+| max_expert_parallel | 搜索范围的最大EP数 | 64 |
+| parser_result | profile解析结果csv文件,不需要解析profile时填写此参数即可 | None |
+| alg_phase | 选择使用搜索算法阶段;0--全流程搜索, 1--ND搜索, 2--流水线负载均衡搜索 | 1 |
+
+## 3 参数配置
+
+对于想要修改的参数,直接在json配置文件中修改即可。比如说在支持DP,TP,EP和PP,不支持CP,FSDP的场景下,可以按照如下方式修改:
+
+```json
+{
+ // 其他配置
+ "xxx": x,
+ ...
+
+ // 自定义需要调优的并行配置
+ "strategy": {
+ "DP": true,
+ "TP": true,
+ "PP": true,
+ "EP": true,
+ "CP": false,
+ "FSDP": false
+ }
+}
+```
+
+[返回原文档](../README.md)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/docs/torchtitan.md b/hyper_parallel/auto_parallel/fast-tuner/docs/torchtitan.md
new file mode 100644
index 0000000000000000000000000000000000000000..d1d5337eab1fe67cea6d1a9a0ae942553023f804
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/docs/torchtitan.md
@@ -0,0 +1,110 @@
+# 基于Torchtitan的自动搜索
+
+## 1 快速上手
+
+在安装了fast-tuner工具之后,使用如下命令:
+
+```bash
+fast-tuner-parallel --config ./config/args_for_parallel_tool_titan.json
+```
+
+其中 json 文件示例如下:
+
+```json
+{
+ // 基础配置
+ "rank_num" : 8,
+ "gbs": 32,
+ "npus_per_node": 8,
+ "nnodes": 1,
+ "output_path": "./output/nd_output/",
+
+ // Torchtitan相关配置
+ "toml_path" : "./config/example/torchtitan/debug_model.toml",
+ "torchtitan_path" : "./torchtitan/run_train.sh"
+}
+```
+
+参考文件:[配置文件 args_for_parallel_tool_titan.json](../config/setup_config/args_for_parallel_tool_titan.json)
+
+输出推荐并行配置:
+
+```js
+dp, tp, pp, ep, step_time(μs)
+ 2, 1, 4, 1, 14383
+ 1, 1, 8, 1, 15811
+ 2, 2, 2, 1, 22738
+ 4, 1, 2, 1, N/A
+```
+
+## 2 参数说明
+
+**基础参数**
+
+| 参数 | 含义 | 默认值 |
+|---------------|--------------------------------------------|------------------------------------------------------------------------|
+| rank_num | 工具可用于做profile的卡数 | 8 |
+| npu_per_nodes | 用户用于训练的每个节点的卡数 | 8 |
+| nnodes | 用户用于训练的节点数 | 1 |
+| strategy | 工具搜索的并行策略,可选策略有 [DP, TP, PP, CP, EP, FSDP] | {"DP":true, "TP":true, "EP":true, "FSDP":true, "CP":false, "PP":false} |
+| gbs | global batch size | 32 |
+| output_path | 日志输出路径 | ./output/nd_output/ |
+
+**加速库参数**
+
+
+
+ | 加速库 |
+ 参数 |
+ 含义 |
+ 示例 |
+
+
+ | torchtitan |
+ torchtitan_path |
+ 训练脚本路径,本工具需要此文件拉起profile |
+ ./torchtitan/run_train.sh |
+
+
+ | toml_path |
+ 模型参数配置文件 |
+ ./config/example/torchtitan/debug_model.toml |
+
+
+
+通常用户只需要配置好基础参数,也就是模型规模,支持的并行范式等信息,以及所用的大模型训练套件信息,即可开始配置搜索。对于熟悉本工具的用户,也提供一些更高自由度的选项。
+
+**高级参数**
+
+| 参数 | 含义 | 默认值 |
+|---------------------|--------------------------------------------|----------------------------------------|
+| select_recompute | 是否搜索自定义选重 | True |
+| profile_data_dir | 已有profile数据路径 | None |
+| parallel_num | dryrun并行数 | 2 |
+| max_expert_parallel | 搜索范围的最大EP数 | 64 |
+| parser_result | profile解析结果csv文件,不需要解析profile时填写此参数即可 | None |
+| alg_phase | 选择使用搜索算法阶段;0--全流程搜索, 1--ND搜索, 2--流水线负载均衡搜索 | 1 |
+
+## 3 参数配置
+
+对于想要修改的参数,直接在json配置文件中修改即可。比如说在支持DP,TP,EP和PP,不支持CP,FSDP的场景下,可以按照如下方式修改:
+
+```json
+{
+ #其他配置
+ "xxx": x,
+ ...
+
+ #自定义需要调优的并行配置
+ "strategy": {
+ "DP": true,
+ "TP": true,
+ "PP": true,
+ "EP": true,
+ "CP": flase,
+ "FSDP": false,
+ }
+}
+```
+
+[返回原文档](../README.md)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/build_initial_spaces.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/build_initial_spaces.py
new file mode 100644
index 0000000000000000000000000000000000000000..54e56cb302a20d9901312f8b319112c67a9d37ff
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/build_initial_spaces.py
@@ -0,0 +1,116 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""build initial space for nd"""
+
+import itertools
+import math
+
+def cal_factor(num):
+ factors = []
+ for i in range(1, num + 1):
+ if num % i == 0:
+ factors.append(i)
+ return factors
+
+def find_integers_less_than_ceil(num_layers, pp):
+ if pp == 1:
+ # 不开pp时, vpp为1
+ return [1]
+ ceil_value = math.ceil(num_layers / pp)
+ result = list(range(1, ceil_value))
+ return result
+
+def find_ep(expert_num):
+ if expert_num is None:
+ return [1]
+ ep_list = []
+ for i in range(1, expert_num + 1):
+ if i == 1 or i % 2 == 0:
+ ep_list.append(i)
+ return ep_list
+
+def build_initial_spaces(input_args, para):
+ """
+ generate initial space for nd
+ """
+ # part1. 求[dp, tp, cp, pp]的取值范围
+ # 找到 world_size 的所有因子
+ factors = cal_factor(input_args.world_size)
+
+ # 找出所有可能的四个正整数乘积等于world_size
+ part1_combinations = []
+ for combination in itertools.combinations_with_replacement(factors, 4):
+ product = 1
+ for num in combination:
+ product *= num
+ if product == input_args.world_size:
+ perms = list(itertools.permutations(combination))
+ unique_configs = list(set(perms))
+ part1_combinations.extend(unique_configs)
+ part1_combinations = filter_specify_strategy(para, part1_combinations)
+
+ # part2. [ep, op, vp, mbs]
+ # ep是dp*tp的约数
+ # vp是 < ceil(总层数/pp)的任意整数
+ # mbs 取值(1, 2, 4)中
+ part2_combinations = []
+ num_layers = input_args.num_layers
+ for world_size_config in part1_combinations:
+ op_options = get_op_options(world_size_config)
+ vp_options = find_integers_less_than_ceil(num_layers, world_size_config[3])
+ mbs_options = [1, 2, 4]
+ # todo: mindformers 和 mindspeed 对ep的限制不同,当前是mindformers的,待修改,把这个挪到专家剪枝
+ ep_options = find_ep(input_args.expert_num)
+ result = list(itertools.product(ep_options, op_options, vp_options, mbs_options))
+
+ dp = world_size_config[0]
+ for tmp_result in result:
+ op = tmp_result[1]
+ # 格式 [(dp, tp, cp, pp), (ep, op, vp, mbs)]
+ if not para.STRATEGY.FSDP:
+ if op != dp:
+ # 这里不搜FSDP的含义是: FSDP取剩余的空间
+ tmp_result = (tmp_result[0], -1, tmp_result[2], tmp_result[3])
+ part2_combinations.append([world_size_config, tmp_result])
+
+
+ if not para.STRATEGY.EP:
+ part2_combinations = [config for config in part2_combinations if config[1][0] == 1]
+ # part3. sp只有开关与否 格式 [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp]
+ final_combinations = []
+ for part2_config in part2_combinations:
+ final_combinations.append([part2_config, False])
+ if part2_config[0][1] != 1:
+ final_combinations.append([part2_config, True])
+ return final_combinations
+
+
+def filter_specify_strategy(para, part1_combinations):
+ if not para.STRATEGY.DP:
+ part1_combinations = [config for config in part1_combinations if config[0] == 1]
+ if not para.STRATEGY.TP:
+ part1_combinations = [config for config in part1_combinations if config[1] == 1]
+ if not para.STRATEGY.CP:
+ part1_combinations = [config for config in part1_combinations if config[2] == 1]
+ if not para.STRATEGY.PP:
+ part1_combinations = [config for config in part1_combinations if config[3] == 1]
+ return part1_combinations
+
+
+# 优化器并行可以是dp的任意因子
+def get_op_options(world_size_config):
+ dp = world_size_config[0]
+ op_options = cal_factor(dp)
+ return op_options
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/expert_filter_configs.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/expert_filter_configs.py
new file mode 100644
index 0000000000000000000000000000000000000000..b788712f67202a6d976f0d536d185555230a921c
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/expert_filter_configs.py
@@ -0,0 +1,200 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""expert filter"""
+
+import math
+from fast_tuner.utils.logger import logger
+
+
+class ExpertFilterManager:
+ """
+ expert experience filter nd configs
+ """
+ def __init__(self, input_args, gbs):
+ self.expert_filters = []
+ self.input_args = input_args
+ self.gbs = gbs
+
+ @staticmethod
+ def sequential_combination(selected_experiences, candidate_space):
+ result = candidate_space
+ for exp in selected_experiences:
+ result = exp(result)
+ return result
+
+ @staticmethod
+ def get_cp(config):
+ return config[0][0][2]
+
+ @staticmethod
+ def get_dp(config):
+ return config[0][0][0]
+
+ @staticmethod
+ def get_op(config):
+ return config[0][1][1]
+
+ @staticmethod
+ def get_tp(config):
+ return config[0][0][1]
+
+ @staticmethod
+ def get_pp(config):
+ return config[0][0][3]
+
+ @staticmethod
+ def get_sp_switch(config):
+ return config[1]
+
+ @staticmethod
+ def get_world_size(config):
+ return math.prod(config[0][0])
+
+ @staticmethod
+ def get_mbs(config):
+ return config[0][1][3]
+
+ @staticmethod
+ def get_ep(config):
+ return config[0][1][0]
+
+ def get_gbs(self):
+ return self.gbs
+
+ def get_num_layers(self):
+ return self.input_args.num_layers
+
+ def get_mbn(self):
+ return self.input_args.mbn
+
+ def add_experience(self, experience_function):
+ """
+ 添加一个专家经验函数到列表中
+ :param experience_function: 要添加的专家经验函数
+ """
+ self.expert_filters.append(experience_function)
+ logger.info(f"add experience succ: {experience_function.__name__}")
+
+ def remove_experience(self, experience_function):
+ """
+ 从列表中移除一个专家经验函数
+ :param experience_function: 要移除的专家经验函数
+ """
+ if experience_function in self.expert_filters:
+ self.expert_filters.remove(experience_function)
+ logger.info(f"remove experience succ: {experience_function.__name__}")
+ else:
+ logger.info(f"can not find experience: {experience_function.__name__},无法移除。")
+
+ def ep_for_torchtitan(self, candidate_space):
+ """
+ 这里默认为etp=1的场景
+ """
+ configs = []
+ for config in candidate_space:
+ ep = self.get_ep(config)
+ cp = self.get_cp(config)
+ tp = self.get_tp(config)
+ op = self.get_op(config)
+ if ep % (cp * tp) == 0 and (op * cp * tp) % ep == 0:
+ configs.append(config)
+ return configs
+
+ def ep_for_mindspore(self, candidate_space):
+ configs = []
+ for config in candidate_space:
+ ep = self.get_ep(config)
+ dp = self.get_dp(config)
+ tp = self.get_tp(config)
+ op = self.get_op(config)
+ if op % ((dp * tp) / ep) == 0:
+ configs.append(config)
+ return configs
+
+ def cp_for_deepseek_expert(self, candidate_space):
+ # 2.28deepseek版本cp = 1
+ return [config for config in candidate_space if self.get_cp(config) == 1]
+
+ def dp_cp_ep_for_megatron_expert(self, candidate_space):
+ # megatron dp * cp % ep == 0
+ return [config for config in candidate_space
+ if self.get_dp(config) * self.get_cp(config) % self.get_ep(config) == 0]
+
+ def pp_for_deepseek(self, candidate_space):
+ # 万卡训练deepseek, 要求pp>1, pp过小内存会超
+ return [config for config in candidate_space if self.get_pp(config) > 1]
+
+ def pp_for_768die(self, candidate_space):
+ # 768die, 要求pp<=32,
+ return [config for config in candidate_space if self.get_pp(config) <= 32]
+
+ def tp_for_910b_expert(self, candidate_space):
+ # 910b为1机8卡,机间通信耗时过大,因此不能超过8
+ return [config for config in candidate_space if self.get_tp(config) <= 8]
+
+ def tp_for_large_scale_expert(self, candidate_space):
+ # 超节点技术,专家经验不超过64即可
+ return [config for config in candidate_space if self.get_tp(config) <= 64]
+
+ def tp_for_large_scale_768die(self, candidate_space):
+ # 768die, tp要是2的幂
+ return [config for config in candidate_space if self.get_tp(config) % 3 != 0]
+
+ def tp_for_yoco_expert(self, candidate_space):
+ return [config for config in candidate_space if 56 % self.get_tp(config) == 0]
+
+ def ep_for_large_scale_expert(self, candidate_space):
+ return [config for config in candidate_space if self.get_ep(config) <= 64]
+
+ def sp_for_lm_expert(self, candidate_space):
+ # 在千卡规模及以上sp一定开启,千卡规模以下可支持sp搜索
+ world_size = self.get_world_size(candidate_space[0])
+ return [config for config in candidate_space
+ if world_size < 1000 or (self.get_tp(config) ==1 or self.get_sp_switch(config))]
+
+ def pp_for_mbs_expert(self, candidate_space):
+ return [config for config in candidate_space if
+ self.get_pp(config) <=
+ min(self.get_num_layers(), self.get_gbs() // self.get_dp(config) // self.get_mbs(config))]
+
+ def gbs_for_dp_expert(self, candidate_space):
+ return [config for config in candidate_space if self.get_gbs() % self.get_dp(config) == 0]
+
+def expert_filter_configs(search_spaces, input_args, gbs):
+ """
+
+ :param search_spaces: 初始搜索空间 [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp]
+ :param input_args: 用户输入模型配置信息
+ :param gbs: global batch size
+ :return: 使用专家经验剪枝搜索空间后得到的配置 [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp]
+ """
+ expert_manager = ExpertFilterManager(input_args, gbs)
+ expert_manager.add_experience(expert_manager.cp_for_deepseek_expert)
+ expert_manager.add_experience(expert_manager.tp_for_large_scale_expert)
+ expert_manager.add_experience(expert_manager.ep_for_large_scale_expert)
+ expert_manager.add_experience(expert_manager.sp_for_lm_expert)
+ expert_manager.add_experience(expert_manager.pp_for_mbs_expert)
+ expert_manager.add_experience(expert_manager.gbs_for_dp_expert)
+ #expert_manager.add_experience(expert_manager.pp_for_deepseek)
+ #expert_manager.add_experience(expert_manager.dp_cp_ep_for_megatron_expert)
+ expert_manager.add_experience(expert_manager.ep_for_torchtitan)
+ #expert_manager.add_experience(expert_manager.ep_for_mindspore)
+ # add for 768die
+ expert_manager.add_experience(expert_manager.pp_for_768die)
+ expert_manager.add_experience(expert_manager.tp_for_large_scale_768die)
+ # # add for yoco model
+ # expert_manager.add_experience(expert_manager.tp_for_yoco_expert)
+ valid_configs = expert_manager.sequential_combination(expert_manager.expert_filters, search_spaces)
+ return valid_configs
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/memory_model.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/memory_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8e3048da43cbfd32a588e5d09d32fb79b23665
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/memory_model.py
@@ -0,0 +1,417 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""nd phase memory model"""
+
+import re
+import os
+import math
+import csv
+from enum import Enum
+from collections import defaultdict
+from fast_tuner.utils.logger import logger
+from fast_tuner.utils.common import generate_files, initial_offset, offset_for_dualpipe, is_dualpipe_open
+from fast_tuner.utils.dryrun_manage import launch_dryrun, read_dryrun_info
+
+
+def memory_simple_prune(config, profile_info):
+ memory_max = config.memory_max
+ tmp_layer = config.num_layer // config.pp_size
+ profile_memory = config.pp_size * tmp_layer * profile_info.act_mem_full_recomp \
+ + profile_info.embedding_mem + tmp_layer * profile_info.static_layer_MoE
+ coe_mem = 0.2
+ if profile_memory * (1 + coe_mem) > memory_max:
+ return False
+ return True
+
+def trans_format(dryrun_info_init, test_ep):
+ """
+ trans dryrun_info_init to [dp, tp, pp, ep, peak_mem_ep1, peak_mem_ep2]
+ """
+ config_mem_dict = defaultdict(lambda: [0, 0])
+ for config in dryrun_info_init:
+ key = tuple(config[:3])
+ peak_mem = config[4]
+ ep = config[3]
+ if ep == test_ep[0]:
+ index = 0
+ elif ep == test_ep[1]:
+ index = 1
+ else:
+ logger.error(f"error ep {ep} is not supported")
+ continue
+ config_mem_dict[key][index] = peak_mem
+
+ return [list(key) + value for key, value in config_mem_dict.items()]
+
+def grey_box_memory_prune(mindformers_args, dryrun_info_init, test_ep, max_expert_parallel):
+ """
+ ep灰盒剪枝
+
+ :param max_expert_parallel: ep最大值
+ :param test_ep:
+ :param dryrun_info_init: [dp, tp, pp, ep, peak_value]
+ :param mindformers_args:
+ :return: [dp, tp, pp, ep, evaluate_peak_mem]
+ """
+ dryrun_info = trans_format(dryrun_info_init, test_ep)
+ ep1, ep2 = test_ep
+ logger.info(f"dryrun_info len: {len(dryrun_info)} format [dp, tp, pp, peak_mem_ep{ep1}, peak_mem_ep{ep2}]")
+ logger.info("\n".join(str(config) for config in dryrun_info))
+
+ try:
+ max_mem = int(re.search(r'\d+', mindformers_args.context.max_device_memory).group()) * 1024
+ except (AttributeError, ValueError):
+ max_mem = 58 * 1024
+
+ memory_aware_configs = []
+ logger.info("format: dp_tp_pp_ep_evaluateMem")
+ ep_power = find_power_of_two(max_expert_parallel)
+ for dp, tp, pp, peak_ep, peak_ep_double in dryrun_info:
+ # 线性拟合ep会影响到的内存和ep不会影响到的内存
+ ep_memory = (peak_ep - peak_ep_double) * test_ep[1] #所有专家的内存
+ base_memory = peak_ep - ep_memory / test_ep[0]
+ # 确定ep最大能开多大,最大为6, ep最大64
+ ep_upperbound = 0
+ for i in range(ep_power+1):
+ if (dp*tp) % (2**i) == 0:
+ ep_upperbound += 1
+ # 输出满足内存上限的ep,如果ep64都不够就不返回
+ for j in range(ep_upperbound):
+ ep = 2 ** j
+ evaluate_mem = base_memory + ep_memory / ep
+ logger.info(f"{dp}_{tp}_{pp}_{ep}_{evaluate_mem}")
+ if evaluate_mem <= max_mem:
+ memory_aware_configs.append([dp, tp, pp, ep, evaluate_mem])
+ return memory_aware_configs
+
+def find_power_of_two(m):
+ if m <= 0:
+ return None
+ power = math.log2(m)
+ if power.is_integer():
+ return int(power)
+ return None
+
+def filter_oom(search_space, input_args, para):
+ """
+ filter evaluate oom configs
+ """
+ # todo: 是否dryurn返回值不同,需判断这里是否需要处理
+ if para.DRYRUN:
+ # 生成要做dryrun的配置
+ care_part_configs = select_dry_config(search_space, input_args)
+ test_ep = (8, 16)
+ dry_config = generate_dry_config(care_part_configs, input_args, test_ep)
+ dryrun_exe_switch = bool(para.DRYRUN_DATA_DIR)
+ if dryrun_exe_switch:
+ logger.info("need auto dryrun process")
+ file_task = "dryrun_yaml" if para.YAML_PATH else "dryrun_shell"
+ dryrun_file_dir = os.path.abspath(para.OUTPUT_PATH) + os.sep + file_task + os.sep
+ generate_files(dry_config, dryrun_file_dir, file_task, para, input_args)
+ dryrun_data_dir = os.path.join(os.path.abspath(para.OUTPUT_PATH), "dryrun_output")
+ if input_args.mf_args:
+ # 基于mindformers(mindspore)才做dryrun
+ launch_dryrun(input_args, dryrun_file_dir, dryrun_data_dir, para)
+ else:
+ dryrun_data_dir = para.DRYRUN_DATA_DIR
+
+ if input_args.mf_args:
+ dryrun_info = read_dryrun_info(dryrun_data_dir)
+ candidate_configs = grey_box_memory_prune(input_args, dryrun_info, test_ep, para.MAX_EXPERT_PARALLEL)
+ else:
+ candidate_configs = dry_config
+ generate_csv(para.OUTPUT_PATH, candidate_configs, input_args)
+ return candidate_configs
+
+ candidate_configs = [] #the format is [dp, tp, pp, ep, cp, op, evaluate_mem]
+ for config in search_space:
+ op_disable = False
+ dp, tp, cp, pp = config[0][0]
+ ep, op = config[0][1][:2]
+ if op < 0:
+ op_disable = True
+ op = dp
+ input_args.op, input_args.tp, input_args.pp, input_args.ep = op, tp, pp, ep
+ dense_size, moe_size, vocab_size = compute_weight_and_optimizer_memory(input_args)
+ try:
+ max_mem = int(re.search(r'\d+', input_args.context.max_device_memory).group()) * 1024
+ except (AttributeError, ValueError):
+ max_mem = 58 * 1024
+ if input_args.pp == 1:
+ evaluate_memory = (2 * vocab_size+moe_size * (input_args.num_layers - input_args.first_k_dense_replace)
+ + dense_size * input_args.first_k_dense_replace)
+ elif input_args.first_k_dense_replace > input_args.num_layers // input_args.pp:
+ estimated_first = vocab_size + dense_size * (input_args.num_layers // input_args.pp)
+ estimated_general = moe_size * math.ceil((input_args.num_layers+2)/input_args.pp)
+ evaluate_memory = max(estimated_first, estimated_general)
+ else:
+ estimated_first = (vocab_size + dense_size * input_args.first_k_dense_replace
+ + moe_size * (input_args.num_layers // input_args.pp - input_args.first_k_dense_replace))
+ estimated_general = moe_size * math.ceil((input_args.num_layers + 2) / input_args.pp)
+ evaluate_memory = max(estimated_first, estimated_general)
+ if op_disable: op = -1
+ if evaluate_memory <= max_mem:
+ if [dp, tp, pp, ep, cp, op, evaluate_memory] not in candidate_configs:
+ candidate_configs.append([dp, tp, pp, ep, cp, op, evaluate_memory])
+ else:
+ logger.info(f"mem over limit: evaluate mem {evaluate_memory}"
+ f"config dp {dp} tp {tp} pp {pp} ep {ep} cp {cp} op {op}")
+ logger.info(f"Prune Search space size: {len(candidate_configs)},"
+ f"format: [dp, tp, pp, ep, cp, op/fsdp, evaluate_peak_mem]")
+ return candidate_configs
+
+def select_dry_config(valid_configs, input_args):
+ """
+
+ :param valid_configs: [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp]
+ :param input_args: 配置文件参数
+ :return: [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp]
+ 其中vp=1, mbs=1, sp=true 每个(dp, tp, cp, pp),对应ep最大的配置列表
+ """
+ first = valid_configs[0][0][0]
+ max_ep = valid_configs[0][0][1][0]
+ op_with_ep = valid_configs[0][0][1][1]
+ ans = []
+ for config in valid_configs:
+ current_first = config[0][0]
+ current_ep = config[0][1][0]
+ current_op = config[0][1][1]
+ pp = config[0][0][3]
+ num_layers = input_args.num_layers
+
+ if is_dualpipe_open(input_args) and pp * 2 >= num_layers:
+ continue
+
+ if current_first == first:
+ if current_ep > max_ep:
+ max_ep = current_ep
+ op_with_ep = current_op
+ else:
+ ans.append([[first, [max_ep, op_with_ep, 1, 1]], True])
+ first = current_first
+ max_ep = current_ep
+ op_with_ep = current_op
+ # 添加最后一组数据
+ ans.append([[first, [max_ep, op_with_ep, 1, 1]], True])
+ logger.info(f"Dryrun candidate config size: {len(ans)}")
+ return ans
+
+def generate_csv(output_path, dryrun_config, input_args):
+ """
+ generate nd result to csv file
+ """
+ # 表头
+ if input_args.expert_num is not None:
+ headers = ['dp', 'tp', 'pp', 'ep', 'evaluate_mem']
+ else:
+ headers = ['dp', 'tp', 'pp', 'evaluate_mem']
+
+ # 写入 CSV 文件
+ try:
+ csv_path = os.path.join(os.path.abspath(output_path), "nd_candidate_config.csv")
+ with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
+ writer = csv.writer(csvfile)
+ # 写入表头
+ writer.writerow(headers)
+ # 写入数据
+ writer.writerows(dryrun_config)
+ logger.info("CSV file generate succ.")
+ except Exception as e:
+ logger.info(f"write CSV file fail: {e}")
+
+class CareTpye(Enum):
+ UNKNOWN = 0
+ WITH_EXPERT_MF = 1
+ NO_EXPERT = 2
+ WITH_EXPERT_NO_MF = 3
+
+def generate_dry_config(care_part_configs, input_args, test_ep):
+ """
+
+ :param test_ep: 要做dryrun的ep
+ :param care_part_configs: [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp]
+ :param input_args: 模型及环境等信息
+ :return: [dp, tp, pp, ep, offset] 或 [dp, tp, pp]
+ """
+ dry_run_config = []
+ layers_num = input_args.num_layers
+
+ if input_args.expert_num is not None and input_args.mf_args is not None:
+ care_type = CareTpye.WITH_EXPERT_MF
+ logger.info('dryrun configs format: dp_tp_pp_ep_offset')
+ elif input_args.expert_num is None:
+ care_type = CareTpye.NO_EXPERT
+ logger.info('dryrun configs format: dp_tp_pp')
+ else:
+ care_type = CareTpye.WITH_EXPERT_NO_MF
+ logger.info('dryrun configs format: dp_tp_pp_ep')
+
+ for config in care_part_configs:
+ dp, tp, _, pp = config[0][0]
+ ep = config[0][1][0]
+ # 若为deepseek模型
+ if care_type == CareTpye.WITH_EXPERT_MF:
+ for ep in test_ep:
+ # mindformers的约束
+ if input_args.mf_args is not None and dp * tp % ep != 0:
+ continue
+
+ use_zero_bubble_v = is_dualpipe_open(input_args)
+
+ offset_calculator = offset_for_dualpipe if use_zero_bubble_v else initial_offset
+ new_config = [dp, tp, pp, ep, offset_calculator(pp, layers_num)]
+ dry_run_config.append(new_config)
+ elif care_type == CareTpye.NO_EXPERT:
+ new_config = [dp, tp, pp]
+ dry_run_config.append(new_config)
+ else:
+ new_config = [dp, tp, pp, ep]
+ dry_run_config.append(new_config)
+
+
+ logger.info(f"Dryrun config size: {len(dry_run_config)}")
+ for config in dry_run_config:
+ config_str = "_".join(str(x) for x in config)
+ logger.info(config_str)
+ return dry_run_config
+
+NUM_BYTES_IN_MEGA = 1024 * 1024
+
+def compute_weight_and_optimizer_memory(input_args):
+ """
+
+ :param input_args: input training parameters
+ :return: [dense_layer mem, moe_layer mem, embedding_layer mem]/MB
+ """
+ #Attention projection size
+ if input_args.kv_channels:
+ query_projection_to_hidden_size_ratio =(
+ input_args.kv_channels * input_args.num_attention_heads / input_args.hidden_size)
+ else: query_projection_to_hidden_size_ratio = 1
+ # Group Query Attention.
+ if not input_args.group_query_attention:
+ input_args.num_query_groups = input_args.num_attention_heads
+ # MoE.
+ num_experts = 1 if input_args.expert_num is None else input_args.expert_num
+ gated_linear_multiplier = 3 / 2 if input_args.swiglu else 1
+
+ if input_args.expert_num is not None:
+ moe_ffn_hidden_size = input_args.moe_intermediate_size
+ shared_expert_ffn_hidden_size = (
+ input_args.moe_intermediate_size
+ if input_args.moe_shared_expert_intermediate_size is None
+ else input_args.moe_shared_expert_intermediate_size
+ )
+ else:
+ moe_ffn_hidden_size = 0
+ shared_expert_ffn_hidden_size = 0
+
+ if input_args.multi_latent_attention:
+ if input_args.qk_head_dim:
+ qk_head_dim = input_args.qk_head_dim
+ elif input_args.qk_nope_head_dim:
+ qk_head_dim = input_args.qk_nope_head_dim
+ else:
+ qk_head_dim = 0
+ print('qk head dim not specified')
+ if input_args.qk_pos_emb_head_dim:
+ qk_pos_emb_head_dim = input_args.qk_pos_emb_head_dim
+ elif input_args.qk_pos_rope_head_dim:
+ qk_pos_emb_head_dim = input_args.qk_pos_rope_head_dim
+ else:
+ qk_pos_emb_head_dim = 0
+ print('qk pos head dim not specified')
+ assert not input_args.group_query_attention
+ if input_args.q_lora_rank is None:
+ q_term = input_args.hidden_size * input_args.num_attention_heads * (qk_head_dim + qk_pos_emb_head_dim)
+ else:
+ ## q lora + rope + q norm
+ q_term = input_args.q_lora_rank * (
+ input_args.hidden_size + input_args.num_attention_heads * (qk_head_dim + qk_pos_emb_head_dim) + 1)
+
+ self_attn_term = (
+ q_term
+ ## kv lora + rope + kv norm
+ + input_args.kv_lora_rank
+ * (input_args.hidden_size + input_args.num_attention_heads * (qk_head_dim + input_args.v_head_dim) + 1)
+ + input_args.hidden_size * qk_pos_emb_head_dim
+ ## o proj
+ + (input_args.num_attention_heads * input_args.v_head_dim) * input_args.hidden_size
+ )
+ else:
+ self_attn_term = (
+ 2
+ * input_args.hidden_size
+ * input_args.hidden_size
+ * (
+ # Attention.
+ (
+ (1 + (input_args.num_query_groups / input_args.num_attention_heads))
+ * query_projection_to_hidden_size_ratio
+ )
+ )
+ )
+
+ num_parameters_in_dense_ffn = (
+ 2
+ * input_args.hidden_size
+ * (
+ # Dense MoE MLP.
+ (input_args.ffn_hidden_size * gated_linear_multiplier)
+ # Transformer layer norms.
+ + 2
+ )
+ )
+
+ num_parameters_in_moe_ffn_without_routed_expert = (
+ 2
+ * input_args.hidden_size
+ * (
+ # MoE MLP.
+ # Shared MoE MLP.
+ + (shared_expert_ffn_hidden_size * gated_linear_multiplier)
+ # Transformer layer norms.
+ + 2
+ )
+ )
+
+ num_parameters_of_routed_expert = (2 * input_args.hidden_size * moe_ffn_hidden_size
+ * num_experts * gated_linear_multiplier / input_args.ep)
+
+ embedding_size = input_args.hidden_size * input_args.padded_vocab_size
+ final_layernorm = 2 * input_args.hidden_size
+
+ num_bytes_per_parameter = (
+ 18 if not input_args.use_distributed_optimizer else 6 + (12 / input_args.op)
+ )
+
+ if not input_args.use_distributed_optimizer:
+ num_bytes_per_parameter_routed_expert = 18
+ elif input_args.op < input_args.ep:
+ num_bytes_per_parameter_routed_expert = 18
+ else:
+ num_bytes_per_parameter_routed_expert = 6 + (12 / input_args.op * input_args.ep)
+
+ embedding_layer = embedding_size / input_args.tp * num_bytes_per_parameter
+ dense_layer = (num_parameters_in_dense_ffn + self_attn_term) / input_args.tp * num_bytes_per_parameter
+ moe_layer = (self_attn_term + num_parameters_in_moe_ffn_without_routed_expert) * num_bytes_per_parameter + \
+ num_parameters_of_routed_expert * num_bytes_per_parameter_routed_expert
+ moe_layer /= input_args.tp
+ final_layernorm /= input_args.tp
+
+ if input_args.pp == 1:
+ return [dense_layer/NUM_BYTES_IN_MEGA, moe_layer/NUM_BYTES_IN_MEGA,
+ (embedding_layer * 2 + final_layernorm)/NUM_BYTES_IN_MEGA]
+ return [dense_layer/NUM_BYTES_IN_MEGA, moe_layer/NUM_BYTES_IN_MEGA, embedding_layer/NUM_BYTES_IN_MEGA]
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/para_for_nd_search.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/para_for_nd_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..0928767ddbb033001dd754cd371fb54a7a099dde
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/para_for_nd_search.py
@@ -0,0 +1,317 @@
+"""class ParaForNd inner use"""
+
+import re
+import toml
+from fast_tuner.utils.common import cal_model_layers_num
+from fast_tuner.utils.input_config import InputConfig
+from fast_tuner.utils.logger import logger
+
+
+class ParaForNd:
+ """
+ trans different params to inner param
+ """
+ def __init__(self, para):
+ self.seq_len = 2048
+ self.mf_args = None
+ self.max_mem = 58 * 1024
+ self.world_size = 1
+ self.dp = 8 # dp_replicate_degree * dp_shard_degree(OP)
+ self.dp_replicate_degree = 1
+ self.dp_shard_degree = -1
+ self.op = 1 # equal to real fsdp value
+ self.pp = 1
+ self.tp = 1
+ self.cp = 1
+ self.ep = 1
+ self.etp = 1
+ self.gbs = 8
+ self.mbn = 1
+ self.mbs = 1
+ self.num_layers = 6
+ self.swiglu = False
+ self.qk_pos_emb_head_dim = 64
+ self.qk_head_dim = 128
+ self.q_lora_rank = None
+ self.kv_lora_rank = None
+ self.v_head_dim = 128
+ self.expert_num = 8
+ self.moe_intermediate_size = 256
+ self.hidden_size = 256
+ self.ffn_hidden_size = 1024
+ self.num_attention_heads = 16
+ self.padded_vocab_size = 2048
+ self.first_k_dense_replace = 1
+ self.use_distributed_optimizer = True
+ self.kv_channels = 128
+ self.group_query_attention = True
+ self.num_query_groups = 8
+ self.moe_shared_expert_intermediate_size = None
+ self.multi_latent_attention = False
+ # add for titan
+ self.enable_fsdp_float8_all_gather = False
+ self.precompute_float8_dynamic_scale_for_fsdp = False
+ self.fsdp_reshard_after_forward = "default"
+ self.enable_async_tensor_parallel = False
+ self.pipeline_parallel_schedule = "Interleaved1F1B"
+ self.model_name = None
+
+ # for profile parser todo: 只为mindspeed服务,不应该在这里
+ self.num_layer_list = [4, 3]
+ self.recompute_method = 'block'
+ self.recompute_num_layers = 1
+ self.profile_steps = 2
+ self.get_args_from_file(para)
+
+ def print_member_value(self):
+ for attr, value in self.__dict__.items():
+ logger.info(f"{attr}: {value}")
+
+ def convert_value(self, value_str):
+ """尝试将字符串转换为整数/浮点数,失败则返回原字符串"""
+ if not value_str:
+ return value_str
+ # 尝试整数转换(处理正负整数)
+ if value_str.lstrip('-').isdigit():
+ return int(value_str)
+ # 尝试浮点数转换(处理正负浮点数、科学计数法)
+ try:
+ return float(value_str)
+ except ValueError:
+ return value_str # 非数字则返回原字符串
+
+ def parse_toml_parameters(self, toml_file_path):
+ params = toml.load(toml_file_path)
+ return params
+
+ # 部分变量没有解析出来,多行参数和引用都能解析
+ def parse_sh_parameters(self, sh_file_path):
+ """
+ 解析 shell 脚本内容中的参数,返回包含各模块参数的字典
+
+ Args:
+ sh_content: 读取的 shell 脚本内容(字符串)
+ Returns:
+ dict: 按模块分组的参数字典,格式为 {模块名: {参数名: 参数值}}
+ """
+ # 存储结果:模块名 -> {参数名: 参数值}
+
+ with open(sh_file_path, 'r', encoding='utf-8') as f:
+ sh_content = f.read()
+ result = {}
+
+ # 1. 提取所有变量(如 TP=1, DATA_PATH="xxx" 等),用于替换参数中的变量引用(如 ${TP})
+ variables = self.parse_variable(sh_content)
+
+ # 2. 提取模块参数(如 GPT_ARGS、DATA_ARGS)
+ self.parse_module_args(result, sh_content, variables)
+
+ # 3. 展开 result:将所有模块的参数字典合并为一个扁平字典
+ flattened_params = {}
+ for _, params in result.items():
+ # 合并到全局字典(若有重复参数,后面的模块会覆盖前面的)
+ flattened_params.update(params)
+
+ return flattened_params
+
+ def parse_module_args(self, result, sh_content, variables):
+ """
+ 匹配模块定义(如 GPT_ARGS="...参数...")
+ """
+ module_pattern = re.compile(r'^(\w+)\s*=\s*"(.*?)"$', re.DOTALL | re.MULTILINE)
+ for module_match in module_pattern.finditer(sh_content):
+ module_name = module_match.group(1) # 模块名:GPT_ARGS 或 DATA_ARGS
+ module_content = module_match.group(2) # 模块内的参数内容
+
+ # 解析模块内的参数(--key value 形式)
+ param_pattern = re.compile(r'--(\w[\w-]*)\s+([^\s\\]+)')
+ # 处理无值参数(如 --use-flash-attn)
+ flag_pattern = re.compile(r'--(\w[\w-]*)')
+
+ module_params = {}
+
+ # 先提取带值的参数(如 --data-path $DATA_PATH)
+ for match in param_pattern.finditer(module_content):
+ param_key = match.group(1) # 参数名:data-path
+ param_value = match.group(2) # 参数值:$DATA_PATH
+
+ # 替换变量引用(如 $DATA_PATH 替换为实际值)
+ for var_name, var_val in variables.items():
+ param_value = param_value.replace(f'${var_name}', var_val).replace(f'${{{var_name}}}', var_val)
+
+ module_params[param_key.replace('-', '_')] = self.convert_value(param_value)
+
+ # 再提取无值参数(如 --use-flash-attn)
+ for match in flag_pattern.finditer(module_content):
+ param_key = match.group(1)
+ standard_key = param_key.replace('-', '_')
+ if standard_key not in module_params: # 避免覆盖已提取的带值参数
+ module_params[param_key.replace('-', '_')] = True # 用 True 表示开关参数开启
+ if module_params:
+ result[module_name] = module_params
+
+ def parse_variable(self, sh_content):
+ """
+ 提取所有变量(如 TP=1, DATA_PATH="xxx" 等),用于替换参数中的变量引用(如 ${TP})
+ """
+ variables = {}
+ var_pattern = re.compile(r'^(\w+)\s*=\s*([\'"])(.*?)\2$', re.MULTILINE)
+ for match in var_pattern.finditer(sh_content):
+ var_name = match.group(1) # 变量名(如 A)
+ var_value = match.group(3) # 变量值(如 cccc)
+ variables[var_name] = var_value
+ # 补充处理不带引号的变量(如 TP=1, PP=4)
+ var_pattern2 = re.compile(r'^(\w+)\s*=\s*([\w.]+)$', re.MULTILINE)
+ for match in var_pattern2.finditer(sh_content):
+ var_name = match.group(1)
+ var_value = match.group(2)
+ variables[var_name] = var_value
+ return variables
+
+ def get_args_from_file(self, para):
+ """
+ get args from yaml/shell/toml file
+ """
+ if para.YAML_PATH:
+ self.trans_yaml_param(para)
+
+ elif para.SHELL_PATH:
+ self.trans_shell_param(para)
+
+ elif para.TOML_PATH:
+ self.trans_toml_param(para)
+
+ else:
+ raise RuntimeError("only support yaml/shell/toml file, pls input valid config file")
+
+ self.print_member_value()
+
+ def trans_toml_param(self, para):
+ """
+ trans toml param to inner param
+ """
+ parsed_args = self.parse_toml_parameters(para.TOML_PATH)
+ self.seq_len = parsed_args['training'].get('seq_len', 4096)
+ self.world_size = para.NPUS_PER_NODE * para.NNODES
+ self.pp = parsed_args['parallelism'].get('pipeline_parallel_degree', 1)
+ self.tp = parsed_args['parallelism'].get('tensor_parallel_degree', 1)
+ self.cp = parsed_args['parallelism'].get('context_parallel_degree', 1)
+ self.dp_shard_degree = parsed_args['parallelism'].get('data_parallel_shard_degree', -1)
+ if self.dp_shard_degree == -1:
+ fsdp = self.world_size // self.pp // self.tp // self.cp // parsed_args['parallelism'].get(
+ 'data_parallel_replicate_degree', 1)
+ else:
+ fsdp = self.dp_shard_degree
+ self.op = fsdp
+ self.dp = parsed_args['parallelism'].get('data_parallel_replicate_degree', 1) * fsdp
+ self.ep = parsed_args['parallelism'].get('expert_parallel_degree', 1)
+ self.etp = parsed_args['parallelism'].get('expert_tensor_parallel_degree', 1)
+ self.fsdp_reshard_after_forward = parsed_args['parallelism'].get('fsdp_reshard_after_forward', "default")
+ self.enable_async_tensor_parallel = parsed_args['parallelism'].get('enable_async_tensor_parallel', False)
+ local_batch_size = parsed_args['training'].get('local_batch_size', 1)
+ self.gbs = local_batch_size * self.dp
+ self.mbs = parsed_args['parallelism'].get('pipeline_parallel_microbatch_size', 1)
+ self.mbn = local_batch_size // self.mbs
+ self.enable_fsdp_float8_all_gather = parsed_args['quantize']['linear']['float8'].get(
+ 'enable_fsdp_float8_all_gather', False)
+ self.precompute_float8_dynamic_scale_for_fsdp = parsed_args['quantize']['linear']['float8'].get(
+ 'precompute_float8_dynamic_scale_for_fsdp', False)
+ self.model_name = parsed_args['model'].get('name', "llama3")
+ model_flavor = parsed_args['model'].get('flavor', "8B")
+ try:
+ # pylint: disable=C0415
+ import torchtitan.protocols.train_spec as train_spec_module
+ train_spec = train_spec_module.get_train_spec(self.model_name)
+ model_args = train_spec.model_args[model_flavor]
+
+ self.num_layers = model_args.n_layers
+ self.qk_pos_emb_head_dim = getattr(model_args, 'qk_rope_head_dim', 64)
+ self.qk_head_dim = getattr(model_args, 'qk_nope_head_dim', 128)
+ self.q_lora_rank = getattr(model_args, 'q_lora_rank', None)
+ self.kv_lora_rank = getattr(model_args, 'kv_lora_rank', None)
+ self.v_head_dim = getattr(model_args, 'v_head_dim', None)
+ self.expert_num = getattr(getattr(model_args, 'moe_args', None), 'num_experts', None)
+ self.moe_intermediate_size = getattr(model_args, 'moe_inter_dim', None)
+ self.hidden_size = getattr(model_args, 'dim', None)
+ self.ffn_hidden_size = getattr(model_args, 'inter_dim', 4 * self.hidden_size)
+ self.cal_ffn_hidden_size(model_args)
+ self.num_attention_heads = getattr(model_args, 'n_heads', None)
+ self.padded_vocab_size = getattr(model_args, 'vocab_size', None)
+ self.first_k_dense_replace = getattr(model_args, 'n_dense_layers', 0)
+ self.use_distributed_optimizer = self.dp_shard_degree != 1
+ self.kv_channels = self.hidden_size // self.num_attention_heads
+ n_kv_heads = getattr(model_args, 'n_kv_heads', None)
+ self.group_query_attention = (n_kv_heads is not None and n_kv_heads != self.num_attention_heads)
+ self.num_query_groups = n_kv_heads
+ self.moe_shared_expert_intermediate_size = self.moe_intermediate_size * getattr(
+ getattr(model_args, 'moe_args', None), 'num_shared_experts', 0)
+ self.multi_latent_attention = bool(self.q_lora_rank)
+
+ except Exception as e:
+ print(f'Error is: {e}')
+ self.profile_steps = 1
+
+ def trans_yaml_param(self, para):
+ """
+ trans yaml param to inner param
+ """
+ input_args = InputConfig(para.YAML_PATH)
+ self.dp = input_args.parallel_config.data_parallel
+ self.tp = input_args.parallel_config.model_parallel
+ self.pp = input_args.parallel_config.pipeline_stage
+ self.cp = 1 if input_args.parallel_config.get('context_parallel') is None else input_args.parallel_config.get(
+ 'context_parallel')
+ self.world_size = self.dp * self.tp * self.cp * self.pp
+ self.num_layers = cal_model_layers_num(input_args)
+ self.mbn = input_args.parallel_config.micro_batch_num
+ self.max_mem = input_args.context.max_device_memory
+ self.expert_num = None if input_args.moe_config is None else input_args.context.expert_num
+ self.mf_args = input_args
+
+ def cal_ffn_hidden_size(self, model_args):
+ """
+ 更新ffn_hidden_size
+ """
+ print(f"origin_ffn_hidden_size {self.ffn_hidden_size}")
+ multiple_of = getattr(model_args, 'multiple_of', None)
+ if multiple_of is None:
+ return
+ ffn_dim_multiplier = getattr(model_args, 'ffn_dim_multiplier')
+ if multiple_of and ffn_dim_multiplier:
+ tmp_ffn_hidden_size = int(2 * self.ffn_hidden_size / 3)
+ tmp_ffn_hidden_size = int(ffn_dim_multiplier * tmp_ffn_hidden_size)
+ self.ffn_hidden_size = multiple_of * ((tmp_ffn_hidden_size + multiple_of - 1) // multiple_of)
+ print("current_ffn_hidden_size {self.ffn_hidden_size}")
+
+ def trans_shell_param(self, para):
+ """
+ trans shell param to inner param
+ """
+ parsed_args = self.parse_sh_parameters(para.SHELL_PATH)
+ self.world_size = parsed_args['nproc_per_node'] * parsed_args['nnodes']
+ self.gbs = parsed_args.get('global_batch_size')
+ self.mbs = parsed_args.get('micro_batch_size')
+ self.pp = parsed_args.get('pipeline_model_parallel_size', 1)
+ self.tp = parsed_args.get('tensor_model_parallel_size', 1)
+ self.cp = parsed_args.get('context_parallel_size', 1)
+ self.dp = self.world_size // self.pp // self.tp // self.cp
+ self.mbn = self.gbs // self.mbs // self.dp
+ self.num_layers = parsed_args.get('num_layers')
+ self.swiglu = parsed_args.get('swiglu', False)
+ self.qk_pos_emb_head_dim = parsed_args.get('qk_pos_emb_head_dim', 64)
+ self.qk_head_dim = parsed_args.get('qk_head_dim', 128)
+ self.q_lora_rank = parsed_args.get('q_lora_rank')
+ self.kv_lora_rank = parsed_args.get('kv_lora_rank', 32)
+ self.v_head_dim = parsed_args.get('v_head_dim', 128)
+ self.expert_num = parsed_args.get('num_experts')
+ self.hidden_size = parsed_args.get('hidden_size')
+ self.ffn_hidden_size = parsed_args.get('ffn_hidden_size', 4 * self.hidden_size)
+ self.moe_intermediate_size = parsed_args.get('moe_intermediate_size', self.ffn_hidden_size)
+ self.num_attention_heads = parsed_args.get('num_attention_heads')
+ self.padded_vocab_size = parsed_args.get('padded_vocab_size', parsed_args.get('vocab_size'))
+ self.first_k_dense_replace = parsed_args.get('first_k_dense_replace', 0)
+ self.use_distributed_optimizer = parsed_args.get('use_distributed_optimizer', False)
+ self.kv_channels = parsed_args.get('kv_channels', self.hidden_size // self.num_attention_heads)
+ self.group_query_attention = parsed_args.get('group_query_attention', False)
+ self.moe_shared_expert_intermediate_size = parsed_args.get('moe_shared_expert_intermediate_size')
+ self.multi_latent_attention = parsed_args.get('multi_latent_attention', False)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/parallel_tool.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/parallel_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..659df7bb563692b0e894fc2466e98969c6da8b8e
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/parallel_tool.py
@@ -0,0 +1,161 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""entry file"""
+
+import argparse
+import time
+import sys
+import types
+
+from fast_tuner.ndsearch.para_for_nd_search import ParaForNd
+from fast_tuner.ndsearch.memory_model import filter_oom
+from fast_tuner.ndsearch.build_initial_spaces import build_initial_spaces
+from fast_tuner.ndsearch.expert_filter_configs import expert_filter_configs
+
+from fast_tuner.utils.common import check_dryrun_parallel_number
+from fast_tuner.utils.input_param import InputParam
+from fast_tuner.utils.logger import logger
+from fast_tuner.utils.common import parse_args_from_json
+from fast_tuner.utils.ppc_input import ParallelInput
+from fast_tuner.utils.profiling.profile_launch import ProfileLaunch
+from fast_tuner.utils.profiling.profile_parser import ProfileParser, ProfileMemParser
+from fast_tuner.utils.profiling.profile_prepare import profile_prepare
+# 这里不应该有的
+from fast_tuner.pipeline_conductor import pp_util
+from fast_tuner.pipeline_conductor.pipeline_parallel import pipeline_proc
+
+
+__all__ = ['taylor_search_tool']
+
+def taylor_search_tool(para):
+ """
+ A function for find out optimal ND parallel configuration.
+
+ Args:
+ param para: 用户输入自定义的参数
+
+ Returns:
+ parallel config fill back to yaml file: [dp, cp, tp, ...] and candidate csv file.
+ """
+ start_time = time.time()
+ para.print_params()
+
+ input_args = ParaForNd(para)
+
+ initial_configs = build_initial_spaces(input_args, para)
+ logger.info(f"Initial Search space size: {len(initial_configs)}")
+ expert_prune_search_space = expert_filter_configs(initial_configs, input_args, para.GBS)
+ if len(expert_prune_search_space) == 0:
+ logger.info("expert_prune_search_space is empty. Please check your expert rules.")
+ sys.exit()
+ logger.info(f"Expert Prune Search space size: {len(expert_prune_search_space)}")
+
+ mem_prune_space = filter_oom(expert_prune_search_space, input_args, para)
+ logger.info('%s', '\n'.join(str(item) for item in mem_prune_space))
+
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ logger.info(f"program before profiling cost time: {elapsed_time} s")
+
+ # profile_configs: list[dp, tp, pp, 0, 0, layers_num, toml/yaml/shell path, profile_result_dir]
+ # profile_file_dir: profile shell/toml file dir
+ profile_configs, profile_file_dir = profile_prepare(mem_prune_space, para, input_args)
+ if para.ALG_PHASE == 1:
+ logger.info(f"ALG_PHASE: {para.ALG_PHASE}, no need to profile and solve pipeline")
+ return
+ # 自动执行profile
+ profile_launch = ProfileLaunch(profile_configs, para)
+ profile_launch.profile_launch(profile_file_dir)
+ # 自动profile解析
+ profile_parser = ProfileParser(input_args, para)
+ profile_parser.parse_batch_profile_result(profile_configs)
+
+ # 内存解析
+ profile_mem_parser = ProfileMemParser(input_args, para)
+ profile_mem_parser.mem_parser(profile_configs)
+
+ # 流水线求解 todo: 想办法把candidate_configs传进去,不用csv读取
+ pipeline_input = ParallelInput(para, profile_file_dir)
+ pipeline_proc(pipeline_input)
+
+def main():
+ logger.info('start to run parallel tool')
+ parser = argparse.ArgumentParser(description='Run taylor_search_tool with user input parameters')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['YAML_PATH']}", default='',
+ help='Path to the YAML file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['SHELL_PATH']}", default='',
+ help='Path to the SHELL type config file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['TOML_PATH']}", default='',
+ help='Path to the TOML file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['MINDFORMERS_DIR']}",
+ default='', help='Directory of mindformers')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['MINDSPEED_PATH']}", default='',
+ help='Path to the MindSpeed file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['TORCHTITAN_PATH']}", default='',
+ help='Path to the Torchtitan file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['DRYRUN_DATA_DIR']}", default='',
+ help='Directory of dryrun data')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['PROFILE_DATA_DIR']}",
+ default='', help='Directory of profile data')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['NNODES']}", type=int,
+ default=1, help='The number of nodes')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['NPUS_PER_NODE']}", type=int,
+ default=8, help='The npus per node')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['DRYRUN_LIM']}", type=int,
+ default=2, help='The maximum number of dryrun at once')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['RANK_NUM']}", type=int,
+ default=64, help='Number of available device number')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['SOLVER_NAME']}", type=str,
+ default='HIGHS', help='Name of the solver')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['DATASET']}", type=str,
+ default='', help='Directory of dataset')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['MAX_EXPERT_PARALLEL']}", type=int,
+ default=64, help='Max number of expert parallel')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['OUTPUT_PATH']}", type=str,
+ default='./output/', help='Directory of output info')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['ENV_JSON']}", type=str,
+ default='', help='Environment variable config json file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['GBS']}", type=int,
+ default=1024, help='Global batch size')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['CONFIG']}", type=str,
+ help='Path to the JSON config file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['SELECT_RECOMPUTE']}", type=bool, default=True,
+ help='Whether search select recompute')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['ALG_PHASE']}", type=int, default=0,
+ help='Phase of parallel strategy search algorithm')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['PARSER_RESULT']}", type=str, default=None,
+ help='Profiling parser result file')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['DRYRUN']}", type=pp_util.str2bool, default=False,
+ help='Is auto dryrun')
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['CHECK']}", type=pp_util.str2bool, default=False,
+ help="Is double check")
+ parser.add_argument(f"--{InputParam.PARAM_MAPPING['STRATEGY']}", type=pp_util.str2dict,
+ default={'DP':True, 'TP':True, 'EP':True, 'FSDP':True, 'PP':False, 'CP':False},
+ help="Which parallel strategies are enabled")
+
+ args = parser.parse_args()
+ # for test
+ args.config = './config/setup_config/args_for_parallel_tool_titan.json'
+ # for test
+ if args.config:
+ parse_args_from_json(args)
+ args.strategy = types.SimpleNamespace(**args.strategy)
+ check_dryrun_parallel_number(args.dryrun_lim)
+
+ para = InputParam(args)
+ taylor_search_tool(para)
+
+if __name__ == '__main__':
+ main()
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/dryrun.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/dryrun.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcd4c1e6427dc90324ac316a0b553697fd388880
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/dryrun.py
@@ -0,0 +1,204 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""dryrun operation"""
+import os
+import shutil
+import json
+import argparse
+import subprocess
+from fast_tuner.utils.logger import logger
+from fast_tuner.pipeline_conductor import pp_util
+from multiprocessing import Pool
+from fast_tuner.pipeline_conductor.pp_util import pipeline_output_file
+
+dryrun_config_error = 'The config_file location and ms_adapter_file location are essential, please config!'
+
+
+class DryRun:
+ """dryrun operation"""
+ env_config_json = ''
+ register_path = 'research/jiutian'
+ dryrun_lim = 16
+ config_file_type = 0
+ is_write_to_file = True
+
+ def __init__(self, config_file, ms_adapter_file, output_name):
+ self.config_file = config_file
+ self.ms_adapter_file = ms_adapter_file
+ self.rank_gap = None
+ pp_output_file = os.path.join(os.getcwd(), pipeline_output_file)
+ if not os.path.exists(pp_output_file):
+ os.mkdir(pp_output_file)
+ self.log_file_name = os.path.join(pp_output_file, output_name)
+ if os.path.exists(self.log_file_name):
+ shutil.rmtree(self.log_file_name)
+ os.mkdir(self.log_file_name)
+
+ def dryrun(self, stage_num, rank_size):
+ self.rank_gap = rank_size // stage_num
+ self.set_env(rank_size, self.env_config_json)
+ remainder = stage_num
+ device_id = 0
+ while remainder > 0:
+ if remainder > self.dryrun_lim:
+ dryrun_num = self.dryrun_lim
+ else:
+ dryrun_num = remainder
+ remainder -= dryrun_num
+ with Pool(processes=dryrun_num) as pool:
+ pool.map(self.run_rank, range(device_id, device_id + dryrun_num))
+ device_id += dryrun_num
+ logger.info('pull dryrun of all stages!')
+
+ def set_env(self, rank_size, env_config_json):
+ with open(env_config_json, 'r', encoding='utf-8') as f:
+ env_vars = json.load(f)
+ env_vars['RANK_SIZE'] = str(rank_size)
+ os.environ.update(env_vars)
+
+ def start_dryrun(self, recompute_config, offset, num_layers, num_vpp, num_stage, rank_size, dense_layers, micro):
+ if self.config_file_type == 0:
+ name = pp_util.bulid_yaml(self.config_file, recompute_config, offset,
+ num_layers, num_vpp, num_stage, dense_layers, micro)
+ elif self.config_file_type == 1:
+ name = pp_util.bulid_shell(self.config_file, offset, num_layers, num_vpp, num_stage, dense_layers, micro)
+ else:
+ raise TypeError(dryrun_config_error)
+ self.config_file = name
+ self.dryrun(num_stage, rank_size)
+
+ def run_rank(self, stage):
+ device_id = stage
+ rank_id = stage * self.rank_gap
+ cwd = os.getcwd()
+ log_file = os.path.join(cwd, self.log_file_name, f'rank_{rank_id}.log')
+ logger.info(f"start training for rank_{rank_id}, device_{device_id}, waiting a moment...")
+ if self.config_file_type == 0:
+ os.environ['ASCEND_RT_VISIBLE_DEVICES'] = str(device_id)
+ os.environ['RANK_ID'] = str(rank_id)
+ command = ['python', self.ms_adapter_file, '--register_path',
+ self.register_path, '--config', self.config_file]
+ with open(log_file, 'w', encoding='utf-8') as log:
+ subprocess.run(command, stdout=log, stderr=subprocess.STDOUT)
+ elif self.config_file_type == 1:
+ env = os.environ.copy()
+ env['RANK_ID'] = str(rank_id)
+ command = ['bash', self.config_file, str(device_id), self.ms_adapter_file, log_file]
+ subprocess.run(command, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ else:
+ raise TypeError(dryrun_config_error)
+
+ def extract_memory_info(self, num_stage):
+ cwd = os.getcwd()
+ peak_mem = []
+ for stage in range(num_stage):
+ rank_id = self.rank_gap * stage
+ log_file = os.path.join(cwd, self.log_file_name, f"rank_{rank_id}.log")
+ peak_mem.append(pp_util.extract_peak_memory(log_file))
+ return peak_mem
+
+ def extract_memory_info_act(self, num_stage):
+ cwd = os.getcwd()
+ peak_mem = []
+ for stage in range(num_stage):
+ rank_id = self.rank_gap * stage
+ log_file = os.path.join(
+ cwd, self.log_file_name, f"rank_{rank_id}.log")
+ peak_mem.append(pp_util.extract_actual_peak_memory(log_file))
+ return peak_mem
+
+
+def one_rank_dryrun(stage, yaml_file, mindformer_file, output_file):
+ dry_run = DryRun(yaml_file, mindformer_file, output_file)
+ rank_size, pipeline_stage = pp_util.get_ranks_stages(yaml_file)
+ dry_run.rank_gap = rank_size // pipeline_stage
+ dry_run.set_env(rank_size, dry_run.env_config_json)
+ dry_run.run_rank(stage)
+
+
+def all_rank_dryrun(config_file, ms_adapter_file, output_file):
+ dry_run = DryRun(config_file, ms_adapter_file, output_file)
+ if DryRun.config_file_type == 0:
+ rank_size, pipeline_stage = pp_util.get_ranks_stages(config_file)
+ elif DryRun.config_file_type == 1:
+ rank_size, pipeline_stage = pp_util.get_shell_ranks_stages(config_file)
+ else:
+ raise TypeError(dryrun_config_error)
+ dry_run.dryrun(pipeline_stage, rank_size)
+ print(dry_run.extract_memory_info(pipeline_stage))
+ print(dry_run.extract_memory_info_act(pipeline_stage))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(prog='Yaml config dryrun',
+ description='Write config to the yaml file, and dryrun it', epilog='')
+ parser.add_argument('--yaml', '-y', type=str, default=None,
+ help="Path of training config (.yaml)")
+ parser.add_argument('--shell', '-sh', type=str, default=None,
+ help="Path of training config (.sh)")
+ parser.add_argument('--mindformers', '-mf', type=str, default=None,
+ help="Absolute path of run_mindformers (.py)")
+ parser.add_argument('--mindspeed', '-mp', type=str, default=None,
+ help="Absolute path of posttrain_gpt (.py)")
+ parser.add_argument('--output_file', '-f', type=str, default='dryrun_output',
+ help="The location to place the output files")
+ parser.add_argument('--offset', '-o', type=pp_util.str2list,
+ default=None, help="offset list")
+ parser.add_argument('--is_recompute', '-ir', type=pp_util.str2bool,
+ default=None, help="Whether to open recompute")
+ parser.add_argument('--recompute_layers', '-rl', type=pp_util.str2list,
+ default=None, help="recompute_layers list")
+ parser.add_argument('--is_select_recompute', '-is', type=pp_util.str2bool,
+ default=None, help="Whether to open select_recompute")
+ parser.add_argument('--select_recompute_layers', '-sl', type=pp_util.str2list,
+ default=None, help="select_recompute_layers list")
+ parser.add_argument('--env_config_json', '-e', type=str, required=True,
+ default='./config/boss_env_config.json', help="Path of environment config (.json)")
+ parser.add_argument('--register_path', '-rp', type=str, default='research/jiutian',
+ help="Path of register")
+ parser.add_argument('--dryrun_lim', '-dl', type=pp_util.str2int, default=16,
+ help="The number of dryrun at once")
+ args = parser.parse_args()
+
+ if args.yaml and args.mindformers:
+ config_file = args.yaml
+ ms_adapter_file = args.mindformers
+ DryRun.config_file_type = 0
+ elif args.shell and args.mindspeed:
+ config_file = args.shell
+ ms_adapter_file = args.mindspeed
+ DryRun.config_file_type = 1
+ else:
+ raise TypeError(dryrun_config_error)
+
+ output_file = args.output_file
+ DryRun.env_config_json = args.env_config_json
+ DryRun.register_path = args.register_path
+ DryRun.dryrun_lim = args.dryrun_lim
+ if args.recompute_layers and args.is_recompute is None:
+ args.is_recompute = True
+ if args.select_recompute_layers and args.is_select_recompute is None:
+ args.is_select_recompute = True
+
+ if args.offset is None and args.is_select_recompute is None and args.is_recompute is None:
+ logger.info('Use old yaml config to dryrun')
+ elif DryRun.config_file_type == 0:
+ config_file = pp_util.build_new_config_yaml(args)
+ elif DryRun.config_file_type == 1:
+ config_file = pp_util.build_new_config_shell(args)
+ else:
+ raise TypeError(dryrun_config_error)
+
+ all_rank_dryrun(config_file, ms_adapter_file, output_file)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/fitting.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/fitting.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c3268c2740c05dcd655fc4383d985132497211f
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/fitting.py
@@ -0,0 +1,242 @@
+import numpy as np
+from fast_tuner.utils.logger import logger
+from fast_tuner.pipeline_conductor import solution
+from fast_tuner.pipeline_conductor import dryrun
+from fast_tuner.pipeline_conductor import pp_util
+"""fitting file"""
+class FitMem:
+ cur_peak_mem = []
+
+ def __init__(self, cur_solution: solution.Solution):
+ self.cur_solution = cur_solution
+ self.cur_peak_mem = cur_solution.check_peak_mem
+ self.peaks = cur_solution.peak_num
+ self.init_config = cur_solution.init_config
+ self.init_memory = cur_solution.init_config.memory
+
+ def linear_fit(self):
+ if self.init_config.expert_input.is_select_recomp and self.init_config.expert_input.is_full_recomp:
+ length_x_lim = 11
+ elif not self.init_config.expert_input.is_select_recomp and not self.init_config.expert_input.is_full_recomp:
+ length_x_lim = 7
+ else:
+ length_x_lim = 9
+ if self.init_config.pipeline_stage < length_x_lim:
+ length_x = self.init_config.pipeline_stage
+ else:
+ length_x = length_x_lim
+ if length_x < 4:
+ logger.warning(f'can not fit for pipeline_stage = {self.init_config.pipeline_stage}')
+ return
+ coe_a, array_b = self.form_coe_matrix_mem_array(length_x)
+ np.set_printoptions(suppress=True, precision=3) # 禁用科学计数法,保留 6 位小数
+ modified_mem, res, rank, s = np.linalg.lstsq(coe_a, array_b, rcond=None)
+ logger.info(f'the residual = {np.linalg.norm(res)}')
+ logger.info(f'the rank = {rank}')
+ if np.linalg.norm(res) > 1e-3 or rank < self.init_config.pipeline_stage:
+ logger.warning(f'The distribution can not correct the memory!')
+ else:
+ modified_mem = list(np.round(np.array(modified_mem), decimals=1))
+ self.correct_mem(modified_mem)
+
+ def correct_mem(self, modify_mem):
+ self.init_memory.static_mem0 = modify_mem[0]
+ self.init_memory.static_mem = modify_mem[1]
+ self.init_memory.lm_head_mem = modify_mem[2]
+ self.init_memory.act_mem = modify_mem[3]
+
+ if len(modify_mem) == 5:
+ self.init_memory.layer_mem = modify_mem[4]
+ if len(modify_mem) == 6:
+ self.init_memory.layer_mem = modify_mem[4]
+ self.init_memory.re_comp_mem = modify_mem[5]
+ if len(modify_mem) == 8:
+ self.init_memory.layer_mem = modify_mem[4]
+ self.init_memory.re_comp_mem = modify_mem[5]
+ self.init_memory.layer_mem012 = modify_mem[6]
+ self.init_memory.act_mem12 = modify_mem[7]
+ self.init_memory.act_mem0 = modify_mem[7]
+ if self.init_config.expert_input.is_select_recomp:
+ if len(modify_mem) == 9:
+ self.init_memory.layer_mem = modify_mem[4]
+ self.init_memory.re_comp_mem = modify_mem[5]
+ self.init_memory.layer_mem012 = modify_mem[6]
+ self.init_memory.act_mem12 = modify_mem[7]
+ self.init_memory.re_comp_mem12 = modify_mem[8]
+ self.init_memory.re_comp_mem0 = modify_mem[8]
+ if len(modify_mem) == 10:
+ self.init_memory.layer_mem = modify_mem[4]
+ self.init_memory.re_comp_mem = modify_mem[5]
+ self.init_memory.layer_mem012 = modify_mem[6]
+ self.init_memory.act_mem12 = modify_mem[7]
+ self.init_memory.re_comp_mem12 = modify_mem[8]
+ self.init_memory.re_comp_mem0 = modify_mem[8]
+ self.init_memory.select_mem = modify_mem[9]
+ if len(modify_mem) > 10:
+ self.init_memory.layer_mem = modify_mem[4]
+ self.init_memory.re_comp_mem = modify_mem[5]
+ self.init_memory.layer_mem012 = modify_mem[6]
+ self.init_memory.act_mem12 = modify_mem[7]
+ self.init_memory.re_comp_mem12 = modify_mem[8]
+ self.init_memory.re_comp_mem0 = modify_mem[8]
+ self.init_memory.select_mem = modify_mem[9]
+ self.init_memory.select_mem12 = modify_mem[10]
+ self.init_memory.select_mem0 = self.init_memory.select_mem12
+ else:
+ if len(modify_mem) >= 9:
+ self.init_memory.layer_mem = modify_mem[4]
+ self.init_memory.re_comp_mem = modify_mem[5]
+ self.init_memory.layer_mem012 = modify_mem[6]
+ self.init_memory.act_mem12 = modify_mem[7]
+ self.init_memory.re_comp_mem12 = modify_mem[8]
+ self.init_memory.re_comp_mem0 = modify_mem[8]
+ self.init_memory.update_up_mem()
+ logger.info('The correct memory information:')
+ self.init_memory.print_mem()
+
+ def form_coe_matrix_mem_array(self, length_x):
+ coe_a = np.empty((self.init_config.pipeline_stage, length_x), float)
+ array_b = np.empty(self.init_config.pipeline_stage, float)
+ for stage in range(self.init_config.pipeline_stage):
+ if stage == 0:
+ coe_a[stage][0] = 1
+ coe_a[stage][1] = 0
+ coe_a[stage][2] = 0
+ elif stage == self.init_config.pipeline_stage - 1:
+ coe_a[stage][0] = 0
+ coe_a[stage][1] = 0
+ coe_a[stage][2] = 1
+ else:
+ coe_a[stage][0] = 0
+ coe_a[stage][1] = 1
+ coe_a[stage][2] = 0
+ coe_a[stage][3] = self.cur_solution.peak_num.peak_num_act_type2[stage]
+ if length_x == 4:
+ array_b[stage] = (self.cur_peak_mem[stage] - self.cur_solution.layer2_dis_stage[stage] *
+ self.init_memory.layer_mem -
+ self.peaks.peak_num_recompute_type2[stage] * self.init_memory.re_comp_mem -
+ self.cur_solution.layer1_dis_stage[stage] * self.init_memory.layer_mem012 -
+ self.peaks.peak_num_act_type1[stage] * self.init_memory.act_mem12 -
+ self.peaks.peak_num_recompute_type1[stage] * self.init_memory.re_comp_mem12 -
+ self.peaks.peak_num_select_recom_type2[stage] * self.init_memory.select_mem -
+ self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12)
+ continue
+ if length_x == 6:
+ coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage]
+ coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage]
+ array_b[stage] = (self.cur_peak_mem[stage] -
+ self.cur_solution.layer1_dis_stage[stage] * self.init_memory.layer_mem012 -
+ self.peaks.peak_num_act_type1[stage] * self.init_memory.act_mem12 -
+ self.peaks.peak_num_recompute_type1[stage] * self.init_memory.re_comp_mem12 -
+ self.peaks.peak_num_select_recom_type2[stage] * self.init_memory.select_mem -
+ self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12)
+ continue
+ if length_x == 8:
+ coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage]
+ coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage]
+ coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage]
+ coe_a[stage][7] = self.peaks.peak_num_act_type1[stage]
+ array_b[stage] = (self.cur_peak_mem[stage] - self.peaks.peak_num_recompute_type1[stage] *
+ self.init_memory.re_comp_mem12 - self.peaks.peak_num_select_recom_type2[stage] *
+ self.init_memory.select_mem -
+ self.peaks.peak_num_select_recom_type1[stage] * self.init_memory.select_mem12)
+ continue
+ if length_x == 9:
+ coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage]
+ coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage]
+ coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage]
+ coe_a[stage][7] = self.peaks.peak_num_act_type1[stage]
+ coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage]
+ array_b[stage] = (self.cur_peak_mem[stage] - self.peaks.peak_num_select_recom_type2[stage] *
+ self.init_memory.select_mem -
+ self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12)
+ continue
+ if self.init_config.expert_input.is_select_recomp:
+ if length_x == 10:
+ coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage]
+ coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage]
+ coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage]
+ coe_a[stage][7] = self.peaks.peak_num_act_type1[stage]
+ coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage]
+ coe_a[stage][9] = self.peaks.peak_num_select_recom_type2[stage]
+ array_b[stage] = (self.cur_peak_mem[stage] -
+ self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12)
+ continue
+ if length_x >= 11:
+ coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage]
+ coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage]
+ coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage]
+ coe_a[stage][7] = self.peaks.peak_num_act_type1[stage]
+ coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage]
+ coe_a[stage][9] = self.peaks.peak_num_select_recom_type2[stage]
+ coe_a[stage][10] = self.peaks.peak_num_select_recom_type1[stage]
+ array_b[stage] = self.cur_peak_mem[stage]
+ continue
+ else:
+ if length_x > 9:
+ coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage]
+ coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage]
+ coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage]
+ coe_a[stage][7] = self.peaks.peak_num_act_type1[stage]
+ coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage]
+ array_b[stage] = self.cur_peak_mem[stage]
+ continue
+ return coe_a, array_b
+
+ def is_over_mem(self):
+ over_mem = self.get_over_mem(self.cur_solution)
+ if all(over_mem[stage] == 0 for stage in range(len(over_mem))):
+ if dryrun.DryRun.is_write_to_file:
+ if dryrun.DryRun.config_file_type == 0:
+ recompute_config = pp_util.build_recompute_config(True, True, self.cur_solution.
+ rs_dis.tolist(), self.cur_solution.ra_dis.tolist())
+ pp_util.write_config_to_yaml(recompute_config, self.cur_solution.offset.tolist(), self.init_config.
+ config_file)
+ elif dryrun.DryRun.config_file_type == 1:
+ pp_util.write_config_to_shell(self.cur_solution.offset.tolist(), self.init_config.
+ config_file)
+ else:
+ raise TypeError(dryrun.dryrun_config_error)
+ logger.info(f'The result is available for training, config has write to '
+ f'{self.init_config.config_file}!')
+ return over_mem, False
+
+ return over_mem, True
+
+ def reduce_mem_lim_for_fitting(self, over_mem, i):
+ if over_mem.keys().__contains__(0):
+ self.init_config.memory.mem_lim_stage0 -= over_mem[0] * (i + 1)
+ if over_mem.keys().__contains__(self.init_config.pipeline_stage - 1):
+ self.init_config.memory.mem_lim_last -= over_mem[self.init_config.pipeline_stage - 1] * (i + 1)
+ over_mem = {key: value for key, value in over_mem.items() if key != 0 and
+ key != self.init_config.pipeline_stage - 1}
+ if over_mem:
+ self.init_config.memory.mem_lim_others -= max(over_mem.values()) * (i + 1)
+
+ def set_peak_mem_by_dryrun(self, cur_solution: solution.Solution):
+ init_config = cur_solution.init_config
+ expert_input = init_config.expert_input
+ dry_run = dryrun.DryRun(expert_input.config_file, expert_input.ms_adapter_file,
+ expert_input.double_check_dryrun_filename)
+ recompute_config = pp_util.build_recompute_config(True, True, cur_solution.rs_dis.
+ tolist(), cur_solution.ra_dis.tolist())
+ total_number = init_config.num_layers_type1 + init_config.num_layers_type2
+ dry_run.start_dryrun(recompute_config, cur_solution.offset.tolist(), total_number,
+ init_config.pp_interleave_num,
+ init_config.pipeline_stage, init_config.rank_size, init_config.num_layers_type1,
+ init_config.micro_batch_num)
+ cur_solution.check_peak_mem = dry_run.extract_memory_info_act(init_config.pipeline_stage)
+ self.cur_peak_mem = cur_solution.check_peak_mem
+ logger.info(f'check peak_mem = {cur_solution.check_peak_mem}')
+
+ def get_over_mem(self, cur_solution: solution.Solution):
+ self.set_peak_mem_by_dryrun(cur_solution)
+ over_mem = {}
+ init_config = cur_solution.init_config
+ for stage in range(init_config.pipeline_stage):
+ if cur_solution.check_peak_mem[stage] - init_config.mem_lim > 0:
+ logger.warning(f'The stage{stage} result is over memory limit! please check the offset!')
+ over_mem[stage] = cur_solution.check_peak_mem[stage] - init_config.mem_lim
+ else:
+ over_mem[stage] = 0
+ return over_mem
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/math_model.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/math_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..57b3db0e342641c168e88197a489f473e69100a9
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/math_model.py
@@ -0,0 +1,319 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""math model construct for pp"""
+import numpy as np
+import sys
+from ortools.linear_solver import pywraplp
+from fast_tuner.utils.logger import logger
+
+from fast_tuner.pipeline_conductor import pp_util
+from fast_tuner.pipeline_conductor.start_service import InitConfig, HIGHS_NAME
+from fast_tuner.pipeline_conductor import micro
+
+
+class Model:
+ def __init__(self, model_input: InitConfig):
+ # 初始化输入
+ self.input = model_input
+ self.memory = model_input.memory
+ self.expert_input = model_input.expert_input
+ # 前后计算层数、时间
+ self.x_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.x_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.offset = None
+ self.f_duration = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.b_duration = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ # 当前stage,interleave是否存在某一个种type的层
+ self.indicator_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.indicator_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ # 选择重计算、完全重计算、内存
+ self.rs_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.rs_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.ra_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.ra_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.rs_or_ra = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ self.mem = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object)
+ # 前后向开始时间、micro batch在各个part中分布
+ magic_number = 5
+ if self.input.micro_batch_num // self.input.pipeline_stage >= magic_number + 1:
+ dummy_mbn = self.input.micro_batch_num % self.input.pipeline_stage + magic_number * self.input.pipeline_stage
+ self.residue = self.input.micro_batch_num // self.input.pipeline_stage - magic_number
+ temp_parts = magic_number
+ self.input.parts = magic_number
+ else:
+ dummy_mbn = self.input.micro_batch_num
+ self.residue = 0
+ temp_parts = model_input.parts
+ self.input.parts = temp_parts
+ self.forward_s = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num, self.input.parts,
+ dummy_mbn, self.input.seq_splits), dtype=object)
+ self.backward_s = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num, self.input.parts,
+ dummy_mbn, self.input.seq_splits), dtype=object)
+ self.distribution = pp_util.construct_distribution(dummy_mbn, self.input.pipeline_stage)
+ # 形成前后向任务列表
+ self.sort_micro = micro.SortMicro(temp_parts, model_input.pp_interleave_num, model_input.pipeline_stage,
+ self.distribution, self.expert_input.low_mem, model_input.seq_splits)
+ self.final_orders = self.sort_micro.final_orders
+ # 求解初始化
+ self.solver = pywraplp.Solver.CreateSolver(model_input.expert_input.solver_name)
+ if not self.solver:
+ self.solver = pywraplp.Solver.CreateSolver(HIGHS_NAME)
+ self.layer_upper = 20
+ self.b_duration_upper = self.layer_upper * 3
+ self.mem_upper = 60000
+ self.time_upper = 99999
+ self.min_time = None
+ self.gap = None
+ self.model_status = None
+ self.solution_status = None
+ self.run_time = None
+
+ def define_variables(self):
+ self.stable_dur = self.solver.NumVar(0, self.time_upper, f'stable_duration')
+ for stage in range(self.input.pipeline_stage):
+ for vpp in range(self.input.pp_interleave_num):
+ self.x_type1[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'x_type1_{stage}_{vpp}')
+ self.x_type2[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'x_type2_{stage}_{vpp}')
+ self.indicator_type1[stage][vpp] = self.solver.BoolVar(f'indicator1_{stage}_{vpp}')
+ self.indicator_type2[stage][vpp] = self.solver.BoolVar(f'indicator2_{stage}_{vpp}')
+ self.f_duration[stage][vpp] = self.solver.NumVar(0, self.layer_upper, f'f_duration_{stage}_{vpp}')
+ self.b_duration[stage][vpp] = self.solver.NumVar(0, self.b_duration_upper, f'b_duration_{stage}_{vpp}')
+ if self.input.expert_input.is_select_recomp:
+ self.rs_type1[stage][vpp] = self.solver.IntVar(
+ 0, self.layer_upper, f'rs_type1_{stage}_{vpp}')
+ self.rs_type2[stage][vpp] = self.solver.IntVar(
+ 0, self.layer_upper, f'rs_type2_{stage}_{vpp}')
+ else:
+ self.rs_type1[stage][vpp] = self.solver.IntVar(0, 0, f'rs_type1_{stage}_{vpp}')
+ self.rs_type2[stage][vpp] = self.solver.IntVar(0, 0, f'rs_type2_{stage}_{vpp}')
+ if self.input.expert_input.is_full_recomp:
+ self.ra_type1[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'ra_type1_{stage}_{vpp}')
+ self.ra_type2[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'ra_type2_{stage}_{vpp}')
+ else:
+ self.ra_type1[stage][vpp] = self.solver.IntVar(0, 0, f'ra_type1_{stage}_{vpp}')
+ self.ra_type2[stage][vpp] = self.solver.IntVar(0, 0, f'ra_type2_{stage}_{vpp}')
+ if not self.expert_input.is_support_ra_with_rs:
+ self.rs_or_ra[stage][vpp] = self.solver.IntVar(0, 1, f'rs_or_ra_{stage}_{vpp}')
+ self.mem[stage][vpp] = self.solver.NumVar(0, self.mem_upper, f'mem_{stage}_{vpp}')
+
+ for stage in range(self.input.pipeline_stage):
+ for vpp in range(self.input.pp_interleave_num):
+ for part in range(self.input.parts):
+ for micro_id in range(self.distribution[part]):
+ for split in range(self.input.seq_splits):
+ self.forward_s[stage][vpp][part][micro_id][split] = (
+ self.solver.NumVar(0, self.time_upper,
+ f'forward_s_{stage}_{vpp}_{part}_{micro_id}_{split}'))
+ self.backward_s[stage][vpp][part][micro_id][split] = (
+ self.solver.NumVar(0, self.time_upper,
+ f'backward_s{stage}_{vpp}_{part}_{micro_id}_{split}'))
+ logger.info(f'Number of variables = {self.solver.NumVariables()}')
+
+ def define_constraint(self):
+ # 先type1后type2的层分配约束
+ self.layer_alloc_constraints()
+
+ # 前后向计算时间约束
+ self.forward_backward_time_constraints()
+
+ # 同stage micro-batch之间的约束
+ self.same_stage_micro_constraints()
+ # 同micro-batch,stage间的约束
+ self.same_micro_stage_constraints()
+
+ # 内存约束
+ self.mem_constraints()
+ # layer约束
+ self.layers_constraints()
+ if self.input.parts > 1:
+ for s in range(self.input.pipeline_stage):
+ self.solver.Add(self.stable_dur >= self.forward_s[s][0][-1][0][0] - self.forward_s[s][0][-2][0][0])
+ logger.info(f"Number of constraints = {self.solver.NumConstraints()}")
+
+ def layer_alloc_constraints(self):
+ for vpp in range(self.input.pp_interleave_num):
+ for stage in range(self.input.pipeline_stage):
+ self.solver.Add(self.x_type1[stage][vpp] <= self.layer_upper * self.indicator_type1[stage][vpp])
+ self.solver.Add(self.x_type1[stage][vpp] >= self.indicator_type1[stage][vpp])
+ self.solver.Add(self.x_type2[stage][vpp] <= self.layer_upper * self.indicator_type2[stage][vpp])
+ self.solver.Add(self.x_type2[stage][vpp] >= self.indicator_type2[stage][vpp])
+ if stage < self.input.pipeline_stage - 1:
+ self.solver.Add(self.indicator_type1[stage][vpp] >= self.indicator_type1[stage + 1][vpp])
+ self.solver.Add(self.indicator_type2[stage][vpp] <= self.indicator_type2[stage + 1][vpp])
+ if vpp < self.input.pp_interleave_num - 1:
+ self.solver.Add(
+ self.indicator_type1[self.input.pipeline_stage - 1][vpp] >= self.indicator_type1[0][vpp + 1])
+ self.solver.Add(
+ self.indicator_type2[self.input.pipeline_stage - 1][vpp] <= self.indicator_type2[0][vpp + 1])
+
+ def forward_backward_time_constraints(self):
+ if self.input.expert_input.is_head_loss_input:
+ head_loss = self.input.expert_input.head_loss
+ else:
+ try:
+ head_loss = ((self.input.vocab_size / 2 / self.input.hidden_size) /
+ (
+ 1 + 1 + 3 * self.input.intermediate_size / 2 / self.input.hidden_size + self.input.seq_length
+ / self.input.hidden_size) * 1.6)
+ except TypeError:
+ head_loss = 1
+ for stage in range(self.input.pipeline_stage):
+ for vpp in range(self.input.pp_interleave_num):
+ # self.expert_input.layer_ratio is an integer now (used to be a list of integers)
+ if stage == self.input.pipeline_stage - 1 and vpp == self.input.pp_interleave_num - 1:
+ self.solver.Add(self.f_duration[stage][vpp] == self.x_type2[stage][vpp] +
+ self.x_type1[stage][vpp] * self.expert_input.layer_ratio + head_loss)
+ else:
+ self.solver.Add(self.f_duration[stage][vpp] == self.x_type2[stage][vpp] +
+ self.x_type1[stage][vpp] * self.expert_input.layer_ratio)
+ self.solver.Add(self.b_duration[stage][vpp] == self.expert_input.backward_ratio *
+ self.f_duration[stage][vpp] + self.expert_input.srRatio *
+ (self.rs_type2[stage][vpp] + self.rs_type1[stage][vpp] * self.expert_input.layer_ratio)
+ + self.expert_input.recompute_ratio * self.ra_type2[stage][vpp] +
+ self.ra_type1[stage][vpp] * self.expert_input.layer_ratio)
+
+ def same_stage_micro_constraints(self):
+ for stage in range(self.input.pipeline_stage):
+ stage_order = self.final_orders[stage]
+ for i in range(len(stage_order) - 1):
+ p0, vpp0, state0, id0, split0 = (stage_order[i].part, stage_order[i].vpp, stage_order[i].state,
+ stage_order[i].micro_id, stage_order[i].split)
+ p1, vpp1, state1, id1, split1 = (stage_order[i + 1].part, stage_order[i + 1].vpp,
+ stage_order[i + 1].state, stage_order[i + 1].micro_id,
+ stage_order[i].split)
+ if state0 == 'f':
+ if state1 == 'f':
+ self.solver.Add(self.forward_s[stage][vpp0][p0][id0][split0] + self.f_duration[stage][vpp0] /
+ self.input.seq_splits <= self.forward_s[stage][vpp1][p1][id1][split1])
+ else:
+ self.solver.Add(self.forward_s[stage][vpp0][p0][id0][split0] + self.f_duration[stage][vpp0] /
+ self.input.seq_splits <= self.backward_s[stage][vpp1][p1][id1][split1])
+ else:
+ if state1 == 'f':
+ self.solver.Add(self.backward_s[stage][vpp0][p0][id0][split0] + self.b_duration[stage][vpp0] /
+ self.input.seq_splits <= self.forward_s[stage][vpp1][p1][id1][split1])
+ else:
+ self.solver.Add(self.backward_s[stage][vpp0][p0][id0][split0] + self.b_duration[stage][vpp0] /
+ self.input.seq_splits <= self.backward_s[stage][vpp1][p1][id1][split1])
+
+ def same_micro_stage_constraints(self):
+ for part in range(self.input.parts):
+ for micro_id in range(self.distribution[part]):
+ for split in range(self.input.seq_splits):
+ for vpp in range(self.input.pp_interleave_num):
+ for stage in range(self.input.pipeline_stage):
+ # 前向:
+ if stage != self.input.pipeline_stage - 1:
+ self.solver.Add(
+ self.forward_s[stage][vpp][part][micro_id][split] + self.f_duration[stage][vpp] /
+ self.input.seq_splits <= self.forward_s[stage + 1][vpp][part][micro_id][split])
+ elif vpp != self.input.pp_interleave_num - 1:
+ self.solver.Add(
+ self.forward_s[stage][vpp][part][micro_id][split] + self.f_duration[stage][vpp] /
+ self.input.seq_splits <= self.forward_s[0][vpp + 1][part][micro_id][split])
+ else:
+ self.solver.Add(
+ self.forward_s[stage][vpp][part][micro_id][split] + self.f_duration[stage][vpp] /
+ self.input.seq_splits <= self.backward_s[stage][vpp][part][micro_id][split])
+ # 后向:
+ if stage != 0:
+ self.solver.Add(
+ self.backward_s[stage][vpp][part][micro_id][split] + self.b_duration[stage][vpp] /
+ self.input.seq_splits <= self.backward_s[stage - 1][vpp][part][micro_id][split])
+ else:
+ if vpp != 0:
+ self.solver.Add(
+ self.backward_s[stage][vpp][part][micro_id][split] + self.b_duration[stage][vpp]
+ / self.input.seq_splits <= self.backward_s[
+ self.input.pipeline_stage - 1][vpp - 1][part][micro_id][split])
+
+ def layers_constraints(self):
+ layers_type1 = 0
+ layers_type2 = 0
+ indicator_total = 0
+ for stage in range(self.input.pipeline_stage):
+ for vpp in range(self.input.pp_interleave_num):
+ layers_type1 += self.x_type1[stage][vpp]
+ layers_type2 += self.x_type2[stage][vpp]
+ indicator_total += self.indicator_type1[stage][vpp] + self.indicator_type2[stage][vpp]
+ self.solver.Add(self.x_type1[stage][vpp] >= self.ra_type1[stage][vpp] + self.rs_type1[stage][vpp])
+ self.solver.Add(self.x_type2[stage][vpp] >= self.ra_type2[stage][vpp] + self.rs_type2[stage][vpp])
+ if not self.expert_input.is_support_ra_with_rs:
+ self.solver.Add(self.rs_type1[stage][vpp] <= self.layer_upper * self.rs_or_ra[stage][vpp])
+ self.solver.Add(self.rs_type2[stage][vpp] <= self.layer_upper * self.rs_or_ra[stage][vpp])
+ self.solver.Add(self.ra_type1[stage][vpp] <= self.layer_upper * (1 - self.rs_or_ra[stage][vpp]))
+ self.solver.Add(self.ra_type2[stage][vpp] <= self.layer_upper * (1 - self.rs_or_ra[stage][vpp]))
+ self.solver.Add(layers_type1 == self.input.num_layers_type1)
+ self.solver.Add(layers_type2 == self.input.num_layers_type2)
+ self.solver.Add(indicator_total >= self.input.pipeline_stage * self.input.pp_interleave_num)
+ self.solver.Add(indicator_total <= self.input.pipeline_stage * self.input.pp_interleave_num + 1)
+
+ def mem_constraints(self):
+ for stage in range(self.input.pipeline_stage):
+ for vpp in range(self.input.pp_interleave_num):
+ self.solver.Add(self.mem[stage][vpp] == self.memory.act_mem *
+ (self.x_type2[stage][vpp] - self.rs_type2[stage][vpp] - self.ra_type2[stage][vpp]) +
+ self.memory.select_mem * self.rs_type2[stage][vpp] +
+ self.memory.re_comp_mem * self.ra_type2[stage][vpp] +
+ self.memory.act_mem0 * (self.x_type1[stage][vpp] - self.rs_type1[stage][vpp] -
+ self.ra_type1[stage][vpp]) + self.memory.select_mem0 *
+ self.rs_type1[stage][vpp] + self.memory.re_comp_mem0 * self.ra_type1[stage][vpp])
+ for stage in range(self.input.pipeline_stage):
+ total_layer_mem = 0
+ for vpp in range(self.input.pp_interleave_num):
+ total_layer_mem += (self.x_type1[stage][vpp] * self.memory.layer_mem012 + self.x_type2[stage][vpp] *
+ self.memory.layer_mem)
+ for sub in range(1, len(self.final_orders[stage]) + 1):
+ consume_mem = total_layer_mem
+ for i in range(sub):
+ part, vpp, state, micro_id, split = (self.final_orders[stage][i].part,
+ self.final_orders[stage][i].vpp,
+ self.final_orders[stage][i].state,
+ self.final_orders[stage][i].micro_id,
+ self.final_orders[stage][i].split)
+ # 计算时间均分,因此这里激活内存做均分处理
+ if state == 'f':
+ consume_mem += self.mem[stage][vpp] / self.input.seq_splits
+ else:
+ consume_mem -= self.mem[stage][vpp] / self.input.seq_splits
+ if stage == 0:
+ self.solver.Add(consume_mem <= self.memory.mem_lim_stage0)
+ elif stage == self.input.pipeline_stage - 1:
+ self.solver.Add(consume_mem <= self.memory.mem_lim_last)
+ else:
+ self.solver.Add(consume_mem <= self.memory.mem_lim_others)
+
+ def define_obj(self):
+ self.solver.Minimize(self.backward_s[0][0][-1][self.distribution[-1] - 1][0] + self.b_duration[0][0] /
+ self.input.seq_splits + self.residue * self.stable_dur)
+
+ def output_model(self, mps_file):
+ with open(mps_file, 'w', encoding='utf-8') as file:
+ mps_text = self.solver.ExportModelAsMpsFormat(False, False)
+ file.write(mps_text)
+ logger.info(f'Had write to file: {mps_file}')
+
+ def solve(self):
+ self.solver.EnableOutput()
+ logger.info(f'Solving with {self.solver.SolverVersion()}')
+ if self.expert_input.time_limit != sys.maxsize:
+ self.solver.SetTimeLimit(self.expert_input.time_limit)
+ self.solver.Solve()
+ self.min_time = self.solver.Objective().Value()
+ logger.info(f'The objective value = {self.min_time}')
+
+
+if __name__ == '__main__':
+ yaml_file = 'C:\\working\\768_4k.yaml'
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/memory.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0324ed0a301fef28a130139dbc5c3032fb31d04
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/memory.py
@@ -0,0 +1,74 @@
+from fast_tuner.utils.logger import logger
+"""memory model"""
+
+class Memory:
+ select_mem0 = 76
+ select_mem12 = 76
+ select_mem = 77
+ re_comp_mem0 = 3
+ re_comp_mem12 = 3
+ re_comp_mem = 5
+ act_mem0 = 79
+ act_mem12 = 79
+ act_mem = 79
+ layer_mem012 = 349
+ layer_mem = 340
+ static_mem0 = 734
+ static_mem = 116
+ lm_head_mem = 690
+
+ def __init__(self, mem_lim):
+ self.mem_lim = mem_lim
+ self.mem_lim_stage0 = mem_lim - self.static_mem0
+ self.mem_lim_others = mem_lim - self.static_mem
+ self.mem_lim_last = mem_lim - self.lm_head_mem
+
+ def update_up_mem(self):
+ self.mem_lim_stage0 = self.mem_lim - self.static_mem0
+ self.mem_lim_others = self.mem_lim - self.static_mem
+ self.mem_lim_last = self.mem_lim - self.lm_head_mem
+
+ def print_mem(self):
+ logger.info(f'select_mem0={self.select_mem0}, select_mem12={self.select_mem12}, select_mem={self.select_mem}, '
+ f're_comp_mem0={self.re_comp_mem0}, re_comp_mem12={self.re_comp_mem12}, '
+ f're_comp_mem={self.re_comp_mem}, '
+ f'act_mem0={self.act_mem0}, act_mem12={self.act_mem12}, act_mem={self.act_mem}, '
+ f'layer_mem012={self.layer_mem012}, layer_mem={self.layer_mem}, '
+ f'static_mem0={self.static_mem0}, static_mem={self.static_mem}, '
+ f'lm_head_mem={self.lm_head_mem}, mem_lim_stage0={self.mem_lim_stage0}, '
+ f'mem_lim_others={self.mem_lim_others}, mem_lim_last={self.mem_lim_last}')
+
+ def write_memory_to_file(self, mem_file):
+ with open(mem_file, 'w', encoding='utf-8') as file:
+ file.write(f'select_mem0={self.select_mem0}\n')
+ file.write(f'select_mem12={self.select_mem12}\n')
+ file.write(f'select_mem={self.select_mem}\n')
+ file.write(f're_comp_mem0={self.re_comp_mem0}\n')
+ file.write(f're_comp_mem12={self.re_comp_mem12}\n')
+ file.write(f're_comp_mem={self.re_comp_mem}\n')
+ file.write(f'act_mem0={self.act_mem0}\n')
+ file.write(f'act_mem12={self.act_mem12}\n')
+ file.write(f'act_mem={self.act_mem}\n')
+ file.write(f'layer_mem012={self.layer_mem012}\n')
+ file.write(f'layer_mem={self.layer_mem}\n')
+ file.write(f'static_mem0={self.static_mem0}\n')
+ file.write(f'static_mem={self.static_mem}\n')
+ file.write(f'lm_head_mem={self.lm_head_mem}\n')
+ logger.info(f'Write memory info to {mem_file}')
+
+ def get_mem(self):
+ mem = (f'select_mem0={self.select_mem0}\n'
+ f'select_mem12={self.select_mem12}\n'
+ f'select_mem={self.select_mem}\n'
+ f're_comp_mem0={self.re_comp_mem0}\n'
+ f're_comp_mem12={self.re_comp_mem12}\n'
+ f're_comp_mem={self.re_comp_mem}\n'
+ f'act_mem0={self.act_mem0}\n'
+ f'act_mem12={self.act_mem12}\n'
+ f'act_mem={self.act_mem}\n'
+ f'layer_mem012={self.layer_mem012}\n'
+ f'layer_mem={self.layer_mem}\n'
+ f'static_mem0={self.static_mem0}\n'
+ f'static_mem={self.static_mem}\n'
+ f'lm_head_mem={self.lm_head_mem}')
+ return mem
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/micro.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/micro.py
new file mode 100644
index 0000000000000000000000000000000000000000..281878a807f112103c1f33bbc9ae3854792fbfb3
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/micro.py
@@ -0,0 +1,149 @@
+from fast_tuner.pipeline_conductor.memory import Memory
+"""micro proc for pp"""
+
+class Micro:
+ part = int
+ vpp = int
+ state = 'f'
+ micro_id = int
+ split = int
+
+ def __init__(self, part, vpp, state, micro_id, split):
+ self.part = part
+ self.vpp = vpp
+ self.state = state
+ self.micro_id = micro_id
+ self.split = split
+
+
+class SortMicro:
+ parts = int
+ num_vpp = int
+ num_stage = int
+ distribution = []
+ low_mem = bool
+ seq_split = int
+ is_f_then_b = False
+
+ def __init__(self, parts, num_vpp, num_stage, distribution, low_mem, seq_split):
+ self.forward = []
+ self.backward = []
+ self.warmup_num = []
+ self.final_orders = []
+ self.parts = parts
+ self.num_vpp = num_vpp
+ self.num_stage = num_stage
+ self.distribution = distribution
+ self.low_mem = low_mem
+ self.seq_split = seq_split
+ self.build_f_b_sort()
+ self.set_warmup_num()
+ self.set_micro_sort()
+
+ def build_f_b_sort(self):
+ for part in range(self.parts):
+ for vpp in range(self.num_vpp):
+ for micro_id in range(self.distribution[part]):
+ for split in range(self.seq_split):
+ micro = Micro(part, vpp, 'f', micro_id, split)
+ self.forward.append(micro)
+ for vpp in range(self.num_vpp - 1, -1, -1):
+ for micro_id in range(self.distribution[part]):
+ for split in range(self.seq_split - 1, -1, -1):
+ micro = Micro(part, vpp, 'b', micro_id, split)
+ self.backward.append(micro)
+
+ def set_warmup_num(self):
+ for stage in range(self.num_stage):
+ if self.low_mem:
+ warmup = min(((self.num_vpp - 1) * self.distribution[0] + (self.num_stage - stage - 1)) *
+ self.seq_split, len(self.forward))
+ else:
+ warmup = min(((self.num_vpp - 1) * self.distribution[0] + (self.num_stage - stage - 1) * 2) *
+ self.seq_split, len(self.forward))
+ # 最后一个stage,第一个micro前向做完之后才能做后向
+ if stage == self.num_stage - 1:
+ warmup = warmup + self.seq_split - 1
+ self.warmup_num.append(warmup)
+
+ def set_micro_sort(self):
+ for stage in range(self.num_stage):
+ stage_order = []
+ stage_order += self.forward[: self.warmup_num[stage]]
+ for i in range(self.warmup_num[stage], len(self.forward)):
+ stage_order.append(self.forward[i])
+ stage_order.append(self.backward[i - self.warmup_num[stage]])
+ stage_order += self.backward[len(self.forward) - self.warmup_num[stage]:]
+ self.final_orders.append(stage_order)
+
+
+class PeakNum:
+ sort_micro = SortMicro
+
+ def __init__(self, sort_micro: SortMicro):
+ self.peak_num_recompute_type1 = {}
+ self.peak_num_recompute_type2 = {}
+ self.peak_num_select_recom_type1 = {}
+ self.peak_num_select_recom_type2 = {}
+ self.peak_num_act_type1 = {}
+ self.peak_num_act_type2 = {}
+ self.max_mem = {}
+ self.micro_num_of_max_mem = {}
+ self.sort_micro = sort_micro
+ self.num_stage = sort_micro.num_stage
+ self.num_vpp = sort_micro.num_vpp
+
+ def set_peak_act_recompute_num(self, x_type2, rs_type2, ra_type2, x_type1, rs_type1, ra_type1, memory: Memory):
+ for stage in range(self.sort_micro.num_stage):
+ # 各个micro处的内存及激活、重计算份数
+ self.max_mem[stage] = 0
+ if stage == 0:
+ static_mem = memory.static_mem0
+ elif stage == self.sort_micro.num_stage - 1:
+ static_mem = memory.lm_head_mem
+ else:
+ static_mem = memory.static_mem
+ layer_mem = (memory.layer_mem012 * sum(x_type1[vpp][stage] for vpp in range(self.num_vpp)) + memory.
+ layer_mem * sum(x_type2[vpp][stage] for vpp in range(self.num_vpp)))
+ mem_stage = static_mem + layer_mem
+ num_recom_type1_stage = 0
+ num_recom_type2_stage = 0
+ num_select_recom_type1_stage = 0
+ num_select_recom_type2_stage = 0
+ num_act_type1_stage = 0
+ num_act_type2_stage = 0
+ for i in range(len(self.sort_micro.final_orders[stage])):
+ micro_batch = self.sort_micro.final_orders[stage][i]
+ vpp = micro_batch.vpp
+ act_num_type1 = x_type1[vpp][stage] - rs_type1[vpp][stage] - ra_type1[vpp][stage]
+ act_num_type2 = x_type2[vpp][stage] - rs_type2[vpp][stage] - ra_type2[vpp][stage]
+ act_mem = memory.act_mem12 * act_num_type1 + memory.act_mem * act_num_type2
+ ra_mem = memory.re_comp_mem12 * ra_type1[vpp][stage] + memory.re_comp_mem * ra_type2[vpp][stage]
+ rs_mem = memory.select_mem12 * rs_type1[vpp][stage] + memory.select_mem * rs_type2[vpp][stage]
+ # 计算时间切成seq_split份,可以看成数据切成seq_split份,即动态内存切分;layer_mem与static_mem不变
+ total_mem = (act_mem + ra_mem + rs_mem) / self.sort_micro.seq_split
+ if micro_batch.state == 'f':
+ mem_stage += total_mem
+ num_recom_type1_stage += ra_type1[vpp][stage]
+ num_recom_type2_stage += ra_type2[vpp][stage]
+ num_select_recom_type1_stage += rs_type1[vpp][stage]
+ num_select_recom_type2_stage += rs_type2[vpp][stage]
+ num_act_type1_stage += act_num_type1
+ num_act_type2_stage += act_num_type2
+ else:
+ mem_stage -= total_mem
+ num_recom_type1_stage -= ra_type1[vpp][stage]
+ num_recom_type2_stage -= ra_type2[vpp][stage]
+ num_select_recom_type1_stage -= rs_type1[vpp][stage]
+ num_select_recom_type2_stage -= rs_type2[vpp][stage]
+ num_act_type1_stage -= act_num_type1
+ num_act_type2_stage -= act_num_type2
+ if mem_stage > self.max_mem[stage]:
+ self.max_mem[stage] = mem_stage
+ self.peak_num_recompute_type1[stage] = num_recom_type1_stage
+ self.peak_num_recompute_type2[stage] = num_recom_type2_stage
+ self.peak_num_select_recom_type1[stage] = num_select_recom_type1_stage
+ self.peak_num_select_recom_type2[stage] = num_select_recom_type2_stage
+ self.peak_num_act_type1[stage] = num_act_type1_stage
+ self.peak_num_act_type2[stage] = num_act_type2_stage
+ self.micro_num_of_max_mem[stage] = i + 1
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pipeline_parallel.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pipeline_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..762632687be3a7f7202b438641506b9777972052
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pipeline_parallel.py
@@ -0,0 +1,232 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""pp solver entry file"""
+import sys
+import os.path
+import argparse
+from fast_tuner.utils.logger import logger
+from fast_tuner.pipeline_conductor import math_model, pp_util
+from fast_tuner.pipeline_conductor.start_service import InitConfig, ExpertInput, HIGHS_NAME, pipeline_output_file
+from fast_tuner.pipeline_conductor import solution
+from fast_tuner.pipeline_conductor import fitting
+from fast_tuner.utils.ppc_input import ParallelInput
+from fast_tuner.pipeline_conductor.result_csv import ResultCsv
+from fast_tuner.pipeline_conductor.dryrun import DryRun, dryrun_config_error
+
+from fast_tuner.utils.common import check_dryrun_parallel_number
+
+mps_dir_name = 'model_mps'
+sol_dir_name = 'sol_output'
+
+
+def pp_calculator(expert_input: ExpertInput, model_args =None) -> solution.Solution:
+ init_config = InitConfig(expert_input)
+ cur_solution = solve_problem(init_config)
+ if cur_solution.solution_status == 'None':
+ return cur_solution
+ if expert_input.is_double_check:
+ fit_mem = fitting.FitMem(cur_solution)
+ over_mem, is_over_mem = fit_mem.is_over_mem()
+ if not is_over_mem:
+ return cur_solution
+ if expert_input.fit_level == 0:
+ i = 0
+ while i < 5 and is_over_mem:
+ fit_mem.reduce_mem_lim_for_fitting(over_mem, i)
+ logger.info(f'Correct the memory at the {i + 1} time!')
+ init_config.memory.print_mem()
+ cur_solution = solve_problem(init_config)
+ if cur_solution.solution_status == 'None':
+ logger.info('dryrun_check result is over memory limit')
+ return cur_solution
+ fit_mem = fitting.FitMem(cur_solution)
+ over_mem, is_over_mem = fit_mem.is_over_mem()
+ i += 1
+ else:
+ fit_mem.linear_fit()
+ cur_solution = solve_problem(init_config)
+ return cur_solution
+
+
+def pipeline_proc(pipeline_input: ParallelInput):
+ if len(pipeline_input.candidate_configs) == 0:
+ raise ValueError('There is no candidate configs!')
+ is_low_mem = pipeline_input.is_lowmem
+ solver_name = pipeline_input.solver_name
+ ms_adapter_file = pipeline_input.ms_adapter_file
+ DryRun.env_config_json = pipeline_input.env_config_json
+ DryRun.dryrun_lim = pipeline_input.dryrun_lim
+ ExpertInput.is_dryrun = pipeline_input.dryrun
+ ExpertInput.is_double_check = pipeline_input.check
+ result_csv = ResultCsv(pipeline_input.output_path, pipeline_output_file)
+ num_all = len(pipeline_input.candidate_configs)
+ num_cur = 0
+ for candidate in pipeline_input.candidate_configs:
+ result_csv.config_to_csv(candidate, is_low_mem, solver_name)
+ for candidate in pipeline_input.candidate_configs:
+ candidate_input = ExpertInput(candidate.config_path, ms_adapter_file)
+ candidate_input.model_args = candidate.model_args
+ candidate_input.low_mem = is_low_mem
+ candidate_input.solver_name = solver_name
+ candidate_input.layer_ratio = candidate.profiling_info.dmratio
+ candidate_input.backward_ratio = candidate.profiling_info.bfratio
+ candidate_input.head_loss = candidate.profiling_info.hratio
+ candidate_input.recompute_ratio = candidate.profiling_info.re_grow_ration
+ num_cur += 1
+ logger.info(f'---------------------- Testing {num_cur}/{num_all}:{candidate.config_path} ----------------------')
+ try:
+ cur_solution = pp_calculator(candidate_input)
+ result_csv.result_to_csv(cur_solution)
+ except Exception as e:
+ logger.error(f'{candidate.config_path} error: {e}. Continue to next one')
+
+
+def solve_problem(init_config: InitConfig):
+ origin_model = math_model.Model(init_config)
+ origin_model.define_variables()
+ origin_model.define_constraint()
+ origin_model.define_obj()
+ cur_solution = solution.Solution(init_config)
+
+ # output mps
+ mps_dir = os.path.join(init_config.expert_input.output_file_dir, mps_dir_name)
+ if not os.path.exists(mps_dir):
+ os.mkdir(mps_dir)
+ mps_file = os.path.join(mps_dir, init_config.mps_sol_filename + '.mps')
+ origin_model.output_model(mps_file)
+
+ # solve
+ sol_dir = os.path.join(init_config.expert_input.output_file_dir, sol_dir_name)
+ if not os.path.exists(sol_dir):
+ os.mkdir(sol_dir)
+ sol_file = os.path.join(sol_dir, init_config.mps_sol_filename + '.sol')
+ if not os.path.exists(mps_file):
+ logger.error('build model error!')
+ if init_config.expert_input.solver_name == HIGHS_NAME:
+ is_origin_solver = False
+ pp_util.highs_solve_mps(mps_file, sol_file, origin_model, init_config.expert_input.time_limit)
+ elif init_config.expert_input.solver_name == 'QIUQI':
+ is_origin_solver = False
+ solver_file = '/home/zhugelu/MIXSolver/bin/MIXSolver' # 更改为本地的求解器地址
+ pp_util.qiuqi_solver_mps(solver_file, mps_file, sol_file, origin_model)
+ else:
+ is_origin_solver = True
+ origin_model.solve()
+
+ cur_solution.set_solution(origin_model, is_origin_solver, sol_file)
+ if cur_solution.solution_status != 'None':
+ cur_solution.solution_print()
+ return cur_solution
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(prog='TAYLOR AutoBalancing', description=(
+ 'Balance layers onto pipeline stages, '
+ + 'considering recomputation and interleaving'),
+ epilog='')
+ # 大模型类别
+ parser.add_argument('-llm', '--llm_class', type=int, default=0,
+ help="0-deepseek,1-boss")
+ # Training Yaml configuration
+ parser.add_argument('-yaml', '--train_yaml', type=str, default=None,
+ help="Path of training config (.yaml)")
+ parser.add_argument('-shell', '--train_shell', type=str, default=None,
+ help="Path of training config (.sh)")
+ # mindformers location
+ parser.add_argument('-mindformers', '--mindformers_loc', type=str, default=None,
+ help="Absolute path of run_mindformers (.py)")
+ parser.add_argument('-mindspeed', '--mindspeed_loc', type=str, default=None,
+ help="Absolute path of posttrain_gpt (.py)")
+ # solver name
+ parser.add_argument('-solver', '--solver_name', type=str, default=HIGHS_NAME,
+ help="The solver name")
+ # layer ratio
+ parser.add_argument('-layer_ratio', '---layer_ratio', type=float, default=0.33,
+ help="Time ratio of calculating of dense to moe")
+ # backward ratio
+ parser.add_argument('-b_ratio', '---backward_ratio', type=float, default=2.0,
+ help="Time ratio of calculating of backward to forward")
+ # head_loss
+ parser.add_argument('-head_loss', '---head_loss', type=float, default=1.5,
+ help="Time of the last layer is added")
+ # recompute_ratio
+ parser.add_argument('-ra_ratio', '---recompute_ratio', type=float, default=1,
+ help="Time of the last layer is added")
+ # Search time
+ parser.add_argument('-t', '--time_limit', type=int, default=sys.maxsize,
+ help="Limitation on searching time")
+ # 是否自动Dryrun
+ parser.add_argument('-dryrun', '--dryrun', type=pp_util.str2bool, default=True,
+ help="Is auto dryrun")
+ # 是否自动check
+ parser.add_argument('-check', '--check', type=pp_util.str2bool, default=True,
+ help="IS double check")
+ parser.add_argument('-is_write', '--is_write', type=pp_util.str2bool, default=True,
+ help="IS write solution to config file")
+ # fit level,0:超内存时直接减少内存上限求解;1:超内存时线性回归拟合内存信息求解
+ parser.add_argument('-fit', '--fit_level', type=int, default=0,
+ help="Fit memory when the result is over the limit: 0-reduce the memory limit;"
+ " 1 or >1-fit the memory info")
+ # 是否提取solution信息
+ parser.add_argument('-extract', '--extract', type=pp_util.str2bool, default=False,
+ help="Extract solution file separately")
+ parser.add_argument('-solution', '--solution', default=None, help="The solution file")
+ # env_config_json
+ parser.add_argument('-env', '--env_config_json', type=str, required=True,
+ default='./config/boss_env_config.json', help="Path of environment config (.json)")
+ parser.add_argument('-register', '--register_path', type=str, default='research/jiutian',
+ help="Path of register")
+ parser.add_argument('-dryrun_lim', '--dryrun_lim', type=pp_util.str2int, default=16,
+ help="The number of dryrun at once")
+
+ args = parser.parse_args()
+ check_dryrun_parallel_number(args.dryrun_lim)
+ if args.train_yaml and args.mindformers_loc:
+ config_file = args.train_yaml
+ ms_adapter_file = args.mindformers_loc
+ DryRun.config_file_type = 0
+ ExpertInput.is_full_recomp = True
+ elif args.train_shell and args.mindspeed_loc:
+ config_file = args.train_shell
+ ms_adapter_file = args.mindspeed_loc
+ DryRun.config_file_type = 1
+ ExpertInput.is_full_recomp = False
+ else:
+ raise TypeError(dryrun_config_error)
+
+ if args.extract:
+ solution.extract_solution_file(args.train_yaml, args.solution)
+ sys.exit(0)
+
+ expert_input = ExpertInput(config_file=config_file, ms_adapter_file=ms_adapter_file)
+ expert_input.solver_name = args.solver_name
+ expert_input.llm_class = int(args.llm_class)
+ expert_input.time_limit = int(args.time_limit)
+ if args.time_limit < sys.maxsize:
+ logger.warning(f'You have configured the time limit parameter! The solution may be not optimal!')
+ expert_input.is_dryrun = args.dryrun
+ expert_input.is_double_check = args.check
+ expert_input.fit_level = args.fit_level
+ expert_input.layer_ratio = args.layer_ratio
+ expert_input.backward_ratio = args.backward_ratio
+ expert_input.head_loss = args.head_loss
+ expert_input.recompute_ratio = args.recompute_ratio
+
+ DryRun.env_config_json = args.env_config_json
+ DryRun.register_path = args.register_path
+ DryRun.dryrun_lim = args.dryrun_lim
+ DryRun.is_write_to_file = args.is_write
+
+ pp_calculator(expert_input)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_simulator.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_simulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..675f8d7178627b5ed9771bfa8f1dc456e8e31353
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_simulator.py
@@ -0,0 +1,926 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""pp simulator"""
+from __future__ import annotations
+from collections.abc import Iterable
+import copy
+import sys
+import time
+from dataclasses import dataclass, field
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.patches import Rectangle, Polygon
+from matplotlib import colors
+from matplotlib.transforms import ScaledTranslation
+sys.setrecursionlimit(8192)
+
+def format_2d_inputs(a, raw, col):
+ if isinstance(a, (int, float)):
+ return np.broadcast_to(a, (raw, col))
+ if isinstance(a, (list, tuple)):
+ if all(isinstance(item, (list, tuple)) for item in a):
+ return np.array(a)
+ if all(isinstance(item, (int, float)) for item in a):
+ return np.array([a])
+ raise ValueError(f"Unsupported inputs: {a}")
+ else:
+ raise ValueError(f"Unsupported inputs: {a}")
+
+def apply_color(l:list, c:list[str]):
+ for i in range(len(l)):
+ l[i] = f'{l[i]:.4f}' if isinstance(l[i], float) else l[i]
+ l[i] = f"\033[{c[i]}m{l[i]}\033[0m"
+ return l
+
+def apply_format(l:list):
+ s = f'{l[0]:^22}'
+ symbol = ['=', '+', '+', '+', '+', '+']
+ for i in range(len(l) - 1):
+ s = f'{s}{symbol[i]}{l[i + 1]:^22}'
+ return s
+
+def color_mix(c1, c2, w1=0.5, w2=0.5):
+ rgb = (np.array(colors.to_rgba(c1, 1)) * w1 + np.array(colors.to_rgba(c2, 1)) * w2) / (w1 + w2)
+ return colors.to_rgba(rgb)
+
+
+class CausalError(Exception):
+ def __init__(self, msg, blocks:list[list[MicroBlockSim]]=None, loop:list[BlockSim]=[]) -> None:
+ self.msg = msg
+ self.canvas = PlotMgr(num_plots=1, figsize=(12, 6))
+ self.canvas.draw_loop(blocks, loop, 0, False, False, True)
+ self.canvas.ax[0].set_title("Block pipeline dependency")
+ super().__init__()
+ print(f"{self.canvas.msg}")
+
+ def __str__(self):
+ plt.show()
+ return f"{self.msg}"
+
+
+class CausalCommError(Exception):
+ def __init__(self, msg, blocks:list[list[MicroBlockSim]]=None, loop:list[BlockSim]=[]) -> None:
+ self.msg = msg
+ self.canvas = PlotMgr(num_plots=1, figsize=(12, 6))
+ self.canvas.draw_comm_loop(blocks, loop, 0)
+ self.canvas.ax[0].set_title("Block comm pipeline dependency")
+ super().__init__()
+ print(f"{self.canvas.msg}")
+
+ def __str__(self):
+ plt.show()
+ return f"{self.msg}"
+
+def dfs_builder(comm=False):
+ def decorator(func):
+ def wrapper(*args, **kwargs):
+ self = args[0]
+ pre, left = (self.depend_pre, self.depend_left) if comm else (self.pre, self.left)
+ if self.finish:
+ return
+ if pre is None or left is None:
+ raise NotImplementedError
+ if self.in_queue:
+ raise ValueError
+ self.in_queue = True
+ res = func(*args, **kwargs)
+ self.finish = True
+ self.in_queue = False
+ return res
+ return wrapper
+ return decorator
+
+def timer(func:function):
+ def wrapper(*args, **kwargs):
+ T0 = time.time()
+ res = func(*args, **kwargs)
+ T1 = time.time() - T0
+ print(f"function `{func.__name__}` time used: {T1:.4f} s", flush=True)
+ return res
+ return wrapper
+
+
+class PlotMgr:
+ def __init__(self, num_plots=2, ax_type='block', subplot_args=None, *args, **kwargs):
+ self.fig = plt.figure(figsize=kwargs.get('figsize', (12, 8)))
+ self.fig.subplots_adjust(wspace=0, hspace=0.4)
+ ax_type = ax_type if isinstance(ax_type, (list, tuple)) else [ax_type] * num_plots
+ self.ax = []
+ for i in range(num_plots):
+ if subplot_args is None:
+ self.ax.append(self.fig.add_subplot(num_plots * 100 + 10 + i + 1))
+ elif isinstance(subplot_args, Iterable) and len(subplot_args) >= num_plots:
+ self.ax.append(self.fig.add_subplot(subplot_args[i]))
+ else:
+ raise ValueError(f"Unsupported subplot_args format: {subplot_args}")
+
+ def _set_block_ax(self, ax:plt.Axes, pp:int) -> plt.Axes:
+ ax.set_title("Pipeline Flow Timeline")
+ ax.set_yticks(range(pp), [f"stage {p}" for p in range(pp)])
+ for tick in ax.get_yticklabels():
+ tick.set_verticalalignment('top')
+ tick.set_transform(tick.get_transform() + ScaledTranslation(0, 0.05-1/pp, self.fig.dpi_scale_trans))
+ tick.set_fontsize(12)
+ ax.set_ylim(0, pp)
+ ax.invert_yaxis()
+
+ def _get_block_indices(self, blocks:list[list[MicroBlockSim]], mode='compact', equal_wide=False):
+ if mode not in ['compact', 'joint', 'timeline']:
+ raise ValueError(f"Get unsupported draw mode: {mode}")
+ if mode == 'timeline' and not blocks[-1][-1].finish:
+ raise ValueError(f"Block building should be finished before drawing timeline")
+ block_index = []
+ for p in range(len(blocks)):
+ inds = []
+ for block in blocks[p]:
+ if mode == 'compact':
+ if block._type == 'c':
+ inds.append(1 if equal_wide else block.time)
+ else:
+ inds.append(0)
+ elif mode == 'joint':
+ if block._type == 'c':
+ inds.append(1 if equal_wide else block.time)
+ else:
+ inds.append(block.time)
+ else:
+ inds.append(1)
+ inds.insert(0, 0)
+ inds = np.cumsum(inds)
+ block_index.append(inds)
+ return block_index
+
+ def draw_block(self, block_index:list[list[float]], blocks:list[list[MicroBlockSim]], ax_index:int = 0, equal_wide=False, width=1, phase=False):
+ for p in range(len(blocks)):
+ for b, block in enumerate(blocks[p]):
+ if block._type == 'c':
+ block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide, width=width, phase=phase)
+ return self
+
+ def draw_comm(self, block_index:list[list[float]], blocks:list[list[MicroBlockSim]], ax_index:int = 0, equal_wide=False, mode='compact'):
+ for p in range(len(blocks)):
+ for b, block in enumerate(blocks[p]):
+ if block._type == 'c' and mode == 'compact':
+ if block.send_block:
+ block.send_block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide)
+ if block.rec_block:
+ block.rec_block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide)
+ elif block._type in ['s', 'r'] and mode in ['joint', 'timeline']:
+ block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide, mode=mode)
+ return self
+
+ def draw_connect(self, block_index:list[list[float]], blocks:list[list[MicroBlockSim]], ax_index:int = 0, equal_wide=False, mode='compact'):
+ for p in range(len(blocks)):
+ for b, block in enumerate(blocks[p]):
+ if block._type == 'c' and mode == 'compact' and block.send_block:
+ dual_p = block.send_block.dual._stage
+ dual_ind = blocks[dual_p].index(block.send_block.dual.host)
+ block.send_block.draw_comm(self.ax[ax_index], index_from=block_index[p][b], index_to=block_index[dual_p][dual_ind], equal_wide=equal_wide, mode=mode)
+ elif block._type == 's' and mode in ['joint', 'timeline']:
+ dual_p = block.dual._stage
+ dual_ind = blocks[dual_p].index(block.dual)
+ block.draw_comm(self.ax[ax_index], index_from=block_index[p][b], index_to=block_index[dual_p][dual_ind], equal_wide=equal_wide, mode=mode)
+ return self
+
+ def draw(self, blocks:list[list[MicroBlockSim]], ax_index:int = 0, comm=False, connect=False, equal_wide=False, mode='compact', phase=False) -> PlotMgr:
+ pp = len(blocks)
+ block_index = self._get_block_indices(blocks, mode=mode, equal_wide=equal_wide)
+ width = max(np.max(block_index[p]) for p in range(pp)) if blocks[0][-1].end is None else max(blocks[p][-1].end for p in range(pp))
+ self.draw_block(block_index, blocks, ax_index, equal_wide, width, phase=phase)
+ if comm:
+ self.draw_comm(block_index, blocks, ax_index, equal_wide, mode)
+ if connect:
+ self.draw_connect(block_index, blocks, ax_index, equal_wide, mode)
+ self._set_block_ax(self.ax[ax_index], pp)
+ self.ax[ax_index].set_xlim(0, width)
+ return self
+
+ def draw_loop(self, blocks:list[list[MicroBlockSim]], loop:list[BlockSim], ax_index:int = 0, comm=False, connect=False, equal_wide=False) -> PlotMgr:
+ self.draw(blocks, ax_index, comm, connect, equal_wide, phase=True)
+ block_index = self._get_block_indices(blocks, equal_wide=equal_wide)
+ msg = 'dependency loop: '
+ for b in range(len(loop) - 1):
+ p = loop[b]._stage
+ ind = blocks[p].index(loop[b])
+ x1, y1, dx1, _ = loop[b].loc_size(block_index[p][ind], equal_wide)
+ p = loop[b + 1]._stage
+ ind = blocks[p].index(loop[b + 1])
+ x2, y2, dx2, _ = loop[b + 1].loc_size(block_index[p][ind], equal_wide)
+ msg = f'{msg} {loop[b].color_label} -> '
+ self.ax[ax_index].annotate(None, xy=(x1 + dx1 / 2, y1), xytext=(x2 + dx2 / 2, y2),
+ arrowprops=dict(fc='white', ec='r', arrowstyle='simple', shrinkA=5, shrinkB=5, connectionstyle="arc3,rad=-0.1"))
+ self.msg = f'{msg} {loop[len(loop) - 1].color_label}'
+ return self
+
+ def draw_comm_loop(self, lines:list[list[BlockSim]], loop:list[BlockSim], ax_index:int = 0) -> PlotMgr:
+ self.draw(lines, ax_index, True, True, True, 'joint', phase=True)
+ block_index = self._get_block_indices(lines, mode='joint', equal_wide=True)
+ msg = 'dependency loop: '
+ for b in range(len(loop) - 1):
+ p = loop[b]._stage
+ ind = lines[p].index(loop[b])
+ x1, y1, dx1, _ = loop[b].loc_size(block_index[p][ind], True, 'joint')
+ p = loop[b + 1]._stage
+ ind = lines[p].index(loop[b + 1])
+ x2, y2, dx2, _ = loop[b + 1].loc_size(block_index[p][ind], True, 'joint')
+ msg = f'{msg} {loop[b].color_label} -> '
+ self.ax[ax_index].annotate(None, xy=(x1 + abs(dx1) / 2, y1), xytext=(x2 + abs(dx2) / 2, y2), size=10,
+ arrowprops=dict(fc='white', ec='r', arrowstyle='simple', shrinkA=3, shrinkB=3, connectionstyle="arc3,rad=-0.1", lw=0.8))
+ self.msg = f'{msg} {loop[len(loop) - 1].color_label}'
+ return self
+
+ def draw_mem(self, block_mem_list:list[np.ndarray], ax_index:int = 0) -> PlotMgr:
+ for p in range(len(block_mem_list)):
+ self.ax[ax_index].plot((block_mem_list[p].T)[0], (block_mem_list[p].T)[1], label=f"stage-{p}")
+ self.ax[ax_index].set_title("Block Memory Timeline")
+ self.ax[ax_index].set_xlim(0, max(np.max((block_mem.T)[0]) for block_mem in block_mem_list))
+
+ def draw_info(self, bubble_info:dict, mem_info:list):
+ info_list = [f'{k} bubble: {v:.4f}' for k, v in bubble_info.items()]
+ self.fig.text(0.5, 0.5, ', '.join(info_list), ha='center', va='center',
+ fontdict={'fontsize':13, 'weight':'medium'}, color='C3')
+ info_list = [f"{v:.2f}" for v in mem_info]
+ self.fig.text(0.5, 0.05, f"peak memory: {', '.join(info_list)}", ha='center', va='center',
+ fontdict={'fontsize':13, 'weight':'medium'}, color='C0')
+
+ def show(self):
+ self.fig.legend(bbox_to_anchor=(0.22, 0.45))
+ plt.show()
+ plt.savefig("figure.pdf") # 默认存储在当前目录
+
+
+@dataclass
+class BlockSim:
+ _stage: int # p
+ _state: str # s
+ _id: int # m
+ _virtual: int # v
+ time: float
+ _type: str
+ start: float = None
+ end: float = None
+ pre: BlockSim = field(repr=False, default=None)
+ left: BlockSim = field(repr=False, default=None)
+ right: MicroBlockSim = field(repr=False, default=None)
+ depend_pre: BlockSim = field(repr=False, default=None)
+ depend_left: BlockSim = field(repr=False, default=None)
+ finish = False
+ in_queue = False
+ _flag = False
+ _color = '0;38'
+ father: BlockSim = field(repr=False, default=None)
+ @property
+ def label(self) -> tuple:
+ return (self._type, self._state, self._id, self._virtual, self._stage)
+ @property
+ def color_label(self) -> str:
+ return f"\033[{self._color}m{self.label}\033[0m"
+ @property
+ def repr(self) -> str:
+ raise NotImplementedError
+ def draw(self, ax:plt.Axes, *args, **kwargs):
+ raise NotImplementedError
+ @dfs_builder(False)
+ def build_without_comm(self) -> None:
+ r"""Build pipeline timeline without comm blocks and dependency."""
+ self.pre.build_without_comm()
+ self.left.build_without_comm()
+ self.start = max(self.pre.end, self.left.end)
+ self.end = self.start + self.time
+ @dfs_builder(True)
+ def build_with_comm(self) -> None:
+ r"""Build pipeline timeline with comm blocks and dependency."""
+ self.depend_pre.build_with_comm()
+ self.depend_left.build_with_comm()
+ self.start = max(self.depend_pre.end, self.depend_left.end)
+ self.end = self.start + self.time
+ def reset_time_recursive(self) -> None:
+ raise NotImplementedError
+ def reset_time(self) -> None:
+ self.start = None
+ self.end = None
+ self.finish = False
+ def loc_size(self, x:int = 0, equal_wide=False, mode='compact') -> tuple:
+ x = x if self.start is None else self.start
+ dx = 1 if equal_wide else self.time
+ return x, self._stage + 0.5, dx, 1
+ def loop(self, comm=False) -> list[BlockSim]:
+ if self._flag and not self.in_queue:
+ return []
+ l = []
+ if self.in_queue:
+ loop = [self]
+ block = self.father
+ while block.father and block is not self:
+ block = block.father
+ loop.append(block)
+ return loop
+ self._flag = True
+ self.in_queue = True
+ depends = [self.depend_pre, self.depend_left] if comm else [self.pre, self.left]
+ for dep in depends:
+ if dep:
+ dep.father = self
+ l.extend(dep.loop(comm=comm))
+ dep.father = None
+ self.in_queue = False
+ return l
+
+ def comm_loop(self) -> list[BlockSim]:
+ return self.loop(True)
+
+@dataclass
+class HeadBlockSim(BlockSim):
+ _stage: int # p
+ _type: str = 'h'
+ _id: int = field(repr=False, init=False)
+ _state: str = field(repr=False, init=False)
+ _virtual: int = field(repr=False, init=False)
+ time: float = 0.
+ start: float = 0.
+ end: float = 0.
+ finish = True
+ @property
+ def label(self) -> tuple:
+ return (self._type, self._stage)
+ @property
+ def repr(self) -> str:
+ s_list = []
+ block = self
+ while block:
+ s_list.append(block.__repr__())
+ block = block.right
+ return '\n'.join(s_list)
+ def draw(self, ax, *args, **kwargs):
+ return
+ def build_without_comm(self):
+ return
+ def build_with_comm(self):
+ return
+ def reset_time_recursive(self):
+ return
+
+@dataclass
+class MicroBlockSim(BlockSim):
+ _type: str = 'c'
+ mem: float = 0.
+ phase: str = None
+ send_block: SendBlockSim = field(repr=False, default=None)
+ rec_block: RecBlockSim = field(repr=False, default=None)
+ def __post_init__(self):
+ self._color = '1;34' if self._state == 'f' else '1;33'
+
+ def draw(self, ax:plt.Axes, *args, **kwargs) -> None:
+ x, y, dx, dy = self.loc_size(kwargs.get('index', 0), kwargs.get('equal_wide', False))
+ color = (167/255, 184/255, 231/255) if self._state == 'f' else (255/255, 213/255, 143/255)
+ mix_color = (240/255, 255/255, 245/255) if self._state == 'f' else (255/255, 240/255, 255/255)
+ color = color_mix(mix_color, color, w1=self._virtual / 3)
+ if self.phase == 'warmup' and kwargs.get('phase', False):
+ edgecolor = 'lightblue'
+ elif self.phase == 'cooldown' and kwargs.get('phase', False):
+ edgecolor = 'orange'
+ else:
+ edgecolor = 'black'
+ rect = Rectangle((x, y - dy / 2), dx, dy, facecolor=color, edgecolor=edgecolor, linewidth=0.4)
+ if dx > 0.008 * kwargs.get('width', 0):
+ ax.text(rect.xy[0] + dx / 2, rect.xy[1] + dy / 2, str(self._id), ha='center', va='center', color='black', fontdict={'fontsize':9})
+ ax.add_patch(rect)
+
+ def reset_time_recursive(self) -> None:
+ if self.finish:
+ self.pre.reset_time_recursive()
+ self.left.reset_time_recursive()
+ self.reset_time()
+
+@dataclass
+class CommBlockSim(BlockSim):
+ host: MicroBlockSim = field(repr=False, default=None)
+ dual: CommBlockSim = field(repr=False, default=None)
+ def joint_loc(self, index:int = 0, equal_wide=False) -> tuple:
+ raise NotImplementedError
+ def get_triangle(self, x, y, dx, dy) -> tuple:
+ raise NotImplementedError
+ def draw(self, ax:plt.Axes, *args, **kwargs) -> None:
+ color = (167/255, 184/255, 231/255) if self._state == 'f' else (255/255, 213/255, 143/255)
+ mix_color = (240/255, 255/255, 255/255) if self._state == 'f' else (255/255, 240/255, 255/255)
+ color = color_mix(mix_color, color, w1=1.2*self._virtual/3)
+ index, equal_wide, mode = (kwargs.get('index', 0), kwargs.get('equal_wide', False), kwargs.get('mode', 'compact'))
+ x, y, dx, dy = self.loc_size(index, equal_wide, mode)
+ xy = self.get_triangle(x, y, dx, dy)
+ tri = Polygon(xy, closed=True, facecolor=color, edgecolor='black', linewidth=0.4)
+ ax.add_patch(tri)
+
+@dataclass
+class SendBlockSim(CommBlockSim):
+ _type: str = 's'
+ _color = '35'
+ def loc_size(self, index:int = 0, equal_wide=False, mode='compact') -> tuple:
+ host_x, _, host_dx, _ = self.host.loc_size(index, equal_wide)
+ x, y, _, _ = super().loc_size(index, equal_wide)
+ _dx = self.time
+ _dy = min(np.sqrt(self.time) * 0.6, 0.6)
+ if mode == 'compact':
+ x = host_x + host_dx - _dx
+ return x, y, _dx, _dy
+ def get_triangle(self, x, y, dx, dy) -> tuple:
+ return [[x, y - dy / 2], [x, y + dy / 2], [x + dx, y]]
+
+ def draw_comm(self, ax:plt.Axes, *args, **kwargs) -> None:
+ index_from, index_to = (kwargs.get('index_from', 0), kwargs.get('index_to', 0))
+ equal_wide, mode = (kwargs.get('equal_wide', False), kwargs.get('mode', 'compact'))
+ x, y, dx, _ = self.loc_size(index_from, equal_wide, mode)
+ x_, y_, dx_, _ = self.dual.loc_size(index_to, equal_wide, mode)
+ ax.annotate(None, xy=(x_ - dx_/2, y_), xytext=(x + dx/2, y), arrowprops=dict(ec='grey', arrowstyle='->', shrinkA=2, shrinkB=2))
+
+ @dfs_builder(True)
+ def build_with_comm(self) -> None:
+ r"""Build pipeline timeline with comm blocks and dependency."""
+ self.dual.depend_left.build_with_comm()
+ self.depend_left.build_with_comm()
+ self.start = max(self.depend_left.end, self.dual.depend_left.end)
+ self.end = self.start + self.time
+
+ def loop(self, comm=False) -> list[BlockSim]:
+ if comm:
+ return self.comm_loop()
+ return super().loop(comm)
+
+ def comm_loop(self) -> list[BlockSim]:
+ if self._flag and not self.in_queue:
+ return []
+ l = []
+ if self.in_queue:
+ loop = [self]
+ block = self.father
+ while block.father and block is not self:
+ block = block.father
+ loop.append(block)
+ return loop
+ self._flag = True
+ self.in_queue = True
+ depends = [self.dual.depend_left, self.depend_left]
+ for dep in depends:
+ if dep:
+ dep.father = self
+ l.extend(dep.comm_loop())
+ dep.father = None
+ self.in_queue = False
+ return l
+
+@dataclass
+class RecBlockSim(CommBlockSim):
+ _type: str = 'r'
+ _color = '32'
+ def loc_size(self, index:int = 0, equal_wide=False, mode='compact') -> tuple:
+ host_x, _, _, _ = self.host.loc_size(index, equal_wide)
+ x, y, _, _ = super().loc_size(index, equal_wide)
+ _dx = self.time
+ _dy = min(np.sqrt(self.time) * 0.6, 0.6)
+ if mode == 'compact':
+ x = host_x
+ return x, y, -_dx, -_dy
+ def get_triangle(self, x, y, dx, dy) -> tuple:
+ return [[x, y], [x - dx, y + dy / 2], [x - dx, y - dy / 2]]
+
+ @dfs_builder(True)
+ def build_with_comm(self) -> None:
+ r"""Build pipeline timeline with comm blocks and dependency."""
+ self.dual.build_with_comm()
+ self.depend_left.build_with_comm()
+ self.start = max(self.depend_left.end, self.dual.start)
+ self.end = self.start + self.time
+
+class PipelineBuild:
+ @staticmethod
+ def _inter_merge(a: list[MicroBlockSim], b: list[MicroBlockSim], delta: int = 0) -> list[MicroBlockSim]:
+ res = []
+ if delta >= 0:
+ res.extend(a[:delta])
+ a = a[delta:]
+ else:
+ res.extend(b[:-delta])
+ b = b[-delta:]
+ stable_count = 0
+ while len(a):
+ block = a.pop(0)
+ block.phase = 'stable'
+ res.append(block)
+ stable_count += 1
+ if len(b):
+ block = b.pop(0)
+ block.phase = 'stable'
+ res.append(block)
+ stable_count += 1
+ else:
+ break
+ if stable_count:
+ res[-1].phase = 'cooldown'
+ if len(a):
+ res.extend(a)
+ elif len(b):
+ res.extend(b)
+ return res
+
+ @staticmethod
+ def _build_chain(line: list[MicroBlockSim], p: int) -> list[BlockSim]:
+ head = HeadBlockSim(p)
+ left = head
+ for item in line:
+ left.right = item
+ item.left = left
+ left = item
+ if p == 0:
+ head.right.pre = head
+ return line
+
+ @staticmethod
+ def build_1f1b(pp, micro_num, p, forward_time, backward_time, block_mem) -> list[BlockSim]:
+ for_line = [MicroBlockSim(p, 'f', i, 0, forward_time, mem=block_mem) for i in range(micro_num)]
+ back_line = [MicroBlockSim(p, 'b', i, 0, backward_time, mem=block_mem) for i in range(micro_num)]
+ line = PipelineBuild._inter_merge(for_line, back_line, pp - p - 1)
+ return PipelineBuild._build_chain(line, p)
+
+ @staticmethod
+ def build_virtualpipeline(pp, micro_num, vp, p, forward_time, backward_time, block_mem) -> list[BlockSim]:
+ for_line = []
+ back_line = []
+ r = micro_num % pp
+ for inter in range(micro_num // pp):
+ for i in range(vp):
+ if inter == 0:
+ for_line.extend([MicroBlockSim(p, 'f', m, i, forward_time[i], mem=block_mem[i], phase='warmup') for m in range(r)])
+ back_line.extend([MicroBlockSim(p, 'b', m, i, backward_time[i], mem=block_mem[i], phase='cooldown') for m in range(r)])
+ for_line.extend([MicroBlockSim(p, 'f', r + m + inter * pp, i, forward_time[i], mem=block_mem[i], phase='warmup') for m in range(pp)])
+ back_line.extend([MicroBlockSim(p, 'b', r + m + inter * pp, i, backward_time[i], mem=block_mem[i], phase='cooldown') for m in range(pp)])
+ else:
+ line = PipelineBuild._inter_merge(for_line, back_line, (vp + 1) * pp - 2 * p - 2 + r * (vp - 1))
+ return PipelineBuild._build_chain(line, p)
+
+ @staticmethod
+ def build_virtualpipeline2(pp, micro_num, vp, p, forward_time, backward_time, block_mem):
+ for_line = []
+ back_line = []
+ r = micro_num % pp
+ for inter in range(micro_num // pp):
+ for i in range(vp):
+ if inter == 0:
+ for_line.extend([MicroBlockSim(p, 'f', m, i, forward_time[i], mem=block_mem[i]) for m in range(r)])
+ back_line.extend([MicroBlockSim(p, 'b', m, i, backward_time[i], mem=block_mem[i]) for m in range(r)])
+ for_line.extend([MicroBlockSim(p, 'f', r + m + inter * pp, i, forward_time[i], mem=block_mem[i]) for m in range(pp)])
+ back_line.extend([MicroBlockSim(p, 'b', r + m + inter * pp, i, backward_time[i], mem=block_mem[i]) for m in range(pp)])
+ line = PipelineBuild._inter_merge(for_line, back_line, vp * pp - p - 1)
+ return PipelineBuild._build_chain(line, p)
+
+
+class PipelineSimulator:
+ def __init__(self, block_time:list, micro_num:int, comm_time:float = 0.1, layer_recompute=False, block_mem=1, backward_ratio=2.):
+ self.init(block_time, micro_num, comm_time, layer_recompute, block_mem, backward_ratio)
+
+ def init(self, block_time, micro_num, comm_time, layer_recompute, block_mem, backward_ratio, *args, **kwargs):
+ self.micro_num = micro_num
+ self.pp, self.vp = self._base_init(block_time)
+ self.block_num = 2 * self.vp * self.micro_num
+ self.comm_time = comm_time
+ self._input_format(block_time, layer_recompute, block_mem, backward_ratio)
+ self._statistic_init()
+ self._comm = True
+ self.adjust_func_list = [self.swap_send_rec]
+ # Construct pipeline blocks
+ if self.vp == 1:
+ self.blocks = [PipelineBuild.build_1f1b(self.pp, self.micro_num, p,
+ self.block_time[0, p],
+ self.backward_time[0, p],
+ self.block_mem[0, p]) for p in range(self.pp)]
+ else:
+ self.blocks = [PipelineBuild.build_virtualpipeline(self.pp, self.micro_num, self.vp, p,
+ self.block_time[:, p],
+ self.backward_time[:, p],
+ self.block_mem[:, p]) for p in range(self.pp)]
+ if self.micro_num >= self.pp:
+ self.adjust_func_list = [self.vpp_send_delay, self.residue_delay] + self.adjust_func_list
+ self._build_block() # create connection among compute blocks
+ self._build_comm_block() # create comm blocks for each compute block
+
+ def run(self, comm=True, print_info=True) -> PipelineSimulator:
+ self._comm = comm
+ self._check_loop()
+ if comm:
+ self.lines = self._create_lines(*self.adjust_func_list)
+ self._check_comm_loop()
+ for b in range(self.block_num):
+ for p in range(self.pp):
+ self.blocks[p][b].build_with_comm()
+ self.lines[0][-1].build_with_comm()
+ else:
+ for p in range(self.pp):
+ for block in self.blocks[p]:
+ block.build_without_comm()
+ self._statistic_info()
+ if print_info:
+ self.print_info()
+ return self
+
+ def show(self, comm=True, connect=None) -> PipelineSimulator:
+ self.canvas = PlotMgr(2, ['block', 'memory'])
+ if self._comm:
+ connect = True if connect is None else connect
+ self.canvas.draw(self.lines, 0, comm, connect, False, 'timeline')
+ else:
+ connect = False if connect is None else connect
+ self.canvas.draw(self.blocks, 0, comm, connect, False, 'timeline')
+ self.canvas.draw_mem(self.states['block_mem_list'], 1)
+ self.canvas.draw_info(self.bubbles, self.peak_memory)
+ self.canvas.show()
+ return self
+
+ def print_info(self):
+ print('\033[1;37m' + '—'*13, f'pp:{self.pp:>2}, vp:{self.vp:>2}, micro:{self.micro_num:>3} ' + '—'*12 + '\033[0m')
+ print('-'*20, ' bubble ', '-'*20)
+ print(apply_format(apply_color(list(self.bubbles.keys()), ['1;33', '1;32', '1;31', '1;35', '1;36'])))
+ print(apply_format(apply_color(list(self.bubbles.values()), ['1;33', '1;32', '1;31', '1;35', '1;36'])))
+ print('-'*20, ' memory ', '-'*20)
+ print(f"peak memory: {', '.join([f'{v:.2f}' for v in self.peak_memory])}")
+ return self
+
+ def _base_init(self, block_time) -> tuple:
+ if isinstance(block_time, (list, tuple)):
+ if all(isinstance(item, (list, tuple)) for item in block_time):
+ vp = len(block_time)
+ pp = len(block_time[0])
+ elif all(isinstance(item, (int, float)) for item in block_time):
+ vp = 1
+ pp = len(block_time)
+ else:
+ raise ValueError(f"Unsupported input format block_time: {block_time}")
+ else:
+ raise ValueError(f"Unsupported input format block_time: {block_time}")
+ if self.micro_num < pp:
+ raise ValueError(f" `micro_num`({self.micro_num}) should equal or larger than `pp`({pp})")
+ return pp, vp
+
+ def _input_format(self, block_time, layer_recompute, block_mem, backward_ratio) -> None:
+ self.block_time = format_2d_inputs(block_time, self.vp, self.pp)
+ if isinstance(layer_recompute, bool):
+ self.layer_recompute = self.block_time if layer_recompute else format_2d_inputs(0, self.vp, self.pp)
+ else:
+ self.layer_recompute = format_2d_inputs(layer_recompute, self.vp, self.pp)
+ if isinstance(block_mem, (int, float)):
+ self.block_mem = self.block_time * block_mem
+ else:
+ self.block_mem = format_2d_inputs(block_mem, self.vp, self.pp)
+ self.backward_ratio = format_2d_inputs(backward_ratio, self.vp, self.pp)
+
+ def _statistic_init(self) -> None:
+ self.forward_time = self.block_time
+ self.backward_time = np.flip(self.block_time * self.backward_ratio + self.layer_recompute, axis=0)
+ self.states = {'last_time': np.zeros(self.pp),
+ 'warmup_time': np.zeros(self.pp),
+ 'cooldown_time': np.zeros(self.pp),
+ 'stable_free_time': (np.zeros((self.vp, self.pp)), np.zeros((self.vp, self.pp))),
+ 'block_mem_list': [np.array([[0, 0]]) for _ in range(self.pp)]}
+ self.model_compute_time = (np.sum(self.forward_time) + np.sum(self.forward_time * self.backward_ratio)) * self.micro_num
+ self.hardware_compute_time = (np.sum(self.forward_time) + np.sum(self.backward_time)) * self.micro_num
+ self.bubbles = {'real': 0,
+ 'ideal': (self.pp - 1) / self.vp / self.micro_num,
+ 'imba': 0,
+ 'comm': 0}
+ if np.sum(self.layer_recompute) > 1e-5:
+ self.bubbles['recompute'] = self.hardware_compute_time / self.model_compute_time - 1
+ p, v, m = self.pp, self.vp, self.micro_num
+ if self.vp == 1:
+ if self.pp == 2:
+ self.bubbles['comm'] = 4*m
+ elif self.pp %2 == 0:
+ self.bubbles['comm'] = 4*p*m + 4*p**2 - 14*p
+ else:
+ self.bubbles['comm'] = 4*p*m + 4*p**2 - 12*p
+ elif self.pp <= 5:
+ comm_coef_list = [[4, -2, 0], [6, -2, -6], [4, 0, 12], [6, -2, 40]]
+ self.bubbles['comm'] = np.dot(np.array([p*v*m, m*p, 1]), comm_coef_list[self.pp - 2])
+ elif self.pp % 2 == 0:
+ self.bubbles['comm'] = 4*p*v*m + 4*p**2 - 13*p
+ else:
+ self.bubbles['comm'] = 6*p*v*m - 2*v*p**2 + 4*v*p - 2*p*m + 6*p**2 -16*p
+
+ self.bubbles['comm'] *= self.comm_time / self.model_compute_time
+
+ def _statistic_info(self) -> None:
+ for p in range(self.pp):
+ blocks = self.lines[p] if self._comm else self.blocks[p]
+ current_mem = 0
+ for block in blocks:
+ if block._type == 'c' and block._state == 'f':
+ current_mem += block.mem
+ elif block._type == 'c' and block._state == 'b':
+ if not self._comm or not block.rec_block:
+ current_mem -= block.mem
+ elif block._type == 'r' and block.host._state == 'b':
+ current_mem -= block.host.mem
+ block = block.host
+ else:
+ continue
+ self.states['block_mem_list'][p] = np.append(self.states['block_mem_list'][p], np.array([[block.end, current_mem]]), axis=0)
+ self.states['block_mem_list'][p] = np.append(self.states['block_mem_list'][p], np.array([[blocks[-1].end, current_mem]]), axis=0)
+ self.peak_memory = [np.max((self.states['block_mem_list'][p].T)[1]) for p in range(self.pp)]
+ self.end_time = max(np.max((self.states['block_mem_list'][p].T)[0]) for p in range(self.pp))
+ self.bubbles['real'] = (self.pp * self.end_time - self.model_compute_time) / self.model_compute_time
+ self.bubbles['imba'] = self.bubbles['real'] - self.bubbles['ideal'] + 1e-10
+ if not self._comm:
+ self.bubbles.pop('comm')
+ else:
+ self.bubbles['imba'] -= self.bubbles['comm']
+ if self.bubbles.get('recompute'):
+ self.bubbles['imba'] -= self.bubbles['recompute']
+
+ def _get_pre_label(self, label:tuple) -> tuple:
+ t, s, m, v, p = label
+ if (s, v, p) == ('f', 0, 0):
+ return ('h', p)
+ if (s, p) == ('f', 0):
+ return (t, s, m, v - 1, self.pp - 1)
+ if (s, p) == ('b', self.pp - 1):
+ if v == 0:
+ return (t, 'f', m, self.vp - 1, p)
+ return (t, s, m, v - 1, 0)
+ if s == 'f':
+ return (t, s, m, v, p - 1)
+ if s == 'b':
+ return (t, s, m, v, p + 1)
+ raise ValueError(f"Illegal label: {label}")
+
+ def _build_block(self) -> None:
+ r"""Build `pre` relation for computation blocks."""
+ books = {self.blocks[0][0].pre.label: self.blocks[0][0].pre}
+ for p in range(self.pp):
+ for item in self.blocks[p]:
+ books[item.label] = item
+ for p in range(self.pp):
+ block = self.blocks[p][0]
+ while block:
+ pre_label = self._get_pre_label(block.label)
+ block.pre = books[pre_label]
+ block = block.right
+
+ def _build_comm_block(self) -> None:
+ r"""Build `send_block` and `rec_block` relation among a computation block and two comm blocks."""
+ for p in range(self.pp):
+ block = self.blocks[p][0]
+ while block:
+ pre = block.pre
+ if pre._stage != block._stage:
+ block.rec_block = RecBlockSim(p, block._state, block._id, block._virtual, self.comm_time)
+ pre.send_block = SendBlockSim(pre._stage, pre._state, pre._id, pre._virtual, self.comm_time)
+ block.rec_block.host = block
+ block.rec_block.dual = pre.send_block
+ pre.send_block.host = pre
+ pre.send_block.dual = block.rec_block
+ block.depend_pre = block.rec_block
+ block.rec_block.depend_pre = pre.send_block
+ pre.send_block.depend_pre = pre
+ else:
+ block.depend_pre = pre
+ block = block.right
+
+ def _check_loop(self) -> None:
+ loop = self.blocks[0][-1].loop()
+ if loop:
+ raise CausalError('Block dependency exist loops!', self.blocks, loop)
+ for p in range(self.pp):
+ for block in self.blocks[p]:
+ block._flag = False
+
+ def _check_comm_loop(self) -> None:
+ loop = self.lines[0][-1].comm_loop()
+ if loop:
+ raise CausalCommError('Block comm dependency exist loops!', self.lines, loop)
+ for p in range(self.pp):
+ for block in self.lines[p]:
+ block._flag = False
+
+ def _create_lines(self, *adjust_func) -> list[list[BlockSim]]:
+ lines = [copy.copy(self.blocks[p]) for p in range(self.pp)]
+ for p in range(self.pp):
+ for b in range(self.block_num):
+ block = self.blocks[p][b]
+ pre = block.pre
+ if block.rec_block:
+ lines[p].insert(lines[p].index(block), block.rec_block)
+ if pre._type == 'h':
+ lines[pre._stage].insert(0, pre.send_block)
+ else:
+ lines[pre._stage].insert(lines[pre._stage].index(pre) + 1, pre.send_block)
+ for func in adjust_func:
+ lines = func(lines)
+ for p in range(self.pp):
+ for b, block in enumerate(lines[p]):
+ if b == 0:
+ block.depend_left = block.left if block.left else block.host.left
+ else:
+ block.depend_left = lines[p][b - 1]
+ return lines
+
+ def _get_block_phase(self, p:int, b:int) -> str:
+ r = self.micro_num % self.pp
+ if b < (self.vp + 1) * self.pp - 2 * p - 2 + r:
+ return 'warmup'
+ if b > self.block_num - (self.vp + 1) * self.pp + 2 * p:
+ return 'cooldown'
+ return 'stable'
+
+ def _send_block_delay(self, lines, p:int, b:int, distance:int) -> None:
+ i_send = lines[p].index(self.blocks[p][b].send_block)
+ send_block = lines[p].pop(i_send)
+ i_new = lines[p].index(self.blocks[p][b + distance]) + 1
+ lines[p].insert(i_new, send_block)
+
+ def swap_send_rec(self, lines) -> list[list[BlockSim]]:
+ for p in range(self.pp):
+ for b, block in enumerate(self.blocks[p]):
+ if b >= len(self.blocks[p]) - 1:
+ continue
+ i_b = lines[p].index(block)
+ i_bn = lines[p].index(self.blocks[p][b + 1])
+ if i_bn - i_b == 3:
+ if p % 2 == 0 and lines[p][i_b+1]._type == 'r' and lines[p][i_b+2]._type == 's':
+ lines[p][i_b+1], lines[p][i_b+2] = lines[p][i_b+2], lines[p][i_b+1]
+ if p % 2 == 1 and lines[p][i_b+1]._type == 's' and lines[p][i_b+2]._type == 'r':
+ if block.phase == 'warmup' and self.blocks[p][b + 1].phase == 'cooldown':
+ continue
+ lines[p][i_b+1], lines[p][i_b+2] = lines[p][i_b+2], lines[p][i_b+1]
+ if lines[p][i_b+1].dual._stage == lines[p][i_b+2].dual._stage:
+ pd = lines[p][i_b+1].dual._stage
+ j_b1 = lines[pd].index(lines[p][i_b+1].dual)
+ j_b2 = lines[pd].index(lines[p][i_b+2].dual)
+ if j_b1 > j_b2:
+ lines[p][i_b+1], lines[p][i_b+2] = lines[p][i_b+2], lines[p][i_b+1]
+ self.diff_4_swap(i_b, i_bn, lines, p)
+ return lines
+
+ def diff_4_swap(self, i_b, i_bn, lines, p):
+ if i_bn - i_b == 4:
+ if lines[p][i_b + 1].dual._stage == lines[p][i_b + 2].dual._stage and lines[p][i_b + 2].dual._stage == \
+ lines[p][i_b + 3].dual._stage:
+ if lines[p][i_b + 1]._type == 's' and lines[p][i_b + 2]._type == 's' and lines[p][i_b + 3]._type == 'r':
+ lines[p][i_b + 1], lines[p][i_b + 2] = lines[p][i_b + 2], lines[p][i_b + 1]
+
+ def vpp_send_delay(self, lines) -> list[list[BlockSim]]:
+ if self.micro_num % self.pp != 0:
+ return lines
+ for p in range(self.pp):
+ for b, block in enumerate(self.blocks[p]):
+ if block.send_block is not None and block.phase == 'stable':
+ self._send_block_delay(lines, p, b, 1)
+ return lines
+
+ def residue_delay(self, lines) -> list[list[BlockSim]]:
+ r = self.micro_num % self.pp
+ if not r:
+ return lines
+ for p in range(self.pp):
+ for b, block in enumerate(self.blocks[p]):
+ if block.send_block is None:
+ continue
+ if p == self.pp - 1 and block._id < self.pp + r and block._state == 'f':
+ self._send_block_delay(lines, -1, b, r + max(0, block._id - self.pp + 1))
+ elif p == 0 and block._id < self.pp + r and block._state == 'b':
+ if self.micro_num // self.pp == 1:
+ self._send_block_delay(lines, 0, b, r)
+ else:
+ self._send_block_delay(lines, 0, b, r + self.pp)
+ elif block.phase == 'stable':
+ self._send_block_delay(lines, p, b, 1)
+ return lines
+
+def test_comm_imba_zero():
+ for pp in range(2, 10):
+ for m in range(pp, 3*pp+1): # check vp=1 imba
+ sim = PipelineSimulator(np.ones(pp).tolist(), m, comm_time=0.1).run(print_info=False)
+ if sim.bubbles['imba'] > 0.001:
+ sim.print_info()
+ for vp in range(2, 6): # check vp>1 imba
+ for m in range(pp, 4*pp+1, pp):
+ sim = PipelineSimulator(np.ones((vp, pp)).tolist(), m, comm_time=0.1).run(print_info=False)
+ if sim.bubbles['imba'] > 0.001:
+ sim.print_info()
+
+'''
+当前缺陷:
+ 1. 开comm与vpp时, pp要求整除micro, 否则可能成环;
+'''
+
+if __name__ == '__main__':
+ layer_dis = [[2*0.6, 1.6, 1, 3, 2, 3, 4, 2], [6, 6, 6, 4, 6, 5, 5, 5+1.5]]
+ recompute = [[2*0.6, 2, 1, 3, 2, 3, 4, 2], [6, 6, 6, 4, 6, 4, 5, 0]]
+ PipelineSimulator(layer_dis, 128, 0.07, recompute).run(False).show(False, False)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_util.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c10871fcf16caedc907b37f4b072d42d63db237
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_util.py
@@ -0,0 +1,521 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""pp utility"""
+import numpy as np
+import time
+import re
+import sys
+import subprocess
+import argparse
+import chardet
+import highspy
+import yaml
+import ast
+import os
+import shlex
+
+from fast_tuner.utils.logger import logger
+
+pipeline_output_file = 'pipeline_output'
+dryrun_yaml_dir = 'dryrun_yaml'
+dryrun_shell_dir = 'dryrun_shell'
+dryrun_error = 'Dryrun failed, please check the mindspore environment!'
+
+def update_yaml_value(yaml_file, key, value):
+ with open(yaml_file, 'r', encoding='utf-8') as file:
+ yaml_data = yaml.safe_load(file)
+ if key in yaml_data:
+ print(f"find the {key}, update the context")
+ else:
+ print(f"can not find {key}")
+ yaml_data[key] = value
+ with open(yaml_file, 'w', encoding='utf-8') as file:
+ yaml.dump(yaml_data, file, default_flow_style=False)
+
+
+def highs_solve_mps(mps_file, solution_file, origin_model, time_limit):
+ highs_solver = highspy.Highs()
+ highs_solver.readModel(mps_file)
+ if time_limit != sys.maxsize:
+ highs_solver.setOptionValue('time_limit', time_limit)
+ # highs_solver.setOptionValue('threads', 8)
+
+ highs_solver.run()
+
+ info = highs_solver.getInfo()
+ origin_model.solution_status = highs_solver.solutionStatusToString(info.primal_solution_status)
+ logger.info(f'Solution status = {origin_model.solution_status}')
+ origin_model.model_status = highs_solver.modelStatusToString(highs_solver.getModelStatus())
+ logger.info(f'Model status = {origin_model.model_status}')
+ origin_model.run_time = highs_solver.getRunTime()
+ logger.info(f'Run time = {origin_model.run_time}')
+ origin_model.gap = info.mip_gap
+ logger.info(f'Gap = {origin_model.gap * 100:.2f}%')
+ logger.info(f'Optimal objective = {info.objective_function_value}')
+ origin_model.min_time = info.objective_function_value
+ if origin_model.model_status == 'Time limit reached' and origin_model.gap > 0:
+ logger.info(f'Dual bound = {info.mip_dual_bound}')
+ if origin_model.solution_status == 'None':
+ if origin_model.model_status == 'Infeasible':
+ logger.error(f'{mps_file} is Infeasible, Please check the memory limit!')
+ elif origin_model.model_status == 'Time limit reached':
+ logger.error(f'{mps_file} is not finished!, Please check the time limit!')
+ else:
+ logger.error(f'{mps_file} is no solution, model_status = {origin_model.model_status}!')
+ highs_solver.writeSolution(solution_file, 0)
+ logger.info(f'Writing the solution to {solution_file}')
+
+
+def qiuqi_solver_mps(solver_file, model_file, solution_file_name, origin_model):
+ command0 = ['chmod', '+x', solver_file]
+ subprocess.run(command0)
+ command1 = [solver_file, '-Help']
+ subprocess.run(command1, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ command2 = [solver_file, f'-SolutionFile={solution_file_name}', model_file]
+ subprocess.run(command2)
+ logger.info(f'Writing the solution to {solution_file_name}')
+ with open(solution_file_name, 'r', encoding='utf-8') as file:
+ content = file.read()
+ result_content = re.search(r'# Objective: (\d+\.\d+)', content)
+ objective_function_value = result_content.group(1)
+ logger.info(f'Optimal objective = {objective_function_value}')
+ origin_model.min_time = objective_function_value
+
+
+def write_config_to_yaml(recompute_config, offset, yaml_file):
+ with open(yaml_file, 'r', encoding='utf-8') as file:
+ data = yaml.safe_load(file)
+ data['recompute_config'] = recompute_config
+ data['model']['model_config']['offset'] = offset
+ with open(yaml_file, 'w', encoding='utf-8') as file:
+ yaml.dump(data, file, default_flow_style=False, indent=4)
+
+
+def write_config_to_shell(offset, shell_file):
+ configs, unparses = parse_shell(shell_file)
+ layer_num = configs.get('NUM_LAYERS') // configs.get('PP')
+ layer_list = flatten_to_str(offset, layer_num)
+ configs['NUM_LAYERS_LIST'] = layer_list
+ configs_to_shell(shell_file, configs, unparses)
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def str2list(v):
+ if isinstance(v, list):
+ return v
+ if isinstance(ast.literal_eval(v), list):
+ return ast.literal_eval(v)
+ else:
+ raise argparse.ArgumentTypeError('List value expected.')
+
+
+def str2int(v):
+ if isinstance(v, int):
+ return v
+ if isinstance(int(v), int):
+ return int(v)
+ else:
+ raise argparse.ArgumentTypeError('Int value expected.')
+
+
+def str2dict(v):
+ if isinstance(v, dict):
+ return v
+ if isinstance(ast.literal_eval(v), dict):
+ return ast.literal_eval(v)
+ else:
+ raise argparse.ArgumentTypeError('Dict value expected.')
+
+
+
+def build_new_config_yaml(args):
+ with open(args.yaml, 'r', encoding='utf-8') as file:
+ data = yaml.safe_load(file)
+ if args.offset:
+ data['model']['model_config']['offset'] = args.offset
+ logger.info(f'change offset to {args.offset}')
+ recompute_config = build_recompute_config(args.is_select_recompute, args.is_recompute,
+ args.select_recompute_layers, args.recompute_layers)
+ if args.is_select_recompute is None:
+ logger.info(f'is_select_recompute is None')
+ else:
+ data['recompute_config']['select_recompute'] = recompute_config['select_recompute']
+ data['recompute_config']['select_comm_recompute'] = recompute_config['select_comm_recompute']
+ logger.info(f'change is_select_recompute to {args.is_select_recompute}')
+ logger.info(f'change select_recompute_layers to {args.select_recompute_layers}')
+ if args.is_recompute is None:
+ logger.info(f'is_recompute is None')
+ else:
+ data['recompute_config']['recompute'] = recompute_config['recompute']
+ logger.info(f'change is_recompute to {args.is_recompute}')
+ logger.info(f'change recompute_layers to {args.recompute_layers}')
+ file_name, file_ext = os.path.splitext(os.path.basename(args.yaml))
+ yaml_path = os.path.dirname(os.path.abspath(args.yaml))
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ new_yaml = os.path.join(yaml_path, f'{file_name}_{timestamp}.yaml')
+ with open(new_yaml, 'w', encoding='utf-8') as file:
+ yaml.dump(data, file, default_flow_style=False, indent=4)
+ logger.info(f'build new yaml {new_yaml}')
+ return new_yaml
+
+
+def build_new_config_shell(args):
+ configs, unparses = parse_shell(args.shell)
+ layer_num = configs.get('NUM_LAYERS') // configs.get('PP')
+ layer_list = flatten_to_str(args.offset, layer_num)
+ configs['NUM_LAYERS_LIST'] = layer_list
+ logger.info(f'change offset to {args.offset}')
+ file_name, file_ext = os.path.splitext(os.path.basename(args.shell))
+ shell_path = os.path.dirname(os.path.abspath(args.shell))
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ new_shell = os.path.join(shell_path, f'{file_name}_{timestamp}.sh')
+ configs_to_shell(new_shell, configs, unparses)
+ logger.info(f'build new shell {new_shell}')
+ return new_shell
+
+
+def build_recompute_config(is_select_recompute, is_recompute, select_recompute_layers, recompute_layers):
+ recompute_config = {
+ 'parallel_optimizer_comm_recompute': False,
+ 'mp_comm_recompute': True,
+ 'recompute_slice_activation': True
+ }
+
+ if is_select_recompute:
+ # 选择对应层的算子进行重计算
+ recompute_config['select_recompute'] = {}
+ recompute_config['select_recompute'][r'feed_forward\.null'] = select_recompute_layers
+ recompute_config['select_recompute'][r'feed_forward\.w1\.activation\.silu'] = select_recompute_layers
+ recompute_config['select_recompute'][r'feed_forward\.w1\.reshape'] = select_recompute_layers
+ recompute_config['select_recompute'][r'feed_forward\.w2\.reshape'] = select_recompute_layers
+ recompute_config['select_recompute'][r'add'] = select_recompute_layers
+ recompute_config['select_recompute'][r'cast_up'] = select_recompute_layers
+ # 选择对应层的算子进行通信重计算
+ recompute_config['select_comm_recompute'] = {}
+ recompute_config['select_comm_recompute'][r'.*\.norm'] = select_recompute_layers
+ recompute_config['select_comm_recompute'][r'attention\.wq\.reshape'] = select_recompute_layers
+ recompute_config['select_comm_recompute'][r'attention\.wk\.reshape'] = select_recompute_layers
+ recompute_config['select_comm_recompute'][r'feed_forward\.w1\.reshape'] = select_recompute_layers
+ recompute_config['select_comm_recompute'][r'feed_forward\.w3\.reshape'] = select_recompute_layers
+ elif is_select_recompute is False:
+ recompute_config['select_recompute'] = False
+ recompute_config['select_comm_recompute'] = False
+
+ if is_recompute:
+ recompute_config['recompute'] = recompute_layers
+ elif is_recompute is False:
+ recompute_config['recompute'] = False
+
+ return recompute_config
+
+
+def construct_distribution(num_micro, num_stage):
+ parts, remainder = num_micro // num_stage, num_micro % num_stage
+ distribution = []
+ for part in range(parts):
+ if part == 0:
+ distribution.append(num_stage + remainder)
+ else:
+ distribution.append(num_stage)
+ return distribution
+
+
+def sort_micro(parts, num_vpp, num_stage, distribution, low_mem, seq_split, is_f_then_b: bool = False):
+ forward = []
+ backward = []
+ final_orders = []
+ for part in range(parts):
+ for vpp in range(num_vpp):
+ for micro_id in range(distribution[part]):
+ for split in range(seq_split):
+ forward.append((part, vpp, 'f', micro_id, split))
+ for vpp in range(num_vpp - 1, -1, -1):
+ for micro_id in range(distribution[part]):
+ for split in range(seq_split - 1, -1, -1):
+ backward.append((part, vpp, 'b', micro_id, split))
+ # f-then-b的调度规则,待启用
+ if is_f_then_b:
+ for stage in range(num_stage):
+ stage_order = []
+ for part in range(parts):
+ for micro_id in range(distribution[part]):
+ stage_order.append((part, 0, 'f', micro_id, 0))
+ for micro_id in range(distribution[part]):
+ stage_order.append((part, 0, 'b', micro_id, 0))
+ final_orders.append(stage_order)
+ return final_orders
+
+ for stage in range(num_stage):
+ if low_mem:
+ warmup = min(((num_vpp - 1) * distribution[0] + (num_stage - stage - 1)) * seq_split, len(forward))
+ else:
+ warmup = min(((num_vpp - 1) * distribution[0] + (num_stage - stage - 1) * 2) * seq_split, len(forward))
+ # 最后一个stage,第一个micro前向做完之后才能做后向
+ if stage == num_stage - 1:
+ warmup = warmup + seq_split - 1
+ stage_order = []
+ stage_order += forward[: warmup]
+ for i in range(warmup, len(forward)):
+ stage_order.append(forward[i])
+ stage_order.append(backward[i - warmup])
+ stage_order += backward[len(forward) - warmup:]
+ final_orders.append(stage_order)
+ return final_orders
+
+
+def get_init_input_peak_mem():
+ return []
+
+
+def get_state_input_peak_mem():
+ return []
+
+
+def get_ranks_stages(yaml_file):
+ with open(yaml_file, 'r', encoding='utf-8') as file:
+ data = yaml.safe_load(file)
+ pipeline_stage = data['parallel_config']['pipeline_stage']
+ model_parallel = data['parallel_config']['model_parallel']
+ data_parallel = data['parallel_config']['data_parallel']
+ rank_size = pipeline_stage * model_parallel * data_parallel
+ return rank_size, pipeline_stage
+
+
+def get_shell_ranks_stages(shell_file):
+ configs, unparses = parse_shell(shell_file)
+ pipeline_stage = configs.get('PP')
+ model_parallel = configs.get('TP')
+ data_parallel = configs.get('DP')
+ rank_size = pipeline_stage * model_parallel * data_parallel
+ return rank_size, pipeline_stage
+
+
+def get_layers_distribution(offset, num_layers, num_stages, num_vpp):
+ layers_stage_vpp = [[0 for _ in range(num_vpp)] for _ in range(num_stages)]
+ for stage in range(num_stages):
+ for vpp in range(num_vpp):
+ layers_stage_vpp[stage][vpp] = offset[vpp][stage] + num_layers // (num_stages * num_vpp)
+ return layers_stage_vpp
+
+
+def find_most_times_stage(layer_num_of_stage):
+ # key 为各个stage的layer数,value 为layer数对应的stage编号List
+ frequency_dict = {}
+ layer_num = layer_num_of_stage[0]
+ max_time = 0
+ for i in range(len(layer_num_of_stage)):
+ if layer_num_of_stage[i] not in frequency_dict:
+ frequency_dict[layer_num_of_stage[i]] = []
+ frequency_dict[layer_num_of_stage[i]].append(i)
+ if len(frequency_dict[layer_num_of_stage[i]]) > max_time:
+ max_time = len((frequency_dict[layer_num_of_stage[i]]))
+ layer_num = layer_num_of_stage[i]
+ return frequency_dict[layer_num]
+
+
+def extract_peak_memory(file_path):
+ with open(file_path, 'rb') as file:
+ raw_data = file.read()
+ result = chardet.detect(raw_data)
+ encoding = result['encoding']
+ with open(file_path, 'r', encoding=encoding) as file:
+ content = file.read()
+ result_content = re.search(r'Used peak memory usage \(without fragments\):\s*(\d+)M', content)
+ if result_content:
+ return int(result_content.group(1))
+ else:
+ raise ValueError(dryrun_error)
+
+
+def extract_actual_peak_memory(file_path):
+ with open(file_path, 'rb') as file:
+ raw_data = file.read()
+ result = chardet.detect(raw_data)
+ encoding = result['encoding']
+ with open(file_path, 'r', encoding=encoding) as file:
+ content = file.read()
+ result_content = re.search(
+ r'Actual peak memory usage \(with fragments\):\s*(\d+)M', content)
+ if result_content:
+ return int(result_content.group(1))
+ else:
+ raise ValueError(dryrun_error)
+
+
+def get_peak_batch(micro_batch_num, num_stage, num_vpp, low_mem, offset, num_layers):
+ distribution = construct_distribution(micro_batch_num, num_stage)
+ final_orders = sort_micro(micro_batch_num // num_stage, num_vpp, num_stage, distribution, low_mem, 1)
+ layers_stage_vpp = get_layers_distribution(offset, num_layers, num_stage, num_vpp)
+ # 求峰值内存时的激活份数
+ peaks = [0] * num_stage
+ for stage in range(num_stage):
+ cur_mem = 0
+ for micro in final_orders[stage]:
+ if micro[2] == 'f':
+ cur_mem += layers_stage_vpp[stage][micro[1]]
+ else:
+ cur_mem -= layers_stage_vpp[stage][micro[1]]
+ peaks[stage] = max(peaks[stage], cur_mem)
+ return peaks
+
+
+def build_coe_array(peak_num_act, peak_num_select_recom, peak_num_recompute, x, stage_range):
+ num_stage = len(peak_num_act)
+ layer_dis = [sum(x[vpp][stage] for vpp in range(1)) for stage in range(num_stage)]
+ coe_a = np.empty((stage_range[-1] - stage_range[0], 5), float)
+ for stage in range(stage_range[0], stage_range[-1]):
+ coe_a[stage - stage_range[0]][0] = 1
+ coe_a[stage - stage_range[0]][1] = peak_num_act[stage]
+ coe_a[stage - stage_range[0]][2] = layer_dis[stage]
+ coe_a[stage - stage_range[0]][3] = peak_num_recompute[stage]
+ coe_a[stage - stage_range[0]][4] = peak_num_select_recom[stage]
+ return coe_a
+
+
+def bulid_yaml(old_yaml_file, recompute_config, offset, num_layers, num_vpp, num_stage, dense_layers, micro):
+ with open(old_yaml_file, 'r', encoding='utf-8') as file:
+ data = yaml.safe_load(file)
+ data['recompute_config'] = recompute_config
+ data['model']['model_config']['offset'] = offset
+ if 'mtp_depth' in data['model']['model_config']:
+ mtp_depth = data['model']['model_config']['mtp_depth']
+ num_layers -= mtp_depth
+ data['model']['model_config']['num_layers'] = num_layers
+ if 'moe_config' in data:
+ if 'first_k_dense_replace' in data['moe_config']:
+ data['moe_config']['first_k_dense_replace'] = dense_layers
+ data['parallel_config']['pipeline_stage'] = num_stage
+ data['parallel_config']['micro_batch_num'] = micro
+ if 'pp_interleave_num' in data['model']['model_config']:
+ data['model']['model_config']['pp_interleave_num'] = num_vpp
+ pipeline_output = os.path.join(os.getcwd(), pipeline_output_file)
+ if not os.path.exists(pipeline_output):
+ os.mkdir(pipeline_output)
+ new_file = os.path.join(pipeline_output, dryrun_yaml_dir)
+ if not os.path.exists(new_file):
+ os.mkdir(new_file)
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ new_yaml_name = os.path.join(new_file, f'{timestamp}.yaml')
+ with open(new_yaml_name, 'w', encoding='utf-8') as file:
+ yaml.dump(data, file, default_flow_style=False, indent=4)
+ return new_yaml_name
+
+
+def bulid_shell(old_shell_file, offset, num_layers, num_vpp, num_stage, dense_layers, micro):
+ configs, unparses = parse_shell(old_shell_file)
+ layer_num = num_layers//num_stage
+ layer_list = flatten_to_str(offset, layer_num)
+ configs['NUM_LAYERS_LIST'] = layer_list
+ # 内存不够,重计算全开
+ if 'RECOMPUTE_NUM_LAYERS' in configs:
+ configs['RECOMPUTE_NUM_LAYERS'] = max(layer_list)
+ if 'FIRST_K_DENSE_REPLACE' in configs:
+ configs['FIRST_K_DENSE_REPLACE'] = dense_layers
+ configs['PP'] = num_stage
+ configs['VPP'] = num_vpp
+ configs['MBS'] = 1
+ configs['GBS'] = micro * configs.get('MBS') * configs.get('DP')
+ configs['NUM_LAYERS'] = num_layers
+ pipeline_output = os.path.join(os.getcwd(), pipeline_output_file)
+ if not os.path.exists(pipeline_output):
+ os.mkdir(pipeline_output)
+ new_file = os.path.join(pipeline_output, dryrun_shell_dir)
+ if not os.path.exists(new_file):
+ os.mkdir(new_file)
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ new_shell_name = os.path.join(new_file, f'{timestamp}.sh')
+ configs_to_shell(new_shell_name, configs, unparses)
+ return new_shell_name
+
+
+def parse_shell(shell_file):
+ with open(shell_file, 'r', encoding='utf-8') as f:
+ content = f.read()
+ lexer = shlex.shlex(content, posix=True)
+ lexer.whitespace_split = True
+ lexer.whitespace = '\n'
+ lexer.escape = ''
+ configs = {}
+ unparses = ''
+ for token in lexer:
+ if '=' in token:
+ key, value = token.split('=', 1)
+ key = key.strip()
+ try:
+ value = int(value)
+ except ValueError:
+ pass
+ configs[key] = value
+ else:
+ unparses += token + '\n'
+ return configs, unparses
+
+
+def parse_shell_config(config_value):
+ parts = config_value.split('--')
+ paras = {}
+ for part in parts:
+ part_split = part.strip().split(maxsplit=1)
+ if part_split:
+ key = part_split[0].strip()
+ value = part_split[1].strip(' \n\\') if len(part_split)>1 else ''
+ try:
+ value = int(value)
+ except ValueError:
+ pass
+ paras[key] = value
+ return paras
+
+
+def change_shell_config(configs_dict, config, para, value):
+ config_value = configs_dict[config]
+ parse_config_value = parse_shell_config(config_value)
+ parse_config_value[para] = value
+ content = '\n'
+ for key, value in parse_config_value.items():
+ content += f' --{key} {value} \\\n'
+ configs_dict[config] = content
+
+
+def configs_to_shell(shell_name, configs, unparses):
+ with open(shell_name, 'w', encoding='utf-8') as f:
+ for key, value in configs.items():
+ if isinstance(value, int) or value[0] == '$':
+ f.write(f'{key}={value}\n')
+ else:
+ f.write(f'{key}="{value}"\n')
+ f.write(unparses)
+
+
+def flatten_to_str(lst, num=0):
+ items = []
+ for item in lst:
+ if isinstance(item, list):
+ items.append(flatten_to_str(item, num))
+ elif isinstance(item, int):
+ items.append(str(item + num))
+ else:
+ item.append(str(item))
+ return ','.join(items)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/result_csv.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/result_csv.py
new file mode 100644
index 0000000000000000000000000000000000000000..1555002400d3c65f03a18b8b3a5194864a4d3337
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/result_csv.py
@@ -0,0 +1,159 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""generate result to csv"""
+
+import os
+import time
+import csv
+import yaml
+
+from fast_tuner.utils.logger import logger
+from fast_tuner.pipeline_conductor import pp_util
+from fast_tuner.pipeline_conductor.dryrun import DryRun, dryrun_config_error
+
+class ResultCsv:
+ """
+ generate result csv
+ """
+ def __init__(self, output_path, output_file='pipeline_output', name='test_result'):
+ self.header = ['test', 'layers', 'micro', 'dp', 'tp', 'pp', 'ep', 'vp', 'op/fsdp', 'dense:moe', '反向:正向',
+ '重计算增加比率', 'mtp+head', 'moe时长', 'low_mem', '目标值', 'cost', 'x', 'offset',
+ 'ra', '内存信息', '内存上限(GB)', '求解器', 'GAP', 'solution_status', 'model_status',
+ '求解耗时/s', 'dryrun_check']
+ self.name = name
+ timestamp = time.strftime("%Y%m%d%H%M%S")
+ csv_dir = os.path.join(os.path.abspath(output_path), output_file)
+ if not os.path.exists(csv_dir):
+ os.mkdir(csv_dir)
+ self.path = os.path.join(csv_dir, f'{self.name}_{timestamp}.csv')
+ self.create_csv_file()
+
+ def create_csv_file(self):
+ """
+ create csv file
+ """
+ with open(self.path, 'w', encoding='utf-8-sig') as file:
+ header = ','.join(self.header) + '\n'
+ file.write(header)
+ logger.info (f'Successfully created {self.path}')
+
+ def config_to_csv(self, candidate, low_mem, solver_name):
+ new_row = ['']*len(self.header)
+ h = self.header
+ new_row[h.index('test')] = candidate.config_path
+ new_row[h.index('求解器')] = solver_name
+ new_row[h.index('low_mem')] = low_mem
+ new_row[h.index('dense:moe')] = candidate.profiling_info.dmratio
+ new_row[h.index('反向:正向')] = candidate.profiling_info.bfratio
+ new_row[h.index('重计算增加比率')] = candidate.profiling_info.re_grow_ration
+ new_row[h.index('mtp+head')] = candidate.profiling_info.hratio
+ new_row[h.index('moe时长')] = candidate.profiling_info.moe_fw
+ if DryRun.config_file_type == 0:
+ self.yaml_to_row(candidate.config_path, new_row)
+ elif DryRun.config_file_type == 1:
+ self.shell_to_row(candidate.config_path, new_row)
+ elif DryRun.config_file_type == 2:
+ self.toml_to_row(candidate, new_row)
+ else:
+ raise TypeError(dryrun_config_error)
+ with open(self.path, 'a', newline='', encoding='utf-8-sig') as file:
+ writer = csv.writer(file)
+ writer.writerows([new_row])
+
+ def yaml_to_row(self, yaml_file, row):
+ """
+ trans yaml info to row of csv
+ """
+ with open(yaml_file, 'r', encoding='utf-8') as file:
+ data = yaml.safe_load(file)
+ h = self.header
+ row[h.index('layers')] = data['model']['model_config']['num_layers']
+ if 'mtp_depth' in data['model']['model_config']:
+ row[h.index('layers')] += data['model']['model_config']['mtp_depth']
+ row[h.index('micro')] = data['parallel_config']['micro_batch_num']
+ row[h.index('dp')] = data['parallel_config']['data_parallel']
+ row[h.index('tp')] = data['parallel_config']['model_parallel']
+ row[h.index('pp')] = data['parallel_config']['pipeline_stage']
+ if 'expert_parallel' in data['parallel_config']:
+ row[h.index('ep')] = data['parallel_config']['expert_parallel']
+ else:
+ row[h.index('ep')] = 1
+ if 'pp_interleave_num' in data['model']['model_config']:
+ row[h.index('vp')] = data['model']['model_config']['pp_interleave_num']
+ else:
+ row[h.index('vp')] = 1
+
+ def shell_to_row(self, shell_file, row):
+ h = self.header
+ configs, _ = pp_util.parse_shell(shell_file)
+ row[h.index('layers')] = configs.get('NUM_LAYERS')
+ row[h.index('micro')] = configs.get('GBS') // (configs.get('MBS') * configs.get('DP'))
+ row[h.index('dp')] = configs.get('DP')
+ row[h.index('tp')] = configs.get('TP')
+ row[h.index('pp')] = configs.get('PP')
+ row[h.index('ep')] = configs.get('EP', 1)
+ row[h.index('vp')] = configs.get('VPP', 1)
+
+ def toml_to_row(self, candidate, row):
+ """
+ trans toml info to row of csv
+ """
+ model_args = candidate.model_args
+ h = self.header
+ row[h.index('layers')] = model_args.num_layers
+ row[h.index('micro')] = model_args.mbn
+ row[h.index('dp')] = model_args.dp
+ row[h.index('tp')] = model_args.tp
+ row[h.index('pp')] = model_args.pp
+ row[h.index('ep')] = model_args.ep
+ row[h.index('vp')] = 1
+ row[h.index('op/fsdp')] = model_args.dp_shard_degree
+
+ def result_to_csv(self, solution):
+ """
+ fill solution result to csv
+ """
+ row_cost =[]
+ row_no_cost = []
+ with open(self.path, 'r', newline='', encoding='utf-8-sig') as file:
+ reader = csv.reader(file)
+ h = self.header
+ for row in reader:
+ if row[h.index('test')] == str(solution.init_config.config_file):
+ row[h.index('内存信息')] = solution.init_config.memory.get_mem()
+ row[h.index('内存上限(GB)')] = solution.init_config.memory.mem_lim / 1024
+ row[h.index('求解耗时/s')] = solution.run_time
+ row[h.index('solution_status')] = solution.solution_status
+ row[h.index('model_status')] = solution.model_status
+ if solution.solution_status != 'None':
+ row[h.index('目标值')] = solution.object_value
+ row[h.index('x')] = solution.layer_dis.tolist()
+ row[h.index('offset')] = solution.offset.tolist()
+ row[h.index('ra')] = solution.ra_dis.tolist()
+ row[h.index('GAP')] = solution.gap
+ row[h.index('cost')] = solution.object_value * float(row[h.index('moe时长')])
+ row[h.index('dryrun_check')] = solution.check_peak_mem
+ if row[h.index('cost')]:
+ row_cost.append(row)
+ else:
+ row_no_cost.append(row)
+ sorted_rows = sorted(row_cost[1:], key=lambda x: float(x[self.header.index('cost')]))
+ sorted_rows = [self.header] + sorted_rows + row_no_cost
+ with open(self.path, 'w', newline='', encoding='utf-8-sig') as file:
+ writer = csv.writer(file)
+ writer.writerows(sorted_rows)
+
+if __name__ == '__main__':
+ csv = ResultCsv('./output')
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/solution.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/solution.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f79440dcb24abbd2960b3ecc10d979c25c01ea
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/solution.py
@@ -0,0 +1,256 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+'''
+This Module consists a couple of utility functions
+for solution. Including parser, peak memory calculation
+and validity check
+'''
+
+import itertools
+import re
+import numpy as np
+from fast_tuner.pipeline_conductor.start_service import InitConfig
+from fast_tuner.pipeline_conductor.math_model import Model
+from fast_tuner.pipeline_conductor import micro
+from fast_tuner.pipeline_conductor.start_service import ExpertInput
+from fast_tuner.utils.logger import logger
+
+class Solution:
+ '''
+ parse and record various properties of
+ solutions, including a few simple analysis
+ '''
+ x_type1 = [int]
+ x_type2 = [int]
+ layer_dis = [int]
+ layer1_dis_stage = [int]
+ layer2_dis_stage = [int]
+ offset = [int]
+ indicator_type1 = [int]
+ indicator_type2 = [int]
+ rs_type1 = [int]
+ rs_type2 = [int]
+ rs_dis = [int]
+ ra_type1 = [int]
+ ra_type2 = [int]
+ ra_dis = [int]
+ forward_s = [float]
+ backward_s = [float]
+ peak_num = micro.PeakNum
+ check_peak_mem = []
+ object_value = float
+ gap = float
+ solution_status = str
+ model_status = str
+ run_time = float
+ sol_file = ''
+
+ def __init__(self, init_config: InitConfig):
+ self.init_config = init_config
+ self.x_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.x_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.offset = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.indicator_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.indicator_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.rs_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.rs_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.ra_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.ra_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int)
+ self.forward_s = np.empty((self.init_config.pipeline_stage, self.init_config.pp_interleave_num,
+ self.init_config.parts, self.init_config.micro_batch_num,
+ self.init_config.seq_splits), dtype=float)
+ self.backward_s = np.empty((self.init_config.pipeline_stage, self.init_config.pp_interleave_num,
+ self.init_config.parts, self.init_config.micro_batch_num,
+ self.init_config.seq_splits), dtype=float)
+
+ def set_solution(self, solved_model: Model, is_origin_solver, sol_file):
+ '''
+ parse the raw MIP solution into pipeline
+ parallelism layer distributions
+ '''
+ self.sol_file = sol_file
+ if is_origin_solver:
+ for stage in range(self.init_config.pipeline_stage):
+ for vpp in range(self.init_config.pp_interleave_num):
+ self.x_type1[vpp][stage] = solved_model.x_type1[stage][vpp].solution_value()
+ self.x_type2[vpp][stage] = solved_model.x_type2[stage][vpp].solution_value()
+ self.indicator_type1[vpp][stage] = solved_model.indicator_type1[stage][vpp].solution_value()
+ self.indicator_type2[vpp][stage] = solved_model.indicator_type2[stage][vpp].solution_value()
+ self.rs_type1[vpp][stage] = solved_model.rs_type1[stage][vpp].solution_value()
+ self.rs_type2[vpp][stage] = solved_model.rs_type2[stage][vpp].solution_value()
+ self.ra_type1[vpp][stage] = solved_model.ra_type1[stage][vpp].solution_value()
+ self.ra_type2[vpp][stage] = solved_model.ra_type2[stage][vpp].solution_value()
+ for part in range(self.init_config.parts):
+ for micro_ind, seq in itertools.product(range(solved_model.distribution[part]),
+ range(self.init_config.seq_splits)):
+ self.forward_s[stage][vpp][part][micro_ind][seq] = (
+ solved_model.forward_s[stage][vpp][part][micro_ind][seq].solution_value())
+ self.backward_s[stage][vpp][part][micro_ind][seq] = (
+ solved_model.backward_s[stage][vpp][part][micro_ind][seq].solution_value())
+ else:
+ with open(sol_file, 'r', encoding='utf-8') as file:
+ for line in file:
+ content1 = re.search(r'^x_type1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content2 = re.search(r'^x_type2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content3 = re.search(r'^indicator1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content4 = re.search(r'^indicator2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content5 = re.search(r'^rs_type1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content6 = re.search(r'^rs_type2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content7 = re.search(r'^ra_type1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content8 = re.search(r'^ra_type2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content9 = re.search(
+ r'^forward_s_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ content10 = re.search(
+ r'^backward_s(\d+)_(\d+)_(\d+)_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)',
+ line)
+ if content1:
+ self.x_type1[int(content1.group(2))][int(content1.group(1))] = round(float(content1.group(3)))
+ if content2:
+ self.x_type2[int(content2.group(2))][int(content2.group(1))] = round(float(content2.group(3)))
+ if content3:
+ self.indicator_type1[int(content3.group(2))][int(content3.group(1))] = round(
+ float(content3.group(3)))
+ if content4:
+ self.indicator_type2[int(content4.group(2))][int(content4.group(1))] = round(
+ float(content4.group(3)))
+ if content5:
+ self.rs_type1[int(content5.group(2))][int(content5.group(1))] = round(float(content5.group(3)))
+ if content6:
+ self.rs_type2[int(content6.group(2))][int(content6.group(1))] = round(float(content6.group(3)))
+ if content7:
+ self.ra_type1[int(content7.group(2))][int(content7.group(1))] = round(float(content7.group(3)))
+ if content8:
+ self.ra_type2[int(content8.group(2))][int(content8.group(1))] = round(float(content8.group(3)))
+
+ if content9:
+ self.forward_s[int(content9.group(1))][int(content9.group(2))][int(content9.group(3))][
+ int(content9.group(4))][int(content9.group(5))] = float(content9.group(6))
+ if content10:
+ self.backward_s[int(content10.group(1))][int(content10.group(2))][int(content10.group(3))][
+ int(content10.group(4))][int(content10.group(5))] = float(content10.group(6))
+ self.object_value = solved_model.min_time
+ self.gap = solved_model.gap
+ self.model_status = solved_model.model_status
+ self.solution_status = solved_model.solution_status
+ self.run_time = solved_model.run_time
+ if self.solution_status != 'None':
+ self.cal_peak_mem(solved_model)
+ self.set_total_dis()
+ self.check_time_list()
+
+ def cal_peak_mem(self, solved_model: Model):
+ self.peak_num = micro.PeakNum(solved_model.sort_micro)
+ self.peak_num.set_peak_act_recompute_num(self.x_type2, self.rs_type2, self.ra_type2,
+ self.x_type1, self.rs_type1, self.ra_type1,
+ solved_model.memory)
+
+ def set_total_dis(self):
+ self.layer_dis = self.x_type1 + self.x_type2
+ self.offset = (self.layer_dis - (self.init_config.num_layers_type1 + self.init_config.num_layers_type2) //
+ (self.init_config.pp_interleave_num * self.init_config.pipeline_stage))
+ self.rs_dis = self.rs_type1 + self.rs_type2
+ self.ra_dis = self.ra_type1 + self.ra_type2
+ self.layer1_dis_stage = [sum(self.x_type1[v][s] for v in range(self.init_config.pp_interleave_num)) for s in
+ range(self.init_config.pipeline_stage)]
+ self.layer2_dis_stage = [sum(self.x_type2[v][s] for v in range(self.init_config.pp_interleave_num)) for s in
+ range(self.init_config.pipeline_stage)]
+
+ def check_time_list(self):
+ '''
+ inner functions for testing whether
+ the starting times align with exec order
+ '''
+ s_time = [[] for _ in range(self.init_config.pipeline_stage)]
+ for stage in range(self.init_config.pipeline_stage):
+ for i in range(len(self.peak_num.sort_micro.final_orders[stage])):
+ micro_batch = self.peak_num.sort_micro.final_orders[stage][i]
+ if micro_batch.state == 'f':
+ s_time[stage].append(self.forward_s[stage][micro_batch.vpp][micro_batch.part][micro_batch.micro_id]
+ [micro_batch.split])
+ if micro_batch.state == 'b':
+ s_time[stage].append(self.backward_s[stage][micro_batch.vpp][micro_batch.part][micro_batch.micro_id]
+ [micro_batch.split])
+ # 各stage的micro0batch start time单调递增
+ for stage in range(self.init_config.pipeline_stage):
+ for i in range(len(self.peak_num.sort_micro.final_orders[stage]) - 1):
+ if s_time[stage][i] > s_time[stage][i + 1]:
+ raise ValueError('Time sequence error!')
+
+ def solution_print(self):
+ '''print the information of the solution'''
+ logger.info('layer distribution: ')
+ logger.info(self.layer_dis.tolist())
+ logger.info('layer of type1 distribution: ')
+ logger.info(self.x_type1.tolist())
+ logger.info('the indicator of layer1: ')
+ logger.info(self.indicator_type1.tolist())
+
+ logger.info('layer of type2 distribution: ')
+ logger.info(self.x_type2.tolist())
+ logger.info('the indicator of layer2: ')
+ logger.info(self.indicator_type2.tolist())
+ logger.info('the offset: ')
+ logger.info(self.offset.tolist())
+
+ logger.info('rs distribution: ')
+ logger.info(self.rs_dis.tolist())
+ logger.info('layer of type1 rs distribution: ')
+ logger.info(self.rs_type1.tolist())
+ logger.info('layer of type2 rs distribution: ')
+ logger.info(self.rs_type2.tolist())
+
+ logger.info('ra distribution: ')
+ logger.info(self.ra_dis.tolist())
+ logger.info('layer of type1 ra distribution: ')
+ logger.info(self.ra_type1.tolist())
+ logger.info('layer of type2 ra distribution: ')
+ logger.info(self.ra_type2.tolist())
+
+ logger.info('the peak memory:')
+ logger.info(list(self.peak_num.max_mem.values()))
+ logger.info('the number of micro batch when peak memory')
+ logger.info(list(self.peak_num.micro_num_of_max_mem.values()))
+
+ logger.info('the number of activations of the type of layer1')
+ logger.info(list(self.peak_num.peak_num_act_type1.values()))
+ logger.info('the number of select recomputes of the type of layer1')
+ logger.info(list(self.peak_num.peak_num_select_recom_type1.values()))
+ logger.info('the number of full recomputes of the type of layer1')
+ logger.info(list(self.peak_num.peak_num_recompute_type1.values()))
+
+ logger.info('the number of activations of the type of layer2')
+ logger.info(list(self.peak_num.peak_num_act_type2.values()))
+ logger.info('the number of select recomputes of the type of layer2')
+ logger.info(list(self.peak_num.peak_num_select_recom_type2.values()))
+ logger.info('the number of full recomputes of the type of layer2')
+ logger.info(list(self.peak_num.peak_num_recompute_type2.values()))
+
+def extract_solution_file(yaml_file, sol_file):
+ expert_input = ExpertInput(yaml_file, '')
+ expert_input.is_dryrun = False
+ init_config = InitConfig(expert_input)
+ origin_model = Model(init_config)
+ file_solution = Solution(init_config)
+ file_solution.set_solution(origin_model, False, sol_file)
+ file_solution.solution_print()
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/start_service.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/start_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..8772d29cadc9323bd1693e30515243a908a8e833
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/start_service.py
@@ -0,0 +1,487 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""start service"""
+import os
+import re
+import numpy as np
+import yaml
+
+from fast_tuner.utils.logger import logger
+from fast_tuner.pipeline_conductor import dryrun, pp_util
+from fast_tuner.pipeline_conductor import micro
+from fast_tuner.pipeline_conductor.memory import Memory
+
+pipeline_output_file = 'pipeline_output'
+init_dryrun_file = 'init_dryrun'
+double_check_dryrun_filename = 'double_check_dryrun'
+HIGHS_NAME = 'HIGHS'
+default_low_mem = False
+default_time_limit = 1e10
+# 0:deepseek模型;1:boss模型
+model_class = 0
+
+
+# 专家输入:专家可根据环境变化更改
+class ExpertInput:
+ config_file = ''
+ ms_adapter_file = ''
+ model_args = None
+
+ solver_name = HIGHS_NAME
+ llm_class = model_class
+ time_limit = int(default_time_limit)
+
+ is_dryrun = True
+ is_double_check = False
+ fit_level = 0
+
+ low_mem = default_low_mem
+ layer_ratio = 0.3
+ backward_ratio = 2
+ srRatio = 0.33
+ is_support_ra_with_rs = False
+ is_select_recomp = False
+ is_full_recomp = True
+
+ is_head_loss_input = True
+ head_loss = 1.493
+ recompute_ratio = 0.246
+
+ output_file = pipeline_output_file
+ output_file_dir = os.path.join(os.getcwd(), output_file)
+ double_check_dryrun_filename = double_check_dryrun_filename
+
+ def __init__(self, config_file, ms_adapter_file):
+ if not os.path.exists(self.output_file_dir):
+ os.mkdir(self.output_file_dir)
+ self.config_file = config_file
+ self.ms_adapter_file = ms_adapter_file
+
+
+class InitDryrun:
+ layers_stage_vpp = []
+ init_offset = []
+ recompute_config = {}
+ parts = int
+ distribution = []
+ x_type1 = []
+ x_type2 = []
+ rs_type1 = []
+ rs_type2 = []
+ ra_type1 = []
+ ra_type2 = []
+ dense_layers = None
+
+ def __init__(self, data_parallel, model_parallel, expert_input: ExpertInput):
+ if expert_input.llm_class == 0:
+ self.init_stages = 16
+ self.init_layers = self.init_stages + 2
+ self.init_micro = 32
+ self.parts = self.init_micro // self.init_stages
+ self.distribution = pp_util.construct_distribution(self.init_micro, self.init_stages)
+ self.init_vpp = 1
+ self.dense_layers = 9
+ self.deepseek_build_offset()
+ self.layers_stage_vpp = pp_util.get_layers_distribution(self.init_offset, self.init_layers,
+ self.init_stages,
+ self.init_vpp)
+ self.recompute_config = self.construct_rec_config_for_deepseek()
+ elif expert_input.llm_class == 1:
+ self.init_stages = 13
+ self.init_layers = 72
+ self.init_micro = self.init_stages * 3
+ self.parts = self.init_micro // self.init_stages
+ self.distribution = pp_util.construct_distribution(self.init_micro, self.init_stages)
+ self.init_vpp = 1
+ self.boss_build_offset()
+ self.layers_stage_vpp = pp_util.get_layers_distribution(self.init_offset, self.init_layers,
+ self.init_stages,
+ self.init_vpp)
+ self.x_type1 = np.zeros((self.init_vpp, self.init_stages), int)
+ self.x_type2 = np.zeros((self.init_vpp, self.init_stages), int)
+ for stage in range(0, 7):
+ self.x_type1[0][stage] = self.layers_stage_vpp[stage][0]
+ for stage in range(7, 13):
+ self.x_type2[0][stage] = self.layers_stage_vpp[stage][0]
+ self.recompute_config = self.construct_rec_config_for_boss()
+ else:
+ raise ValueError(f'can not support the model class number of {expert_input.llm_class}!')
+
+ self.rank_size = data_parallel * model_parallel * self.init_stages
+
+ self.config_file = expert_input.config_file
+ self.ms_adapter_file = expert_input.ms_adapter_file
+ self.dryrun_output = init_dryrun_file
+
+ def deepseek_build_offset(self):
+ self.init_offset = np.zeros((self.init_vpp, self.init_stages), int).tolist()
+ self.init_offset[0][7] = 1
+ self.init_offset[0][14] = 1
+
+ def boss_build_offset(self):
+ self.init_offset = [[1, 2, -1, 1, 0, 2, -4, 2, -1, 1, 1, 2, 1]]
+
+ def construct_rec_config_for_boss(self):
+ is_select_recompute = True
+ is_re_comp = True
+ self.rs_type1 = np.zeros((self.init_vpp, self.init_stages), int)
+ self.rs_type2 = np.zeros((self.init_vpp, self.init_stages), int)
+ self.ra_type1 = np.zeros((self.init_vpp, self.init_stages), int)
+ self.ra_type2 = np.zeros((self.init_vpp, self.init_stages), int)
+
+ select_recompute_layers = [[0, 0, 0, 0, 1, 4, 0, 1, 1, 0, 2, 4, 0]]
+ recompute_layers = [[6, 4, 2, 3, 0, 0, 0, 4, 2, 3, 0, 0, 6]]
+ for stage in range(0, 7):
+ self.rs_type1[0][stage] = select_recompute_layers[0][stage]
+ self.ra_type1[0][stage] = recompute_layers[0][stage]
+ for stage in range(7, 13):
+ self.rs_type2[0][stage] = select_recompute_layers[0][stage]
+ self.ra_type2[0][stage] = recompute_layers[0][stage]
+
+ recompute_config = pp_util.build_recompute_config(is_select_recompute, is_re_comp, select_recompute_layers,
+ recompute_layers)
+ return recompute_config
+
+ def construct_rec_config_for_deepseek(self):
+ is_select_recompute = True
+ is_re_comp = True
+ select_recompute_layers = np.zeros((self.init_vpp, self.init_stages), int).tolist()
+ select_recompute_layers[0][3] = self.layers_stage_vpp[3][0]
+ select_recompute_layers[0][4] = self.layers_stage_vpp[4][0]
+ select_recompute_layers[0][10] = self.layers_stage_vpp[10][0]
+ select_recompute_layers[0][11] = self.layers_stage_vpp[11][0]
+
+ recompute_layers = np.zeros((self.init_vpp, self.init_stages), int).tolist()
+ recompute_layers[0][0] = self.layers_stage_vpp[0][0]
+ recompute_layers[0][1] = self.layers_stage_vpp[1][0]
+ recompute_layers[0][2] = self.layers_stage_vpp[2][0]
+ recompute_layers[0][7] = self.layers_stage_vpp[7][0]
+ recompute_layers[0][8] = self.layers_stage_vpp[8][0]
+ recompute_layers[0][9] = self.layers_stage_vpp[9][0]
+ recompute_layers[0][14] = self.layers_stage_vpp[14][0]
+
+ recompute_layers[0][15] = self.layers_stage_vpp[15][0]
+
+ recompute_config = pp_util.build_recompute_config(is_select_recompute, is_re_comp, select_recompute_layers,
+ recompute_layers)
+ return recompute_config
+
+ def init_dryrun(self):
+ dry_run = dryrun.DryRun(self.config_file, self.ms_adapter_file, self.dryrun_output)
+ dry_run.start_dryrun(self.recompute_config, self.init_offset, self.init_layers, self.init_vpp, self.init_stages,
+ self.rank_size, self.dense_layers, self.init_micro)
+ peak_mem = [dry_run.extract_memory_info(self.init_stages), dry_run.extract_memory_info_act(self.init_stages)]
+ return peak_mem
+
+
+class InitConfig:
+ pipeline_stage = int
+ micro_batch_num = int
+ model_parallel = int
+ data_parallel = int
+ expert_parallel = int
+ rank_size = int
+ num_layers_type1 = int
+ num_layers_type2 = int
+ pp_interleave_num = int(1)
+ seq_length = int
+ hidden_size = int
+ intermediate_size = int
+ vocab_size = int
+ mem_lim = float
+ seq_splits = int(1)
+ parts = int
+ mps_sol_filename = ''
+
+ def __init__(self, expert_input: ExpertInput):
+ self.expert_input = expert_input
+ self.config_file = expert_input.config_file
+ self.ms_adapter_file = expert_input.ms_adapter_file
+ if dryrun.DryRun.config_file_type == 0:
+ self.get_yaml_config()
+ elif dryrun.DryRun.config_file_type == 1:
+ self.get_shell_config()
+ elif dryrun.DryRun.config_file_type == 2:
+ self.get_toml_config()
+ else:
+ raise TypeError(dryrun.dryrun_config_error)
+ self.rank_size = self.pipeline_stage * self.model_parallel * self.data_parallel
+ self.parts = self.micro_batch_num // self.pipeline_stage
+ if self.parts == 0:
+ raise ValueError(f'stage = {self.pipeline_stage} is greater than micro batch = {self.micro_batch_num}!, '
+ f'please check the config file!')
+ self.mps_sol_filename = (f'layers{self.num_layers_type1}_{self.num_layers_type2}_micro{self.micro_batch_num}_'
+ f'dp{self.data_parallel}'
+ f'_tp{self.model_parallel}_pp{self.pipeline_stage}_vp{self.pp_interleave_num}'
+ f'_ep{self.expert_parallel}')
+
+ if self.pp_interleave_num <= 1:
+ self.expert_input.low_mem = True
+
+ self.memory = Memory(self.mem_lim)
+ self.set_memory(expert_input.is_dryrun)
+
+ def set_memory(self, is_dryrun):
+ memory_dir = os.path.join(self.expert_input.output_file_dir, 'memory_info')
+ filename = (f'layers{self.num_layers_type1}_{self.num_layers_type2}_micro{self.micro_batch_num}_dp{self.data_parallel}'
+ f'_tp{self.model_parallel}_pp{self.pipeline_stage}_vp1'
+ f'_ep{self.expert_parallel}.txt')
+ memory_file_name = os.path.join(memory_dir, filename)
+ if is_dryrun:
+ self.mem_calculator_by_dryrun()
+ if not os.path.exists(memory_dir):
+ os.mkdir(memory_dir)
+ self.memory.write_memory_to_file(memory_file_name)
+ elif os.path.exists(memory_file_name):
+ logger.info(f'mem_lim = {self.mem_lim}')
+ with open(memory_file_name, 'r', encoding='utf-8') as file:
+ lines = file.readlines()
+ for line in lines:
+ key, value = line.split('=')
+ if hasattr(self.memory, key):
+ setattr(self.memory, key, float(value))
+ else:
+ logger.error(f'NameError: {key}')
+ self.memory.mem_lim_stage0 = self.memory.mem_lim - self.memory.static_mem0
+ self.memory.mem_lim_others = self.memory.mem_lim - self.memory.static_mem
+ self.memory.mem_lim_last = self.memory.mem_lim - self.memory.lm_head_mem
+ logger.info(f'Using the {filename} memory for vpp{self.pp_interleave_num}')
+ else:
+ logger.info(f'There is no memory file: {memory_file_name}! Using the default memory!')
+ self.memory.print_mem()
+
+ def get_yaml_config(self):
+ with open(self.config_file, 'r', encoding='utf-8') as file:
+ data = yaml.safe_load(file)
+ self.pipeline_stage = data['parallel_config']['pipeline_stage']
+ self.micro_batch_num = data['parallel_config']['micro_batch_num']
+ self.model_parallel = data['parallel_config']['model_parallel']
+ self.data_parallel = data['parallel_config']['data_parallel']
+ if 'expert_parallel' in data['parallel_config']:
+ self.expert_parallel = data['parallel_config']['expert_parallel']
+ else:
+ self.expert_parallel = 1
+ self.num_layers_type2 = data['model']['model_config']['num_layers']
+ if 'mtp_depth' in data['model']['model_config']:
+ self.num_layers_type2 += data['model']['model_config']['mtp_depth']
+ if 'pp_interleave_num' in data['model']['model_config']:
+ self.pp_interleave_num = data['model']['model_config']['pp_interleave_num']
+ else:
+ self.pp_interleave_num = 1
+ self.num_layers_type1 = 0
+ if 'moe_config' in data:
+ if 'first_k_dense_replace' in data['moe_config']:
+ self.num_layers_type1 = data['moe_config']['first_k_dense_replace']
+ if self.expert_input.llm_class == 1:
+ self.num_layers_type1 = 36
+ self.num_layers_type2 -= self.num_layers_type1
+ if not self.expert_input.is_head_loss_input:
+ self.seq_length = data['model']['model_config']['seq_length']
+ self.hidden_size = data['model']['model_config']['hidden_size']
+ self.intermediate_size = data['model']['model_config']['intermediate_size']
+ self.vocab_size = data['model']['model_config']['vocab_size']
+ self.mem_lim = int(re.search(r'(\d+)GB', data['context']['max_device_memory']).group(1)) * 1024.0
+
+
+ def get_shell_config(self):
+ configs, unparses = pp_util.parse_shell(self.config_file)
+ self.pipeline_stage = configs.get('PP')
+ self.micro_batch_num = configs.get('GBS') // configs.get('MBS') // configs.get('DP')
+ self.model_parallel = configs.get('TP')
+ self.data_parallel = configs.get('DP')
+ self.expert_parallel = configs.get('EP', 1)
+ self.pp_interleave_num = configs.get('VPP', 1)
+ self.num_layers_type1 = configs.get('FIRST_K_DENSE_REPLACE', 0)
+ if self.expert_input.llm_class == 1:
+ self.num_layers_type1 = 36
+ self.num_layers_type2 = configs.get('NUM_LAYERS') - self.num_layers_type1
+ if not self.expert_input.is_head_loss_input:
+ self.seq_length = configs.get('SEQ_LEN')
+ self.hidden_size = configs.get('HIDDEN_SIZE')
+ self.intermediate_size = configs.get('FFN_HIDDEN_SIZE')
+ if 'VOCAB_SIZE' in configs:
+ self.vocab_size = configs.get('VOCAB_SIZE')
+ self.mem_lim = configs.get('MAX_DEVICE_MEMORY', 58) * 1024.0
+
+ def get_toml_config(self):
+ model_args = self.expert_input.model_args
+ self.pipeline_stage = model_args.pp
+ self.micro_batch_num = model_args.mbn
+ self.model_parallel = model_args.tp
+
+ self.data_parallel = model_args.dp
+ self.expert_parallel = model_args.ep
+ self.pp_interleave_num = 1
+
+ self.num_layers_type1 = model_args.first_k_dense_replace
+ self.num_layers_type2 = model_args.num_layers - self.num_layers_type1
+
+ self.seq_length = model_args.seq_len
+ self.hidden_size = model_args.hidden_size
+ self.intermediate_size = model_args.ffn_hidden_size
+
+ self.vocab_size = model_args.padded_vocab_size
+ self.mem_lim = 58 * 1024.0
+
+
+ def mem_calculator_by_dryrun(self):
+ self.set_cur_memory_info()
+ self.memory.mem_lim_stage0 = self.mem_lim - self.memory.static_mem0
+ self.memory.mem_lim_others = self.mem_lim - self.memory.static_mem
+ self.memory.mem_lim_last = self.mem_lim - self.memory.lm_head_mem
+
+ def set_cur_memory_info(self):
+ init_dryrun = InitDryrun(self.data_parallel, self.model_parallel, self.expert_input)
+ is_input_mem = False
+ if not is_input_mem:
+ peak_mem = init_dryrun.init_dryrun()
+ logger.info(f'The first initial dryrun: peak_mem={peak_mem}')
+ else:
+ logger.info('Using input peak memory for act_layer memory!')
+ peak_mem = pp_util.get_init_input_peak_mem()
+ if any(mem == 0 for mem in peak_mem[1]):
+ raise ValueError(pp_util.dryrun_error)
+ if self.expert_input.llm_class == 0:
+ self.update_deepseek_memory(peak_mem, init_dryrun)
+ elif self.expert_input.llm_class == 1:
+ self.update_boss_memory(peak_mem, init_dryrun)
+ else:
+ raise ValueError(f'can not support the model class number of {self.expert_input.llm_class}!')
+
+ def update_deepseek_memory(self, peak_mem, init_dryrun: InitDryrun):
+ peaks = pp_util.get_peak_batch(init_dryrun.init_micro, init_dryrun.init_stages, init_dryrun.init_vpp,
+ True, init_dryrun.init_offset, init_dryrun.init_layers)
+ logger.info(f'peaks={peaks}')
+ self.memory.select_mem12 = (peak_mem[0][3] - peak_mem[0][4]) / (peaks[3] - peaks[4])
+ if self.memory.select_mem12 < 0:
+ self.memory.select_mem12 = 0
+ self.memory.select_mem0 = self.memory.select_mem12
+ self.memory.select_mem = (peak_mem[0][10] - peak_mem[0][11]) / (peaks[10] - peaks[11])
+ if self.memory.select_mem < 0:
+ self.memory.select_mem = 0
+ self.memory.re_comp_mem12 = (peak_mem[0][1] - peak_mem[0][2]) / (peaks[1] - peaks[2])
+ if self.memory.re_comp_mem12 < 0:
+ self.memory.re_comp_mem12 = 0
+ self.memory.re_comp_mem0 = self.memory.re_comp_mem12
+ self.memory.re_comp_mem = (peak_mem[0][8] - peak_mem[0][9]) / (peaks[8] - peaks[9])
+ if self.memory.re_comp_mem < 0:
+ self.memory.re_comp_mem = 0
+ self.memory.act_mem12 = (peak_mem[0][5] - peak_mem[0][6]) / (peaks[5] - peaks[6])
+ if self.memory.act_mem12 < 0:
+ self.memory.act_mem12 = 0
+ self.memory.act_mem0 = self.memory.act_mem12
+ self.memory.act_mem = (peak_mem[0][12] - peak_mem[0][13]) / (peaks[12] - peaks[13])
+ # 更正re_comp_mem
+ if dryrun.DryRun.config_file_type == 1:
+ self.memory.re_comp_mem = self.memory.act_mem
+
+ self.memory.layer_mem012 = (((peak_mem[0][7] - peak_mem[0][1]) - self.memory.re_comp_mem12 * (peaks[7] - peaks[1]))
+ / (init_dryrun.layers_stage_vpp[7][0] - init_dryrun.layers_stage_vpp[1][0]))
+ self.memory.layer_mem = (((peak_mem[0][14] - peak_mem[0][9]) - self.memory.re_comp_mem * (peaks[14] - peaks[9]))
+ / (init_dryrun.layers_stage_vpp[14][0] - init_dryrun.layers_stage_vpp[9][0]))
+
+ layers_stage = [sum(init_dryrun.layers_stage_vpp[stage][vpp] for vpp in range(init_dryrun.init_vpp))
+ for stage in range(init_dryrun.init_stages)]
+ if init_dryrun.layers_stage_vpp[0][0] >= 3:
+ self.memory.static_mem0 = (peak_mem[1][0] - self.memory.layer_mem012 * 3 - self.memory.layer_mem *
+ (layers_stage[0] - 3) - self.memory.re_comp_mem12 * peaks[0] * 3 /
+ layers_stage[0] - self.memory.re_comp_mem * peaks[0] * (layers_stage[0] - 3) /
+ layers_stage[0])
+ else:
+ self.memory.static_mem0 = (peak_mem[1][0] - self.memory.layer_mem012 * init_dryrun.layers_stage_vpp[0][0] -
+ self.memory.layer_mem * (layers_stage[0] - init_dryrun.layers_stage_vpp[0][0]) -
+ self.memory.re_comp_mem12 * peaks[0] * init_dryrun.layers_stage_vpp[0][0] /
+ layers_stage[0] - self.memory.re_comp_mem * peaks[0] *
+ (layers_stage[0] - init_dryrun.layers_stage_vpp[0][0]) / layers_stage[0])
+ self.memory.static_mem = (peak_mem[1][-2] - self.memory.layer_mem * layers_stage[-2]
+ - self.memory.re_comp_mem * peaks[-2])
+ self.memory.lm_head_mem = (peak_mem[1][-1] - self.memory.layer_mem * layers_stage[-1]
+ - self.memory.re_comp_mem * peaks[-1])
+
+ def update_boss_memory(self, peak_mem, init_dryrun: InitDryrun):
+ sort_micro = micro.SortMicro(init_dryrun.parts, init_dryrun.init_vpp, init_dryrun.init_stages,
+ init_dryrun.distribution, True, 1)
+ peaks = micro.PeakNum(sort_micro)
+ init_mem = Memory(self.mem_lim)
+ peaks.set_peak_act_recompute_num(init_dryrun.x_type2, init_dryrun.rs_type2, init_dryrun.ra_type2,
+ init_dryrun.x_type1, init_dryrun.rs_type1, init_dryrun.ra_type1, init_mem)
+ stage_range = [1, 6]
+ coe_a1 = pp_util.build_coe_array(peaks.peak_num_act_type1, peaks.peak_num_select_recom_type1,
+ peaks.peak_num_recompute_type1, init_dryrun.x_type1, stage_range)
+ mem_b1 = [peak_mem[0][stage] for stage in range(stage_range[0], stage_range[-1])]
+ mem_result, res, rank, s = np.linalg.lstsq(coe_a1, mem_b1, rcond=None)
+ mem_result = np.round(mem_result, decimals=1)
+ logger.info('the type of layer1:')
+ logger.info(f'the residual = {res}')
+ logger.info(f'the normal of residual = {np.linalg.norm(res)}')
+ logger.info(f'the rank = {rank}')
+ static1 = mem_result[0]
+ self.memory.act_mem12 = mem_result[1]
+ self.memory.act_mem0 = mem_result[1]
+ self.memory.layer_mem012 = mem_result[2]
+ self.memory.re_comp_mem12 = mem_result[3]
+ self.memory.re_comp_mem0 = mem_result[3]
+ self.memory.select_mem12 = mem_result[4]
+ self.memory.select_mem0 = mem_result[4]
+
+ stage_range = [7, 12]
+ coe_a2 = pp_util.build_coe_array(peaks.peak_num_act_type2, peaks.peak_num_select_recom_type2,
+ peaks.peak_num_recompute_type2, init_dryrun.x_type2, stage_range)
+ mem_b2 = [peak_mem[0][stage] for stage in range(stage_range[0], stage_range[-1])]
+ mem_result, res, rank, s = np.linalg.lstsq(coe_a2, mem_b2, rcond=None)
+ mem_result = np.round(mem_result, decimals=1)
+ logger.info('the type of layer2:')
+ logger.info(f'the residual = {res}')
+ logger.info(f'the normal of residual = {np.linalg.norm(res)}')
+ logger.info(f'the rank = {rank}')
+
+ static2 = mem_result[0]
+ self.memory.act_mem = mem_result[1]
+ self.memory.layer_mem = mem_result[2]
+ self.memory.re_comp_mem = mem_result[3]
+ self.memory.select_mem = mem_result[4]
+
+
+ static_mem0 = (peak_mem[1][0] - self.memory.layer_mem012 * init_dryrun.x_type1[0][0] -
+ self.memory.layer_mem * init_dryrun.x_type2[0][0] -
+ self.memory.act_mem12 * peaks.peak_num_act_type1[0] -
+ self.memory.act_mem * peaks.peak_num_act_type2[0] -
+ self.memory.re_comp_mem12 * peaks.peak_num_recompute_type1[0] -
+ self.memory.re_comp_mem * peaks.peak_num_recompute_type2[0] -
+ self.memory.select_mem12 * peaks.peak_num_select_recom_type1[0] -
+ self.memory.select_mem * peaks.peak_num_select_recom_type2[0])
+ logger.info(f'static1 = {static1}')
+ logger.info(f'static2 = {static2}')
+ static_mem = ((static2 + static1) / 2)
+ lm_head_mem = (peak_mem[1][-1] - self.memory.layer_mem012 * init_dryrun.x_type1[0][-1] -
+ self.memory.layer_mem * init_dryrun.x_type2[0][-1] -
+ self.memory.act_mem12 * peaks.peak_num_act_type1[init_dryrun.init_stages - 1] -
+ self.memory.act_mem * peaks.peak_num_act_type2[init_dryrun.init_stages - 1] -
+ self.memory.re_comp_mem12 * peaks.peak_num_recompute_type1[init_dryrun.init_stages - 1] -
+ self.memory.re_comp_mem * peaks.peak_num_recompute_type2[init_dryrun.init_stages - 1] -
+ self.memory.select_mem12 * peaks.peak_num_select_recom_type1[init_dryrun.init_stages - 1]
+ - self.memory.select_mem *
+ peaks.peak_num_select_recom_type2[init_dryrun.init_stages - 1])
+ self.memory.static_mem0 = np.round(static_mem0, decimals=1)
+ self.memory.static_mem = np.round(static_mem, decimals=1)
+ self.memory.lm_head_mem = np.round(lm_head_mem, decimals=1)
+
+
+if __name__ == '__main__':
+ expert_input = ExpertInput(yaml_file='C:\\working\\768_4k.yaml',
+ mind_former_file='~/mindformer.py')
+ expert_input.is_dryrun = False
+ model_input = InitConfig(expert_input)
+ print(model_input)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_tool.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..602748901e4958576fbebd3dfbab884f003144b4
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_tool.py
@@ -0,0 +1,65 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""entry file for pipeline"""
+
+import argparse
+
+from fast_tuner.utils.common import check_dryrun_parallel_number
+from fast_tuner.utils.ppc_input import ParallelInput
+from fast_tuner.utils.logger import logger
+from fast_tuner.pipeline_conductor.pipeline_parallel import pipeline_proc
+from fast_tuner.pipeline_conductor import pp_util
+
+
+if __name__ == '__main__':
+ logger.info('start to run pipeline tool')
+ # 用户输入profiling结果,候选配置等信息,流水线工具给出配置cost排序
+ parser = argparse.ArgumentParser(description='Run taylor pipeline_search_tool with user input parameters')
+ parser.add_argument('--files_dir', type=str, default='./output/dryrun_yaml/',
+ help='Path to the YAML or SHELL file directory')
+ parser.add_argument('--yaml_path', type=str, default=None,
+ help='Path of training config (.yaml)')
+ parser.add_argument('--shell_path', type=str, default=None,
+ help="Path of training config (.sh)")
+ parser.add_argument('--mindformers_dir', type=str, default=None,
+ help='Directory of mindformers')
+ parser.add_argument('--mindspeed_path', type=str, default=None,
+ help="Absolute path of posttrain_gpt (.py)")
+ parser.add_argument('--profile_data_dir', type=str, default='./profile_data/',
+ help='Directory of profile data')
+ parser.add_argument('--solver_name', type=str, default='HIGHS',
+ help='Name of the solver')
+ parser.add_argument('--parser_result', type=str, default=None,
+ help='Profiling parser result file')
+ parser.add_argument('--nd_path', type=str, default='./config/nd_result.csv',
+ help='Path to nd result file')
+ parser.add_argument('--env_json', type=str, required=True,
+ default='./config/boss_env_config.json', help="Path of environment config (.json)")
+ parser.add_argument('--register_path', type=str,default='research/jiutian',
+ help="Path of register")
+ parser.add_argument('--parallel_num', type=pp_util.str2int, default=16,
+ help="The number of dryrun at once")
+ parser.add_argument('--dryrun', type=pp_util.str2bool, default=True,
+ help="Is auto dryrun")
+ parser.add_argument('--check', type=pp_util.str2bool, default=True,
+ help="Is double check")
+ parser.add_argument('--output_path', type=str,
+ default='./output/',
+ help='Directory of output info')
+
+ args = parser.parse_args()
+ check_dryrun_parallel_number(args.parallel_num)
+ pipeline_input = ParallelInput(args)
+ pipeline_proc(pipeline_input)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/__init__.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/common.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9970957f0baba03147ad68be0e04ef459dbf7bc
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/common.py
@@ -0,0 +1,402 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""common use functions"""
+import os
+import json
+
+from pathlib import Path
+import shutil
+import yaml
+import toml
+import numpy as np
+
+from fast_tuner.pipeline_conductor.pp_util import configs_to_shell, parse_shell
+from fast_tuner.utils.logger import logger
+
+
+def cal_model_layers_num(input_args):
+ if input_args.model.model_config.mtp_depth is None:
+ input_args.model.model_config.mtp_depth = 0
+ return input_args.model.model_config.num_layers + input_args.model.model_config.mtp_depth
+
+GENERAL_TOML = 'DP{}_TP{}_PP{}_EP{}_CP{}_FSDP{}_pretrain.toml'
+MOE_PATTERN_YAML = 'DP{}_TP{}_PP{}_EP{}_pretrain.yaml'
+LLAMA_PATTERN_YAML = 'DP{}_TP{}_PP{}_pretrain.yaml'
+MOE_PATTERN_SHELL = 'DP{}_TP{}_PP{}_EP{}_pretrain.sh'
+LLAMA_PATTERN_SHELL = 'DP{}_TP{}_PP{}_pretrain.sh'
+MODULE_PATTERN_REG_YAML = r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)_pretrain.yaml'
+MOE_PATTERN_REG_SHELL = r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)_pretrain.sh'
+LLAMA_PATTERN_REG_SHELL = r'DP(\d+)_TP(\d+)_PP(\d+)_pretrain.sh'
+
+def initial_offset(pipeline_stage, num_layers):
+ '''
+ set offset in a memory friendly way such that
+ we can estimate the minimal memory requirement
+ '''
+
+ if pipeline_stage == 1 or num_layers % pipeline_stage == 0:
+ offset = 0
+ return offset
+
+ pp_interleave_num = 1
+ offset = np.zeros((pp_interleave_num, pipeline_stage), dtype=int).tolist()
+ remainder = num_layers % (pp_interleave_num * pipeline_stage)
+ for vpp in range(pp_interleave_num):
+ offset[0][0] += 1
+ remainder-=1
+ for stage in range(pipeline_stage):
+ if remainder == 0:
+ break
+ offset[vpp][pipeline_stage - stage - 2] += 1
+ remainder -= 1
+ return offset[0]
+
+def offset_for_dualpipe(pipeline_stage, num_layers):
+ '''
+ set offset in a memory friendly way such that
+ we can estimate the minimal memory requirement
+ '''
+ pp_interleave_num = 2
+ offset = np.zeros((pp_interleave_num, pipeline_stage), dtype=int).tolist()
+ new_layers = num_layers - 2
+ remainder = new_layers % pipeline_stage
+ base = new_layers // pipeline_stage
+ origin_base = new_layers // pipeline_stage
+
+ for stage in range(pipeline_stage):
+ # 获取当前stage的层数
+ if remainder == 0:
+ cur_layer = base
+ else:
+ cur_layer = base + 1
+ remainder -= 1
+ # 给vpp分配层
+ vpp1_layer = cur_layer // pp_interleave_num
+ if pipeline_stage - stage - 1 == 0:
+ offset[0][pipeline_stage - stage - 1] = vpp1_layer + 2 - origin_base
+ else:
+ offset[0][pipeline_stage - stage - 1] = vpp1_layer - origin_base
+ vpp2_layer = cur_layer - vpp1_layer
+ offset[1][stage] = vpp2_layer - origin_base
+ return offset
+
+def cal_world_size(input_args):
+ '''compute the worldsize from config'''
+ world_size = input_args.dp * input_args.tp * input_args.pp * input_args.cp
+ return world_size
+
+def generate_dryrun_yaml(destination_file, config, para):
+ """
+ :param para: 用户输入参数
+ :param destination_file: 修改并行参数的配置yaml文件
+ :param config: [dp, tp, pp, ep]
+ :return:
+ """
+ # 复制YAML文件
+ shutil.copy2(para.YAML_PATH, destination_file)
+
+ # 读取复制后的YAML文件, 修改config_para字段的值
+ with open(destination_file, 'r', encoding='utf-8') as file:
+ yaml_data = yaml.safe_load(file)
+ yaml_data['parallel_config']['data_parallel'] = config[0]
+ yaml_data['parallel_config']['model_parallel'] = config[1]
+ yaml_data['parallel_config']['pipeline_stage'] = config[2]
+ yaml_data['parallel_config']['expert_parallel'] = config[3]
+ yaml_data['parallel_config']['micro_batch_num'] = para.GBS // config[0]
+ yaml_data['model']['model_config']['offset'] = config[4]
+ if yaml_data['parallel'].get('dataset_strategy') is not None:
+ strategy_size = len(yaml_data['parallel']['dataset_strategy'])
+ yaml_data['parallel']['dataset_strategy'] = [[config[0], 1] for _ in range(strategy_size)]
+ # 重计算设置为true
+ yaml_data['recompute_config']['recompute'] = True
+
+ # todo: 适配tnd
+ yaml_data['train_dataset']['data_loader']['dataset_dir'] = para.DATASET
+
+
+ # 将修改后的数据写回YAML文件
+ with open(destination_file, 'w', encoding='utf-8') as file:
+ yaml.dump(yaml_data, file, default_flow_style=False, allow_unicode=True)
+
+ logger.info(f"The dryrun YAML file copied and modified, new file is: {destination_file}")
+
+def generate_dryrun_shell(destination_file, config, para):
+ """
+
+ :param para: 用户输入参数
+ :param destination_file: 修改并行参数的配置shell文件
+ :param config: [dp, tp, pp, ep, offset] or [dp, tp, pp]
+ :return:
+ """
+ configs, unparses = parse_shell(para.SHELL_PATH)
+ configs['DP'] = config[0]
+ configs['TP'] = config[1]
+ configs['PP'] = config[2]
+ configs_to_shell(destination_file, configs, unparses)
+ logger.info(f'The dryrun SHELL file copied and modified, new file is {destination_file}')
+
+def generate_profile_yaml(destination_file, config, para):
+ """
+ :param para: 用户输入参数
+
+ :param destination_file:
+ :param config: [dp, tp, pp, ep, offset, num_layers]
+ :return:
+ """
+ # 复制YAML文件
+ shutil.copy2(para.YAML_PATH, destination_file)
+
+ # 读取复制后的YAML文件, 修改config_para字段的值
+ with open(destination_file, 'r', encoding='utf-8') as file:
+ yaml_data = yaml.safe_load(file)
+ yaml_data['parallel_config']['data_parallel'] = config[0]
+ yaml_data['parallel_config']['model_parallel'] = config[1]
+ yaml_data['parallel_config']['pipeline_stage'] = config[2]
+ yaml_data['parallel_config']['expert_parallel'] = config[3]
+ yaml_data['parallel_config']['micro_batch_num'] = para.GBS // config[0]
+ yaml_data['model']['model_config']['offset'] = config[4]
+ yaml_data['recompute_config']['recompute'] = config[5]
+ yaml_data['model']['model_config']['num_layers'] = config[6]
+
+ yaml_data['profile'] = True
+ yaml_data['profile_output'] = os.path.join(Path(destination_file).parent, "output")
+ yaml_data['init_start_profile'] = True
+ yaml_data['profile_communication'] = True
+ yaml_data['profile_memory'] = True
+ yaml_data['op_time'] = True
+ yaml_data['profile_level'] = 1
+ yaml_data['profile_start_step'] = 4
+ yaml_data['profile_stop_step'] = 6
+
+ # 将修改后的数据写回YAML文件
+ with open(destination_file, 'w', encoding='utf-8') as file:
+ yaml.dump(yaml_data, file, default_flow_style=False, allow_unicode=True)
+
+ logger.info(f"The profile YAML file copied and modified, config_para is: {config}")
+
+def generate_profile_shell(destination_file, config, para):
+ """
+ :param para: 用户输入参数
+
+ :param destination_file:
+ :param config: [dp, tp, pp, ep, offset, num_layers]
+ :return:
+ """
+ configs, unparses = parse_shell(para.SHELL_PATH)
+ configs['DP'] = config[0]
+ configs['TP'] = config[1]
+ configs['PP'] = config[2]
+ if '_EP' in destination_file:
+ configs['EP'] = config[3]
+ configs['FIRST_K_DENSE_RAPLACE'] = 3
+ else:
+ configs['EP'] = 0 # 或者这里不需要写入EP
+ # TODO 应该需要读NPUS_PER_NODE,当前硬编码为8
+ profile_need_rank_num = configs['DP'] * configs['TP'] * configs['PP']
+ if profile_need_rank_num < para.RANK_NUM:
+ configs['NNODES'] = profile_need_rank_num // 8
+ configs['NPUS_PER_NODE'] = 8
+ if configs['NNODES'] == 0:
+ configs['NNODES'] = 1
+ configs['NPUS_PER_NODE'] = profile_need_rank_num
+ else:
+ configs['NNODES'] = para.RANK_NUM // 8
+ configs['NPUS_PER_NODE'] = 8
+ configs['NUM_LAYERS'] = config[-1]
+ configs['MINDSPEED_PATH'] = para.MINDSPEED_PATH
+ # 生成要分析的进程编号列表
+ step = profile_need_rank_num // config[2]
+ profile_ranks = ' '.join(map(str, range(0, profile_need_rank_num, step)))
+ if '_EP' in destination_file:
+ output_dir = os.path.join(para.OUTPUT_PATH, "profile_result",
+ f"DP{config[0]}_TP{config[1]}_PP{config[2]}_EP{config[3]}_profile")
+ else:
+ output_dir = os.path.join(para.OUTPUT_PATH, "profile_result",
+ f"DP{config[0]}_TP{config[1]}_PP{config[2]}_profile")
+ # 使用f-string构建参数字符串,提高可读性
+ profile_args = (
+ "--profile "
+ "--profile-step-start 5 "
+ "--profile-step-end 7 "
+ "--profile-level level1 "
+ "--profile-with-stack "
+ "--profile-with-cpu "
+ f"--profile-ranks {profile_ranks} " # 确保参数之间有空格
+ f"--profile-save-path {output_dir}"
+ )
+ # todo: 需要适配原有的args內容
+ if configs['EP'] == 0:
+ mem_args = (
+ "--use-distributed-optimizer "
+ "--recompute-method block "
+ "--recompute-granularity full "
+ "--recompute-num-layers 1 "
+ "--num-layer-list 4,3 "
+ )
+ else:
+ mem_args = (
+ "--use-distributed-optimizer "
+ "--recompute-method block "
+ "--recompute-granularity full "
+ "--recompute-num-layers 1 "
+ "--num-layer-list 4,3 "
+ "--first-k-dense-replace 3 "
+ )
+
+ # 存入配置字典
+ configs['PROFILE_ARGS'] = profile_args
+ configs['MEM_ARGS'] = mem_args
+ configs_to_shell(destination_file, configs, unparses)
+ insert_profile_args_final(destination_file)
+ logger.info(f'The profile SHELL file copied and modified, new file is {destination_file}')
+ return output_dir
+
+def generate_profile_toml(destination_file, config, para):
+ '''generate toml file for profile'''
+ def read_toml(file_path):
+ """读取 TOML 文件并返回字典"""
+ with open(file_path, 'r', encoding='utf-8') as f:
+ return toml.load(f)
+
+ def write_toml(data, file_path):
+ """将字典写入 TOML 文件"""
+ with open(file_path, 'w', encoding='utf-8') as f:
+ toml.dump(data, f)
+
+ dp, tp, pp, ep, cp, op = config[:6]
+ template_data = read_toml(para.TOML_PATH)
+ # [profiling]
+ template_data['profiling']['enable_profiling'] = True
+ template_data['profiling']['profile_freq'] = 20
+ output_trace_dir = os.path.join(os.path.abspath(para.OUTPUT_PATH), "profile_result",
+ f"DP{dp}_TP{tp}_PP{pp}_EP{ep}_CP{cp}_OP{op}_trace")
+ template_data['profiling']['save_traces_folder'] = output_trace_dir
+ # [model]
+ template_data['model']['flavor'] = 'tune'
+ # [training]
+ template_data['training']['steps'] = 20
+ # [parallelism]
+ template_data['parallelism']['data_parallel_replicate_degree'] = dp // op
+ template_data['parallelism']['data_parallel_shard_degree'] = op
+ template_data['parallelism']['tensor_parallel_degree'] = tp
+ template_data['parallelism']['context_parallel_degree'] = cp
+ template_data['parallelism']['expert_parallel_degree'] = ep
+ template_data['parallelism']['pipeline_parallel_degree'] = pp
+ write_toml(template_data, destination_file)
+ return output_trace_dir
+
+
+
+def insert_profile_args_final(shell_file_path, profile_args="$PROFILE_ARGS \\"):
+ '''modify shell files for profile'''
+ with open(shell_file_path, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+
+ # 标记是否在 torchrun 命令块内
+ in_torchrun_block = False
+ modified_lines = []
+
+ for line in lines:
+ stripped = line.strip()
+
+ # 检测命令块开始
+ if "$MEM_ARGS" in stripped:
+ in_torchrun_block = True
+ modified_lines.append(line)
+ modified_lines.append(f" {profile_args}\n")
+ continue
+
+ # 检测命令块结束(最后一个参数行,以 \ 结尾)
+ if in_torchrun_block and stripped.endswith("\\"):
+ modified_lines.append(line)
+ continue
+
+ # 命令块结束后的行
+ if in_torchrun_block:
+ in_torchrun_block = False
+
+ # 普通行
+ modified_lines.append(line)
+
+ # 写入修改后的内容
+ with open(shell_file_path, 'w', encoding='utf-8') as f:
+ f.writelines(modified_lines)
+
+ logger.info(f"insert PROFILE_ARGS to {shell_file_path}")
+
+
+# 定义一个映射,将yaml_task和对应的生成函数关联起来
+TASK_FUNCTION_MAP = {
+ "dryrun_yaml": generate_dryrun_yaml,
+ "dryrun_shell": generate_dryrun_shell,
+ "profile_yaml": generate_profile_yaml,
+ "profile_shell": generate_profile_shell,
+ "profile_toml": generate_profile_toml,
+}
+
+def generate_files(candidate_configs, des_file_directory, file_task, para, input_args):
+ '''generate config files w.r.t your training library'''
+ # 如果目录不存在,则创建它
+ if not os.path.exists(des_file_directory):
+ os.makedirs(des_file_directory)
+
+ # 检查file_task是否有效
+ if file_task not in TASK_FUNCTION_MAP:
+ logger.error(f"Invalid file_task value: {file_task}. Please use one of {list(TASK_FUNCTION_MAP.keys())}.")
+ return None
+
+ generate_function = TASK_FUNCTION_MAP[file_task]
+ if para.SHELL_PATH:
+ pattern = LLAMA_PATTERN_SHELL if input_args.expert_num is None else MOE_PATTERN_SHELL
+ elif para.YAML_PATH:
+ pattern = LLAMA_PATTERN_YAML if input_args.expert_num is None else MOE_PATTERN_YAML
+ else:
+ pattern = GENERAL_TOML
+
+ file_path_list = []
+ output_dir_list = []
+ for config in candidate_configs:
+ # 生成输出文件路径,包括文件名
+ destination_file = (des_file_directory + pattern).format(*config)
+ file_path_list.append(destination_file)
+ # 只有file_task=profile_shell时才会返回output_dir
+ output_dir = generate_function(destination_file, config, para)
+ output_dir_list.append(output_dir)
+ return file_path_list, output_dir_list
+
+
+def is_dualpipe_open(input_args):
+ if input_args.mf_args is None:
+ return False
+ parallel_cfg = input_args.mf_args.parallel
+ use_zero_bubble_v = (
+ hasattr(parallel_cfg, 'pipeline_config') and
+ hasattr(parallel_cfg.pipeline_config, 'pipeline_scheduler') and
+ parallel_cfg.pipeline_config.pipeline_scheduler == 'zero_bubble_v'
+ )
+ return use_zero_bubble_v
+
+def check_dryrun_parallel_number(parallel_num):
+ if parallel_num > 16:
+ raise Exception(f"The parallel number {parallel_num} is too large.")
+
+def parse_args_from_json(args):
+ with open(args.config, 'r', encoding='utf-8') as f:
+ data = json.load(f)
+ for key, value in data.items():
+ if key not in vars(args):
+ logger.error(f"The json contains invalid parameters: {key}")
+ else:
+ setattr(args, key, value)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/dryrun_manage.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/dryrun_manage.py
new file mode 100644
index 0000000000000000000000000000000000000000..b45a4528ef3a767388fab69103edf1dd00f32ebc
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/dryrun_manage.py
@@ -0,0 +1,135 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""dryrun operation"""
+import os
+import subprocess
+import re
+import json
+from multiprocessing import Pool
+from fast_tuner.utils.common import cal_world_size, is_dualpipe_open, \
+ LLAMA_PATTERN_REG_SHELL, MOE_PATTERN_REG_SHELL, MODULE_PATTERN_REG_YAML
+from fast_tuner.utils.logger import logger
+
+
+def env_update(rank_size, env_variable_json):
+ # 设置环境变量
+ with open(env_variable_json, 'r', encoding='utf-8') as f:
+ env_vars = json.load(f)
+ env_vars['RANK_SIZE'] = str(rank_size)
+ os.environ.update(env_vars)
+
+def create_target_dir(file, target_directory):
+ output_dir = os.path.splitext(file)[0]
+ target_dir = os.path.join(target_directory, f"dryrun_{output_dir}")
+ return target_dir
+
+def execute_command(para, config_file_path, target_dir, rank_id, tp):
+ '''execute dryrun command'''
+ if para.YAML_PATH:
+ command = ['python', os.path.join(para.MINDFORMERS_DIR, 'run_mindformer.py'), '--config', config_file_path,
+ '--register_path', para.REGISTER_PATH]
+ else:
+ command = ['bash', config_file_path]
+ try:
+ # 构建日志文件路径
+ os.environ['RANK_ID'] = str(rank_id)
+ os.environ['MODEL_PARALLEL'] = str(tp)
+ os.makedirs(target_dir, exist_ok=True)
+ log_file_path = os.path.join(target_dir, 'dryrun_new.log')
+ with open(log_file_path, 'w', encoding='utf-8') as log_file:
+ logger.info(f"Command: {' '.join(command)} rank {rank_id} start run")
+ subprocess.run(command, stdout=log_file, stderr=subprocess.STDOUT, check=False)
+ except Exception as e:
+ logger.error(f"The command execution failed.: {e}")
+
+def calculate_rank_id(input_args, match, layer_num, rank_size):
+ if is_dualpipe_open(input_args):
+ return rank_size - 1
+ pp = int(match.group(3))
+ x = pp - layer_num % pp
+ rank_id = x * rank_size // pp
+ if rank_id == rank_size:
+ rank_id -= 1
+ return rank_id
+
+def read_dryrun_info(root_dir):
+ '''read the peak mem from dryrun log'''
+ logger.info(f"Reading dryrun info from {root_dir}")
+ result_list = []
+ # 预编译正则表达式,提高匹配性能
+ pattern = re.compile(r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)')
+
+ for entry in os.scandir(root_dir):
+ if not entry.is_dir():
+ continue
+
+ log_file_path = os.path.join(entry.path, 'dryrun_new.log')
+ if not os.path.isfile(log_file_path):
+ continue
+
+ with open(log_file_path, 'r', encoding='utf-8') as log_file:
+ for line in log_file:
+ if "Used peak memory usage (without fragments):" in line:
+ peak_value = line.split(': ')[-1].strip()
+ peak_value = int(peak_value.rstrip('M'))
+ break
+ else:
+ # 如果没有找到 peak 值,跳过当前目录
+ continue
+
+ match = pattern.search(entry.name)
+ if match:
+ dp_num, tp_num, pp_num, ep_num = map(int, match.groups())
+ result_list.append([dp_num, tp_num, pp_num, ep_num, peak_value])
+
+ logger.info(f"read dryrun info done, result size {len(result_list)} format [dp, tp, pp, ep, peak_value]")
+ logger.info('\n'.join(str(item) for item in result_list))
+ return result_list
+
+def get_file_pattern(root_dir, input_args, para):
+ logger.info(f"Reading file pattern from {root_dir}")
+ if para.SHELL_PATH:
+ if input_args.expert_num is None:
+ pattern = LLAMA_PATTERN_REG_SHELL
+ else:
+ pattern = MOE_PATTERN_REG_SHELL
+ else:
+ pattern = MODULE_PATTERN_REG_YAML
+ return pattern
+
+def launch_dryrun(input_args, dryrun_file_dir, dryrun_data_dir, para):
+ '''dryrun launcher'''
+ logger.info('start dryrun')
+ rank_size = cal_world_size(input_args)
+ env_update(rank_size, para.ENV_JSON)
+ tasks = []
+ layer_num = input_args.num_layers
+
+ # 遍历指定目录下的所有文件
+ for root, _, files in os.walk(dryrun_file_dir):
+ for file in files:
+ pattern = get_file_pattern(root, input_args, para)
+ match = re.match(pattern, file)
+ if not match:
+ continue
+ rank_id = calculate_rank_id(input_args, match, layer_num, rank_size)
+ tp = int(match.group(2))
+ target_dir = create_target_dir(file, dryrun_data_dir)
+ yaml_file_path = os.path.join(root, file)
+ tasks.append((para, yaml_file_path, target_dir, rank_id, tp))
+
+ with Pool(processes=para.DRYRUN_LIM) as pool:
+ pool.starmap(execute_command, tasks)
+ logger.info("all dryrun done")
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_config.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..38f2751809b09d91b044bf51bf557073bd87d7b0
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_config.py
@@ -0,0 +1,87 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" Transformer-Config dict parse module """
+
+import os
+import yaml
+
+class InputConfig(dict):
+ """
+ A class for configuration that inherits from Python's dict class.
+ Can parse configuration parameters from yaml files
+
+ Args:
+ args (Any): Extensible parameter list, a yaml configuration file path
+ kwargs (Any): Extensible parameter dictionary, a yaml configuration file path
+
+ Returns:
+ An instance of the class.
+
+ Examples:
+ >>> cfg = InputConfig('./test.yaml')
+ >>> cfg.a
+ """
+
+ def __init__(self, *args):
+ super().__init__()
+ cfg_dict = {}
+
+ for arg in args:
+ if isinstance(arg, str):
+ if arg.endswith('yaml') or arg.endswith('yml'):
+ raw_dict = InputConfig._file2dict(arg)
+ cfg_dict.update(raw_dict)
+
+ InputConfig._dict2config(self, cfg_dict)
+
+ def __getattr__(self, key):
+ if key not in self:
+ return None
+ return self[key]
+
+ @staticmethod
+ def _file2dict(filename=None):
+ """Convert config file to dictionary.
+
+ Args:
+ filename (str) : config file.
+ """
+ if filename is None:
+ raise NameError(f'This {format(filename)} cannot be empty.')
+
+ filepath = os.path.realpath(filename)
+ with open(filepath, encoding='utf-8') as fp:
+ # 文件指针重置到文件开头
+ fp.seek(0)
+ cfg_dict = yaml.safe_load(fp)
+
+ return cfg_dict
+
+ @staticmethod
+ def _dict2config(config, dic):
+ """Convert dictionary to config.
+
+ Args:
+ config : Config object
+ dic (dict) : dictionary
+ """
+ if isinstance(dic, dict):
+ for key, value in dic.items():
+ if isinstance(value, dict):
+ sub_config = InputConfig()
+ dict.__setitem__(config, key, sub_config)
+ InputConfig._dict2config(sub_config, value)
+ else:
+ config[key] = dic[key]
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_param.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_param.py
new file mode 100644
index 0000000000000000000000000000000000000000..25261961b4ef8731e775d6ff3f1edb93b6cf6254
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_param.py
@@ -0,0 +1,60 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""inner para class"""
+from fast_tuner.utils.logger import logger
+
+
+class InputParam:
+ '''All information that we need to record'''
+ PARAM_MAPPING = {
+ "ENV_JSON" : "env_json",
+ "DATASET" : "dataset",
+ "YAML_PATH" : "yaml_path",
+ "SHELL_PATH" : "shell_path",
+ "TOML_PATH" : "toml_path",
+ "MINDFORMERS_DIR" : "mindformers_dir",
+ "MINDSPEED_PATH" : "mindspeed_path",
+ "TORCHTITAN_PATH": "torchtitan_path",
+ "DRYRUN_DATA_DIR" : "dryrun_data_dir",
+ "PROFILE_DATA_DIR" : "profile_data_dir",
+ "DRYRUN_LIM" : "dryrun_lim",
+ "RANK_NUM" : "rank_num",
+ "SOLVER_NAME" : "solver_name",
+ "MAX_EXPERT_PARALLEL": "max_expert_parallel",
+ "OUTPUT_PATH": "output_path",
+ "GBS": "gbs",
+ "SELECT_RECOMPUTE": "select_recompute",
+ "ALG_PHASE": "alg_phase",
+ "PARSER_RESULT": "parser_result",
+ "DRYRUN": "dryrun",
+ "CHECK": "check",
+ "CONFIG": "config",
+ "NPUS_PER_NODE": "npus_per_node",
+ "NNODES": "nnodes",
+ "STRATEGY": "strategy",
+ }
+
+ def __init__(self, args):
+ self.args = args
+
+ def __getattr__(self, name):
+ if name in self.PARAM_MAPPING:
+ param_name = self.PARAM_MAPPING[name]
+ return getattr(self.args, param_name)
+ raise AttributeError(f"'InputParam' object has no attribute '{name}'")
+
+ def print_params(self):
+ for key, value in self.PARAM_MAPPING.items():
+ logger.info(f"{key}: {value} = {getattr(self, key)}")
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/logger.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bd4716beedb4c868444680c5f52568230e1dc0e
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/logger.py
@@ -0,0 +1,40 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""logger file"""
+import logging
+
+DEFAULT_STDOUT_FORMAT = '%(levelname)s %(asctime)s %(filename)s:%(lineno)d - %(message)s'
+FORMATTER = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+OUTPUT_LEVEL_NUM = logging.WARNING
+logging.addLevelName(OUTPUT_LEVEL_NUM, "OUTPUT")
+
+
+def setup_logger(name: str, level: int = logging.DEBUG):
+ """setup a logger"""
+ ch = logging.StreamHandler()
+ ch.setLevel(level)
+ ch.setFormatter(FORMATTER)
+
+ def output(self, message, *args):
+ self.warning(message, *args)
+
+ logging.Logger.output = output
+ ppb_logger = logging.getLogger(name)
+ ppb_logger.setLevel(level)
+ ppb_logger.addHandler(ch)
+
+ return ppb_logger
+
+logger = setup_logger('symphony', level=logging.INFO)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/ppc_input.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/ppc_input.py
new file mode 100644
index 0000000000000000000000000000000000000000..832e0f582da8e07c281dd33807b35c5a649ae869
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/ppc_input.py
@@ -0,0 +1,139 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""interface between nd and pp """
+import os
+import re
+import csv
+from typing import List
+from pathlib import Path
+from fast_tuner.utils.logger import logger
+from fast_tuner.utils.common import GENERAL_TOML
+from fast_tuner.utils.profiling.profile_info import ProfileInfo
+from fast_tuner.pipeline_conductor.start_service import ExpertInput
+from fast_tuner.pipeline_conductor.dryrun import DryRun, dryrun_config_error
+from fast_tuner.ndsearch.para_for_nd_search import ParaForNd
+
+class ParallelConfig:
+ def __init__(self, gbs, config):
+ self.dp = config[0]
+ self.tp = config[1]
+ self.pp = config[2]
+ self.vp = 1
+ self.ep = config[3]
+ self.micro = gbs / self.dp
+
+class PipelineInputConfig:
+ def __init__(self, profiling_info: ProfileInfo, config_path, model_args = None):
+ self.profiling_info = profiling_info
+ self.config_path = config_path
+ self.model_args = model_args
+
+class ParallelInput:
+ def __init__(self, para, profile_file_dir=None):
+ self.candidate_configs: List[PipelineInputConfig] = []
+
+ self.init_configs_info(para, profile_file_dir)
+
+ self.is_lowmem = int(os.getenv('ENABLE_LESS_MEM_VPP', 0))
+ args = para.args
+ self.solver_name = args.solver_name
+ self.env_config_json = args.env_json
+ self.dryrun_lim = args.dryrun_lim
+ self.dryrun = args.dryrun
+ self.check = args.check
+ self.output_path = args.output_path
+ if DryRun.config_file_type == 0:
+ self.ms_adapter_file = args.mindformers_dir
+ elif DryRun.config_file_type == 1:
+ self.ms_adapter_file = args.mindspeed_path
+ elif DryRun.config_file_type == 2:
+ self.ms_adapter_file = args.torchtitan_path
+ else:
+ raise TypeError(dryrun_config_error)
+
+ @staticmethod
+ def parse_results_by_csv(csv_file):
+ result_dict = {}
+ with open(csv_file, mode='r', encoding='utf-8') as csvfile:
+ reader = csv.DictReader(csvfile)
+ for row in reader:
+ # key = "_".join([row['dp'], row['tp'], row['pp'], row['ep']])
+ key_columns = list(row.keys())[:-5] # 获取前N-5列的列名
+ key = "_".join([row[col] for col in key_columns])
+ value = [float(row['dmratio']), float(row['bfratio']), float(row['re_grow_ration']),
+ float(row['hratio']), float(row['moe_fw'])]
+ result_dict[key] = value
+ return result_dict
+
+ def get_model_files_dir(self, args, profile_file_dir=None):
+ if profile_file_dir is not None:
+ files_dir = Path(profile_file_dir)
+ elif args.files_dir is not None:
+ files_dir = Path(args.files_dir)
+ else:
+ raise RuntimeError('Must specify either files_dir or profile_file_dir')
+
+ if args.yaml_path:
+ model_files = files_dir.glob('*.yaml')
+ elif args.shell_path:
+ model_files = files_dir.glob('*.sh')
+ elif args.toml_path:
+ model_files = files_dir.glob('*.toml')
+ else:
+ raise Exception("No yaml_path or shell_path specified")
+ return model_files
+
+ def get_args_info(self, para, config_file):
+ if config_file.name.endswith('.toml'):
+ para.TOML_PATH = config_file
+ model_args = ParaForNd(para)
+ return model_args
+ else:
+ return None
+
+ def init_configs_info(self, para, profile_file_dir=None):
+ args = para.args
+ model_files = self.get_model_files_dir(args, profile_file_dir)
+ csv_result = {}
+ # 若用户直接输入profiling解析信息---csv文件,则从文件中读入
+ if args.parser_result is not None:
+ csv_result = self.parse_results_by_csv(args.parser_result)
+
+ if args.yaml_path :
+ pattern = re.compile(r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)')
+ DryRun.config_file_type = 0
+ elif args.shell_path:
+ pattern = re.compile(r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)')
+ DryRun.config_file_type = 1
+ elif args.toml_path:
+ pattern = re.compile(GENERAL_TOML.replace('{}', '(\d+)'))
+ DryRun.config_file_type = 2
+
+ for config_file in model_files:
+ match = pattern.search(config_file.name)
+ if match:
+ # dp_num, tp_num, pp_num, ep_num = map(int, match.groups())
+ # config = [dp_num, tp_num, pp_num, ep_num]
+ config = [int(x) for x in match.groups()]
+ config_str = "_".join(map(str, config))
+ if config_str in csv_result:
+ profile_list = csv_result[config_str]
+ else:
+ profile_list = []
+ # todo: profiling解析的输入有待确认,例如rank
+ profile_info = ProfileInfo(args.profile_data_dir, profile_list)
+ input_args = self.get_args_info(para, config_file)
+ pipeline_input_config = PipelineInputConfig(profile_info, config_file, input_args)
+ self.candidate_configs.append(pipeline_input_config)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_info.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_info.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dd0f246b34ba9a16a02c129eabb0eddc9826327
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_info.py
@@ -0,0 +1,65 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""profile parser result interface"""
+import csv
+from fast_tuner.utils.logger import logger
+
+from fast_tuner.utils.logger import logger
+
+
+class ProfileInfo:
+ def __init__(self, profiling_path, input_data):
+ # 若input_data为空,则需要解析profiling结果
+ if not input_data:
+ logger.info('did not use the input data!')
+ input_data = self.parse_profiling_result(profiling_path)
+ self.dmratio = input_data[0]
+ self.bfratio = input_data[1]
+ self.re_grow_ration = input_data[2]
+ self.hratio = input_data[3]
+ self.moe_fw = input_data[4]
+ logger.info(f'{input_data}')
+
+ @staticmethod
+ def generate_csv():
+ # 定义 CSV 文件的列名
+ headers = ['dp', 'tp', 'pp', 'ep', 'vp', 'dmratio', 'bfratio', 'hratio', 'moe_bw', 're_grow_ration']
+
+ # 示例数据,这里可以根据实际需求修改或添加更多行数据
+ data = [
+ [128, 1, 8, 16, 1, 0.1, 0.2, 0.3, 100, 0.34],
+ [128, 1, 8, 8, 1, 0.4, 0.5, 0.6, 200, 0.24]
+ ]
+
+ # 定义要保存的 CSV 文件路径
+ csv_file_path = './config/profiling_result.csv'
+
+ # 打开文件并写入数据
+ with open(csv_file_path, mode='w', newline='', encoding='utf-8') as csvfile:
+ writer = csv.writer(csvfile)
+
+ # 写入列名
+ writer.writerow(headers)
+
+ # 写入数据行
+ for row in data:
+ writer.writerow(row)
+
+ logger.info(f"CSV 文件已生成,路径为: {csv_file_path}")
+
+ def parse_profiling_result(self, profiling_path):
+ # todo: 待填充
+ profiling_data = [0.1, 0.2, 0.3, 100, 2]
+ return profiling_data
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_launch.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..437a580285ad3ca39209c4a2a041ee9b6687b006
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_launch.py
@@ -0,0 +1,73 @@
+"""launch profile"""
+from fast_tuner.utils.common import LLAMA_PATTERN_REG_SHELL, MOE_PATTERN_REG_SHELL, GENERAL_TOML
+from fast_tuner.utils.logger import logger
+import subprocess
+import os
+import re
+import time
+from pathlib import Path
+class ProfileLaunch:
+ # 自动执行profile
+ def __init__(self, profile_configs, para):
+ self.profile_configs = profile_configs
+ self.para = para
+
+ def profile_launch(self, profile_file_dir):
+ # 遍历指定目录下的所有文件
+ for root, _, files in os.walk(profile_file_dir):
+ for file in files:
+ # if '_EP' in file:
+ # pattern = MOE_PATTERN_REG_SHELL
+ # else:
+ # pattern = LLAMA_PATTERN_REG_SHELL
+ # match = re.match(pattern, file)
+ # if not match:
+ # continue
+ profile_file_path = os.path.join(root, file)
+ if file.endswith(".toml"):
+ self.run_torchtitan(profile_file_path)
+ elif file.endswith(".sh"):
+ self.run_shell(profile_file_path)
+ else:
+ logger.error(f"file type: {file} is not supported.")
+
+ def run_shell(self, profile_file_dir):
+ cmd = ["bash", profile_file_dir]
+ logger.info(f"profile command: {cmd}")
+
+ process = subprocess.run(
+ cmd,
+ preexec_fn=os.setpgrp,
+ check=False,
+ )
+ # 为避免profile子进程未结束生成profile文件失败,增加sleep
+ time.sleep(60)
+ return_code = process.returncode
+ logger.info("Last job returns %d.", return_code)
+
+
+ def run_torchtitan(self, profile_file_toml):
+ run_file = self.para.TORCHTITAN_PATH
+ regex_pattern = GENERAL_TOML.replace('{}', '(\d+)')
+ match = re.match(regex_pattern, Path(profile_file_toml).name)
+ if match:
+ dp = int(match.group(1))
+ tp = int(match.group(2))
+ pp = int(match.group(3))
+ cp = int(match.group(5))
+ world_size = dp * tp * pp * cp
+ else:
+ raise ValueError(f"Invalid profile_file_toml: {profile_file_toml}")
+ cmd = f'CONFIG_FILE={profile_file_toml} NGPU={world_size} {run_file}'
+ logger.info(f"run_torchtitan profile command: {cmd}")
+
+ process = subprocess.run(
+ cmd,
+ preexec_fn=os.setpgrp,
+ shell=True,
+ check=False,
+ )
+ # 为避免profile子进程未结束生成profile文件失败,增加sleep
+ time.sleep(20)
+ return_code = process.returncode
+ logger.info("Last job returns %d.", return_code)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_parser.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..9de91989fb850b8dcf9c7f360c28667dadbbdc2c
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_parser.py
@@ -0,0 +1,681 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""profile parser"""
+import json
+import os
+import statistics
+import csv
+from collections import defaultdict
+from fast_tuner.pipeline_conductor.pp_util import parse_shell
+from fast_tuner.utils.logger import logger
+from fast_tuner.ndsearch.para_for_nd_search import ParaForNd
+from fast_tuner.ndsearch.memory_model import compute_weight_and_optimizer_memory
+from pathlib import Path
+
+encoding = 'utf-8'
+
+def find_file_by_name(directory, filename):
+ for root, _, files in os.walk(directory):
+ if filename in files:
+ return os.path.join(root, filename)
+ logger.error(f'No such file {filename} in directory {directory}')
+ return None
+
+def median_mean(durs):
+ if len(durs) == 0:
+ logger.info(f"get durs no values")
+ return 0
+ median_durs = durs[len(durs)//4 : len(durs) - len(durs)//4]
+ return round(statistics.mean(median_durs)/1000, 3)
+
+class ProfileParser:
+
+ '''
+ The recommend config for different pp is:
+ pp2: [6,5] recomp [6,0], reduce laysers when OOM, layers at least [5,4]
+ if still OOM, do another profile wrt [3,1] no recomp w/ dense_replace = 1
+ pp4: [3,3,3,2] recomp [3,3,0,0]
+ pp8: [3,1,1,1,1,2,2,2] recomp [3,1,1,1,1,2,0,2] (or try pp4 with layers [3,2,2,2] w/ recomp [3,2,0,2])
+ a typical input is a string and a tuple of numbers. 2-tuple for pp2 and 4-tuple for higher pp,
+ which consists the rank for each stages.
+ '''
+
+ profile_data = None
+ config = ''
+ # recomp_bws, moe_fws1, moe_fws2, moe_bws, dense_fws, dense_bws, totals = [], [], [], [], [], [], []
+ # moe_fw1, moe_fw2, moe_bw, recomp_bw, dense_fw, dense_bw, head_ratio, num_last_stage_layers, mtp = 0
+ dense_fws, dense_bws, recomp_dense_bws, moe_fws, moe_bws, recomp_moe_bws, totals = [], [], [], [], [], [], []
+ (dense_fw, dense_bw, recomp_dense_bw, moe_fw, moe_bw,
+ recomp_moe_bw, total, head_ratio, num_last_stage_layers, mtp) = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+ att_tid, grad_att_tid = 2, 2
+
+ def __init__(self, input_args, para):
+ self.para = para
+ self.mbn = input_args.mbn
+ self.layer_num = input_args.num_layers
+ self.num_layer_list = input_args.num_layer_list
+ self.recompute_num_layers = input_args.recompute_num_layers
+ self.dense_num = input_args.first_k_dense_replace
+ self.step = input_args.profile_steps
+ self.model_name = input_args.model_name
+
+ def parse_batch_profile_result(self, profile_configs_results):
+ # step1. 解析profile文件
+ # ep填1, dmratrion:0.33, bfratio=dense_bw/dense_fw, re_grow_ration = 1,hratio=0.68
+ # [dp, tp, pp, dmratio, bfratio, re_grow_ration, hratio, moe_fw]
+ profile_result = []
+ dense_flag = False
+ #Todo 这里para.YAML_PATH的实现待修改
+ if self.para.YAML_PATH:
+ for result in profile_configs_results:
+ profile_dir = result[-1]
+ pp = result[2]
+ self.layer_num = result[-2]
+ if len(result) > 6:
+ config = result[:4]
+ else:
+ config = result[:3]
+ config += [0.33, 0, 1, 0.68, 0]
+ self.config_anal_func(profile_dir, pp, config)
+ self.refresh()
+ profile_result.append(config)
+ elif self.para.SHELL_PATH:
+ for result in profile_configs_results:
+ profile_dir = result[-1]
+ pp = result[2]
+ update_input_args, _ = parse_shell(result[-2])
+ self.mbn = update_input_args.get('GBS') // update_input_args.get('MBS') // update_input_args.get('DP')
+ self.layer_num = update_input_args.get('NUM_LAYERS')
+ if 'FIRST_K_DENSE_RAPLACE' in update_input_args:
+ self.dense_num = update_input_args.get('FIRST_K_DENSE_RAPLACE')
+ else:
+ self.dense_num = 3
+ logger.info(f"mbn: {self.mbn}, layer_num: {self.layer_num}.")
+ if len(result) > 8:
+ config = result[:4]
+ else:
+ config = result[:3]
+ config += [0.33, 0, 1, 0.68, 0]
+ self.config_anal_func(profile_dir, pp, config)
+ self.refresh()
+ profile_result.append(config)
+ elif self.para.TOML_PATH:
+ for result in profile_configs_results:
+ profile_dir = result[-1]
+ pp = result[2]
+ config = result[:-2]
+ config += [0.33, 1, 1, 0.68, 16]
+ self.config_anal_func(profile_dir, pp, config)
+ self.refresh()
+ profile_result.append(config)
+
+ # step2. 生成csv文件
+ # dense模型的result长度是8,moe模型的result长度是9
+ if len(profile_result[0]) == 8:
+ dense_flag = True
+ self.write_result_to_csv(profile_result, dense_flag)
+
+ # 把profile结果写入csv
+ def write_result_to_csv(self, profile_result, is_llama=False):
+ if self.para.YAML_PATH or self.para.SHELL_PATH:
+ if is_llama:
+ headers = ['dp', 'tp', 'pp', 'dmratio', 'bfratio', 're_grow_ration', 'hratio', 'moe_fw']
+ else:
+ headers = ['dp', 'tp', 'pp', 'ep', 'dmratio', 'bfratio', 're_grow_ration', 'hratio', 'moe_fw']
+ else:
+ headers = ['dp', 'tp', 'pp', 'ep', 'cp', 'op', 'dmratio', 'bfratio', 're_grow_ration', 'hratio', 'moe_fw']
+ # 写入 CSV 文件
+ try:
+ csv_path = os.path.join(os.path.abspath(self.para.OUTPUT_PATH), 'profile_parser_result.csv')
+ with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
+ writer = csv.writer(csvfile)
+ # 写入表头
+ writer.writerow(headers)
+ # 写入数据
+ writer.writerows(profile_result)
+ logger.info(f"CSV file {csv_path} generate succ.")
+ self.para.args.parser_result = csv_path
+ except Exception as e:
+ print(f"write CSV file fail: {e}")
+
+ def load_profile(self, file, rk):
+ global encoding
+ path = os.path.abspath('.\\profile_info') + '\\' + str(file)
+ path = find_file_by_name(path + '\\rank_' + str(rk), 'trace_view.json')
+ with open(path, 'r', encoding=encoding) as f:
+ self.profile_data = json.load(f)
+
+ def load_profile_by_dir(self, rank_dir):
+ global encoding
+ path = find_file_by_name(rank_dir, 'trace_view.json')
+ with open(path, 'r', encoding=encoding) as f:
+ self.profile_data = json.load(f)
+
+ def load_profile_by_kernel(self, file):
+ global encoding
+ path = find_file_by_name(file, 'kernel_details.csv')
+ if path is None:
+ logger.error(f"can not find kernel details file {file}")
+ return False
+ logger.info(f'path of kernel_details is {path}')
+ with open(path, 'r', encoding=encoding) as f:
+ reader = csv.DictReader(f)
+ self.profile_data = list(reader)
+ return True
+
+ def load_profile_no_recompute(self, file, rk):
+ global encoding
+ path = os.path.abspath('.\\profile_info') + '\\' + str(file)
+ path = find_file_by_name(path + '\\rank_' + str(rk) + '_norecomp' , 'trace_view.json')
+ with open(path, 'r', encoding=encoding) as f:
+ self.profile_data = json.load(f)
+
+ def refresh(self):
+ self.profile_data = []
+ (self.dense_fws, self.dense_bws, self.recomp_dense_bws, self.moe_fws, self.moe_bws,
+ self.recomp_moe_bws, self.totals) = [
+ [] for _ in range(7)]
+ (self.dense_fw, self.dense_bw, self.recomp_dense_bw, self.moe_fw, self.moe_bw, self.recomp_moe_bw,
+ self.total, self.head_ratio, self.num_last_stage_layers, self.mtp) = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+
+ def release(self):
+ self.profile_data = []
+
+ def set_tid(self, a, g):
+ self.att_tid, self.grad_att_tid = a, g
+
+ def extract_atts(self):
+ atts, grad_atts = [], []
+ for line in self.profile_data:
+ name = line['name']
+ if line['tid'] == self.att_tid and name.endswith('FlashAttentionScore_FlashAttentionScore'):
+ atts.append(float(line['ts']))
+ if line['tid'] == self.grad_att_tid and name.endswith('FlashAttentionScoreGrad_FlashAttentionScoreGrad'):
+ grad_atts.append(float(line['ts']))
+ logger.info(f'num of atts is {len(atts)}, num of grad_atts is {len(grad_atts)}')
+ return atts, grad_atts
+
+ def extract_atts_by_kernel(self):
+ atts, grad_atts = [], []
+ try:
+ with open(self.profile_data, 'r', encoding='utf-8') as file:
+ reader = csv.DictReader(file)
+ for row in reader:
+ name_value = row.get('Name', '')
+ if isinstance(name_value, str):
+ if name_value.endswith('FlashAttentionScore_FlashAttentionScore'):
+ ts_value = float(row.get('Start Time(us)'))
+ atts.append(ts_value)
+ elif name_value.endswith('FlashAttentionScoreGrad_FlashAttentionScoreGrad'):
+ ts_value = float(row.get('Start Time(us)'))
+ grad_atts.append(ts_value)
+ logger.info(f'extract by kernel_details.csv num of atts is {len(atts)}, num of grad_atts')
+ return atts, grad_atts
+ except Exception as e:
+ logger.error(f'extract by kernel_details.csv fail: {e}')
+ return [], []
+
+ def extract_atts_by_loaded_data(self):
+ atts, grad_atts = [], []
+ try:
+ for row in self.profile_data:
+ if row['Type'] == 'FlashAttentionScore':
+ atts.append(float(row['Start Time(us)']))
+ if row['Type'] == 'FlashAttentionScoreGrad':
+ grad_atts.append(float(row['Start Time(us)']))
+ logger.info(f'num of atts is {len(atts)}, num of grad_atts is {len(grad_atts)}')
+ return atts, grad_atts
+ except Exception as e:
+ logger.error(f'extract by loaded_data fail: {e}')
+ return [], []
+
+ def extract_atts_for_titan(self):
+ atts, grad_atts = [], []
+ try:
+ for row in self.profile_data:
+ if row['Name'] == 'aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore':
+ atts.append(float(row['Start Time(us)']))
+ if row['Name'] == 'aclnnFlashAttentionScoreGrad_FlashAttentionScoreGrad_FlashAttentionScoreGrad':
+ grad_atts.append(float(row['Start Time(us)']))
+ logger.info(f'num of atts is {len(atts)}, num of grad_atts is {len(grad_atts)}')
+ return atts, grad_atts
+ except Exception as e:
+ logger.error(f'extract by loaded_data fail: {e}')
+ return [], []
+
+ #todo: 待修改
+ def stage_anal(self, pp, stage, isRecompute): #assuming we profiled 2 steps w/ micro = 32
+ atts, grad_atts = self.extract_atts_by_kernel()
+ layers, xlayers = len(grad_atts) // 64, 2 * len(grad_atts) // 64
+ warm_up = (pp-1-stage) * layers
+ atts = atts[warm_up : 32*xlayers - warm_up] + atts[32*xlayers + warm_up : 32*xlayers + 32*xlayers-warm_up]
+ grad_atts = (grad_atts[warm_up : 32*layers - warm_up]
+ + grad_atts[32*layers + warm_up : 32*layers + 32*layers-warm_up])
+ att_chunks = [atts[xlayers*i:xlayers*(i+1)] for i in range(len(atts)//xlayers-1)]
+ grad_chunks = [grad_atts[layers*i:layers*(i+1)] for i in range(len(grad_atts)//layers-1)]
+ if pp == 2:
+ self.pp2_analysis(att_chunks, grad_chunks, layers, pp, stage)
+ else:
+ self.pp_not2_analysis(att_chunks, grad_chunks, layers, pp, stage)
+
+ def pp_not2_analysis(self, att_chunks, grad_chunks, layers, pp, stage):
+ if stage == 0:
+ for chunk in att_chunks:
+ for i in range(layers - 1):
+ if i <= 2: self.dense_fws.append(chunk[i + 1] - chunk[i])
+ for chunk in grad_chunks:
+ for i in range(layers - 3, layers - 1): self.dense_bws.append(chunk[i + 1] - chunk[i])
+ elif stage == pp - 3:
+ for chunk in att_chunks:
+ for i in range(layers - 1):
+ self.moe_fws1.append(chunk[i + 1] - chunk[i])
+ for chunk in grad_chunks:
+ for i in range(layers - 3, layers - 1): self.moe_bws.append(chunk[i + 1] - chunk[i])
+ elif stage == pp - 2:
+ for chunk in att_chunks:
+ for i in range(layers - 1): self.moe_fws2.append(chunk[i + 1] - chunk[i])
+ for chunk in grad_chunks:
+ for i in range(layers - 3, layers - 1): self.moe_bws.append(chunk[i + 1] - chunk[i])
+ else:
+ for i in range(len(att_chunks) - 1): self.totals.append(att_chunks[i + 1][0] - att_chunks[i][0])
+ self.num_last_stage_layers = layers
+
+ def pp2_analysis(self, att_chunks, grad_chunks, layers, pp, stage):
+ if stage == 0:
+ for chunk in att_chunks:
+ for i in range(layers - 1):
+ if i <= 2:
+ self.dense_fws.append(chunk[i + 1] - chunk[i])
+ else:
+ self.moe_fws1.append(chunk[i + 1] - chunk[i])
+ for chunk in grad_chunks:
+ for i in range(layers - 3, layers - 1): self.dense_bws.append(chunk[i + 1] - chunk[i])
+ for i in range(layers - 4): self.recomp_bws.append(chunk[i + 1] - chunk[i])
+ for i in range(len(att_chunks) - 1): self.totals.append(att_chunks[i + 1][0] - att_chunks[i][0])
+ if stage == pp - 1:
+ for chunk in att_chunks:
+ for i in range(layers - 2): self.moe_fws2.append(chunk[i + 1] - chunk[i])
+ for chunk in grad_chunks:
+ for i in range(2, layers - 1): self.moe_bws.append(chunk[i + 1] - chunk[i])
+ for i in range(len(att_chunks) - 1): self.totals.append(att_chunks[i + 1][0] - att_chunks[i][0])
+ self.num_last_stage_layers = layers
+
+ def stage_anal_for_pp1(self):
+ atts, grad_atts = self.extract_atts_for_titan()
+ layer_num = 10 # 与titan的__init__.py中的
+ try:
+ import torchtitan.protocols.train_spec as train_spec_module
+ model_flavor = 'tune'
+ train_spec = train_spec_module.get_train_spec(self.model_name)
+ model_args = train_spec.model_args[model_flavor]
+ layer_num = model_args.n_layers
+ except Exception as e:
+ print(f'Error is: {e}, and do not get layer_num for profile parser.')
+
+ atts_per_step = len(atts) // self.step
+ grad_atts_per_step = len(grad_atts) // self.step
+ att_chunks = [atts[atts_per_step * i:atts_per_step * (i + 1)] for i in range(self.step)]
+ grad_chunks = [grad_atts[grad_atts_per_step * i:grad_atts_per_step * (i + 1)] for i in range(self.step)]
+
+ for chunk in att_chunks:
+ for i in range(layer_num - 1):
+ self.moe_fws.append(chunk[i + 1] - chunk[i])
+
+ for chunk in grad_chunks:
+ for i in range(layer_num - 1):
+ self.moe_bws.append(chunk[i + 1] - chunk[i])
+
+ def stage_anal_for_pp2(self, pp, stage):
+ atts, grad_atts = self.extract_atts_by_loaded_data()
+ layer_num = self.num_layer_list[stage]
+ warm_up = (pp - 1 - stage) * layer_num
+ atts_per_step = len(atts) // self.step
+ grad_atts_per_step = len(grad_atts) // self.step
+
+ att_chunks = [atts[atts_per_step * i:atts_per_step * (i + 1)] for i in range(self.step)]
+ grad_chunks = [grad_atts[grad_atts_per_step * i:grad_atts_per_step * (i + 1)] for i in range(self.step)]
+
+ if pp == 2:
+ if stage == 0:
+ for chunk in att_chunks:
+ for i in range(self.mbn - 1):
+ idx = warm_up + i * (layer_num + self.recompute_num_layers)
+ for j in range(self.dense_num):
+ self.dense_fws.append(chunk[idx + j + 1] - chunk[idx + j])
+
+ for chunk in grad_chunks:
+ for i in range(self.mbn - 2):
+ idx = warm_up + i * layer_num
+ self.dense_bws.append(chunk[idx + 2] - chunk[idx + 1])
+ self.recomp_dense_bws.append(chunk[idx + 3] - chunk[idx + 2])
+
+ if stage == 1:
+ for chunk in att_chunks:
+ for i in range(self.mbn):
+ idx = warm_up + i * (layer_num + self.recompute_num_layers)
+ for j in range(layer_num - 1):
+ self.moe_fws.append(chunk[idx + j + 1] - chunk[idx + j])
+ if i > 1:
+ idx_last = warm_up + (i - 1) * (layer_num + self.recompute_num_layers)
+ self.totals.append(chunk[idx] - chunk[idx_last])
+
+ for chunk in grad_chunks:
+ for i in range(self.mbn):
+ idx = warm_up + i * layer_num
+ self.moe_bws.append(chunk[idx + 1] - chunk[idx])
+ self.recomp_moe_bws.append(chunk[idx + 2] - chunk[idx + 1])
+ self.num_last_stage_layers = layer_num
+
+ def config_anal(self, config, ranks):
+ self.config = config
+ pp = len(ranks)
+ for i in range(pp):
+ self.load_profile(config, ranks[i])
+ self.stage_anal(pp, i)
+ self.release()
+
+ def config_anal_func(self, profile_dir, pp, profile_result):
+ self.config = profile_dir
+ self.process_folders(profile_dir, pp, profile_result)
+
+ def process_folders(self, profile_result_dir, pp, profile_result):
+ """
+ 遍历根目录下的所有文件夹,并使用 parser 函数解析每个文件夹中的内容
+
+ 参数:
+ profile_result_dir (str): 根目录路径
+ """
+ root_path = Path(profile_result_dir)
+ if not root_path.is_dir():
+ logger.info(f"profile_result_dir:{profile_result_dir} not exist")
+ return
+
+ # 获取所有子文件夹--每个rank的文件夹
+ folders = [f for f in root_path.iterdir() if f.is_dir()]
+ folders.sort()
+ i = 0
+ for rank_folder in folders:
+ logger.info(f"parsing: {rank_folder}")
+ # Todo para.YAML_PATH的分支待修改
+ if self.para.YAML_PATH:
+ self.load_profile_by_dir(rank_folder)
+ self.stage_anal(pp, i, True)
+ else:
+ find_file = self.load_profile_by_kernel(rank_folder)
+ if find_file == False:
+ continue
+ if self.para.SHELL_PATH:
+ self.stage_anal_for_pp2(pp, i)
+ elif self.para.TOML_PATH:
+ self.stage_anal_for_pp1()
+ self.release()
+ i += 1
+ self.refined_data()
+ self.fill_result(profile_result)
+
+ def data_display(self):
+ logger.info(f'moe_fws1 are {self.moe_fws1}')
+ logger.info(f'moe_fws2 are {self.moe_fws2}')
+ logger.info(f'moe_bws is {self.moe_bws}')
+ logger.info(f'recomp_bws is {self.recomp_bws}')
+ logger.info(f'dense_fw is {self.dense_fw}')
+ logger.info(f'dense_bw is {self.dense_bws}')
+ logger.info(f'totals is {self.totals}')
+
+ def refined_data(self):
+ if self.para.YAML_PATH or self.para.SHELL_PATH:
+ if len(self.dense_fws) == 0 or len(self.dense_bws) == 0:
+ logger.info(f"dense_fws len {len(self.dense_fws)} "
+ f"or dense_bws {len(self.dense_fws)} is empty")
+ return
+ self.dense_fw = median_mean(self.dense_fws)
+ self.dense_bw = median_mean(self.dense_bws)
+ self.recomp_dense_bw = median_mean(self.recomp_dense_bws)
+ self.moe_fw = median_mean(self.moe_fws)
+ self.moe_bw = median_mean(self.moe_bws)
+ self.recomp_moe_bw = median_mean(self.recomp_moe_bws)
+ self.total = median_mean(self.totals)
+ self.mtp = (self.total - self.num_last_stage_layers * self.moe_fw -
+ (self.num_last_stage_layers - self.recompute_num_layers) * self.moe_bw - self.recomp_moe_bw)
+ logger.info(
+ f'config is {self.config}, dense forward is {self.dense_fw}, '
+ f'dense backward is {self.dense_bw}, recompute dense backward is {self.recomp_dense_bw}, '
+ f'moe forward is {self.moe_fw}, moe backward is {self.moe_bw}, '
+ f'recompute moe backward is {self.recomp_moe_bw}, lmhead and MTP is {self.mtp}')
+ elif self.para.TOML_PATH:
+ if len(self.moe_fws) == 0 or len(self.moe_bws) == 0:
+ logger.info(f"fws len {len(self.moe_fws)} "
+ f"or bws {len(self.moe_bws)} is empty")
+ return
+ self.moe_fw = median_mean(self.moe_fws)
+ self.moe_bw = median_mean(self.moe_bws)
+ logger.info(
+ f'config is {self.config}, forward is {self.moe_fw}, backward is {self.moe_bw}')
+
+ def fill_result(self, profile_result):
+ if self.para.YAML_PATH or self.para.SHELL_PATH:
+ if len(self.dense_fws) == 0 or len(self.dense_bws) == 0:
+ logger.info(f"config {self.config} parse profile result is empty ")
+ return
+ profile_result[-1] = self.moe_fw
+ profile_result[-2] = self.mtp / self.moe_fw
+ profile_result[-3] = (self.recomp_moe_bw - self.moe_bw) / self.moe_fw
+ profile_result[-4] = self.moe_bw / self.moe_fw
+ profile_result[-5] = self.dense_fw / self.moe_fw
+ elif self.para.TOML_PATH:
+ if len(self.moe_fws) == 0 or len(self.moe_bws) == 0:
+ logger.info(f"config {self.config} parse profile result is empty ")
+ return
+ profile_result[-1] = self.moe_fw
+ profile_result[-2] = 1.5
+ profile_result[-3] = 1.0
+ profile_result[-4] = self.moe_bw / self.moe_fw
+ profile_result[-5] = 0.3
+
+class MemoryInfo:
+ select_mem0 = 76
+ select_mem12 = 76
+ select_mem = 77
+ re_comp_mem0 = 3
+ re_comp_mem12 = 3
+ re_comp_mem = 5
+ act_mem0 = 79
+ act_mem12 = 79
+ act_mem = 79
+ layer_mem012 = 349
+ layer_mem = 340
+ static_mem0 = 734
+ static_mem = 116
+ lm_head_mem = 690
+
+ def write_to_txt(self, memory_file_name):
+ with open(memory_file_name, 'w', encoding='utf-8') as file:
+ file.write(f'select_mem0={self.select_mem0}\n')
+ file.write(f'select_mem12={self.select_mem12}\n')
+ file.write(f'select_mem={self.select_mem}\n')
+ file.write(f're_comp_mem0={self.re_comp_mem0}\n')
+ file.write(f're_comp_mem12={self.re_comp_mem12}\n')
+ file.write(f're_comp_mem={self.re_comp_mem}\n')
+ file.write(f'act_mem0={self.act_mem0}\n')
+ file.write(f'act_mem12={self.act_mem12}\n')
+ file.write(f'act_mem={self.act_mem}\n')
+ file.write(f'layer_mem012={self.layer_mem012}\n')
+ file.write(f'layer_mem={self.layer_mem}\n')
+ file.write(f'static_mem0={self.static_mem0}\n')
+ file.write(f'static_mem={self.static_mem}\n')
+ file.write(f'lm_head_mem={self.lm_head_mem}\n')
+ logger.info(f'Write memory info to {memory_file_name}')
+
+class ProfileMemParser:
+ pipeline_output_file = 'pipeline_output'
+ memory_info_dir = 'memory_info'
+ def __init__(self, input_args, para):
+ self.profile_mem_data = None
+ self.input_args = input_args
+ self.para = para
+ self.peak_memory_usage = 0.0 # MB
+ self.static_memory_usage = 0.0 # MB
+
+ def write_mem_info_to_txt(self, config, mem_info):
+ '''
+
+ Args:
+ config: [dp, tp, pp, ep, offset, recompute, num_layers]
+
+ Returns:
+
+ '''
+ memory_dir = os.path.join(os.getcwd(), self.pipeline_output_file, self.memory_info_dir)
+ os.makedirs(memory_dir, exist_ok=True)
+ dense_layer_num = self.input_args.first_k_dense_replace
+ moe_layer_num = self.input_args.num_layers - dense_layer_num
+
+ dp, tp, pp, ep = config[:4]
+ mbn = self.input_args.gbs // self.input_args.mbs // dp
+ filename = (
+ f'layers{dense_layer_num}_{moe_layer_num}_micro{mbn}_dp{dp}'
+ f'_tp{tp}_pp{pp}_vp1'
+ f'_ep{ep}.txt')
+ memory_file_name = os.path.join(memory_dir, filename)
+ mem_info.write_to_txt(memory_file_name)
+
+ def set_static_mem(self, mem_info, profile_file):
+ self.para.TOML_PATH = profile_file
+ input_args = ParaForNd(self.para)
+ dense_size, moe_size, vocab_size = compute_weight_and_optimizer_memory(input_args)
+ mem_info.layer_mem012 = dense_size
+ mem_info.layer_mem = moe_size
+ mem_info.static_mem0 = vocab_size
+ mem_info.static_mem = 0
+ mem_info.lm_head_mem = 1.5 * moe_size
+ return input_args
+
+ def cal_dynamic_mem(self, mem_info, cur_input_args):
+ dynamic_mem_total = self.peak_memory_usage - self.static_memory_usage
+ if cur_input_args.pp == 1:
+ dynamic_mem_layer = dynamic_mem_total // cur_input_args.num_layers
+ mem_info.act_mem0 = dynamic_mem_layer
+ mem_info.act_mem12 = dynamic_mem_layer
+ mem_info.act_mem = dynamic_mem_layer
+
+
+
+ def manage_mem_infos(self, need_cumlate=True):
+ sorted_times = sorted(self.real_memory_changes.keys())
+ cumulative_mem = 0
+ times = []
+ memory_usage = []
+ for time in sorted_times:
+ if need_cumlate:
+ cumulative_mem += self.real_memory_changes[time]
+ times.append(time)
+ memory_usage.append(cumulative_mem)
+ else:
+ times.append(time)
+ memory_usage.append(self.real_memory_changes[time])
+ max_index = memory_usage.index(max(memory_usage))
+ peak_time = times[max_index]
+ peak_memory_usage = memory_usage[max_index]
+ logger.info(f'Peak memory usage: {peak_memory_usage}, peak time: {peak_time}')
+ self.peak_memory_usage = peak_memory_usage / 1024
+
+ def parse_memory_block(self, filename):
+ real_memory_changes = defaultdict(int)
+ static_memory = 0
+ flag = False
+ with open(filename, mode='r', encoding='utf-8') as file:
+ reader = csv.DictReader(file)
+ # real memory
+ for row in reader:
+ alloc_size = row['Size(KB)']
+ if not alloc_size:
+ logger.debug(f'Size(KB) is none')
+ continue
+ size = float(alloc_size)
+
+ alloc_time = row['Allocation Time(us)']
+ if not alloc_time:
+ logger.debug(f'Allocation Time is none size is {size}')
+ continue
+ start = int(alloc_time.strip().split('.')[0])
+
+ release_time = row['Active Release Time(us)']
+ if not release_time:
+ logger.debug(f'Active Release Time is none, size is {size}')
+ continue
+ end = int(release_time.strip().split('.')[0])
+
+ real_memory_changes[start] += size # 在start时增加内存
+ real_memory_changes[end - 1] += 0 # for plt
+ real_memory_changes[end] -= size # 在end时减少内存
+
+ if flag == False:
+ static_memory = int(row['Allocation Total Allocated(MB)'].strip().split('.')[0])
+
+ flag = True
+ real_memory_changes[start] += (static_memory * 1024)
+
+ print(f'static mem is {static_memory} M')
+ self.static_memory_usage = static_memory
+ self.real_memory_changes = real_memory_changes
+ self.manage_mem_infos()
+
+ def parse_memory_file(self, profile_dir):
+ # step1. parse operator_memory.csv
+ root_path = Path(profile_dir)
+ if not root_path.is_dir():
+ logger.info(f"profile_result_dir:{profile_dir} not exist")
+ return
+ folders = [f for f in root_path.iterdir() if f.is_dir()]
+ folders.sort()
+ for rank_folder in folders:
+ logger.info(f"parsing: {rank_folder}")
+ path = find_file_by_name(rank_folder, 'operator_memory.csv')
+ if not path:
+ logger.error(f"No operator memory csv file in profile_dir {profile_dir} rank_dir {rank_folder}")
+ continue
+ self.parse_memory_block(path)
+
+ def parse_mem(self, profile_dir, profile_file):
+ '''
+
+ Args:
+ profile_dir: profile结果的路径
+
+ Returns:
+
+ '''
+ mem_info = MemoryInfo()
+ self.parse_memory_file(profile_dir)
+ cur_input_args = self.set_static_mem(mem_info, profile_file)
+ self.cal_dynamic_mem(mem_info, cur_input_args)
+ # file mem_info
+ return mem_info
+
+ def mem_parser(self, profile_configs):
+ '''
+
+ Args:
+ profile_configs: [dp, tp, pp, ep, offset, recompute, num_layers, profile_result_dir]
+
+ Returns:
+
+ '''
+ for config in profile_configs:
+ profile_dir = config[-1]
+ profile_file = config[-2]
+ mem_info = self.parse_mem(profile_dir, profile_file)
+ self.write_mem_info_to_txt(config, mem_info)
+ return
diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_prepare.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_prepare.py
new file mode 100644
index 0000000000000000000000000000000000000000..18caaeadb1f80cb1708f6b9780b7472822a25b41
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_prepare.py
@@ -0,0 +1,262 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""profile prepare work"""
+import os
+from fast_tuner.utils.logger import logger
+from fast_tuner.utils.common import generate_files, cal_world_size
+
+
+def find_file_by_name(directory, filename):
+ for root, dirs, files in os.walk(directory):
+ if filename in files:
+ return os.path.join(root, filename)
+
+def decide_pp_dp(config, available_devices):
+ """
+
+ :param config: [config, available_devices]
+ :param available_devices: 可用来做profiling的卡数
+ :return: [dp, tp, pp, ep]
+ """
+ dp, tp, origin_pp, ep, cp, op = config[:6]
+ world_size = dp * tp * origin_pp * cp
+ device_multiple = world_size //available_devices
+ min_pp = 2 # pp最小取2
+
+ # 检查设备数量是足够的, dp * tp >= ep
+ if max(min_pp * tp * cp, min_pp * ep) > available_devices:
+ print(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: 2, ep: {ep}')
+
+ # 裁剪pp
+ if device_multiple >= origin_pp // min_pp:
+ pp = 2
+ else:
+ if device_multiple % 3 == 0:
+ if origin_pp % 3 == 0:
+ pp = origin_pp // 3
+ else:
+ pp = origin_pp // (device_multiple // 3)
+ else:
+ pp = origin_pp // device_multiple
+
+ # 调整dp以满足设备数量限制
+ cur_multiple = pp * dp * tp * cp // available_devices
+ if cur_multiple > 1:
+ dp //= cur_multiple
+ return [dp, tp, pp, ep, cp, op]
+
+def decide_dp(config, available_devices):
+ """
+ :param config: [config, available_devices]
+ :param available_devices: 可用来做profiling的卡数
+ :return: [dp, tp, pp, ep]
+ """
+ dp, tp, origin_pp, ep = config[:4]
+ pp = 2 # pp取2
+
+ # 检查设备数量是足够的, dp * tp >= ep
+ if max(pp * tp, pp * ep) > available_devices:
+ print(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: {pp}, ep: {ep}')
+ return None
+
+ # 调整dp以满足设备数量限制
+ cur_multiple = pp * dp * tp // available_devices
+ if cur_multiple > 1:
+ dp //= cur_multiple
+ if dp % ep != 0:
+ print(
+ f'dp_size ({dp}) is not divisible by ep_size({ep}), for config: dp: {dp}, tp: {tp}, pp: {pp}, ep: {ep}')
+ return None
+ return [dp, tp, pp, ep]
+
+def decide_dp_for_titan(config, available_devices):
+ dp, tp, pp, ep, cp, op = config[:6]
+ world_size = dp * tp * pp * cp
+ pp = 1
+ if max(pp * tp, pp * ep) > available_devices:
+ logger.info(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: {pp}, ep: {ep}, cp: {cp}, fsdp: {op}')
+ return None
+ if world_size <= 8:
+ if world_size != available_devices:
+ logger.info(f'Currently, we assume world size {world_size}'
+ f'equals the number of available devices {available_devices} within a single node')
+ return None
+ else:
+ return config[:6]
+ # 调整dp以满足设备数量限制
+ cur_multiple = pp * dp * cp * tp // available_devices
+ if cur_multiple > 1:
+ if op // cur_multiple == 0:
+ logger.info(f'prune cause of fsdp: ddp: {dp//op}, tp: {tp}, pp: {pp}, ep: {ep}, cp: {cp}, fsdp: {op}')
+ return None
+ op //= cur_multiple
+ dp //= cur_multiple
+ return [dp, tp, pp, ep, cp, op]
+
+def decide_pp_dp_llama(config, available_devices):
+ """
+
+ :param config: [config, available_devices]
+ :param available_devices: 可用来做profiling的卡数
+ :return: [dp, tp, pp]
+ """
+ dp, tp, origin_pp = config[:3]
+ world_size = dp * tp * origin_pp
+ device_multiple = world_size // available_devices
+ min_pp = 1 # pp最小取2
+
+ # 检查设备数量是足够的, dp * tp >= ep
+ if min_pp * tp > available_devices:
+ print(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: 2')
+ return None
+
+ # 裁剪pp
+ if device_multiple >= origin_pp // min_pp:
+ pp = min_pp
+ else:
+ if device_multiple % 3 == 0:
+ if origin_pp % 3 == 0:
+ pp = origin_pp // 3
+ else:
+ pp = origin_pp // (device_multiple // 3)
+ else:
+ pp = origin_pp // device_multiple
+
+ # 调整dp以满足设备数量限制
+ cur_multiple = pp * dp * tp // available_devices
+ if cur_multiple > 1:
+ dp //= cur_multiple
+ return [dp, tp, pp]
+
+def trans_config_satisfy_rank_num(mem_prune_space, para, input_args):
+ """
+
+ :param mem_prune_space: [dp,tp,pp,ep,cp,op]
+ :param para: user config para for tool
+ :param input_args: model args and train args from config file
+ :return: list[[dp, tp, pp, ep, cp, op, offset, num_layers]]
+ """
+ profile_configs = []
+ rank_num = min(para.RANK_NUM, input_args.world_size)
+ for config in mem_prune_space:
+ trans_config = budget_profile_config_generator(config, para, rank_num)
+ if trans_config is not None and trans_config not in profile_configs:
+ profile_configs.append(trans_config)
+ logger.info(f"profile configs len: {len(profile_configs)} by {rank_num} devices")
+ # todo: 待添加注释,说明筛选逻辑
+ if len(profile_configs) > 10:
+ # pp <= 16
+ profile_configs = [cand for cand in profile_configs if cand[2] <= 16]
+ # 按tp , cp, 越小越好, ep从大到小排
+ profile_configs = sorted(profile_configs, key=lambda x: (x[1], x[4], -x[3]))
+ profile_configs = profile_configs[:10]
+ print(*(config for config in profile_configs), sep='\n')
+ return profile_configs
+
+def budget_profile_config_generator(config, para, available_devices):
+ """
+ 根据可用卡数,将config转换成当前卡资源可满足的并行配置做profile,先裁剪pp, pp最小为2,然后再缩减dp
+ 一次profiling同时获取不开重计算和完全重计算的信息
+
+ :param config: [dp, tp, pp, ep, cp, op, evaluate_peak_mem]
+ :param para: user config para for tool
+ :param available_devices:
+ :return: [dp, tp, pp, ep, offset, recompute, num_layers], min(pp)=2, 层数(num_layers+MTP)
+ """
+ if para.YAML_PATH:
+ if len(config) < 4:
+ basic_config = decide_pp_dp_llama(config, available_devices)
+ else:
+ basic_config = decide_pp_dp(config, available_devices)
+ pp = basic_config[2]
+ if len(config) < 4:
+ offset = 0
+ recompute = 0
+ num_layers = 32
+ else:
+ if pp == 2:
+ offset = [1, 0]
+ recompute = [6, 0]
+ num_layers = 11
+ elif pp == 4:
+ offset = [1, 1, 1, 0]
+ recompute = [3, 3, 0, 0]
+ num_layers = 11
+ elif pp == 8:
+ offset = [2, 0, 0, 0, 0, 1, 1, 1]
+ recompute = [3, 1, 1, 1, 1, 2, 0, 2]
+ num_layers = 13
+ else:
+ print(f'pp {pp} not supported')
+ return None
+
+ config_num_layers = num_layers if para.SHELL_PATH else (num_layers - 1)
+ basic_config.extend([offset, recompute, config_num_layers])
+ return basic_config
+ elif para.SHELL_PATH:
+ if len(config) < 4:
+ # TODO llama系模型待适配
+ basic_config = decide_pp_dp_llama(config, available_devices)
+ else:
+ basic_config = decide_dp(config, available_devices)
+ if basic_config is None:
+ return None
+ num_layer_list = [4, 3]
+ recompute_num_layers = 1
+ num_layers = 7
+ basic_config.extend([num_layer_list, recompute_num_layers, num_layers])
+ return basic_config
+ elif para.TOML_PATH:
+ basic_config = decide_dp_for_titan(config, available_devices)
+ return basic_config
+ else:
+ print('Error: No config file path!')
+ return None
+
+def taylor_pp_adaptor(profile_info):
+ layer_ratio = (profile_info['dense_fw']+profile_info['dense_bw'])/(profile_info['moe_fw']+profile_info['moe_bw'])
+ backward_ratio = (profile_info['moe_bw']/profile_info['moe_fw'])
+ return layer_ratio, backward_ratio
+
+def sapp_adaptor(profile_info):
+ body_dense = (profile_info['dense_fw']+profile_info['dense_bw'])/3
+ body_moe = (profile_info['moe_fw']+profile_info['moe_bw'])/3
+ tail = profile_info['head']/3
+ return profile_info['embed'], body_dense, body_moe, tail
+
+def profile_prepare(mem_prune_space, para, input_args):
+ """
+ prepare for profiling
+
+ :param para: user input args for tool
+ :param mem_prune_space: [dp, tp, pp, ep, cp, op, evaluate_peak_mem] or [dp, tp, pp]
+ :param input_args: model args and train args from config file
+ :return: ordered [dp, tp, pp, ep, cost]
+ """
+ # 配置转换可在现有卡资源下profile的配置
+ profile_configs = trans_config_satisfy_rank_num(mem_prune_space, para, input_args)
+ profile_dir = 'profile_yaml' if para.YAML_PATH else 'profile_shell' if para.SHELL_PATH else 'profile_toml'
+ profile_file_dir = os.path.abspath(para.OUTPUT_PATH) + os.sep + profile_dir + os.sep
+ file_task = "profile_yaml" if para.YAML_PATH else "profile_shell" if para.SHELL_PATH else 'profile_toml'
+ file_path_list, output_dir_list = generate_files(profile_configs, profile_file_dir, file_task, para, input_args)
+
+ result = []
+ for config, shell_file_path, profile_dir in zip(profile_configs, file_path_list, output_dir_list):
+ # 创建子列表的副本并添加元素(避免修改原列表)
+ new_config = config.copy()
+ new_config.append(shell_file_path)
+ new_config.append(profile_dir)
+ result.append(new_config)
+ return result, profile_file_dir
diff --git a/hyper_parallel/auto_parallel/fast-tuner/pyproject.toml b/hyper_parallel/auto_parallel/fast-tuner/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..b8fff1c0f345f2f8bc7962e3c630cc59c8f033ab
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/pyproject.toml
@@ -0,0 +1,71 @@
+[build-system]
+requires = ["setuptools>=61.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "fast-tuner"
+version = "0.0.1dev0"
+description = "Fast Algorithm Supported by Theories. AI Framework hyper-parameter tuner."
+readme = ""
+requires-python = ">=3.9"
+authors = [
+ {name = "Taylor Lab"}
+]
+maintainers = [
+ {name = "Taylor Lab"}
+]
+classifiers = [
+ "Development Status :: 2 - Pre-Alpha",
+ "Intended Audience :: Developers",
+ "License :: Other/Proprietary License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+]
+dependencies = [
+ "pyyaml",
+ "numpy<2.0.0",
+ "ortools",
+ "matplotlib",
+ "chardet",
+ "highspy",
+ "toml",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=7.0",
+ "black>=22.0",
+ "flake8>=4.0",
+ "mypy>=0.900",
+]
+
+[project.urls]
+Homepage = "http://example.com"
+Repository = "http://example.com"
+Issues = "http://example.com"
+
+[project.scripts]
+fast-tuner-parallel = "fast_tuner.parallel_tool:main"
+
+[tool.setuptools.packages.find]
+where = ["."]
+include = ["fast_tuner*"]
+
+[tool.black]
+line-length = 120
+target-version = ['py39']
+
+[tool.mypy]
+python_version = "3.9"
+warn_return_any = true
+warn_unused_configs = true
+disallow_untyped_defs = true
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
+python_files = ["test_*.py"]
+python_classes = ["Test*"]
+python_functions = ["test_*"]
diff --git a/hyper_parallel/auto_parallel/fast-tuner/test/units/nd_search_unit_test.py b/hyper_parallel/auto_parallel/fast-tuner/test/units/nd_search_unit_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbee4a279c11aed266cb2deba75fc6ebd5fbaa8c
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/test/units/nd_search_unit_test.py
@@ -0,0 +1,109 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import unittest
+
+from fast_tuner.ndsearch.build_initial_spaces import build_initial_spaces
+from fast_tuner.ndsearch.expert_filter_configs import expert_filter_configs
+from fast_tuner.ndsearch.memory_model import grey_box_memory_prune
+from fast_tuner.utils.input_config import InputConfig
+from fast_tuner.utils.profiling.profile_parser import ProfileParser
+
+
+# ND搜索单元测试用例
+class NdTestCase(unittest.TestCase):
+ # 测试根据dryrun的内存信息对候选配置剪枝功能
+ def test_memory_prune(self):
+ # input construct
+ input_args = InputConfig('../../config/pretrain_deepseek3_671b.yaml')
+ dryrun_info_init = [
+ [32, 2, 16, 8, 142786],
+ [128, 2, 4, 8, 188944],
+ [16, 2, 32, 4, 153480],
+ [16, 8, 8, 8, 96950],
+ [64, 8, 2, 8, 825586],
+ [64, 2, 8, 4, 221281],
+ [8, 4, 32, 8, 81943],
+ [64, 4, 4, 4, 283005],
+ [64, 1, 16, 8, 241679],
+ [32, 8, 4, 4, 270048],
+ [8, 4, 32, 4, 98967],
+ [256, 2, 2, 4, 1656206],
+ [64, 4, 4, 8, 162045],
+ [8, 8, 16, 8, 68632],
+ [4, 8, 32, 8, 54686],
+ [128, 1, 8, 4, 299846],
+ [8, 8, 16, 4, 101784],
+ [32, 8, 4, 8, 148595],
+ [256, 1, 4, 4, 363705],
+ [32, 4, 8, 8, 116591],
+ [4, 8, 32, 4, 71710],
+ [32, 4, 8, 4, 181999],
+ [32, 2, 16, 4, 175938],
+ [16, 2, 32, 8, 136456],
+ [64, 8, 2, 4, 1635569],
+ [64, 2, 8, 8, 155874],
+ [128, 1, 8, 8, 234458],
+ [16, 4, 16, 8, 93350],
+ [256, 2, 2, 8, 847944],
+ [128, 4, 2, 8, 832465],
+ [16, 8, 8, 4, 162358],
+ [32, 1, 32, 8, 245502],
+ [512, 1, 2, 8, 927251],
+ [64, 1, 16, 4, 274811],
+ [16, 4, 16, 4, 126502],
+ [128, 2, 4, 4, 309904],
+ [32, 1, 32, 4, 262506],
+ [256, 1, 4, 8, 242746],
+ [128, 4, 2, 4, 1642448],
+ [512, 1, 2, 4, 1687161]
+ ]
+ test_ep = [4, 8]
+ max_expert_parallel = 64
+ # execute prune
+ memory_aware_configs = grey_box_memory_prune(input_args, dryrun_info_init, test_ep, max_expert_parallel)
+ print(f"Dryrun Search space len: {len(memory_aware_configs)}, format: [dp, tp, pp, ep, evaluate_peak_mem]")
+ # output check
+ assert len(memory_aware_configs) == 10
+ for sub_array in memory_aware_configs:
+ assert len(sub_array) == 5
+ assert sub_array[3] <= 64
+
+
+ def test_initial_expert_space(self):
+ input_args = InputConfig('../../config/pretrain_deepseek3_671b.yaml')
+ initial_configs = build_initial_spaces(input_args)
+ # 生成搜索空间test
+ assert len(initial_configs) == 576030
+
+ # 专家剪枝test
+ expert_prune_search_space = expert_filter_configs(initial_configs, input_args, 1024)
+ assert len(expert_prune_search_space) == 8753
+
+ # dryrun yaml生成test
+
+ # parser test
+ def test_profile_parser(self):
+ test = ProfileParser()
+ test.config_anal('dp8tp4pp4ep32', [0, 32, 64, 127])
+ test.refined_data()
+ test.refresh()
+ test.config_anal('dp8tp4pp2ep32', [0, 127])
+ test.refined_data()
+
+if __name__ == "__main__":
+ suite = unittest.TestLoader().loadTestsFromTestCase(NdTestCase)
+ runner = unittest.TextTestRunner(verbosity=2)
+ runner.run(suite)
diff --git a/hyper_parallel/auto_parallel/fast-tuner/test/units/pipeline_alg_unit_test.py b/hyper_parallel/auto_parallel/fast-tuner/test/units/pipeline_alg_unit_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe0b9489c8a7bf987ecb563dbf0d1852aa5b4207
--- /dev/null
+++ b/hyper_parallel/auto_parallel/fast-tuner/test/units/pipeline_alg_unit_test.py
@@ -0,0 +1,248 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import sys
+import os
+from pathlib import Path
+sys.path.append(os.getcwd())
+import unittest
+
+from fast_tuner.pipeline_conductor import pipeline_parallel
+from fast_tuner.pipeline_conductor.start_service import InitConfig, ExpertInput
+
+#流水线负载均衡算法单元测试用例
+class PipeTestCase(unittest.TestCase):
+
+ def get_yaml_path(self, filename):
+ yaml_file_path = Path(__file__).resolve().parents[2] / 'config' / filename
+ return yaml_file_path
+
+ def test_micro64_dp16_tp4_pp16_vp1_ep8(self):
+ yaml_file = self.get_yaml_path('test_62_16_1_64_1_0.yaml')
+ mind_former_file = ''
+ expert_input = ExpertInput(yaml_file, mind_former_file)
+ expert_input.solver_name = 'HIGHS'
+ expert_input.is_dryrun = False
+ expert_input.backward_ratio = 1.839
+ expert_input.layer_ratio = 0.397
+ expert_input.head_loss = 1.5
+ expert_input.recompute_ratio = 1
+ model_input = InitConfig(expert_input)
+ model_input.memory.select_mem0 = 400.0
+ model_input.memory.select_mem12 = 400.0
+ model_input.memory.select_mem = 765.0
+ model_input.memory.re_comp_mem0 = 30.0
+ model_input.memory.re_comp_mem12 = 30.0
+ model_input.memory.re_comp_mem = 30.0
+ model_input.memory.act_mem0 = 479.0
+ model_input.memory.act_mem12 = 479.0
+ model_input.memory.act_mem = 765.0
+ model_input.memory.layer_mem012 = 866.0
+ model_input.memory.layer_mem = 10691.0
+ model_input.memory.static_mem0 = 3383.0
+ model_input.memory.static_mem = 2382.0
+ model_input.memory.lm_head_mem = 8675.0
+ model_input.memory.mem_lim_stage0 = 51913.0
+ model_input.memory.mem_lim_others = 52914.0
+ model_input.memory.mem_lim_last = 46621.0
+ model_input.memory.print_mem()
+ cur_solution = pipeline_parallel.solve_problem(model_input)
+ object_value = 1132.66
+ self.assertLessEqual(cur_solution.object_value, object_value + 0.01)
+ self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01)
+
+ def test768_4k_gp(self):
+ yaml_file = self.get_yaml_path('768die4k_gp.yaml')
+ mind_former_file = '~/mindformers/run_mindformer.py'
+ expert_input = ExpertInput(yaml_file, mind_former_file)
+ expert_input.solver_name = 'HIGHS'
+ expert_input.is_dryrun = False
+ expert_input.is_double_check = False
+ expert_input.time_limit = 1 * 60
+ expert_input.layer_ratio = 0.71
+ expert_input.backward_ratio = 1.627
+ expert_input.recompute_ratio = 0.246
+ expert_input.head_loss = 1.493
+ model_input = InitConfig(expert_input)
+ model_input.memory.select_mem0 = 416.0
+ model_input.memory.select_mem12 = 416.0
+ model_input.memory.select_mem = 674.0
+ model_input.memory.re_comp_mem0 = 46.0
+ model_input.memory.re_comp_mem12 = 46.0
+ model_input.memory.re_comp_mem = 158.0
+ model_input.memory.act_mem0 = 494.0
+ model_input.memory.act_mem12 = 494.0
+ model_input.memory.act_mem = 673.0
+ model_input.memory.layer_mem012 = 1010.0
+ model_input.memory.layer_mem = 2644.0
+ model_input.memory.static_mem0 = 4243.0
+ model_input.memory.static_mem = 2279.0
+ model_input.memory.lm_head_mem = 8468.0
+ model_input.memory.mem_lim_stage0 = 53101.0
+ model_input.memory.mem_lim_others = 55065.0
+ model_input.memory.mem_lim_last = 48876.0
+ model_input.memory.print_mem()
+ cur_solution = pipeline_parallel.solve_problem(model_input)
+ object_value = 3603.39
+ self.assertLessEqual(cur_solution.object_value, object_value + 0.01)
+ self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01)
+
+ def test512_8k(self):
+ yaml_file = self.get_yaml_path('layers62_micro960_dp4_tp8_pp16_vp1_ep32.yaml')
+ mind_former_file = ''
+ expert_input = ExpertInput(yaml_file, mind_former_file)
+ expert_input.solver_name = 'HIGHS'
+ expert_input.is_dryrun = False
+ expert_input.is_double_check = False
+ expert_input.time_limit = 3 * 60
+ expert_input.layer_ratio = 0.71
+ expert_input.backward_ratio = 1.627
+ expert_input.recompute_ratio = 0.246
+ expert_input.head_loss = 1.493
+ model_input = InitConfig(expert_input)
+ model_input.memory.select_mem0 = 480.0
+ model_input.memory.select_mem12 = 480.0
+ model_input.memory.select_mem = 737.0
+ model_input.memory.re_comp_mem0 = 78.0
+ model_input.memory.re_comp_mem12 = 78.0
+ model_input.memory.re_comp_mem = 190.0
+ model_input.memory.act_mem0 = 614.0
+ model_input.memory.act_mem12 = 614.0
+ model_input.memory.act_mem = 738.0
+ model_input.memory.layer_mem012 = 130.0
+ model_input.memory.layer_mem = 7230.0
+ model_input.memory.static_mem0 = 5789.0
+ model_input.memory.static_mem = 3210.0
+ model_input.memory.lm_head_mem = 7939.0
+ model_input.memory.mem_lim_stage0 = 51555.0
+ model_input.memory.mem_lim_others = 54134.0
+ model_input.memory.mem_lim_last = 49405.0
+ model_input.memory.print_mem()
+ cur_solution = pipeline_parallel.solve_problem(model_input)
+ object_value = 10943.91
+ self.assertLessEqual(cur_solution.object_value, object_value + 0.01)
+ self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01)
+
+ def test512_8k_no_swap(self):
+ yaml_file = self.get_yaml_path('512_8k_no_swap.yaml')
+ mind_former_file = '~/mindformers/run_mindformer.py'
+ expert_input = ExpertInput(yaml_file, mind_former_file)
+ expert_input.solver_name = 'HIGHS'
+ expert_input.is_dryrun = False
+ expert_input.is_double_check = False
+ expert_input.time_limit = 5 * 60
+ expert_input.layer_ratio = 0.71
+ expert_input.backward_ratio = 1.627
+ expert_input.recompute_ratio = 0.246
+ expert_input.head_loss = 1.493
+ model_input = InitConfig(expert_input)
+ model_input.memory.select_mem0 = 832.0
+ model_input.memory.select_mem12 = 832.0
+ model_input.memory.select_mem = 1347.0
+ model_input.memory.re_comp_mem0 = 92.0
+ model_input.memory.re_comp_mem12 = 92.0
+ model_input.memory.re_comp_mem = 316.0
+ model_input.memory.act_mem0 = 988.0
+ model_input.memory.act_mem12 = 988.0
+ model_input.memory.act_mem = 1346.0
+ model_input.memory.layer_mem012 = 1109.0
+ model_input.memory.layer_mem = 3670.0
+ model_input.memory.static_mem0 = 5618.0
+ model_input.memory.static_mem = 2667.0
+ model_input.memory.lm_head_mem = 12405.0
+ model_input.memory.mem_lim_stage0 = 49678.0
+ model_input.memory.mem_lim_others = 52629.0
+ model_input.memory.mem_lim_last = 42891.0
+ model_input.memory.print_mem()
+ cur_solution = pipeline_parallel.solve_problem(model_input)
+ object_value = 5661.55
+ self.assertLessEqual(cur_solution.object_value, object_value + 0.01)
+ self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01)
+
+ def test512_4k(self):
+ yaml_file = self.get_yaml_path('layers62_micro60_dp32_tp4_pp4_vp1_ep128.yaml')
+ mind_former_file = '~/mindformers/run_mindformer.py'
+ expert_input = ExpertInput(yaml_file, mind_former_file)
+ expert_input.solver_name = 'HIGHS'
+ expert_input.is_dryrun = False
+ expert_input.is_double_check = False
+ expert_input.time_limit = 3 * 60
+ expert_input.backward_ratio = 1.627
+ expert_input.layer_ratio = 0.71
+ expert_input.head_loss = 1.5
+ expert_input.recompute_ratio = 0.71
+ model_input = InitConfig(expert_input)
+ model_input.memory.select_mem0 = 415.0
+ model_input.memory.select_mem12 = 415.0
+ model_input.memory.select_mem = 672.0
+ model_input.memory.re_comp_mem0 = 30.0
+ model_input.memory.re_comp_mem12 = 30.0
+ model_input.memory.re_comp_mem = 142.0
+ model_input.memory.act_mem0 = 493.0
+ model_input.memory.act_mem12 = 493.0
+ model_input.memory.act_mem = 673.0
+ model_input.memory.layer_mem012 = 869.0
+ model_input.memory.layer_mem = 2111.0
+ model_input.memory.static_mem0 = 6704.0
+ model_input.memory.static_mem = 2377.0
+ model_input.memory.lm_head_mem = 7986.0
+ model_input.memory.mem_lim_stage0 = 48592.0
+ model_input.memory.mem_lim_others = 52919.0
+ model_input.memory.mem_lim_last = 47310.0
+ model_input.memory.print_mem()
+ cur_solution = pipeline_parallel.solve_problem(model_input)
+ object_value = 5868.49
+ self.assertLessEqual(cur_solution.object_value, object_value + 0.01)
+ self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01)
+
+ # 此用例运行时间较长,大概40分钟
+ def test512_8k_swap(self):
+ yaml_file = self.get_yaml_path('512_8k_swap.yaml')
+ mind_former_file = '~/mindformers/run_mindformer.py'
+ expert_input = ExpertInput(yaml_file, mind_former_file)
+ expert_input.solver_name = 'HIGHS'
+ expert_input.is_dryrun = False
+ expert_input.is_double_check = False
+ expert_input.time_limit = 1 * 60
+ expert_input.layer_ratio = 0.71
+ expert_input.backward_ratio = 1.627
+ expert_input.recompute_ratio = 0.246
+ expert_input.head_loss = 1.493
+ model_input = InitConfig(expert_input)
+ model_input.memory.select_mem0 = 832.0
+ model_input.memory.select_mem12 = 832.0
+ model_input.memory.select_mem = 1347.0
+ model_input.memory.re_comp_mem0 = 92.0
+ model_input.memory.re_comp_mem12 = 92.0
+ model_input.memory.re_comp_mem = 316.0
+ model_input.memory.act_mem0 = 988.0
+ model_input.memory.act_mem12 = 988.0
+ model_input.memory.act_mem = 1346.0
+ model_input.memory.layer_mem012 = 1109.0
+ model_input.memory.layer_mem = 3670.0
+ model_input.memory.static_mem0 = 5438.0
+ model_input.memory.static_mem = 2664.0
+ model_input.memory.lm_head_mem = 12402.0
+ model_input.memory.mem_lim_stage0 = 49858.0
+ model_input.memory.mem_lim_others = 52632.0
+ model_input.memory.mem_lim_last = 42894.0
+ model_input.memory.print_mem()
+ cur_solution = pipeline_parallel.solve_problem(model_input)
+ object_value = 5562.66
+ self.assertLessEqual(cur_solution.object_value, object_value + 20)
+ self.assertGreaterEqual(cur_solution.object_value, object_value - 20)
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/setup.py b/setup.py
index 49290d95acf748bf35c052955a73e5b66734c7e6..112170c3ae3cec1f01247fdac3cfb015b54598a6 100644
--- a/setup.py
+++ b/setup.py
@@ -59,7 +59,7 @@ def get_description():
os_info = get_platform()
cpu_info = platform.machine().strip()
- return 'hyper_parallel platform: %s, cpu: %s' % (os_info, cpu_info)
+ return f'hyper_parallel platform: {os_info}, cpu: {cpu_info}'
def get_install_requires():
@@ -69,7 +69,7 @@ def get_install_requires():
Returns:
list, list of dependent packages.
"""
- with open('requirements.txt') as file:
+ with open('requirements.txt', encoding='utf-8') as file:
return file.read().strip().splitlines()
@@ -151,7 +151,9 @@ if __name__ == '__main__':
long_description=get_readme_content(),
long_description_content_type="text/markdown",
test_suite="tests",
- packages=find_packages(exclude=["*tests*"]),
+ packages=find_packages(exclude=["*tests*",
+ "hyper_parallel.auto_parallel.fast-tuner",
+ "hyper_parallel.auto_parallel.fast-tuner.*"]),
platforms=[get_platform()],
include_package_data=True,
package_data=package_data,