diff --git a/hyper_parallel/auto_parallel/fast-tuner/README.md b/hyper_parallel/auto_parallel/fast-tuner/README.md index 32a1e9d4e562deb1287fe967b88eae0b95ba368b..0ab18663df7ef3f5138aebdc22836e6f6db5210d 100644 --- a/hyper_parallel/auto_parallel/fast-tuner/README.md +++ b/hyper_parallel/auto_parallel/fast-tuner/README.md @@ -64,4 +64,4 @@ pip install -e . [基于Mindformers](./docs/mindformers.md) 使用文档\ [基于Torchtitan](./docs/torchtitan.md) 使用文档 \ -[基于Mindspeed](./docs/mindspeed.md) 使用文档 \ No newline at end of file +[基于Mindspeed](./docs/mindspeed.md) 使用文档 diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/env_config/env_config.json b/hyper_parallel/auto_parallel/fast-tuner/config/env_config/env_config.json new file mode 100644 index 0000000000000000000000000000000000000000..06bcdb67382ec0cc8a49b6e1a267aec30d53d6e2 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/env_config/env_config.json @@ -0,0 +1,10 @@ +{ + "MS_SIMULATION_LEVEL": "1", + "GLOG_v": "3", + "MS_DEV_RUNTIME_CONF": "memory_statistics:True", + "MS_MEMORY_STATISTIC": "1", + "ENABLE_LAZY_INLINE_NO_PIPELINE": "1", + "MS_DEV_DUMP_IR_PASSES": "validate,graph_build,pipeline_split,recompute", + "MS_ALLOC_CONF": "memory_tracker:True", + "ENABLE_LESS_MEM_VPP": "0" +} \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh b/hyper_parallel/auto_parallel/fast-tuner/config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh new file mode 100644 index 0000000000000000000000000000000000000000..738be6ed879a38a6023900c6e98647d7309c92ff --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh @@ -0,0 +1,194 @@ +#!/bin/bash + +# 需要切换MindSpeed版本 +# git checkout 9648d729e4866f8037d0bd76630029410a60e6a6 # checkout commit from MindSpeed core_r0.8.0 in 2025.06.14 + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export CPU_AFFINITY_CONF=1 +export TASK_QUEUE_ENABLE=2 +export PYTORCH_NPU_ALLOC_CONF="expandable_segments:True" +export HCCL_CONNECT_TIMEOUT=3600 +export STREAMS_PER_DEVICE=32 + +NPUS_PER_NODE=16 +MASTER_ADDR=localhost #主节点IP +MASTER_PORT=6000 +NNODES=32 +NODE_RANK=0 + +CKPT_SAVE_DIR="your model save ckpt path" +DATA_PATH="your data path" +TOKENIZER_PATH="your tokenizer path" +CKPT_LOAD_DIR="your model ckpt path" + +TP=2 +PP=8 +EP=16 +CP=1 +CP_TYPE='ulysses_cp_algo' +NUM_LAYERS=64 +SEQ_LEN=4096 +MBS=1 +GBS=3840 + +DISTRIBUTED_ARGS=" + --nproc_per_node $NPUS_PER_NODE \ + --nnodes $NNODES \ + --node_rank $NODE_RANK \ + --master_addr $MASTER_ADDR \ + --master_port $MASTER_PORT +" + +MLA_ARGS=" + --multi-head-latent-attention \ + --qk-rope-head-dim 64 \ + --qk-nope-head-dim 128 \ + --q-lora-rank 1536 \ + --kv-lora-rank 512 \ + --v-head-dim 128 \ + --qk-layernorm \ + --mla-mm-split \ + --mla-fa-without-pad \ +" + +MOE_ARGS=" + --moe-grouped-gemm \ + --moe-permutation-async-comm \ + --use-fused-moe-token-permute-and-unpermute \ + --moe-token-dispatcher-type alltoall \ + --first-k-dense-replace 3 \ + --moe-layer-freq 1 \ + --n-shared-experts 1 \ + --num-experts 256 \ + --moe-router-topk 8 \ + --moe-intermediate-size 2048 \ + --moe-router-load-balancing-type noaux_tc \ + --n-group 8 \ + --topk-group 4 \ + --routed-scaling-factor 2.5 \ + --moe-aux-loss-coeff 0.0001 \ + --seq-aux \ + --norm-topk-prob \ + --moe-router-score-function sigmoid \ + --moe-router-enable-expert-bias \ + --moe-tp-extend-ep \ + --router-gating-in-fp32 \ +" + +MTP_ARGS=" + --mtp-num-layers 1 \ + --mtp-loss-scaling-factor 0.3 \ +" + +DUALPIPE_ARGS=" + --moe-fb-overlap \ + --schedules-method dualpipev \ +" + +MEM_ARGS=" + --mtp-mem-efficient-logits \ + --swap-optimizer \ + --recompute-granularity full \ + --recompute-method block \ + --recompute-num-layers 8 \ +" + +ROPE_ARGS=" + --rope-scaling-beta-fast 32 \ + --rope-scaling-beta-slow 1 \ + --rope-scaling-factor 40 \ + --rope-scaling-mscale 1.0 \ + --rope-scaling-mscale-all-dim 1.0 \ + --rope-scaling-original-max-position-embeddings 4096 \ + --rope-scaling-type yarn +" + +GPT_ARGS=" + --spec mindspeed_llm.tasks.models.spec.deepseek_spec layer_spec \ + --reset-position-ids \ + --gemm-gradient-accumulation-fusion \ + --noop-layers 61,62,63 \ + --manual-gc \ + --manual-gc-interval 50 \ + --no-shared-storage \ + --use-distributed-optimizer \ + --use-flash-attn \ + --use-mcore-models \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --expert-model-parallel-size ${EP} \ + --sequence-parallel \ + --context-parallel-size ${CP} \ + --context-parallel-algo ${CP_TYPE} \ + --num-layers ${NUM_LAYERS} \ + --hidden-size 7168 \ + --ffn-hidden-size 18432 \ + --num-attention-heads 128 \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path ${TOKENIZER_PATH} \ + --seq-length ${SEQ_LEN} \ + --max-position-embeddings 163840 \ + --micro-batch-size ${MBS} \ + --global-batch-size ${GBS} \ + --make-vocab-size-divisible-by 1 \ + --lr 1.0e-5 \ + --train-iters 2000 \ + --lr-decay-style cosine \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --attention-dropout 0.0 \ + --init-method-std 0.02 \ + --hidden-dropout 0.0 \ + --position-embedding-type rope \ + --normalization RMSNorm \ + --use-fused-rotary-pos-emb \ + --use-rotary-position-embeddings \ + --use-fused-swiglu \ + --use-fused-rmsnorm \ + --swiglu \ + --no-masked-softmax-fusion \ + --attention-softmax-in-fp32 \ + --min-lr 1.0e-7 \ + --weight-decay 1e-2 \ + --lr-warmup-iters 500 \ + --clip-grad 1.0 \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --initial-loss-scale 65536 \ + --vocab-size 129280 \ + --padded-vocab-size 129280 \ + --rotary-base 10000 \ + --norm-epsilon 1e-6 \ + --no-load-optim \ + --no-load-rng \ + --bf16 \ + --distributed-timeout-minutes 120 \ +" + +DATA_ARGS=" + --data-path $DATA_PATH \ + --split 100,0,0 +" + +OUTPUT_ARGS=" + --log-interval 1 \ + --save-interval 2000 \ + --eval-interval 2000 \ + --eval-iters 0 \ + --no-save-optim \ + --no-save-rng +" + +python -m torch.distributed.launch $DISTRIBUTED_ARGS pretrain_gpt.py \ + $GPT_ARGS \ + $DATA_ARGS \ + $OUTPUT_ARGS \ + $MLA_ARGS \ + $DUALPIPE_ARGS \ + $MEM_ARGS \ + $ROPE_ARGS \ + $MOE_ARGS \ + $MTP_ARGS \ + --save $CKPT_SAVE_DIR \ + --load $CKPT_LOAD_DIR \ + --distributed-backend nccl | tee logs/pretrain_deepseek3_671b_4k_A3_ptd.log \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/example/mindspore/pretrain_deepseek3_671b.yaml b/hyper_parallel/auto_parallel/fast-tuner/config/example/mindspore/pretrain_deepseek3_671b.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1604b865c609fe518f7f066f3b87458c5109fc81 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/example/mindspore/pretrain_deepseek3_671b.yaml @@ -0,0 +1,225 @@ +seed: 0 +output_dir: './output' # path to save checkpoint/strategy +load_checkpoint: '' +src_strategy_path_or_dir: '' +auto_trans_ckpt: False # If true, auto transform load_checkpoint to load in distributed pipeline +only_save_strategy: False +resume_training: False +use_parallel: True +run_mode: 'train' + +# trainer config +trainer: + type: CausalLanguageModelingTrainer + model_name: 'deepseekV3' + +# runner config +runner_config: + epochs: 2 + batch_size: 1 + sink_mode: True + sink_size: 1 + +# optimizer +optimizer: + type: AdamW + betas: [0.9, 0.95] + eps: 1.e-8 + +# lr schedule +lr_schedule: + type: ConstantWarmUpLR + learning_rate: 2.2e-4 + warmup_ratio: 0.02 + total_steps: -1 # -1 means it will load the total steps of the dataset + +# dataset +train_dataset: &train_dataset + data_loader: + type: BlendedMegatronDatasetDataLoader + datasets_type: "GPTDataset" + sizes: + - 5000 # train dataset size + - 0 + - 0 + config: + random_seed: 1234 + seq_length: 4096 + split: "1, 0, 0" + reset_position_ids: False + reset_attention_mask: False + eod_mask_loss: False + num_dataset_builder_threads: 1 + create_attention_mask: False + data_path: + - '1' + - "./dataset" + shuffle: False + input_columns: ["input_ids", "labels", "loss_mask", "position_ids"] + construct_args_key: ["input_ids", "labels"] + num_parallel_workers: 8 + python_multiprocessing: False + drop_remainder: True + repeat: 1 + numa_enable: False + prefetch_size: 1 +train_dataset_task: + type: CausalLanguageModelDataset + dataset_config: *train_dataset + +# mindspore context init config +context: + mode: 0 #0--Graph Mode; 1--Pynative Mode + device_target: "Ascend" + max_call_depth: 10000 + max_device_memory: "55GB" + save_graphs: False + save_graphs_path: "./graph" + jit_config: + jit_level: "O1" + ascend_config: + parallel_speed_up_json_path: "./parallel_speed_up.json" + +# parallel config for device num = 1024 +parallel_config: + data_parallel: &dp 16 + model_parallel: 4 + pipeline_stage: 16 + expert_parallel: 8 + micro_batch_num: µ_batch_num 32 + vocab_emb_dp: True + use_seq_parallel: True + gradient_aggregation_group: 4 +# when pipeline parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. +micro_batch_interleave_num: 1 + +# parallel context config +parallel: + parallel_mode: 1 # 0-data parallel, 1-semi-auto parallel, 2-auto parallel, 3-hybrid parallel + gradients_mean: False + enable_alltoall: True + full_batch: False + dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1]] + search_mode: "sharding_propagation" + enable_parallel_optimizer: True + strategy_ckpt_config: + save_file: "./ckpt_strategy.ckpt" + only_trainable_params: False + parallel_optimizer_config: + gradient_accumulation_shard: False + parallel_optimizer_threshold: 64 + +# recompute config +recompute_config: + recompute: [3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 2, 0] + select_recompute: False + parallel_optimizer_comm_recompute: True + mp_comm_recompute: True + recompute_slice_activation: True + +# pipeline config +model: + model_config: + type: DeepseekV3Config + auto_register: deepseek3_config.DeepseekV3Config + batch_size: 1 # add for increase predict + seq_length: 4096 + hidden_size: 7168 + num_layers: &num_layers 61 + num_heads: 128 + max_position_embeddings: 4096 + intermediate_size: 18432 + kv_lora_rank: 512 + n_kv_heads: 128 + q_lora_rank: 1536 + qk_rope_head_dim: 64 + v_head_dim: 128 + qk_nope_head_dim: 128 + vocab_size: 129280 + multiple_of: 256 + rms_norm_eps: 1.0e-6 + bos_token_id: 100000 + eos_token_id: 100001 + pad_token_id: 100001 + ignore_token_id: -100 + compute_dtype: "bfloat16" + layernorm_compute_type: "float32" + softmax_compute_type: "float32" + rotary_dtype: "float32" + router_dense_type: "float32" + param_init_type: "float32" + use_past: False + extend_method: "None" + use_flash_attention: True + input_sliced_sig: True + offset: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1] + checkpoint_name_or_path: "" + theta: 10000.0 + return_extra_loss: True + mtp_depth: &mtp_depth 1 + mtp_loss_factor: 0.3 + arch: + type: DeepseekV3ForCausalLM + auto_register: deepseek3.DeepseekV3ForCausalLM + +#moe +moe_config: + expert_num: &expert_num 256 + expert_group_size: 8 + capacity_factor: 1.5 + aux_loss_factor: 0.05 + num_experts_chosen: 8 + routing_policy: "TopkRouterV2" + enable_sdrop: False + balance_via_topk_bias: &balance_via_topk_bias True + topk_bias_update_rate: &topk_bias_update_rate 0.001 + use_fused_ops_topkrouter: True + group_wise_a2a: False + shared_expert_num: 1 + routed_scaling_factor: 2.5 + norm_topk_prob: False + first_k_dense_replace: 3 + moe_intermediate_size: 2048 + # greedy_group_limited strategy, select topk_group from n_group + topk_group: 4 + n_group: 8 + aux_loss_factors: [0.0001, 0., 0.] + aux_loss_types: ["expert", "device", "comm"] + z_loss_factor: 0.0 + expert_model_parallel: 1 + use_gating_sigmoid: True + callback_moe_droprate: False + + +# callbacks +callbacks: + - type: MFLossMonitor + per_print_times: 1 + # balance topk bias with callback + - type: TopkBiasBalanceCallback + balance_via_topk_bias: *balance_via_topk_bias + topk_bias_update_rate: *topk_bias_update_rate + num_layers: *num_layers + mtp_depth: *mtp_depth + expert_num: *expert_num + micro_batch_num: *micro_batch_num + - type: CheckpointMonitor + prefix: "deepseekv3" + save_checkpoint_steps: 1000 + keep_checkpoint_max: 5 + integrated_save: False + async_save: False + checkpoint_format: "safetensors" + +# wrapper cell config +runner_wrapper: + type: MFTrainOneStepCell + scale_sense: 1.0 + use_clip_grad: True + +profile: False +profile_start_step: 1 +profile_stop_step: 10 +init_start_profile: False +profile_communication: False +profile_memory: True diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/example/torchtitan/debug_model.toml b/hyper_parallel/auto_parallel/fast-tuner/config/example/torchtitan/debug_model.toml new file mode 100644 index 0000000000000000000000000000000000000000..1951cc43503574138e9208442c626c01007834bc --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/example/torchtitan/debug_model.toml @@ -0,0 +1,79 @@ +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 debug training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_dryrun.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_dryrun.json new file mode 100644 index 0000000000000000000000000000000000000000..1319893154730045d294d7802251ab9e1a3b05bd --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_dryrun.json @@ -0,0 +1,12 @@ +{ + "train_yaml": "", + "mindformers_dir": "", + "output_path": "dryrun_output", + "offset": [0,1,1,1,1,1,1,0], + "is_recompute": true, + "recompute_layers": [1,1,1,1,1,1,1,1], + "is_select_recompute": false, + "select_recompute_layers": null, + "env_json": "./config/boss_env_config.json", + "dryrun_lim": 16 +} \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool.json new file mode 100644 index 0000000000000000000000000000000000000000..6c866e16580787afbe3e5c748fc56dceb4f83378 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool.json @@ -0,0 +1,9 @@ +{ + "yaml_path" : "./config/pretrain_deepseek3_671b.yaml", + "mindformers_dir" : "./config/pretrain_gpt.py", + "rank_num" : 64, + "dataset" : "./config/dataset/wiki103-4k.mindrecord", + "output_path": "./output/nd_output/", + "env_json" : "./config/env_config.json", + "gbs": 1024 +} \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_mcore.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_mcore.json new file mode 100644 index 0000000000000000000000000000000000000000..2f0e19131809d03eda6f6d216102b5a1c9af9036 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_mcore.json @@ -0,0 +1,10 @@ +{ + "shell_path" : "./config/example/mcore/pretrain_deepseek3_671b_4k_A3_ptd.sh", + "mindspeed_path": "./config/pretrain_gpt.py", + "rank_num" : 512, + "dataset" : "./config/dataset/wiki103-4k.mindrecord", + "max_expert_parallel": 64, + "output_path": "./output/nd_output/", + "env_json" : "./config/env_config.json", + "gbs": 3840 +} \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_titan.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_titan.json new file mode 100644 index 0000000000000000000000000000000000000000..21d2c2aaf8e3041dbe7b40b939d63f9de4ee4682 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_parallel_tool_titan.json @@ -0,0 +1,16 @@ +{ + "toml_path" : "./config/example/torchtitan/debug_model.toml", + "torchtitan_path": "./torchtitan/run_train.sh", + "rank_num" : 8, + "output_path": "./output/", + "npus_per_node": 8, + "nnodes": 1, + "strategy": { + "DP": true, + "TP": true, + "EP": true, + "FSDP": true, + "PP": false, + "CP": false + } +} \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_parallel.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_parallel.json new file mode 100644 index 0000000000000000000000000000000000000000..bc0862ac2805a7d6a4c4eda6b9f98aa2f7eeb58d --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_parallel.json @@ -0,0 +1,16 @@ +{ + "llm_class": "0", + "train_yaml": "", + "mindformers_dir": "", + "layer_ratio": 0.33, + "backward_ratio": 2.0, + "head_loss": 1.5, + "recompute_ratio": 1, + "time_limit": 9999999999, + "dryrun": true, + "check": true, + "fit_level": 0, + "extract": false, + "env_json": "./config/boss_env_config.json", + "dryrun_lim": 16 +} \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_tool.json b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_tool.json new file mode 100644 index 0000000000000000000000000000000000000000..e4a794ebd311b8747b984403fabfb8eed610e2eb --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/config/setup_config/args_for_pipeline_tool.json @@ -0,0 +1,10 @@ +{ + "yaml_dir": "./output/dryrun_yaml/", + "mindformers_dir": "./mindformers/run_mindformer.py", + "profile_data_dir": "./profile_data/", + "parser_result": null, + "env_json": "./config/boss_env_config.json", + "dryrun_lim": 16, + "topn": 0, + "max_vpp": 3 +} \ No newline at end of file diff --git a/hyper_parallel/auto_parallel/fast-tuner/docs/mindformers.md b/hyper_parallel/auto_parallel/fast-tuner/docs/mindformers.md new file mode 100644 index 0000000000000000000000000000000000000000..fefc77209c5ddb05d24f5f03b5639e4029e26983 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/docs/mindformers.md @@ -0,0 +1,115 @@ +# 基于Mindformers的自动搜索 + +## 1 快速上手 + +在安装了fast-tuner工具之后,使用如下命令: + +```bash +fast-tuner-parallel --config ./config/setup_config/args_for_parallel_tool.json +``` + +其中 json 文件示例如下: + +```json +{ + // 基础配置 + "rank_num" : 8, + "gbs": 32, + "output_path": "./output/nd_output/", + + // Mindspore相关配置 + "yaml_path" : "./config/pretrain.yaml", + "mindformers_dir" : "./mindformers", + "dataset" : "./config/dataset/wiki103-4k.mindrecord" +} +``` + +参考文件:[配置文件 args_for_parallel_tool.json](../config/setup_config/args_for_parallel_tool.json) + +输出推荐并行配置: + +```js +dp, tp, pp, ep, step_time(μs) + 2, 1, 4, 1, 14383 + 1, 1, 8, 1, 15811 + 2, 2, 2, 1, 22738 + 4, 1, 2, 1, N/A +``` + +## 2 参数说明 + +**基础参数** + +| 参数 | 含义 | 默认值 | +|---------------|--------------------------------------------|------------------------------------------------------------------------| +| rank_num | 工具可用于做profile的卡数 | 8 | +| strategy | 工具搜索的并行策略,可选策略有 [DP, TP, PP, CP, EP, FSDP] | {"DP":true, "TP":true, "EP":true, "FSDP":true, "CP":false, "PP":false} | +| gbs | global batch size | 32 | +| output_path | 日志输出路径 | ./output/nd_output/ | + +**加速库参数** + + + + + + + + + + + + + + + + + + + + + + + + +
加速库参数含义示例
mindformersmindformers_dirrun_mindformers.py路径,基于原生Mindspore搜索时需要填写该参数./mindformers
yaml_path模型配置文件./config/pretrain.yaml
dataset数据集路径./config/dataset/wiki103-4k.mindrecord
+ +通常用户只需要配置好基础参数,也就是模型规模,支持的并行范式等信息,以及所用的大模型训练套件信息,即可开始配置搜索。对于熟悉本工具的用户,也提供一些更高自由度的选项。 + +**高级参数** + +| 参数 | 含义 | 默认值 | +|---------------------|--------------------------------------------|----------------------------------------| +| select_recompute | 是否搜索自定义选重 | True | +| dryrun_data_dir | 已有dryrun数据路径 | None | +| profile_data_dir | 已有profile数据路径 | None | +| parallel_num | dryrun并行数 | 2 | +| max_expert_parallel | 搜索范围的最大EP数 | 64 | +| parser_result | profile解析结果csv文件,不需要解析profile时填写此参数即可 | None | +| dryrun | 是否使用dryrun计算内存信息 | True | +| check | 是否使用double_check进行内存拟合 | True | +| alg_phase | 选择使用搜索算法阶段;0--全流程搜索, 1--ND搜索, 2--流水线负载均衡搜索 | 1 | + +## 3 参数配置 + +对于想要修改的参数,直接在json配置文件中修改即可。比如说在支持DP,TP,EP和PP,不支持CP,FSDP的场景下,可以按照如下方式修改: + +```json +{ + #其他配置 + "xxx": x, + ... + + #自定义需要调优的并行配置 + "strategy": { + "DP": true, + "TP": true, + "PP": true, + "EP": true, + "CP": false, + "FSDP": false, + } +} +``` + +[返回原文档](../README.md) diff --git a/hyper_parallel/auto_parallel/fast-tuner/docs/mindspeed.md b/hyper_parallel/auto_parallel/fast-tuner/docs/mindspeed.md new file mode 100644 index 0000000000000000000000000000000000000000..ab57c258ba385ebd97567fcfa9da309895dc886c --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/docs/mindspeed.md @@ -0,0 +1,105 @@ +# 基于Mindspeed的自动搜索 + +## 1 快速上手 + +在安装了fast-tuner工具之后,使用如下命令: + +```bash +fast-tuner-parallel --config ./config/setup_config/args_for_parallel_tool_mcore.json +``` + +其中 json 文件示例如下: + +```json +{ + // 基础配置 + "rank_num" : 8, + "gbs": 32, + "output_path": "./output/nd_output/", + + // Mindspeed相关配置 + "shell_path" : "./config/example/mcore/pretrain.sh", + "mindformers_dir" : "./config/pretrain_gpt.py" +} +``` + +参考文件:[配置文件 args_for_parallel_tool_mcore.json](../config/setup_config/args_for_parallel_tool_mcore.json) + +输出推荐并行配置: + +```js +dp, tp, pp, ep, step_time(μs) + 2, 1, 4, 1, 14383 + 1, 1, 8, 1, 15811 + 2, 2, 2, 1, 22738 + 4, 1, 2, 1, N/A +``` + +## 2 参数说明 + +**基础参数** + +| 参数 | 含义 | 默认值 | +|---------------|----------------------------------|------------------------------------------------------------------------| +| rank_num | 工具可用于做profile的卡数 | 8 | +| strategy | 工具搜索的并行策略,可选策略有 [DP, TP, PP, CP, EP, FSDP] | {"DP":true, "TP":true, "EP":true, "FSDP":true, "CP":false, "PP":false} | +| gbs | global batch size | 32 | +| output_path | 日志输出路径 | ./output/nd_output/ | + +**加速库参数** + + + + + + + + + + + + + + + + + + + +
加速库参数含义示例
Mindspeedmindspeed_pathMindSpeed训练脚本路径./home/pretrain_gpt.py
shell_path模型配置文件./config/pretrain.sh
+ +通常用户只需要配置好基础参数,也就是模型规模,支持的并行范式等信息,以及所用的大模型训练套件信息,即可开始配置搜索。对于熟悉本工具的用户,也提供一些更高自由度的选项。 + +**高级参数** + +| 参数 | 含义 | 默认值 | +|---------------------|--------------------------------------------|----------------------------------------| +| select_recompute | 是否搜索自定义选重 | True | +| profile_data_dir | 已有profile数据路径 | None | +| max_expert_parallel | 搜索范围的最大EP数 | 64 | +| parser_result | profile解析结果csv文件,不需要解析profile时填写此参数即可 | None | +| alg_phase | 选择使用搜索算法阶段;0--全流程搜索, 1--ND搜索, 2--流水线负载均衡搜索 | 1 | + +## 3 参数配置 + +对于想要修改的参数,直接在json配置文件中修改即可。比如说在支持DP,TP,EP和PP,不支持CP,FSDP的场景下,可以按照如下方式修改: + +```json +{ + // 其他配置 + "xxx": x, + ... + + // 自定义需要调优的并行配置 + "strategy": { + "DP": true, + "TP": true, + "PP": true, + "EP": true, + "CP": false, + "FSDP": false + } +} +``` + +[返回原文档](../README.md) diff --git a/hyper_parallel/auto_parallel/fast-tuner/docs/torchtitan.md b/hyper_parallel/auto_parallel/fast-tuner/docs/torchtitan.md new file mode 100644 index 0000000000000000000000000000000000000000..d1d5337eab1fe67cea6d1a9a0ae942553023f804 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/docs/torchtitan.md @@ -0,0 +1,110 @@ +# 基于Torchtitan的自动搜索 + +## 1 快速上手 + +在安装了fast-tuner工具之后,使用如下命令: + +```bash +fast-tuner-parallel --config ./config/args_for_parallel_tool_titan.json +``` + +其中 json 文件示例如下: + +```json +{ + // 基础配置 + "rank_num" : 8, + "gbs": 32, + "npus_per_node": 8, + "nnodes": 1, + "output_path": "./output/nd_output/", + + // Torchtitan相关配置 + "toml_path" : "./config/example/torchtitan/debug_model.toml", + "torchtitan_path" : "./torchtitan/run_train.sh" +} +``` + +参考文件:[配置文件 args_for_parallel_tool_titan.json](../config/setup_config/args_for_parallel_tool_titan.json) + +输出推荐并行配置: + +```js +dp, tp, pp, ep, step_time(μs) + 2, 1, 4, 1, 14383 + 1, 1, 8, 1, 15811 + 2, 2, 2, 1, 22738 + 4, 1, 2, 1, N/A +``` + +## 2 参数说明 + +**基础参数** + +| 参数 | 含义 | 默认值 | +|---------------|--------------------------------------------|------------------------------------------------------------------------| +| rank_num | 工具可用于做profile的卡数 | 8 | +| npu_per_nodes | 用户用于训练的每个节点的卡数 | 8 | +| nnodes | 用户用于训练的节点数 | 1 | +| strategy | 工具搜索的并行策略,可选策略有 [DP, TP, PP, CP, EP, FSDP] | {"DP":true, "TP":true, "EP":true, "FSDP":true, "CP":false, "PP":false} | +| gbs | global batch size | 32 | +| output_path | 日志输出路径 | ./output/nd_output/ | + +**加速库参数** + + + + + + + + + + + + + + + + + + + +
加速库参数含义示例
torchtitantorchtitan_path训练脚本路径,本工具需要此文件拉起profile./torchtitan/run_train.sh
toml_path模型参数配置文件./config/example/torchtitan/debug_model.toml
+ +通常用户只需要配置好基础参数,也就是模型规模,支持的并行范式等信息,以及所用的大模型训练套件信息,即可开始配置搜索。对于熟悉本工具的用户,也提供一些更高自由度的选项。 + +**高级参数** + +| 参数 | 含义 | 默认值 | +|---------------------|--------------------------------------------|----------------------------------------| +| select_recompute | 是否搜索自定义选重 | True | +| profile_data_dir | 已有profile数据路径 | None | +| parallel_num | dryrun并行数 | 2 | +| max_expert_parallel | 搜索范围的最大EP数 | 64 | +| parser_result | profile解析结果csv文件,不需要解析profile时填写此参数即可 | None | +| alg_phase | 选择使用搜索算法阶段;0--全流程搜索, 1--ND搜索, 2--流水线负载均衡搜索 | 1 | + +## 3 参数配置 + +对于想要修改的参数,直接在json配置文件中修改即可。比如说在支持DP,TP,EP和PP,不支持CP,FSDP的场景下,可以按照如下方式修改: + +```json +{ + #其他配置 + "xxx": x, + ... + + #自定义需要调优的并行配置 + "strategy": { + "DP": true, + "TP": true, + "PP": true, + "EP": true, + "CP": flase, + "FSDP": false, + } +} +``` + +[返回原文档](../README.md) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/build_initial_spaces.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/build_initial_spaces.py new file mode 100644 index 0000000000000000000000000000000000000000..54e56cb302a20d9901312f8b319112c67a9d37ff --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/build_initial_spaces.py @@ -0,0 +1,116 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""build initial space for nd""" + +import itertools +import math + +def cal_factor(num): + factors = [] + for i in range(1, num + 1): + if num % i == 0: + factors.append(i) + return factors + +def find_integers_less_than_ceil(num_layers, pp): + if pp == 1: + # 不开pp时, vpp为1 + return [1] + ceil_value = math.ceil(num_layers / pp) + result = list(range(1, ceil_value)) + return result + +def find_ep(expert_num): + if expert_num is None: + return [1] + ep_list = [] + for i in range(1, expert_num + 1): + if i == 1 or i % 2 == 0: + ep_list.append(i) + return ep_list + +def build_initial_spaces(input_args, para): + """ + generate initial space for nd + """ + # part1. 求[dp, tp, cp, pp]的取值范围 + # 找到 world_size 的所有因子 + factors = cal_factor(input_args.world_size) + + # 找出所有可能的四个正整数乘积等于world_size + part1_combinations = [] + for combination in itertools.combinations_with_replacement(factors, 4): + product = 1 + for num in combination: + product *= num + if product == input_args.world_size: + perms = list(itertools.permutations(combination)) + unique_configs = list(set(perms)) + part1_combinations.extend(unique_configs) + part1_combinations = filter_specify_strategy(para, part1_combinations) + + # part2. [ep, op, vp, mbs] + # ep是dp*tp的约数 + # vp是 < ceil(总层数/pp)的任意整数 + # mbs 取值(1, 2, 4)中 + part2_combinations = [] + num_layers = input_args.num_layers + for world_size_config in part1_combinations: + op_options = get_op_options(world_size_config) + vp_options = find_integers_less_than_ceil(num_layers, world_size_config[3]) + mbs_options = [1, 2, 4] + # todo: mindformers 和 mindspeed 对ep的限制不同,当前是mindformers的,待修改,把这个挪到专家剪枝 + ep_options = find_ep(input_args.expert_num) + result = list(itertools.product(ep_options, op_options, vp_options, mbs_options)) + + dp = world_size_config[0] + for tmp_result in result: + op = tmp_result[1] + # 格式 [(dp, tp, cp, pp), (ep, op, vp, mbs)] + if not para.STRATEGY.FSDP: + if op != dp: + # 这里不搜FSDP的含义是: FSDP取剩余的空间 + tmp_result = (tmp_result[0], -1, tmp_result[2], tmp_result[3]) + part2_combinations.append([world_size_config, tmp_result]) + + + if not para.STRATEGY.EP: + part2_combinations = [config for config in part2_combinations if config[1][0] == 1] + # part3. sp只有开关与否 格式 [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp] + final_combinations = [] + for part2_config in part2_combinations: + final_combinations.append([part2_config, False]) + if part2_config[0][1] != 1: + final_combinations.append([part2_config, True]) + return final_combinations + + +def filter_specify_strategy(para, part1_combinations): + if not para.STRATEGY.DP: + part1_combinations = [config for config in part1_combinations if config[0] == 1] + if not para.STRATEGY.TP: + part1_combinations = [config for config in part1_combinations if config[1] == 1] + if not para.STRATEGY.CP: + part1_combinations = [config for config in part1_combinations if config[2] == 1] + if not para.STRATEGY.PP: + part1_combinations = [config for config in part1_combinations if config[3] == 1] + return part1_combinations + + +# 优化器并行可以是dp的任意因子 +def get_op_options(world_size_config): + dp = world_size_config[0] + op_options = cal_factor(dp) + return op_options diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/expert_filter_configs.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/expert_filter_configs.py new file mode 100644 index 0000000000000000000000000000000000000000..b788712f67202a6d976f0d536d185555230a921c --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/expert_filter_configs.py @@ -0,0 +1,200 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""expert filter""" + +import math +from fast_tuner.utils.logger import logger + + +class ExpertFilterManager: + """ + expert experience filter nd configs + """ + def __init__(self, input_args, gbs): + self.expert_filters = [] + self.input_args = input_args + self.gbs = gbs + + @staticmethod + def sequential_combination(selected_experiences, candidate_space): + result = candidate_space + for exp in selected_experiences: + result = exp(result) + return result + + @staticmethod + def get_cp(config): + return config[0][0][2] + + @staticmethod + def get_dp(config): + return config[0][0][0] + + @staticmethod + def get_op(config): + return config[0][1][1] + + @staticmethod + def get_tp(config): + return config[0][0][1] + + @staticmethod + def get_pp(config): + return config[0][0][3] + + @staticmethod + def get_sp_switch(config): + return config[1] + + @staticmethod + def get_world_size(config): + return math.prod(config[0][0]) + + @staticmethod + def get_mbs(config): + return config[0][1][3] + + @staticmethod + def get_ep(config): + return config[0][1][0] + + def get_gbs(self): + return self.gbs + + def get_num_layers(self): + return self.input_args.num_layers + + def get_mbn(self): + return self.input_args.mbn + + def add_experience(self, experience_function): + """ + 添加一个专家经验函数到列表中 + :param experience_function: 要添加的专家经验函数 + """ + self.expert_filters.append(experience_function) + logger.info(f"add experience succ: {experience_function.__name__}") + + def remove_experience(self, experience_function): + """ + 从列表中移除一个专家经验函数 + :param experience_function: 要移除的专家经验函数 + """ + if experience_function in self.expert_filters: + self.expert_filters.remove(experience_function) + logger.info(f"remove experience succ: {experience_function.__name__}") + else: + logger.info(f"can not find experience: {experience_function.__name__},无法移除。") + + def ep_for_torchtitan(self, candidate_space): + """ + 这里默认为etp=1的场景 + """ + configs = [] + for config in candidate_space: + ep = self.get_ep(config) + cp = self.get_cp(config) + tp = self.get_tp(config) + op = self.get_op(config) + if ep % (cp * tp) == 0 and (op * cp * tp) % ep == 0: + configs.append(config) + return configs + + def ep_for_mindspore(self, candidate_space): + configs = [] + for config in candidate_space: + ep = self.get_ep(config) + dp = self.get_dp(config) + tp = self.get_tp(config) + op = self.get_op(config) + if op % ((dp * tp) / ep) == 0: + configs.append(config) + return configs + + def cp_for_deepseek_expert(self, candidate_space): + # 2.28deepseek版本cp = 1 + return [config for config in candidate_space if self.get_cp(config) == 1] + + def dp_cp_ep_for_megatron_expert(self, candidate_space): + # megatron dp * cp % ep == 0 + return [config for config in candidate_space + if self.get_dp(config) * self.get_cp(config) % self.get_ep(config) == 0] + + def pp_for_deepseek(self, candidate_space): + # 万卡训练deepseek, 要求pp>1, pp过小内存会超 + return [config for config in candidate_space if self.get_pp(config) > 1] + + def pp_for_768die(self, candidate_space): + # 768die, 要求pp<=32, + return [config for config in candidate_space if self.get_pp(config) <= 32] + + def tp_for_910b_expert(self, candidate_space): + # 910b为1机8卡,机间通信耗时过大,因此不能超过8 + return [config for config in candidate_space if self.get_tp(config) <= 8] + + def tp_for_large_scale_expert(self, candidate_space): + # 超节点技术,专家经验不超过64即可 + return [config for config in candidate_space if self.get_tp(config) <= 64] + + def tp_for_large_scale_768die(self, candidate_space): + # 768die, tp要是2的幂 + return [config for config in candidate_space if self.get_tp(config) % 3 != 0] + + def tp_for_yoco_expert(self, candidate_space): + return [config for config in candidate_space if 56 % self.get_tp(config) == 0] + + def ep_for_large_scale_expert(self, candidate_space): + return [config for config in candidate_space if self.get_ep(config) <= 64] + + def sp_for_lm_expert(self, candidate_space): + # 在千卡规模及以上sp一定开启,千卡规模以下可支持sp搜索 + world_size = self.get_world_size(candidate_space[0]) + return [config for config in candidate_space + if world_size < 1000 or (self.get_tp(config) ==1 or self.get_sp_switch(config))] + + def pp_for_mbs_expert(self, candidate_space): + return [config for config in candidate_space if + self.get_pp(config) <= + min(self.get_num_layers(), self.get_gbs() // self.get_dp(config) // self.get_mbs(config))] + + def gbs_for_dp_expert(self, candidate_space): + return [config for config in candidate_space if self.get_gbs() % self.get_dp(config) == 0] + +def expert_filter_configs(search_spaces, input_args, gbs): + """ + + :param search_spaces: 初始搜索空间 [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp] + :param input_args: 用户输入模型配置信息 + :param gbs: global batch size + :return: 使用专家经验剪枝搜索空间后得到的配置 [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp] + """ + expert_manager = ExpertFilterManager(input_args, gbs) + expert_manager.add_experience(expert_manager.cp_for_deepseek_expert) + expert_manager.add_experience(expert_manager.tp_for_large_scale_expert) + expert_manager.add_experience(expert_manager.ep_for_large_scale_expert) + expert_manager.add_experience(expert_manager.sp_for_lm_expert) + expert_manager.add_experience(expert_manager.pp_for_mbs_expert) + expert_manager.add_experience(expert_manager.gbs_for_dp_expert) + #expert_manager.add_experience(expert_manager.pp_for_deepseek) + #expert_manager.add_experience(expert_manager.dp_cp_ep_for_megatron_expert) + expert_manager.add_experience(expert_manager.ep_for_torchtitan) + #expert_manager.add_experience(expert_manager.ep_for_mindspore) + # add for 768die + expert_manager.add_experience(expert_manager.pp_for_768die) + expert_manager.add_experience(expert_manager.tp_for_large_scale_768die) + # # add for yoco model + # expert_manager.add_experience(expert_manager.tp_for_yoco_expert) + valid_configs = expert_manager.sequential_combination(expert_manager.expert_filters, search_spaces) + return valid_configs diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/memory_model.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/memory_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8e3048da43cbfd32a588e5d09d32fb79b23665 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/memory_model.py @@ -0,0 +1,417 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""nd phase memory model""" + +import re +import os +import math +import csv +from enum import Enum +from collections import defaultdict +from fast_tuner.utils.logger import logger +from fast_tuner.utils.common import generate_files, initial_offset, offset_for_dualpipe, is_dualpipe_open +from fast_tuner.utils.dryrun_manage import launch_dryrun, read_dryrun_info + + +def memory_simple_prune(config, profile_info): + memory_max = config.memory_max + tmp_layer = config.num_layer // config.pp_size + profile_memory = config.pp_size * tmp_layer * profile_info.act_mem_full_recomp \ + + profile_info.embedding_mem + tmp_layer * profile_info.static_layer_MoE + coe_mem = 0.2 + if profile_memory * (1 + coe_mem) > memory_max: + return False + return True + +def trans_format(dryrun_info_init, test_ep): + """ + trans dryrun_info_init to [dp, tp, pp, ep, peak_mem_ep1, peak_mem_ep2] + """ + config_mem_dict = defaultdict(lambda: [0, 0]) + for config in dryrun_info_init: + key = tuple(config[:3]) + peak_mem = config[4] + ep = config[3] + if ep == test_ep[0]: + index = 0 + elif ep == test_ep[1]: + index = 1 + else: + logger.error(f"error ep {ep} is not supported") + continue + config_mem_dict[key][index] = peak_mem + + return [list(key) + value for key, value in config_mem_dict.items()] + +def grey_box_memory_prune(mindformers_args, dryrun_info_init, test_ep, max_expert_parallel): + """ + ep灰盒剪枝 + + :param max_expert_parallel: ep最大值 + :param test_ep: + :param dryrun_info_init: [dp, tp, pp, ep, peak_value] + :param mindformers_args: + :return: [dp, tp, pp, ep, evaluate_peak_mem] + """ + dryrun_info = trans_format(dryrun_info_init, test_ep) + ep1, ep2 = test_ep + logger.info(f"dryrun_info len: {len(dryrun_info)} format [dp, tp, pp, peak_mem_ep{ep1}, peak_mem_ep{ep2}]") + logger.info("\n".join(str(config) for config in dryrun_info)) + + try: + max_mem = int(re.search(r'\d+', mindformers_args.context.max_device_memory).group()) * 1024 + except (AttributeError, ValueError): + max_mem = 58 * 1024 + + memory_aware_configs = [] + logger.info("format: dp_tp_pp_ep_evaluateMem") + ep_power = find_power_of_two(max_expert_parallel) + for dp, tp, pp, peak_ep, peak_ep_double in dryrun_info: + # 线性拟合ep会影响到的内存和ep不会影响到的内存 + ep_memory = (peak_ep - peak_ep_double) * test_ep[1] #所有专家的内存 + base_memory = peak_ep - ep_memory / test_ep[0] + # 确定ep最大能开多大,最大为6, ep最大64 + ep_upperbound = 0 + for i in range(ep_power+1): + if (dp*tp) % (2**i) == 0: + ep_upperbound += 1 + # 输出满足内存上限的ep,如果ep64都不够就不返回 + for j in range(ep_upperbound): + ep = 2 ** j + evaluate_mem = base_memory + ep_memory / ep + logger.info(f"{dp}_{tp}_{pp}_{ep}_{evaluate_mem}") + if evaluate_mem <= max_mem: + memory_aware_configs.append([dp, tp, pp, ep, evaluate_mem]) + return memory_aware_configs + +def find_power_of_two(m): + if m <= 0: + return None + power = math.log2(m) + if power.is_integer(): + return int(power) + return None + +def filter_oom(search_space, input_args, para): + """ + filter evaluate oom configs + """ + # todo: 是否dryurn返回值不同,需判断这里是否需要处理 + if para.DRYRUN: + # 生成要做dryrun的配置 + care_part_configs = select_dry_config(search_space, input_args) + test_ep = (8, 16) + dry_config = generate_dry_config(care_part_configs, input_args, test_ep) + dryrun_exe_switch = bool(para.DRYRUN_DATA_DIR) + if dryrun_exe_switch: + logger.info("need auto dryrun process") + file_task = "dryrun_yaml" if para.YAML_PATH else "dryrun_shell" + dryrun_file_dir = os.path.abspath(para.OUTPUT_PATH) + os.sep + file_task + os.sep + generate_files(dry_config, dryrun_file_dir, file_task, para, input_args) + dryrun_data_dir = os.path.join(os.path.abspath(para.OUTPUT_PATH), "dryrun_output") + if input_args.mf_args: + # 基于mindformers(mindspore)才做dryrun + launch_dryrun(input_args, dryrun_file_dir, dryrun_data_dir, para) + else: + dryrun_data_dir = para.DRYRUN_DATA_DIR + + if input_args.mf_args: + dryrun_info = read_dryrun_info(dryrun_data_dir) + candidate_configs = grey_box_memory_prune(input_args, dryrun_info, test_ep, para.MAX_EXPERT_PARALLEL) + else: + candidate_configs = dry_config + generate_csv(para.OUTPUT_PATH, candidate_configs, input_args) + return candidate_configs + + candidate_configs = [] #the format is [dp, tp, pp, ep, cp, op, evaluate_mem] + for config in search_space: + op_disable = False + dp, tp, cp, pp = config[0][0] + ep, op = config[0][1][:2] + if op < 0: + op_disable = True + op = dp + input_args.op, input_args.tp, input_args.pp, input_args.ep = op, tp, pp, ep + dense_size, moe_size, vocab_size = compute_weight_and_optimizer_memory(input_args) + try: + max_mem = int(re.search(r'\d+', input_args.context.max_device_memory).group()) * 1024 + except (AttributeError, ValueError): + max_mem = 58 * 1024 + if input_args.pp == 1: + evaluate_memory = (2 * vocab_size+moe_size * (input_args.num_layers - input_args.first_k_dense_replace) + + dense_size * input_args.first_k_dense_replace) + elif input_args.first_k_dense_replace > input_args.num_layers // input_args.pp: + estimated_first = vocab_size + dense_size * (input_args.num_layers // input_args.pp) + estimated_general = moe_size * math.ceil((input_args.num_layers+2)/input_args.pp) + evaluate_memory = max(estimated_first, estimated_general) + else: + estimated_first = (vocab_size + dense_size * input_args.first_k_dense_replace + + moe_size * (input_args.num_layers // input_args.pp - input_args.first_k_dense_replace)) + estimated_general = moe_size * math.ceil((input_args.num_layers + 2) / input_args.pp) + evaluate_memory = max(estimated_first, estimated_general) + if op_disable: op = -1 + if evaluate_memory <= max_mem: + if [dp, tp, pp, ep, cp, op, evaluate_memory] not in candidate_configs: + candidate_configs.append([dp, tp, pp, ep, cp, op, evaluate_memory]) + else: + logger.info(f"mem over limit: evaluate mem {evaluate_memory}" + f"config dp {dp} tp {tp} pp {pp} ep {ep} cp {cp} op {op}") + logger.info(f"Prune Search space size: {len(candidate_configs)}," + f"format: [dp, tp, pp, ep, cp, op/fsdp, evaluate_peak_mem]") + return candidate_configs + +def select_dry_config(valid_configs, input_args): + """ + + :param valid_configs: [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp] + :param input_args: 配置文件参数 + :return: [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp] + 其中vp=1, mbs=1, sp=true 每个(dp, tp, cp, pp),对应ep最大的配置列表 + """ + first = valid_configs[0][0][0] + max_ep = valid_configs[0][0][1][0] + op_with_ep = valid_configs[0][0][1][1] + ans = [] + for config in valid_configs: + current_first = config[0][0] + current_ep = config[0][1][0] + current_op = config[0][1][1] + pp = config[0][0][3] + num_layers = input_args.num_layers + + if is_dualpipe_open(input_args) and pp * 2 >= num_layers: + continue + + if current_first == first: + if current_ep > max_ep: + max_ep = current_ep + op_with_ep = current_op + else: + ans.append([[first, [max_ep, op_with_ep, 1, 1]], True]) + first = current_first + max_ep = current_ep + op_with_ep = current_op + # 添加最后一组数据 + ans.append([[first, [max_ep, op_with_ep, 1, 1]], True]) + logger.info(f"Dryrun candidate config size: {len(ans)}") + return ans + +def generate_csv(output_path, dryrun_config, input_args): + """ + generate nd result to csv file + """ + # 表头 + if input_args.expert_num is not None: + headers = ['dp', 'tp', 'pp', 'ep', 'evaluate_mem'] + else: + headers = ['dp', 'tp', 'pp', 'evaluate_mem'] + + # 写入 CSV 文件 + try: + csv_path = os.path.join(os.path.abspath(output_path), "nd_candidate_config.csv") + with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + # 写入表头 + writer.writerow(headers) + # 写入数据 + writer.writerows(dryrun_config) + logger.info("CSV file generate succ.") + except Exception as e: + logger.info(f"write CSV file fail: {e}") + +class CareTpye(Enum): + UNKNOWN = 0 + WITH_EXPERT_MF = 1 + NO_EXPERT = 2 + WITH_EXPERT_NO_MF = 3 + +def generate_dry_config(care_part_configs, input_args, test_ep): + """ + + :param test_ep: 要做dryrun的ep + :param care_part_configs: [[(dp, tp, cp, pp), (ep, op, vp, mbs)], sp] + :param input_args: 模型及环境等信息 + :return: [dp, tp, pp, ep, offset] 或 [dp, tp, pp] + """ + dry_run_config = [] + layers_num = input_args.num_layers + + if input_args.expert_num is not None and input_args.mf_args is not None: + care_type = CareTpye.WITH_EXPERT_MF + logger.info('dryrun configs format: dp_tp_pp_ep_offset') + elif input_args.expert_num is None: + care_type = CareTpye.NO_EXPERT + logger.info('dryrun configs format: dp_tp_pp') + else: + care_type = CareTpye.WITH_EXPERT_NO_MF + logger.info('dryrun configs format: dp_tp_pp_ep') + + for config in care_part_configs: + dp, tp, _, pp = config[0][0] + ep = config[0][1][0] + # 若为deepseek模型 + if care_type == CareTpye.WITH_EXPERT_MF: + for ep in test_ep: + # mindformers的约束 + if input_args.mf_args is not None and dp * tp % ep != 0: + continue + + use_zero_bubble_v = is_dualpipe_open(input_args) + + offset_calculator = offset_for_dualpipe if use_zero_bubble_v else initial_offset + new_config = [dp, tp, pp, ep, offset_calculator(pp, layers_num)] + dry_run_config.append(new_config) + elif care_type == CareTpye.NO_EXPERT: + new_config = [dp, tp, pp] + dry_run_config.append(new_config) + else: + new_config = [dp, tp, pp, ep] + dry_run_config.append(new_config) + + + logger.info(f"Dryrun config size: {len(dry_run_config)}") + for config in dry_run_config: + config_str = "_".join(str(x) for x in config) + logger.info(config_str) + return dry_run_config + +NUM_BYTES_IN_MEGA = 1024 * 1024 + +def compute_weight_and_optimizer_memory(input_args): + """ + + :param input_args: input training parameters + :return: [dense_layer mem, moe_layer mem, embedding_layer mem]/MB + """ + #Attention projection size + if input_args.kv_channels: + query_projection_to_hidden_size_ratio =( + input_args.kv_channels * input_args.num_attention_heads / input_args.hidden_size) + else: query_projection_to_hidden_size_ratio = 1 + # Group Query Attention. + if not input_args.group_query_attention: + input_args.num_query_groups = input_args.num_attention_heads + # MoE. + num_experts = 1 if input_args.expert_num is None else input_args.expert_num + gated_linear_multiplier = 3 / 2 if input_args.swiglu else 1 + + if input_args.expert_num is not None: + moe_ffn_hidden_size = input_args.moe_intermediate_size + shared_expert_ffn_hidden_size = ( + input_args.moe_intermediate_size + if input_args.moe_shared_expert_intermediate_size is None + else input_args.moe_shared_expert_intermediate_size + ) + else: + moe_ffn_hidden_size = 0 + shared_expert_ffn_hidden_size = 0 + + if input_args.multi_latent_attention: + if input_args.qk_head_dim: + qk_head_dim = input_args.qk_head_dim + elif input_args.qk_nope_head_dim: + qk_head_dim = input_args.qk_nope_head_dim + else: + qk_head_dim = 0 + print('qk head dim not specified') + if input_args.qk_pos_emb_head_dim: + qk_pos_emb_head_dim = input_args.qk_pos_emb_head_dim + elif input_args.qk_pos_rope_head_dim: + qk_pos_emb_head_dim = input_args.qk_pos_rope_head_dim + else: + qk_pos_emb_head_dim = 0 + print('qk pos head dim not specified') + assert not input_args.group_query_attention + if input_args.q_lora_rank is None: + q_term = input_args.hidden_size * input_args.num_attention_heads * (qk_head_dim + qk_pos_emb_head_dim) + else: + ## q lora + rope + q norm + q_term = input_args.q_lora_rank * ( + input_args.hidden_size + input_args.num_attention_heads * (qk_head_dim + qk_pos_emb_head_dim) + 1) + + self_attn_term = ( + q_term + ## kv lora + rope + kv norm + + input_args.kv_lora_rank + * (input_args.hidden_size + input_args.num_attention_heads * (qk_head_dim + input_args.v_head_dim) + 1) + + input_args.hidden_size * qk_pos_emb_head_dim + ## o proj + + (input_args.num_attention_heads * input_args.v_head_dim) * input_args.hidden_size + ) + else: + self_attn_term = ( + 2 + * input_args.hidden_size + * input_args.hidden_size + * ( + # Attention. + ( + (1 + (input_args.num_query_groups / input_args.num_attention_heads)) + * query_projection_to_hidden_size_ratio + ) + ) + ) + + num_parameters_in_dense_ffn = ( + 2 + * input_args.hidden_size + * ( + # Dense MoE MLP. + (input_args.ffn_hidden_size * gated_linear_multiplier) + # Transformer layer norms. + + 2 + ) + ) + + num_parameters_in_moe_ffn_without_routed_expert = ( + 2 + * input_args.hidden_size + * ( + # MoE MLP. + # Shared MoE MLP. + + (shared_expert_ffn_hidden_size * gated_linear_multiplier) + # Transformer layer norms. + + 2 + ) + ) + + num_parameters_of_routed_expert = (2 * input_args.hidden_size * moe_ffn_hidden_size + * num_experts * gated_linear_multiplier / input_args.ep) + + embedding_size = input_args.hidden_size * input_args.padded_vocab_size + final_layernorm = 2 * input_args.hidden_size + + num_bytes_per_parameter = ( + 18 if not input_args.use_distributed_optimizer else 6 + (12 / input_args.op) + ) + + if not input_args.use_distributed_optimizer: + num_bytes_per_parameter_routed_expert = 18 + elif input_args.op < input_args.ep: + num_bytes_per_parameter_routed_expert = 18 + else: + num_bytes_per_parameter_routed_expert = 6 + (12 / input_args.op * input_args.ep) + + embedding_layer = embedding_size / input_args.tp * num_bytes_per_parameter + dense_layer = (num_parameters_in_dense_ffn + self_attn_term) / input_args.tp * num_bytes_per_parameter + moe_layer = (self_attn_term + num_parameters_in_moe_ffn_without_routed_expert) * num_bytes_per_parameter + \ + num_parameters_of_routed_expert * num_bytes_per_parameter_routed_expert + moe_layer /= input_args.tp + final_layernorm /= input_args.tp + + if input_args.pp == 1: + return [dense_layer/NUM_BYTES_IN_MEGA, moe_layer/NUM_BYTES_IN_MEGA, + (embedding_layer * 2 + final_layernorm)/NUM_BYTES_IN_MEGA] + return [dense_layer/NUM_BYTES_IN_MEGA, moe_layer/NUM_BYTES_IN_MEGA, embedding_layer/NUM_BYTES_IN_MEGA] diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/para_for_nd_search.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/para_for_nd_search.py new file mode 100644 index 0000000000000000000000000000000000000000..0928767ddbb033001dd754cd371fb54a7a099dde --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/ndsearch/para_for_nd_search.py @@ -0,0 +1,317 @@ +"""class ParaForNd inner use""" + +import re +import toml +from fast_tuner.utils.common import cal_model_layers_num +from fast_tuner.utils.input_config import InputConfig +from fast_tuner.utils.logger import logger + + +class ParaForNd: + """ + trans different params to inner param + """ + def __init__(self, para): + self.seq_len = 2048 + self.mf_args = None + self.max_mem = 58 * 1024 + self.world_size = 1 + self.dp = 8 # dp_replicate_degree * dp_shard_degree(OP) + self.dp_replicate_degree = 1 + self.dp_shard_degree = -1 + self.op = 1 # equal to real fsdp value + self.pp = 1 + self.tp = 1 + self.cp = 1 + self.ep = 1 + self.etp = 1 + self.gbs = 8 + self.mbn = 1 + self.mbs = 1 + self.num_layers = 6 + self.swiglu = False + self.qk_pos_emb_head_dim = 64 + self.qk_head_dim = 128 + self.q_lora_rank = None + self.kv_lora_rank = None + self.v_head_dim = 128 + self.expert_num = 8 + self.moe_intermediate_size = 256 + self.hidden_size = 256 + self.ffn_hidden_size = 1024 + self.num_attention_heads = 16 + self.padded_vocab_size = 2048 + self.first_k_dense_replace = 1 + self.use_distributed_optimizer = True + self.kv_channels = 128 + self.group_query_attention = True + self.num_query_groups = 8 + self.moe_shared_expert_intermediate_size = None + self.multi_latent_attention = False + # add for titan + self.enable_fsdp_float8_all_gather = False + self.precompute_float8_dynamic_scale_for_fsdp = False + self.fsdp_reshard_after_forward = "default" + self.enable_async_tensor_parallel = False + self.pipeline_parallel_schedule = "Interleaved1F1B" + self.model_name = None + + # for profile parser todo: 只为mindspeed服务,不应该在这里 + self.num_layer_list = [4, 3] + self.recompute_method = 'block' + self.recompute_num_layers = 1 + self.profile_steps = 2 + self.get_args_from_file(para) + + def print_member_value(self): + for attr, value in self.__dict__.items(): + logger.info(f"{attr}: {value}") + + def convert_value(self, value_str): + """尝试将字符串转换为整数/浮点数,失败则返回原字符串""" + if not value_str: + return value_str + # 尝试整数转换(处理正负整数) + if value_str.lstrip('-').isdigit(): + return int(value_str) + # 尝试浮点数转换(处理正负浮点数、科学计数法) + try: + return float(value_str) + except ValueError: + return value_str # 非数字则返回原字符串 + + def parse_toml_parameters(self, toml_file_path): + params = toml.load(toml_file_path) + return params + + # 部分变量没有解析出来,多行参数和引用都能解析 + def parse_sh_parameters(self, sh_file_path): + """ + 解析 shell 脚本内容中的参数,返回包含各模块参数的字典 + + Args: + sh_content: 读取的 shell 脚本内容(字符串) + Returns: + dict: 按模块分组的参数字典,格式为 {模块名: {参数名: 参数值}} + """ + # 存储结果:模块名 -> {参数名: 参数值} + + with open(sh_file_path, 'r', encoding='utf-8') as f: + sh_content = f.read() + result = {} + + # 1. 提取所有变量(如 TP=1, DATA_PATH="xxx" 等),用于替换参数中的变量引用(如 ${TP}) + variables = self.parse_variable(sh_content) + + # 2. 提取模块参数(如 GPT_ARGS、DATA_ARGS) + self.parse_module_args(result, sh_content, variables) + + # 3. 展开 result:将所有模块的参数字典合并为一个扁平字典 + flattened_params = {} + for _, params in result.items(): + # 合并到全局字典(若有重复参数,后面的模块会覆盖前面的) + flattened_params.update(params) + + return flattened_params + + def parse_module_args(self, result, sh_content, variables): + """ + 匹配模块定义(如 GPT_ARGS="...参数...") + """ + module_pattern = re.compile(r'^(\w+)\s*=\s*"(.*?)"$', re.DOTALL | re.MULTILINE) + for module_match in module_pattern.finditer(sh_content): + module_name = module_match.group(1) # 模块名:GPT_ARGS 或 DATA_ARGS + module_content = module_match.group(2) # 模块内的参数内容 + + # 解析模块内的参数(--key value 形式) + param_pattern = re.compile(r'--(\w[\w-]*)\s+([^\s\\]+)') + # 处理无值参数(如 --use-flash-attn) + flag_pattern = re.compile(r'--(\w[\w-]*)') + + module_params = {} + + # 先提取带值的参数(如 --data-path $DATA_PATH) + for match in param_pattern.finditer(module_content): + param_key = match.group(1) # 参数名:data-path + param_value = match.group(2) # 参数值:$DATA_PATH + + # 替换变量引用(如 $DATA_PATH 替换为实际值) + for var_name, var_val in variables.items(): + param_value = param_value.replace(f'${var_name}', var_val).replace(f'${{{var_name}}}', var_val) + + module_params[param_key.replace('-', '_')] = self.convert_value(param_value) + + # 再提取无值参数(如 --use-flash-attn) + for match in flag_pattern.finditer(module_content): + param_key = match.group(1) + standard_key = param_key.replace('-', '_') + if standard_key not in module_params: # 避免覆盖已提取的带值参数 + module_params[param_key.replace('-', '_')] = True # 用 True 表示开关参数开启 + if module_params: + result[module_name] = module_params + + def parse_variable(self, sh_content): + """ + 提取所有变量(如 TP=1, DATA_PATH="xxx" 等),用于替换参数中的变量引用(如 ${TP}) + """ + variables = {} + var_pattern = re.compile(r'^(\w+)\s*=\s*([\'"])(.*?)\2$', re.MULTILINE) + for match in var_pattern.finditer(sh_content): + var_name = match.group(1) # 变量名(如 A) + var_value = match.group(3) # 变量值(如 cccc) + variables[var_name] = var_value + # 补充处理不带引号的变量(如 TP=1, PP=4) + var_pattern2 = re.compile(r'^(\w+)\s*=\s*([\w.]+)$', re.MULTILINE) + for match in var_pattern2.finditer(sh_content): + var_name = match.group(1) + var_value = match.group(2) + variables[var_name] = var_value + return variables + + def get_args_from_file(self, para): + """ + get args from yaml/shell/toml file + """ + if para.YAML_PATH: + self.trans_yaml_param(para) + + elif para.SHELL_PATH: + self.trans_shell_param(para) + + elif para.TOML_PATH: + self.trans_toml_param(para) + + else: + raise RuntimeError("only support yaml/shell/toml file, pls input valid config file") + + self.print_member_value() + + def trans_toml_param(self, para): + """ + trans toml param to inner param + """ + parsed_args = self.parse_toml_parameters(para.TOML_PATH) + self.seq_len = parsed_args['training'].get('seq_len', 4096) + self.world_size = para.NPUS_PER_NODE * para.NNODES + self.pp = parsed_args['parallelism'].get('pipeline_parallel_degree', 1) + self.tp = parsed_args['parallelism'].get('tensor_parallel_degree', 1) + self.cp = parsed_args['parallelism'].get('context_parallel_degree', 1) + self.dp_shard_degree = parsed_args['parallelism'].get('data_parallel_shard_degree', -1) + if self.dp_shard_degree == -1: + fsdp = self.world_size // self.pp // self.tp // self.cp // parsed_args['parallelism'].get( + 'data_parallel_replicate_degree', 1) + else: + fsdp = self.dp_shard_degree + self.op = fsdp + self.dp = parsed_args['parallelism'].get('data_parallel_replicate_degree', 1) * fsdp + self.ep = parsed_args['parallelism'].get('expert_parallel_degree', 1) + self.etp = parsed_args['parallelism'].get('expert_tensor_parallel_degree', 1) + self.fsdp_reshard_after_forward = parsed_args['parallelism'].get('fsdp_reshard_after_forward', "default") + self.enable_async_tensor_parallel = parsed_args['parallelism'].get('enable_async_tensor_parallel', False) + local_batch_size = parsed_args['training'].get('local_batch_size', 1) + self.gbs = local_batch_size * self.dp + self.mbs = parsed_args['parallelism'].get('pipeline_parallel_microbatch_size', 1) + self.mbn = local_batch_size // self.mbs + self.enable_fsdp_float8_all_gather = parsed_args['quantize']['linear']['float8'].get( + 'enable_fsdp_float8_all_gather', False) + self.precompute_float8_dynamic_scale_for_fsdp = parsed_args['quantize']['linear']['float8'].get( + 'precompute_float8_dynamic_scale_for_fsdp', False) + self.model_name = parsed_args['model'].get('name', "llama3") + model_flavor = parsed_args['model'].get('flavor', "8B") + try: + # pylint: disable=C0415 + import torchtitan.protocols.train_spec as train_spec_module + train_spec = train_spec_module.get_train_spec(self.model_name) + model_args = train_spec.model_args[model_flavor] + + self.num_layers = model_args.n_layers + self.qk_pos_emb_head_dim = getattr(model_args, 'qk_rope_head_dim', 64) + self.qk_head_dim = getattr(model_args, 'qk_nope_head_dim', 128) + self.q_lora_rank = getattr(model_args, 'q_lora_rank', None) + self.kv_lora_rank = getattr(model_args, 'kv_lora_rank', None) + self.v_head_dim = getattr(model_args, 'v_head_dim', None) + self.expert_num = getattr(getattr(model_args, 'moe_args', None), 'num_experts', None) + self.moe_intermediate_size = getattr(model_args, 'moe_inter_dim', None) + self.hidden_size = getattr(model_args, 'dim', None) + self.ffn_hidden_size = getattr(model_args, 'inter_dim', 4 * self.hidden_size) + self.cal_ffn_hidden_size(model_args) + self.num_attention_heads = getattr(model_args, 'n_heads', None) + self.padded_vocab_size = getattr(model_args, 'vocab_size', None) + self.first_k_dense_replace = getattr(model_args, 'n_dense_layers', 0) + self.use_distributed_optimizer = self.dp_shard_degree != 1 + self.kv_channels = self.hidden_size // self.num_attention_heads + n_kv_heads = getattr(model_args, 'n_kv_heads', None) + self.group_query_attention = (n_kv_heads is not None and n_kv_heads != self.num_attention_heads) + self.num_query_groups = n_kv_heads + self.moe_shared_expert_intermediate_size = self.moe_intermediate_size * getattr( + getattr(model_args, 'moe_args', None), 'num_shared_experts', 0) + self.multi_latent_attention = bool(self.q_lora_rank) + + except Exception as e: + print(f'Error is: {e}') + self.profile_steps = 1 + + def trans_yaml_param(self, para): + """ + trans yaml param to inner param + """ + input_args = InputConfig(para.YAML_PATH) + self.dp = input_args.parallel_config.data_parallel + self.tp = input_args.parallel_config.model_parallel + self.pp = input_args.parallel_config.pipeline_stage + self.cp = 1 if input_args.parallel_config.get('context_parallel') is None else input_args.parallel_config.get( + 'context_parallel') + self.world_size = self.dp * self.tp * self.cp * self.pp + self.num_layers = cal_model_layers_num(input_args) + self.mbn = input_args.parallel_config.micro_batch_num + self.max_mem = input_args.context.max_device_memory + self.expert_num = None if input_args.moe_config is None else input_args.context.expert_num + self.mf_args = input_args + + def cal_ffn_hidden_size(self, model_args): + """ + 更新ffn_hidden_size + """ + print(f"origin_ffn_hidden_size {self.ffn_hidden_size}") + multiple_of = getattr(model_args, 'multiple_of', None) + if multiple_of is None: + return + ffn_dim_multiplier = getattr(model_args, 'ffn_dim_multiplier') + if multiple_of and ffn_dim_multiplier: + tmp_ffn_hidden_size = int(2 * self.ffn_hidden_size / 3) + tmp_ffn_hidden_size = int(ffn_dim_multiplier * tmp_ffn_hidden_size) + self.ffn_hidden_size = multiple_of * ((tmp_ffn_hidden_size + multiple_of - 1) // multiple_of) + print("current_ffn_hidden_size {self.ffn_hidden_size}") + + def trans_shell_param(self, para): + """ + trans shell param to inner param + """ + parsed_args = self.parse_sh_parameters(para.SHELL_PATH) + self.world_size = parsed_args['nproc_per_node'] * parsed_args['nnodes'] + self.gbs = parsed_args.get('global_batch_size') + self.mbs = parsed_args.get('micro_batch_size') + self.pp = parsed_args.get('pipeline_model_parallel_size', 1) + self.tp = parsed_args.get('tensor_model_parallel_size', 1) + self.cp = parsed_args.get('context_parallel_size', 1) + self.dp = self.world_size // self.pp // self.tp // self.cp + self.mbn = self.gbs // self.mbs // self.dp + self.num_layers = parsed_args.get('num_layers') + self.swiglu = parsed_args.get('swiglu', False) + self.qk_pos_emb_head_dim = parsed_args.get('qk_pos_emb_head_dim', 64) + self.qk_head_dim = parsed_args.get('qk_head_dim', 128) + self.q_lora_rank = parsed_args.get('q_lora_rank') + self.kv_lora_rank = parsed_args.get('kv_lora_rank', 32) + self.v_head_dim = parsed_args.get('v_head_dim', 128) + self.expert_num = parsed_args.get('num_experts') + self.hidden_size = parsed_args.get('hidden_size') + self.ffn_hidden_size = parsed_args.get('ffn_hidden_size', 4 * self.hidden_size) + self.moe_intermediate_size = parsed_args.get('moe_intermediate_size', self.ffn_hidden_size) + self.num_attention_heads = parsed_args.get('num_attention_heads') + self.padded_vocab_size = parsed_args.get('padded_vocab_size', parsed_args.get('vocab_size')) + self.first_k_dense_replace = parsed_args.get('first_k_dense_replace', 0) + self.use_distributed_optimizer = parsed_args.get('use_distributed_optimizer', False) + self.kv_channels = parsed_args.get('kv_channels', self.hidden_size // self.num_attention_heads) + self.group_query_attention = parsed_args.get('group_query_attention', False) + self.moe_shared_expert_intermediate_size = parsed_args.get('moe_shared_expert_intermediate_size') + self.multi_latent_attention = parsed_args.get('multi_latent_attention', False) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/parallel_tool.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/parallel_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..659df7bb563692b0e894fc2466e98969c6da8b8e --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/parallel_tool.py @@ -0,0 +1,161 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""entry file""" + +import argparse +import time +import sys +import types + +from fast_tuner.ndsearch.para_for_nd_search import ParaForNd +from fast_tuner.ndsearch.memory_model import filter_oom +from fast_tuner.ndsearch.build_initial_spaces import build_initial_spaces +from fast_tuner.ndsearch.expert_filter_configs import expert_filter_configs + +from fast_tuner.utils.common import check_dryrun_parallel_number +from fast_tuner.utils.input_param import InputParam +from fast_tuner.utils.logger import logger +from fast_tuner.utils.common import parse_args_from_json +from fast_tuner.utils.ppc_input import ParallelInput +from fast_tuner.utils.profiling.profile_launch import ProfileLaunch +from fast_tuner.utils.profiling.profile_parser import ProfileParser, ProfileMemParser +from fast_tuner.utils.profiling.profile_prepare import profile_prepare +# 这里不应该有的 +from fast_tuner.pipeline_conductor import pp_util +from fast_tuner.pipeline_conductor.pipeline_parallel import pipeline_proc + + +__all__ = ['taylor_search_tool'] + +def taylor_search_tool(para): + """ + A function for find out optimal ND parallel configuration. + + Args: + param para: 用户输入自定义的参数 + + Returns: + parallel config fill back to yaml file: [dp, cp, tp, ...] and candidate csv file. + """ + start_time = time.time() + para.print_params() + + input_args = ParaForNd(para) + + initial_configs = build_initial_spaces(input_args, para) + logger.info(f"Initial Search space size: {len(initial_configs)}") + expert_prune_search_space = expert_filter_configs(initial_configs, input_args, para.GBS) + if len(expert_prune_search_space) == 0: + logger.info("expert_prune_search_space is empty. Please check your expert rules.") + sys.exit() + logger.info(f"Expert Prune Search space size: {len(expert_prune_search_space)}") + + mem_prune_space = filter_oom(expert_prune_search_space, input_args, para) + logger.info('%s', '\n'.join(str(item) for item in mem_prune_space)) + + end_time = time.time() + elapsed_time = end_time - start_time + logger.info(f"program before profiling cost time: {elapsed_time} s") + + # profile_configs: list[dp, tp, pp, 0, 0, layers_num, toml/yaml/shell path, profile_result_dir] + # profile_file_dir: profile shell/toml file dir + profile_configs, profile_file_dir = profile_prepare(mem_prune_space, para, input_args) + if para.ALG_PHASE == 1: + logger.info(f"ALG_PHASE: {para.ALG_PHASE}, no need to profile and solve pipeline") + return + # 自动执行profile + profile_launch = ProfileLaunch(profile_configs, para) + profile_launch.profile_launch(profile_file_dir) + # 自动profile解析 + profile_parser = ProfileParser(input_args, para) + profile_parser.parse_batch_profile_result(profile_configs) + + # 内存解析 + profile_mem_parser = ProfileMemParser(input_args, para) + profile_mem_parser.mem_parser(profile_configs) + + # 流水线求解 todo: 想办法把candidate_configs传进去,不用csv读取 + pipeline_input = ParallelInput(para, profile_file_dir) + pipeline_proc(pipeline_input) + +def main(): + logger.info('start to run parallel tool') + parser = argparse.ArgumentParser(description='Run taylor_search_tool with user input parameters') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['YAML_PATH']}", default='', + help='Path to the YAML file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['SHELL_PATH']}", default='', + help='Path to the SHELL type config file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['TOML_PATH']}", default='', + help='Path to the TOML file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['MINDFORMERS_DIR']}", + default='', help='Directory of mindformers') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['MINDSPEED_PATH']}", default='', + help='Path to the MindSpeed file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['TORCHTITAN_PATH']}", default='', + help='Path to the Torchtitan file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['DRYRUN_DATA_DIR']}", default='', + help='Directory of dryrun data') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['PROFILE_DATA_DIR']}", + default='', help='Directory of profile data') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['NNODES']}", type=int, + default=1, help='The number of nodes') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['NPUS_PER_NODE']}", type=int, + default=8, help='The npus per node') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['DRYRUN_LIM']}", type=int, + default=2, help='The maximum number of dryrun at once') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['RANK_NUM']}", type=int, + default=64, help='Number of available device number') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['SOLVER_NAME']}", type=str, + default='HIGHS', help='Name of the solver') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['DATASET']}", type=str, + default='', help='Directory of dataset') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['MAX_EXPERT_PARALLEL']}", type=int, + default=64, help='Max number of expert parallel') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['OUTPUT_PATH']}", type=str, + default='./output/', help='Directory of output info') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['ENV_JSON']}", type=str, + default='', help='Environment variable config json file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['GBS']}", type=int, + default=1024, help='Global batch size') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['CONFIG']}", type=str, + help='Path to the JSON config file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['SELECT_RECOMPUTE']}", type=bool, default=True, + help='Whether search select recompute') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['ALG_PHASE']}", type=int, default=0, + help='Phase of parallel strategy search algorithm') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['PARSER_RESULT']}", type=str, default=None, + help='Profiling parser result file') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['DRYRUN']}", type=pp_util.str2bool, default=False, + help='Is auto dryrun') + parser.add_argument(f"--{InputParam.PARAM_MAPPING['CHECK']}", type=pp_util.str2bool, default=False, + help="Is double check") + parser.add_argument(f"--{InputParam.PARAM_MAPPING['STRATEGY']}", type=pp_util.str2dict, + default={'DP':True, 'TP':True, 'EP':True, 'FSDP':True, 'PP':False, 'CP':False}, + help="Which parallel strategies are enabled") + + args = parser.parse_args() + # for test + args.config = './config/setup_config/args_for_parallel_tool_titan.json' + # for test + if args.config: + parse_args_from_json(args) + args.strategy = types.SimpleNamespace(**args.strategy) + check_dryrun_parallel_number(args.dryrun_lim) + + para = InputParam(args) + taylor_search_tool(para) + +if __name__ == '__main__': + main() diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/dryrun.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/dryrun.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd4c1e6427dc90324ac316a0b553697fd388880 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/dryrun.py @@ -0,0 +1,204 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dryrun operation""" +import os +import shutil +import json +import argparse +import subprocess +from fast_tuner.utils.logger import logger +from fast_tuner.pipeline_conductor import pp_util +from multiprocessing import Pool +from fast_tuner.pipeline_conductor.pp_util import pipeline_output_file + +dryrun_config_error = 'The config_file location and ms_adapter_file location are essential, please config!' + + +class DryRun: + """dryrun operation""" + env_config_json = '' + register_path = 'research/jiutian' + dryrun_lim = 16 + config_file_type = 0 + is_write_to_file = True + + def __init__(self, config_file, ms_adapter_file, output_name): + self.config_file = config_file + self.ms_adapter_file = ms_adapter_file + self.rank_gap = None + pp_output_file = os.path.join(os.getcwd(), pipeline_output_file) + if not os.path.exists(pp_output_file): + os.mkdir(pp_output_file) + self.log_file_name = os.path.join(pp_output_file, output_name) + if os.path.exists(self.log_file_name): + shutil.rmtree(self.log_file_name) + os.mkdir(self.log_file_name) + + def dryrun(self, stage_num, rank_size): + self.rank_gap = rank_size // stage_num + self.set_env(rank_size, self.env_config_json) + remainder = stage_num + device_id = 0 + while remainder > 0: + if remainder > self.dryrun_lim: + dryrun_num = self.dryrun_lim + else: + dryrun_num = remainder + remainder -= dryrun_num + with Pool(processes=dryrun_num) as pool: + pool.map(self.run_rank, range(device_id, device_id + dryrun_num)) + device_id += dryrun_num + logger.info('pull dryrun of all stages!') + + def set_env(self, rank_size, env_config_json): + with open(env_config_json, 'r', encoding='utf-8') as f: + env_vars = json.load(f) + env_vars['RANK_SIZE'] = str(rank_size) + os.environ.update(env_vars) + + def start_dryrun(self, recompute_config, offset, num_layers, num_vpp, num_stage, rank_size, dense_layers, micro): + if self.config_file_type == 0: + name = pp_util.bulid_yaml(self.config_file, recompute_config, offset, + num_layers, num_vpp, num_stage, dense_layers, micro) + elif self.config_file_type == 1: + name = pp_util.bulid_shell(self.config_file, offset, num_layers, num_vpp, num_stage, dense_layers, micro) + else: + raise TypeError(dryrun_config_error) + self.config_file = name + self.dryrun(num_stage, rank_size) + + def run_rank(self, stage): + device_id = stage + rank_id = stage * self.rank_gap + cwd = os.getcwd() + log_file = os.path.join(cwd, self.log_file_name, f'rank_{rank_id}.log') + logger.info(f"start training for rank_{rank_id}, device_{device_id}, waiting a moment...") + if self.config_file_type == 0: + os.environ['ASCEND_RT_VISIBLE_DEVICES'] = str(device_id) + os.environ['RANK_ID'] = str(rank_id) + command = ['python', self.ms_adapter_file, '--register_path', + self.register_path, '--config', self.config_file] + with open(log_file, 'w', encoding='utf-8') as log: + subprocess.run(command, stdout=log, stderr=subprocess.STDOUT) + elif self.config_file_type == 1: + env = os.environ.copy() + env['RANK_ID'] = str(rank_id) + command = ['bash', self.config_file, str(device_id), self.ms_adapter_file, log_file] + subprocess.run(command, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + else: + raise TypeError(dryrun_config_error) + + def extract_memory_info(self, num_stage): + cwd = os.getcwd() + peak_mem = [] + for stage in range(num_stage): + rank_id = self.rank_gap * stage + log_file = os.path.join(cwd, self.log_file_name, f"rank_{rank_id}.log") + peak_mem.append(pp_util.extract_peak_memory(log_file)) + return peak_mem + + def extract_memory_info_act(self, num_stage): + cwd = os.getcwd() + peak_mem = [] + for stage in range(num_stage): + rank_id = self.rank_gap * stage + log_file = os.path.join( + cwd, self.log_file_name, f"rank_{rank_id}.log") + peak_mem.append(pp_util.extract_actual_peak_memory(log_file)) + return peak_mem + + +def one_rank_dryrun(stage, yaml_file, mindformer_file, output_file): + dry_run = DryRun(yaml_file, mindformer_file, output_file) + rank_size, pipeline_stage = pp_util.get_ranks_stages(yaml_file) + dry_run.rank_gap = rank_size // pipeline_stage + dry_run.set_env(rank_size, dry_run.env_config_json) + dry_run.run_rank(stage) + + +def all_rank_dryrun(config_file, ms_adapter_file, output_file): + dry_run = DryRun(config_file, ms_adapter_file, output_file) + if DryRun.config_file_type == 0: + rank_size, pipeline_stage = pp_util.get_ranks_stages(config_file) + elif DryRun.config_file_type == 1: + rank_size, pipeline_stage = pp_util.get_shell_ranks_stages(config_file) + else: + raise TypeError(dryrun_config_error) + dry_run.dryrun(pipeline_stage, rank_size) + print(dry_run.extract_memory_info(pipeline_stage)) + print(dry_run.extract_memory_info_act(pipeline_stage)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(prog='Yaml config dryrun', + description='Write config to the yaml file, and dryrun it', epilog='') + parser.add_argument('--yaml', '-y', type=str, default=None, + help="Path of training config (.yaml)") + parser.add_argument('--shell', '-sh', type=str, default=None, + help="Path of training config (.sh)") + parser.add_argument('--mindformers', '-mf', type=str, default=None, + help="Absolute path of run_mindformers (.py)") + parser.add_argument('--mindspeed', '-mp', type=str, default=None, + help="Absolute path of posttrain_gpt (.py)") + parser.add_argument('--output_file', '-f', type=str, default='dryrun_output', + help="The location to place the output files") + parser.add_argument('--offset', '-o', type=pp_util.str2list, + default=None, help="offset list") + parser.add_argument('--is_recompute', '-ir', type=pp_util.str2bool, + default=None, help="Whether to open recompute") + parser.add_argument('--recompute_layers', '-rl', type=pp_util.str2list, + default=None, help="recompute_layers list") + parser.add_argument('--is_select_recompute', '-is', type=pp_util.str2bool, + default=None, help="Whether to open select_recompute") + parser.add_argument('--select_recompute_layers', '-sl', type=pp_util.str2list, + default=None, help="select_recompute_layers list") + parser.add_argument('--env_config_json', '-e', type=str, required=True, + default='./config/boss_env_config.json', help="Path of environment config (.json)") + parser.add_argument('--register_path', '-rp', type=str, default='research/jiutian', + help="Path of register") + parser.add_argument('--dryrun_lim', '-dl', type=pp_util.str2int, default=16, + help="The number of dryrun at once") + args = parser.parse_args() + + if args.yaml and args.mindformers: + config_file = args.yaml + ms_adapter_file = args.mindformers + DryRun.config_file_type = 0 + elif args.shell and args.mindspeed: + config_file = args.shell + ms_adapter_file = args.mindspeed + DryRun.config_file_type = 1 + else: + raise TypeError(dryrun_config_error) + + output_file = args.output_file + DryRun.env_config_json = args.env_config_json + DryRun.register_path = args.register_path + DryRun.dryrun_lim = args.dryrun_lim + if args.recompute_layers and args.is_recompute is None: + args.is_recompute = True + if args.select_recompute_layers and args.is_select_recompute is None: + args.is_select_recompute = True + + if args.offset is None and args.is_select_recompute is None and args.is_recompute is None: + logger.info('Use old yaml config to dryrun') + elif DryRun.config_file_type == 0: + config_file = pp_util.build_new_config_yaml(args) + elif DryRun.config_file_type == 1: + config_file = pp_util.build_new_config_shell(args) + else: + raise TypeError(dryrun_config_error) + + all_rank_dryrun(config_file, ms_adapter_file, output_file) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/fitting.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/fitting.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3268c2740c05dcd655fc4383d985132497211f --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/fitting.py @@ -0,0 +1,242 @@ +import numpy as np +from fast_tuner.utils.logger import logger +from fast_tuner.pipeline_conductor import solution +from fast_tuner.pipeline_conductor import dryrun +from fast_tuner.pipeline_conductor import pp_util +"""fitting file""" +class FitMem: + cur_peak_mem = [] + + def __init__(self, cur_solution: solution.Solution): + self.cur_solution = cur_solution + self.cur_peak_mem = cur_solution.check_peak_mem + self.peaks = cur_solution.peak_num + self.init_config = cur_solution.init_config + self.init_memory = cur_solution.init_config.memory + + def linear_fit(self): + if self.init_config.expert_input.is_select_recomp and self.init_config.expert_input.is_full_recomp: + length_x_lim = 11 + elif not self.init_config.expert_input.is_select_recomp and not self.init_config.expert_input.is_full_recomp: + length_x_lim = 7 + else: + length_x_lim = 9 + if self.init_config.pipeline_stage < length_x_lim: + length_x = self.init_config.pipeline_stage + else: + length_x = length_x_lim + if length_x < 4: + logger.warning(f'can not fit for pipeline_stage = {self.init_config.pipeline_stage}') + return + coe_a, array_b = self.form_coe_matrix_mem_array(length_x) + np.set_printoptions(suppress=True, precision=3) # 禁用科学计数法,保留 6 位小数 + modified_mem, res, rank, s = np.linalg.lstsq(coe_a, array_b, rcond=None) + logger.info(f'the residual = {np.linalg.norm(res)}') + logger.info(f'the rank = {rank}') + if np.linalg.norm(res) > 1e-3 or rank < self.init_config.pipeline_stage: + logger.warning(f'The distribution can not correct the memory!') + else: + modified_mem = list(np.round(np.array(modified_mem), decimals=1)) + self.correct_mem(modified_mem) + + def correct_mem(self, modify_mem): + self.init_memory.static_mem0 = modify_mem[0] + self.init_memory.static_mem = modify_mem[1] + self.init_memory.lm_head_mem = modify_mem[2] + self.init_memory.act_mem = modify_mem[3] + + if len(modify_mem) == 5: + self.init_memory.layer_mem = modify_mem[4] + if len(modify_mem) == 6: + self.init_memory.layer_mem = modify_mem[4] + self.init_memory.re_comp_mem = modify_mem[5] + if len(modify_mem) == 8: + self.init_memory.layer_mem = modify_mem[4] + self.init_memory.re_comp_mem = modify_mem[5] + self.init_memory.layer_mem012 = modify_mem[6] + self.init_memory.act_mem12 = modify_mem[7] + self.init_memory.act_mem0 = modify_mem[7] + if self.init_config.expert_input.is_select_recomp: + if len(modify_mem) == 9: + self.init_memory.layer_mem = modify_mem[4] + self.init_memory.re_comp_mem = modify_mem[5] + self.init_memory.layer_mem012 = modify_mem[6] + self.init_memory.act_mem12 = modify_mem[7] + self.init_memory.re_comp_mem12 = modify_mem[8] + self.init_memory.re_comp_mem0 = modify_mem[8] + if len(modify_mem) == 10: + self.init_memory.layer_mem = modify_mem[4] + self.init_memory.re_comp_mem = modify_mem[5] + self.init_memory.layer_mem012 = modify_mem[6] + self.init_memory.act_mem12 = modify_mem[7] + self.init_memory.re_comp_mem12 = modify_mem[8] + self.init_memory.re_comp_mem0 = modify_mem[8] + self.init_memory.select_mem = modify_mem[9] + if len(modify_mem) > 10: + self.init_memory.layer_mem = modify_mem[4] + self.init_memory.re_comp_mem = modify_mem[5] + self.init_memory.layer_mem012 = modify_mem[6] + self.init_memory.act_mem12 = modify_mem[7] + self.init_memory.re_comp_mem12 = modify_mem[8] + self.init_memory.re_comp_mem0 = modify_mem[8] + self.init_memory.select_mem = modify_mem[9] + self.init_memory.select_mem12 = modify_mem[10] + self.init_memory.select_mem0 = self.init_memory.select_mem12 + else: + if len(modify_mem) >= 9: + self.init_memory.layer_mem = modify_mem[4] + self.init_memory.re_comp_mem = modify_mem[5] + self.init_memory.layer_mem012 = modify_mem[6] + self.init_memory.act_mem12 = modify_mem[7] + self.init_memory.re_comp_mem12 = modify_mem[8] + self.init_memory.re_comp_mem0 = modify_mem[8] + self.init_memory.update_up_mem() + logger.info('The correct memory information:') + self.init_memory.print_mem() + + def form_coe_matrix_mem_array(self, length_x): + coe_a = np.empty((self.init_config.pipeline_stage, length_x), float) + array_b = np.empty(self.init_config.pipeline_stage, float) + for stage in range(self.init_config.pipeline_stage): + if stage == 0: + coe_a[stage][0] = 1 + coe_a[stage][1] = 0 + coe_a[stage][2] = 0 + elif stage == self.init_config.pipeline_stage - 1: + coe_a[stage][0] = 0 + coe_a[stage][1] = 0 + coe_a[stage][2] = 1 + else: + coe_a[stage][0] = 0 + coe_a[stage][1] = 1 + coe_a[stage][2] = 0 + coe_a[stage][3] = self.cur_solution.peak_num.peak_num_act_type2[stage] + if length_x == 4: + array_b[stage] = (self.cur_peak_mem[stage] - self.cur_solution.layer2_dis_stage[stage] * + self.init_memory.layer_mem - + self.peaks.peak_num_recompute_type2[stage] * self.init_memory.re_comp_mem - + self.cur_solution.layer1_dis_stage[stage] * self.init_memory.layer_mem012 - + self.peaks.peak_num_act_type1[stage] * self.init_memory.act_mem12 - + self.peaks.peak_num_recompute_type1[stage] * self.init_memory.re_comp_mem12 - + self.peaks.peak_num_select_recom_type2[stage] * self.init_memory.select_mem - + self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12) + continue + if length_x == 6: + coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage] + coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage] + array_b[stage] = (self.cur_peak_mem[stage] - + self.cur_solution.layer1_dis_stage[stage] * self.init_memory.layer_mem012 - + self.peaks.peak_num_act_type1[stage] * self.init_memory.act_mem12 - + self.peaks.peak_num_recompute_type1[stage] * self.init_memory.re_comp_mem12 - + self.peaks.peak_num_select_recom_type2[stage] * self.init_memory.select_mem - + self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12) + continue + if length_x == 8: + coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage] + coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage] + coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage] + coe_a[stage][7] = self.peaks.peak_num_act_type1[stage] + array_b[stage] = (self.cur_peak_mem[stage] - self.peaks.peak_num_recompute_type1[stage] * + self.init_memory.re_comp_mem12 - self.peaks.peak_num_select_recom_type2[stage] * + self.init_memory.select_mem - + self.peaks.peak_num_select_recom_type1[stage] * self.init_memory.select_mem12) + continue + if length_x == 9: + coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage] + coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage] + coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage] + coe_a[stage][7] = self.peaks.peak_num_act_type1[stage] + coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage] + array_b[stage] = (self.cur_peak_mem[stage] - self.peaks.peak_num_select_recom_type2[stage] * + self.init_memory.select_mem - + self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12) + continue + if self.init_config.expert_input.is_select_recomp: + if length_x == 10: + coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage] + coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage] + coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage] + coe_a[stage][7] = self.peaks.peak_num_act_type1[stage] + coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage] + coe_a[stage][9] = self.peaks.peak_num_select_recom_type2[stage] + array_b[stage] = (self.cur_peak_mem[stage] - + self.peaks.peak_num_select_recom_type1 * self.init_memory.select_mem12) + continue + if length_x >= 11: + coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage] + coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage] + coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage] + coe_a[stage][7] = self.peaks.peak_num_act_type1[stage] + coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage] + coe_a[stage][9] = self.peaks.peak_num_select_recom_type2[stage] + coe_a[stage][10] = self.peaks.peak_num_select_recom_type1[stage] + array_b[stage] = self.cur_peak_mem[stage] + continue + else: + if length_x > 9: + coe_a[stage][4] = self.cur_solution.layer2_dis_stage[stage] + coe_a[stage][5] = self.peaks.peak_num_recompute_type2[stage] + coe_a[stage][6] = self.cur_solution.layer1_dis_stage[stage] + coe_a[stage][7] = self.peaks.peak_num_act_type1[stage] + coe_a[stage][8] = self.peaks.peak_num_recompute_type1[stage] + array_b[stage] = self.cur_peak_mem[stage] + continue + return coe_a, array_b + + def is_over_mem(self): + over_mem = self.get_over_mem(self.cur_solution) + if all(over_mem[stage] == 0 for stage in range(len(over_mem))): + if dryrun.DryRun.is_write_to_file: + if dryrun.DryRun.config_file_type == 0: + recompute_config = pp_util.build_recompute_config(True, True, self.cur_solution. + rs_dis.tolist(), self.cur_solution.ra_dis.tolist()) + pp_util.write_config_to_yaml(recompute_config, self.cur_solution.offset.tolist(), self.init_config. + config_file) + elif dryrun.DryRun.config_file_type == 1: + pp_util.write_config_to_shell(self.cur_solution.offset.tolist(), self.init_config. + config_file) + else: + raise TypeError(dryrun.dryrun_config_error) + logger.info(f'The result is available for training, config has write to ' + f'{self.init_config.config_file}!') + return over_mem, False + + return over_mem, True + + def reduce_mem_lim_for_fitting(self, over_mem, i): + if over_mem.keys().__contains__(0): + self.init_config.memory.mem_lim_stage0 -= over_mem[0] * (i + 1) + if over_mem.keys().__contains__(self.init_config.pipeline_stage - 1): + self.init_config.memory.mem_lim_last -= over_mem[self.init_config.pipeline_stage - 1] * (i + 1) + over_mem = {key: value for key, value in over_mem.items() if key != 0 and + key != self.init_config.pipeline_stage - 1} + if over_mem: + self.init_config.memory.mem_lim_others -= max(over_mem.values()) * (i + 1) + + def set_peak_mem_by_dryrun(self, cur_solution: solution.Solution): + init_config = cur_solution.init_config + expert_input = init_config.expert_input + dry_run = dryrun.DryRun(expert_input.config_file, expert_input.ms_adapter_file, + expert_input.double_check_dryrun_filename) + recompute_config = pp_util.build_recompute_config(True, True, cur_solution.rs_dis. + tolist(), cur_solution.ra_dis.tolist()) + total_number = init_config.num_layers_type1 + init_config.num_layers_type2 + dry_run.start_dryrun(recompute_config, cur_solution.offset.tolist(), total_number, + init_config.pp_interleave_num, + init_config.pipeline_stage, init_config.rank_size, init_config.num_layers_type1, + init_config.micro_batch_num) + cur_solution.check_peak_mem = dry_run.extract_memory_info_act(init_config.pipeline_stage) + self.cur_peak_mem = cur_solution.check_peak_mem + logger.info(f'check peak_mem = {cur_solution.check_peak_mem}') + + def get_over_mem(self, cur_solution: solution.Solution): + self.set_peak_mem_by_dryrun(cur_solution) + over_mem = {} + init_config = cur_solution.init_config + for stage in range(init_config.pipeline_stage): + if cur_solution.check_peak_mem[stage] - init_config.mem_lim > 0: + logger.warning(f'The stage{stage} result is over memory limit! please check the offset!') + over_mem[stage] = cur_solution.check_peak_mem[stage] - init_config.mem_lim + else: + over_mem[stage] = 0 + return over_mem diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/math_model.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/math_model.py new file mode 100644 index 0000000000000000000000000000000000000000..57b3db0e342641c168e88197a489f473e69100a9 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/math_model.py @@ -0,0 +1,319 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""math model construct for pp""" +import numpy as np +import sys +from ortools.linear_solver import pywraplp +from fast_tuner.utils.logger import logger + +from fast_tuner.pipeline_conductor import pp_util +from fast_tuner.pipeline_conductor.start_service import InitConfig, HIGHS_NAME +from fast_tuner.pipeline_conductor import micro + + +class Model: + def __init__(self, model_input: InitConfig): + # 初始化输入 + self.input = model_input + self.memory = model_input.memory + self.expert_input = model_input.expert_input + # 前后计算层数、时间 + self.x_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.x_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.offset = None + self.f_duration = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.b_duration = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + # 当前stage,interleave是否存在某一个种type的层 + self.indicator_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.indicator_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + # 选择重计算、完全重计算、内存 + self.rs_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.rs_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.ra_type1 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.ra_type2 = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.rs_or_ra = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + self.mem = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num), dtype=object) + # 前后向开始时间、micro batch在各个part中分布 + magic_number = 5 + if self.input.micro_batch_num // self.input.pipeline_stage >= magic_number + 1: + dummy_mbn = self.input.micro_batch_num % self.input.pipeline_stage + magic_number * self.input.pipeline_stage + self.residue = self.input.micro_batch_num // self.input.pipeline_stage - magic_number + temp_parts = magic_number + self.input.parts = magic_number + else: + dummy_mbn = self.input.micro_batch_num + self.residue = 0 + temp_parts = model_input.parts + self.input.parts = temp_parts + self.forward_s = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num, self.input.parts, + dummy_mbn, self.input.seq_splits), dtype=object) + self.backward_s = np.empty((self.input.pipeline_stage, self.input.pp_interleave_num, self.input.parts, + dummy_mbn, self.input.seq_splits), dtype=object) + self.distribution = pp_util.construct_distribution(dummy_mbn, self.input.pipeline_stage) + # 形成前后向任务列表 + self.sort_micro = micro.SortMicro(temp_parts, model_input.pp_interleave_num, model_input.pipeline_stage, + self.distribution, self.expert_input.low_mem, model_input.seq_splits) + self.final_orders = self.sort_micro.final_orders + # 求解初始化 + self.solver = pywraplp.Solver.CreateSolver(model_input.expert_input.solver_name) + if not self.solver: + self.solver = pywraplp.Solver.CreateSolver(HIGHS_NAME) + self.layer_upper = 20 + self.b_duration_upper = self.layer_upper * 3 + self.mem_upper = 60000 + self.time_upper = 99999 + self.min_time = None + self.gap = None + self.model_status = None + self.solution_status = None + self.run_time = None + + def define_variables(self): + self.stable_dur = self.solver.NumVar(0, self.time_upper, f'stable_duration') + for stage in range(self.input.pipeline_stage): + for vpp in range(self.input.pp_interleave_num): + self.x_type1[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'x_type1_{stage}_{vpp}') + self.x_type2[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'x_type2_{stage}_{vpp}') + self.indicator_type1[stage][vpp] = self.solver.BoolVar(f'indicator1_{stage}_{vpp}') + self.indicator_type2[stage][vpp] = self.solver.BoolVar(f'indicator2_{stage}_{vpp}') + self.f_duration[stage][vpp] = self.solver.NumVar(0, self.layer_upper, f'f_duration_{stage}_{vpp}') + self.b_duration[stage][vpp] = self.solver.NumVar(0, self.b_duration_upper, f'b_duration_{stage}_{vpp}') + if self.input.expert_input.is_select_recomp: + self.rs_type1[stage][vpp] = self.solver.IntVar( + 0, self.layer_upper, f'rs_type1_{stage}_{vpp}') + self.rs_type2[stage][vpp] = self.solver.IntVar( + 0, self.layer_upper, f'rs_type2_{stage}_{vpp}') + else: + self.rs_type1[stage][vpp] = self.solver.IntVar(0, 0, f'rs_type1_{stage}_{vpp}') + self.rs_type2[stage][vpp] = self.solver.IntVar(0, 0, f'rs_type2_{stage}_{vpp}') + if self.input.expert_input.is_full_recomp: + self.ra_type1[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'ra_type1_{stage}_{vpp}') + self.ra_type2[stage][vpp] = self.solver.IntVar(0, self.layer_upper, f'ra_type2_{stage}_{vpp}') + else: + self.ra_type1[stage][vpp] = self.solver.IntVar(0, 0, f'ra_type1_{stage}_{vpp}') + self.ra_type2[stage][vpp] = self.solver.IntVar(0, 0, f'ra_type2_{stage}_{vpp}') + if not self.expert_input.is_support_ra_with_rs: + self.rs_or_ra[stage][vpp] = self.solver.IntVar(0, 1, f'rs_or_ra_{stage}_{vpp}') + self.mem[stage][vpp] = self.solver.NumVar(0, self.mem_upper, f'mem_{stage}_{vpp}') + + for stage in range(self.input.pipeline_stage): + for vpp in range(self.input.pp_interleave_num): + for part in range(self.input.parts): + for micro_id in range(self.distribution[part]): + for split in range(self.input.seq_splits): + self.forward_s[stage][vpp][part][micro_id][split] = ( + self.solver.NumVar(0, self.time_upper, + f'forward_s_{stage}_{vpp}_{part}_{micro_id}_{split}')) + self.backward_s[stage][vpp][part][micro_id][split] = ( + self.solver.NumVar(0, self.time_upper, + f'backward_s{stage}_{vpp}_{part}_{micro_id}_{split}')) + logger.info(f'Number of variables = {self.solver.NumVariables()}') + + def define_constraint(self): + # 先type1后type2的层分配约束 + self.layer_alloc_constraints() + + # 前后向计算时间约束 + self.forward_backward_time_constraints() + + # 同stage micro-batch之间的约束 + self.same_stage_micro_constraints() + # 同micro-batch,stage间的约束 + self.same_micro_stage_constraints() + + # 内存约束 + self.mem_constraints() + # layer约束 + self.layers_constraints() + if self.input.parts > 1: + for s in range(self.input.pipeline_stage): + self.solver.Add(self.stable_dur >= self.forward_s[s][0][-1][0][0] - self.forward_s[s][0][-2][0][0]) + logger.info(f"Number of constraints = {self.solver.NumConstraints()}") + + def layer_alloc_constraints(self): + for vpp in range(self.input.pp_interleave_num): + for stage in range(self.input.pipeline_stage): + self.solver.Add(self.x_type1[stage][vpp] <= self.layer_upper * self.indicator_type1[stage][vpp]) + self.solver.Add(self.x_type1[stage][vpp] >= self.indicator_type1[stage][vpp]) + self.solver.Add(self.x_type2[stage][vpp] <= self.layer_upper * self.indicator_type2[stage][vpp]) + self.solver.Add(self.x_type2[stage][vpp] >= self.indicator_type2[stage][vpp]) + if stage < self.input.pipeline_stage - 1: + self.solver.Add(self.indicator_type1[stage][vpp] >= self.indicator_type1[stage + 1][vpp]) + self.solver.Add(self.indicator_type2[stage][vpp] <= self.indicator_type2[stage + 1][vpp]) + if vpp < self.input.pp_interleave_num - 1: + self.solver.Add( + self.indicator_type1[self.input.pipeline_stage - 1][vpp] >= self.indicator_type1[0][vpp + 1]) + self.solver.Add( + self.indicator_type2[self.input.pipeline_stage - 1][vpp] <= self.indicator_type2[0][vpp + 1]) + + def forward_backward_time_constraints(self): + if self.input.expert_input.is_head_loss_input: + head_loss = self.input.expert_input.head_loss + else: + try: + head_loss = ((self.input.vocab_size / 2 / self.input.hidden_size) / + ( + 1 + 1 + 3 * self.input.intermediate_size / 2 / self.input.hidden_size + self.input.seq_length + / self.input.hidden_size) * 1.6) + except TypeError: + head_loss = 1 + for stage in range(self.input.pipeline_stage): + for vpp in range(self.input.pp_interleave_num): + # self.expert_input.layer_ratio is an integer now (used to be a list of integers) + if stage == self.input.pipeline_stage - 1 and vpp == self.input.pp_interleave_num - 1: + self.solver.Add(self.f_duration[stage][vpp] == self.x_type2[stage][vpp] + + self.x_type1[stage][vpp] * self.expert_input.layer_ratio + head_loss) + else: + self.solver.Add(self.f_duration[stage][vpp] == self.x_type2[stage][vpp] + + self.x_type1[stage][vpp] * self.expert_input.layer_ratio) + self.solver.Add(self.b_duration[stage][vpp] == self.expert_input.backward_ratio * + self.f_duration[stage][vpp] + self.expert_input.srRatio * + (self.rs_type2[stage][vpp] + self.rs_type1[stage][vpp] * self.expert_input.layer_ratio) + + self.expert_input.recompute_ratio * self.ra_type2[stage][vpp] + + self.ra_type1[stage][vpp] * self.expert_input.layer_ratio) + + def same_stage_micro_constraints(self): + for stage in range(self.input.pipeline_stage): + stage_order = self.final_orders[stage] + for i in range(len(stage_order) - 1): + p0, vpp0, state0, id0, split0 = (stage_order[i].part, stage_order[i].vpp, stage_order[i].state, + stage_order[i].micro_id, stage_order[i].split) + p1, vpp1, state1, id1, split1 = (stage_order[i + 1].part, stage_order[i + 1].vpp, + stage_order[i + 1].state, stage_order[i + 1].micro_id, + stage_order[i].split) + if state0 == 'f': + if state1 == 'f': + self.solver.Add(self.forward_s[stage][vpp0][p0][id0][split0] + self.f_duration[stage][vpp0] / + self.input.seq_splits <= self.forward_s[stage][vpp1][p1][id1][split1]) + else: + self.solver.Add(self.forward_s[stage][vpp0][p0][id0][split0] + self.f_duration[stage][vpp0] / + self.input.seq_splits <= self.backward_s[stage][vpp1][p1][id1][split1]) + else: + if state1 == 'f': + self.solver.Add(self.backward_s[stage][vpp0][p0][id0][split0] + self.b_duration[stage][vpp0] / + self.input.seq_splits <= self.forward_s[stage][vpp1][p1][id1][split1]) + else: + self.solver.Add(self.backward_s[stage][vpp0][p0][id0][split0] + self.b_duration[stage][vpp0] / + self.input.seq_splits <= self.backward_s[stage][vpp1][p1][id1][split1]) + + def same_micro_stage_constraints(self): + for part in range(self.input.parts): + for micro_id in range(self.distribution[part]): + for split in range(self.input.seq_splits): + for vpp in range(self.input.pp_interleave_num): + for stage in range(self.input.pipeline_stage): + # 前向: + if stage != self.input.pipeline_stage - 1: + self.solver.Add( + self.forward_s[stage][vpp][part][micro_id][split] + self.f_duration[stage][vpp] / + self.input.seq_splits <= self.forward_s[stage + 1][vpp][part][micro_id][split]) + elif vpp != self.input.pp_interleave_num - 1: + self.solver.Add( + self.forward_s[stage][vpp][part][micro_id][split] + self.f_duration[stage][vpp] / + self.input.seq_splits <= self.forward_s[0][vpp + 1][part][micro_id][split]) + else: + self.solver.Add( + self.forward_s[stage][vpp][part][micro_id][split] + self.f_duration[stage][vpp] / + self.input.seq_splits <= self.backward_s[stage][vpp][part][micro_id][split]) + # 后向: + if stage != 0: + self.solver.Add( + self.backward_s[stage][vpp][part][micro_id][split] + self.b_duration[stage][vpp] / + self.input.seq_splits <= self.backward_s[stage - 1][vpp][part][micro_id][split]) + else: + if vpp != 0: + self.solver.Add( + self.backward_s[stage][vpp][part][micro_id][split] + self.b_duration[stage][vpp] + / self.input.seq_splits <= self.backward_s[ + self.input.pipeline_stage - 1][vpp - 1][part][micro_id][split]) + + def layers_constraints(self): + layers_type1 = 0 + layers_type2 = 0 + indicator_total = 0 + for stage in range(self.input.pipeline_stage): + for vpp in range(self.input.pp_interleave_num): + layers_type1 += self.x_type1[stage][vpp] + layers_type2 += self.x_type2[stage][vpp] + indicator_total += self.indicator_type1[stage][vpp] + self.indicator_type2[stage][vpp] + self.solver.Add(self.x_type1[stage][vpp] >= self.ra_type1[stage][vpp] + self.rs_type1[stage][vpp]) + self.solver.Add(self.x_type2[stage][vpp] >= self.ra_type2[stage][vpp] + self.rs_type2[stage][vpp]) + if not self.expert_input.is_support_ra_with_rs: + self.solver.Add(self.rs_type1[stage][vpp] <= self.layer_upper * self.rs_or_ra[stage][vpp]) + self.solver.Add(self.rs_type2[stage][vpp] <= self.layer_upper * self.rs_or_ra[stage][vpp]) + self.solver.Add(self.ra_type1[stage][vpp] <= self.layer_upper * (1 - self.rs_or_ra[stage][vpp])) + self.solver.Add(self.ra_type2[stage][vpp] <= self.layer_upper * (1 - self.rs_or_ra[stage][vpp])) + self.solver.Add(layers_type1 == self.input.num_layers_type1) + self.solver.Add(layers_type2 == self.input.num_layers_type2) + self.solver.Add(indicator_total >= self.input.pipeline_stage * self.input.pp_interleave_num) + self.solver.Add(indicator_total <= self.input.pipeline_stage * self.input.pp_interleave_num + 1) + + def mem_constraints(self): + for stage in range(self.input.pipeline_stage): + for vpp in range(self.input.pp_interleave_num): + self.solver.Add(self.mem[stage][vpp] == self.memory.act_mem * + (self.x_type2[stage][vpp] - self.rs_type2[stage][vpp] - self.ra_type2[stage][vpp]) + + self.memory.select_mem * self.rs_type2[stage][vpp] + + self.memory.re_comp_mem * self.ra_type2[stage][vpp] + + self.memory.act_mem0 * (self.x_type1[stage][vpp] - self.rs_type1[stage][vpp] - + self.ra_type1[stage][vpp]) + self.memory.select_mem0 * + self.rs_type1[stage][vpp] + self.memory.re_comp_mem0 * self.ra_type1[stage][vpp]) + for stage in range(self.input.pipeline_stage): + total_layer_mem = 0 + for vpp in range(self.input.pp_interleave_num): + total_layer_mem += (self.x_type1[stage][vpp] * self.memory.layer_mem012 + self.x_type2[stage][vpp] * + self.memory.layer_mem) + for sub in range(1, len(self.final_orders[stage]) + 1): + consume_mem = total_layer_mem + for i in range(sub): + part, vpp, state, micro_id, split = (self.final_orders[stage][i].part, + self.final_orders[stage][i].vpp, + self.final_orders[stage][i].state, + self.final_orders[stage][i].micro_id, + self.final_orders[stage][i].split) + # 计算时间均分,因此这里激活内存做均分处理 + if state == 'f': + consume_mem += self.mem[stage][vpp] / self.input.seq_splits + else: + consume_mem -= self.mem[stage][vpp] / self.input.seq_splits + if stage == 0: + self.solver.Add(consume_mem <= self.memory.mem_lim_stage0) + elif stage == self.input.pipeline_stage - 1: + self.solver.Add(consume_mem <= self.memory.mem_lim_last) + else: + self.solver.Add(consume_mem <= self.memory.mem_lim_others) + + def define_obj(self): + self.solver.Minimize(self.backward_s[0][0][-1][self.distribution[-1] - 1][0] + self.b_duration[0][0] / + self.input.seq_splits + self.residue * self.stable_dur) + + def output_model(self, mps_file): + with open(mps_file, 'w', encoding='utf-8') as file: + mps_text = self.solver.ExportModelAsMpsFormat(False, False) + file.write(mps_text) + logger.info(f'Had write to file: {mps_file}') + + def solve(self): + self.solver.EnableOutput() + logger.info(f'Solving with {self.solver.SolverVersion()}') + if self.expert_input.time_limit != sys.maxsize: + self.solver.SetTimeLimit(self.expert_input.time_limit) + self.solver.Solve() + self.min_time = self.solver.Objective().Value() + logger.info(f'The objective value = {self.min_time}') + + +if __name__ == '__main__': + yaml_file = 'C:\\working\\768_4k.yaml' diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/memory.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..b0324ed0a301fef28a130139dbc5c3032fb31d04 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/memory.py @@ -0,0 +1,74 @@ +from fast_tuner.utils.logger import logger +"""memory model""" + +class Memory: + select_mem0 = 76 + select_mem12 = 76 + select_mem = 77 + re_comp_mem0 = 3 + re_comp_mem12 = 3 + re_comp_mem = 5 + act_mem0 = 79 + act_mem12 = 79 + act_mem = 79 + layer_mem012 = 349 + layer_mem = 340 + static_mem0 = 734 + static_mem = 116 + lm_head_mem = 690 + + def __init__(self, mem_lim): + self.mem_lim = mem_lim + self.mem_lim_stage0 = mem_lim - self.static_mem0 + self.mem_lim_others = mem_lim - self.static_mem + self.mem_lim_last = mem_lim - self.lm_head_mem + + def update_up_mem(self): + self.mem_lim_stage0 = self.mem_lim - self.static_mem0 + self.mem_lim_others = self.mem_lim - self.static_mem + self.mem_lim_last = self.mem_lim - self.lm_head_mem + + def print_mem(self): + logger.info(f'select_mem0={self.select_mem0}, select_mem12={self.select_mem12}, select_mem={self.select_mem}, ' + f're_comp_mem0={self.re_comp_mem0}, re_comp_mem12={self.re_comp_mem12}, ' + f're_comp_mem={self.re_comp_mem}, ' + f'act_mem0={self.act_mem0}, act_mem12={self.act_mem12}, act_mem={self.act_mem}, ' + f'layer_mem012={self.layer_mem012}, layer_mem={self.layer_mem}, ' + f'static_mem0={self.static_mem0}, static_mem={self.static_mem}, ' + f'lm_head_mem={self.lm_head_mem}, mem_lim_stage0={self.mem_lim_stage0}, ' + f'mem_lim_others={self.mem_lim_others}, mem_lim_last={self.mem_lim_last}') + + def write_memory_to_file(self, mem_file): + with open(mem_file, 'w', encoding='utf-8') as file: + file.write(f'select_mem0={self.select_mem0}\n') + file.write(f'select_mem12={self.select_mem12}\n') + file.write(f'select_mem={self.select_mem}\n') + file.write(f're_comp_mem0={self.re_comp_mem0}\n') + file.write(f're_comp_mem12={self.re_comp_mem12}\n') + file.write(f're_comp_mem={self.re_comp_mem}\n') + file.write(f'act_mem0={self.act_mem0}\n') + file.write(f'act_mem12={self.act_mem12}\n') + file.write(f'act_mem={self.act_mem}\n') + file.write(f'layer_mem012={self.layer_mem012}\n') + file.write(f'layer_mem={self.layer_mem}\n') + file.write(f'static_mem0={self.static_mem0}\n') + file.write(f'static_mem={self.static_mem}\n') + file.write(f'lm_head_mem={self.lm_head_mem}\n') + logger.info(f'Write memory info to {mem_file}') + + def get_mem(self): + mem = (f'select_mem0={self.select_mem0}\n' + f'select_mem12={self.select_mem12}\n' + f'select_mem={self.select_mem}\n' + f're_comp_mem0={self.re_comp_mem0}\n' + f're_comp_mem12={self.re_comp_mem12}\n' + f're_comp_mem={self.re_comp_mem}\n' + f'act_mem0={self.act_mem0}\n' + f'act_mem12={self.act_mem12}\n' + f'act_mem={self.act_mem}\n' + f'layer_mem012={self.layer_mem012}\n' + f'layer_mem={self.layer_mem}\n' + f'static_mem0={self.static_mem0}\n' + f'static_mem={self.static_mem}\n' + f'lm_head_mem={self.lm_head_mem}') + return mem diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/micro.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/micro.py new file mode 100644 index 0000000000000000000000000000000000000000..281878a807f112103c1f33bbc9ae3854792fbfb3 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/micro.py @@ -0,0 +1,149 @@ +from fast_tuner.pipeline_conductor.memory import Memory +"""micro proc for pp""" + +class Micro: + part = int + vpp = int + state = 'f' + micro_id = int + split = int + + def __init__(self, part, vpp, state, micro_id, split): + self.part = part + self.vpp = vpp + self.state = state + self.micro_id = micro_id + self.split = split + + +class SortMicro: + parts = int + num_vpp = int + num_stage = int + distribution = [] + low_mem = bool + seq_split = int + is_f_then_b = False + + def __init__(self, parts, num_vpp, num_stage, distribution, low_mem, seq_split): + self.forward = [] + self.backward = [] + self.warmup_num = [] + self.final_orders = [] + self.parts = parts + self.num_vpp = num_vpp + self.num_stage = num_stage + self.distribution = distribution + self.low_mem = low_mem + self.seq_split = seq_split + self.build_f_b_sort() + self.set_warmup_num() + self.set_micro_sort() + + def build_f_b_sort(self): + for part in range(self.parts): + for vpp in range(self.num_vpp): + for micro_id in range(self.distribution[part]): + for split in range(self.seq_split): + micro = Micro(part, vpp, 'f', micro_id, split) + self.forward.append(micro) + for vpp in range(self.num_vpp - 1, -1, -1): + for micro_id in range(self.distribution[part]): + for split in range(self.seq_split - 1, -1, -1): + micro = Micro(part, vpp, 'b', micro_id, split) + self.backward.append(micro) + + def set_warmup_num(self): + for stage in range(self.num_stage): + if self.low_mem: + warmup = min(((self.num_vpp - 1) * self.distribution[0] + (self.num_stage - stage - 1)) * + self.seq_split, len(self.forward)) + else: + warmup = min(((self.num_vpp - 1) * self.distribution[0] + (self.num_stage - stage - 1) * 2) * + self.seq_split, len(self.forward)) + # 最后一个stage,第一个micro前向做完之后才能做后向 + if stage == self.num_stage - 1: + warmup = warmup + self.seq_split - 1 + self.warmup_num.append(warmup) + + def set_micro_sort(self): + for stage in range(self.num_stage): + stage_order = [] + stage_order += self.forward[: self.warmup_num[stage]] + for i in range(self.warmup_num[stage], len(self.forward)): + stage_order.append(self.forward[i]) + stage_order.append(self.backward[i - self.warmup_num[stage]]) + stage_order += self.backward[len(self.forward) - self.warmup_num[stage]:] + self.final_orders.append(stage_order) + + +class PeakNum: + sort_micro = SortMicro + + def __init__(self, sort_micro: SortMicro): + self.peak_num_recompute_type1 = {} + self.peak_num_recompute_type2 = {} + self.peak_num_select_recom_type1 = {} + self.peak_num_select_recom_type2 = {} + self.peak_num_act_type1 = {} + self.peak_num_act_type2 = {} + self.max_mem = {} + self.micro_num_of_max_mem = {} + self.sort_micro = sort_micro + self.num_stage = sort_micro.num_stage + self.num_vpp = sort_micro.num_vpp + + def set_peak_act_recompute_num(self, x_type2, rs_type2, ra_type2, x_type1, rs_type1, ra_type1, memory: Memory): + for stage in range(self.sort_micro.num_stage): + # 各个micro处的内存及激活、重计算份数 + self.max_mem[stage] = 0 + if stage == 0: + static_mem = memory.static_mem0 + elif stage == self.sort_micro.num_stage - 1: + static_mem = memory.lm_head_mem + else: + static_mem = memory.static_mem + layer_mem = (memory.layer_mem012 * sum(x_type1[vpp][stage] for vpp in range(self.num_vpp)) + memory. + layer_mem * sum(x_type2[vpp][stage] for vpp in range(self.num_vpp))) + mem_stage = static_mem + layer_mem + num_recom_type1_stage = 0 + num_recom_type2_stage = 0 + num_select_recom_type1_stage = 0 + num_select_recom_type2_stage = 0 + num_act_type1_stage = 0 + num_act_type2_stage = 0 + for i in range(len(self.sort_micro.final_orders[stage])): + micro_batch = self.sort_micro.final_orders[stage][i] + vpp = micro_batch.vpp + act_num_type1 = x_type1[vpp][stage] - rs_type1[vpp][stage] - ra_type1[vpp][stage] + act_num_type2 = x_type2[vpp][stage] - rs_type2[vpp][stage] - ra_type2[vpp][stage] + act_mem = memory.act_mem12 * act_num_type1 + memory.act_mem * act_num_type2 + ra_mem = memory.re_comp_mem12 * ra_type1[vpp][stage] + memory.re_comp_mem * ra_type2[vpp][stage] + rs_mem = memory.select_mem12 * rs_type1[vpp][stage] + memory.select_mem * rs_type2[vpp][stage] + # 计算时间切成seq_split份,可以看成数据切成seq_split份,即动态内存切分;layer_mem与static_mem不变 + total_mem = (act_mem + ra_mem + rs_mem) / self.sort_micro.seq_split + if micro_batch.state == 'f': + mem_stage += total_mem + num_recom_type1_stage += ra_type1[vpp][stage] + num_recom_type2_stage += ra_type2[vpp][stage] + num_select_recom_type1_stage += rs_type1[vpp][stage] + num_select_recom_type2_stage += rs_type2[vpp][stage] + num_act_type1_stage += act_num_type1 + num_act_type2_stage += act_num_type2 + else: + mem_stage -= total_mem + num_recom_type1_stage -= ra_type1[vpp][stage] + num_recom_type2_stage -= ra_type2[vpp][stage] + num_select_recom_type1_stage -= rs_type1[vpp][stage] + num_select_recom_type2_stage -= rs_type2[vpp][stage] + num_act_type1_stage -= act_num_type1 + num_act_type2_stage -= act_num_type2 + if mem_stage > self.max_mem[stage]: + self.max_mem[stage] = mem_stage + self.peak_num_recompute_type1[stage] = num_recom_type1_stage + self.peak_num_recompute_type2[stage] = num_recom_type2_stage + self.peak_num_select_recom_type1[stage] = num_select_recom_type1_stage + self.peak_num_select_recom_type2[stage] = num_select_recom_type2_stage + self.peak_num_act_type1[stage] = num_act_type1_stage + self.peak_num_act_type2[stage] = num_act_type2_stage + self.micro_num_of_max_mem[stage] = i + 1 diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pipeline_parallel.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pipeline_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..762632687be3a7f7202b438641506b9777972052 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pipeline_parallel.py @@ -0,0 +1,232 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pp solver entry file""" +import sys +import os.path +import argparse +from fast_tuner.utils.logger import logger +from fast_tuner.pipeline_conductor import math_model, pp_util +from fast_tuner.pipeline_conductor.start_service import InitConfig, ExpertInput, HIGHS_NAME, pipeline_output_file +from fast_tuner.pipeline_conductor import solution +from fast_tuner.pipeline_conductor import fitting +from fast_tuner.utils.ppc_input import ParallelInput +from fast_tuner.pipeline_conductor.result_csv import ResultCsv +from fast_tuner.pipeline_conductor.dryrun import DryRun, dryrun_config_error + +from fast_tuner.utils.common import check_dryrun_parallel_number + +mps_dir_name = 'model_mps' +sol_dir_name = 'sol_output' + + +def pp_calculator(expert_input: ExpertInput, model_args =None) -> solution.Solution: + init_config = InitConfig(expert_input) + cur_solution = solve_problem(init_config) + if cur_solution.solution_status == 'None': + return cur_solution + if expert_input.is_double_check: + fit_mem = fitting.FitMem(cur_solution) + over_mem, is_over_mem = fit_mem.is_over_mem() + if not is_over_mem: + return cur_solution + if expert_input.fit_level == 0: + i = 0 + while i < 5 and is_over_mem: + fit_mem.reduce_mem_lim_for_fitting(over_mem, i) + logger.info(f'Correct the memory at the {i + 1} time!') + init_config.memory.print_mem() + cur_solution = solve_problem(init_config) + if cur_solution.solution_status == 'None': + logger.info('dryrun_check result is over memory limit') + return cur_solution + fit_mem = fitting.FitMem(cur_solution) + over_mem, is_over_mem = fit_mem.is_over_mem() + i += 1 + else: + fit_mem.linear_fit() + cur_solution = solve_problem(init_config) + return cur_solution + + +def pipeline_proc(pipeline_input: ParallelInput): + if len(pipeline_input.candidate_configs) == 0: + raise ValueError('There is no candidate configs!') + is_low_mem = pipeline_input.is_lowmem + solver_name = pipeline_input.solver_name + ms_adapter_file = pipeline_input.ms_adapter_file + DryRun.env_config_json = pipeline_input.env_config_json + DryRun.dryrun_lim = pipeline_input.dryrun_lim + ExpertInput.is_dryrun = pipeline_input.dryrun + ExpertInput.is_double_check = pipeline_input.check + result_csv = ResultCsv(pipeline_input.output_path, pipeline_output_file) + num_all = len(pipeline_input.candidate_configs) + num_cur = 0 + for candidate in pipeline_input.candidate_configs: + result_csv.config_to_csv(candidate, is_low_mem, solver_name) + for candidate in pipeline_input.candidate_configs: + candidate_input = ExpertInput(candidate.config_path, ms_adapter_file) + candidate_input.model_args = candidate.model_args + candidate_input.low_mem = is_low_mem + candidate_input.solver_name = solver_name + candidate_input.layer_ratio = candidate.profiling_info.dmratio + candidate_input.backward_ratio = candidate.profiling_info.bfratio + candidate_input.head_loss = candidate.profiling_info.hratio + candidate_input.recompute_ratio = candidate.profiling_info.re_grow_ration + num_cur += 1 + logger.info(f'---------------------- Testing {num_cur}/{num_all}:{candidate.config_path} ----------------------') + try: + cur_solution = pp_calculator(candidate_input) + result_csv.result_to_csv(cur_solution) + except Exception as e: + logger.error(f'{candidate.config_path} error: {e}. Continue to next one') + + +def solve_problem(init_config: InitConfig): + origin_model = math_model.Model(init_config) + origin_model.define_variables() + origin_model.define_constraint() + origin_model.define_obj() + cur_solution = solution.Solution(init_config) + + # output mps + mps_dir = os.path.join(init_config.expert_input.output_file_dir, mps_dir_name) + if not os.path.exists(mps_dir): + os.mkdir(mps_dir) + mps_file = os.path.join(mps_dir, init_config.mps_sol_filename + '.mps') + origin_model.output_model(mps_file) + + # solve + sol_dir = os.path.join(init_config.expert_input.output_file_dir, sol_dir_name) + if not os.path.exists(sol_dir): + os.mkdir(sol_dir) + sol_file = os.path.join(sol_dir, init_config.mps_sol_filename + '.sol') + if not os.path.exists(mps_file): + logger.error('build model error!') + if init_config.expert_input.solver_name == HIGHS_NAME: + is_origin_solver = False + pp_util.highs_solve_mps(mps_file, sol_file, origin_model, init_config.expert_input.time_limit) + elif init_config.expert_input.solver_name == 'QIUQI': + is_origin_solver = False + solver_file = '/home/zhugelu/MIXSolver/bin/MIXSolver' # 更改为本地的求解器地址 + pp_util.qiuqi_solver_mps(solver_file, mps_file, sol_file, origin_model) + else: + is_origin_solver = True + origin_model.solve() + + cur_solution.set_solution(origin_model, is_origin_solver, sol_file) + if cur_solution.solution_status != 'None': + cur_solution.solution_print() + return cur_solution + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='TAYLOR AutoBalancing', description=( + 'Balance layers onto pipeline stages, ' + + 'considering recomputation and interleaving'), + epilog='') + # 大模型类别 + parser.add_argument('-llm', '--llm_class', type=int, default=0, + help="0-deepseek,1-boss") + # Training Yaml configuration + parser.add_argument('-yaml', '--train_yaml', type=str, default=None, + help="Path of training config (.yaml)") + parser.add_argument('-shell', '--train_shell', type=str, default=None, + help="Path of training config (.sh)") + # mindformers location + parser.add_argument('-mindformers', '--mindformers_loc', type=str, default=None, + help="Absolute path of run_mindformers (.py)") + parser.add_argument('-mindspeed', '--mindspeed_loc', type=str, default=None, + help="Absolute path of posttrain_gpt (.py)") + # solver name + parser.add_argument('-solver', '--solver_name', type=str, default=HIGHS_NAME, + help="The solver name") + # layer ratio + parser.add_argument('-layer_ratio', '---layer_ratio', type=float, default=0.33, + help="Time ratio of calculating of dense to moe") + # backward ratio + parser.add_argument('-b_ratio', '---backward_ratio', type=float, default=2.0, + help="Time ratio of calculating of backward to forward") + # head_loss + parser.add_argument('-head_loss', '---head_loss', type=float, default=1.5, + help="Time of the last layer is added") + # recompute_ratio + parser.add_argument('-ra_ratio', '---recompute_ratio', type=float, default=1, + help="Time of the last layer is added") + # Search time + parser.add_argument('-t', '--time_limit', type=int, default=sys.maxsize, + help="Limitation on searching time") + # 是否自动Dryrun + parser.add_argument('-dryrun', '--dryrun', type=pp_util.str2bool, default=True, + help="Is auto dryrun") + # 是否自动check + parser.add_argument('-check', '--check', type=pp_util.str2bool, default=True, + help="IS double check") + parser.add_argument('-is_write', '--is_write', type=pp_util.str2bool, default=True, + help="IS write solution to config file") + # fit level,0:超内存时直接减少内存上限求解;1:超内存时线性回归拟合内存信息求解 + parser.add_argument('-fit', '--fit_level', type=int, default=0, + help="Fit memory when the result is over the limit: 0-reduce the memory limit;" + " 1 or >1-fit the memory info") + # 是否提取solution信息 + parser.add_argument('-extract', '--extract', type=pp_util.str2bool, default=False, + help="Extract solution file separately") + parser.add_argument('-solution', '--solution', default=None, help="The solution file") + # env_config_json + parser.add_argument('-env', '--env_config_json', type=str, required=True, + default='./config/boss_env_config.json', help="Path of environment config (.json)") + parser.add_argument('-register', '--register_path', type=str, default='research/jiutian', + help="Path of register") + parser.add_argument('-dryrun_lim', '--dryrun_lim', type=pp_util.str2int, default=16, + help="The number of dryrun at once") + + args = parser.parse_args() + check_dryrun_parallel_number(args.dryrun_lim) + if args.train_yaml and args.mindformers_loc: + config_file = args.train_yaml + ms_adapter_file = args.mindformers_loc + DryRun.config_file_type = 0 + ExpertInput.is_full_recomp = True + elif args.train_shell and args.mindspeed_loc: + config_file = args.train_shell + ms_adapter_file = args.mindspeed_loc + DryRun.config_file_type = 1 + ExpertInput.is_full_recomp = False + else: + raise TypeError(dryrun_config_error) + + if args.extract: + solution.extract_solution_file(args.train_yaml, args.solution) + sys.exit(0) + + expert_input = ExpertInput(config_file=config_file, ms_adapter_file=ms_adapter_file) + expert_input.solver_name = args.solver_name + expert_input.llm_class = int(args.llm_class) + expert_input.time_limit = int(args.time_limit) + if args.time_limit < sys.maxsize: + logger.warning(f'You have configured the time limit parameter! The solution may be not optimal!') + expert_input.is_dryrun = args.dryrun + expert_input.is_double_check = args.check + expert_input.fit_level = args.fit_level + expert_input.layer_ratio = args.layer_ratio + expert_input.backward_ratio = args.backward_ratio + expert_input.head_loss = args.head_loss + expert_input.recompute_ratio = args.recompute_ratio + + DryRun.env_config_json = args.env_config_json + DryRun.register_path = args.register_path + DryRun.dryrun_lim = args.dryrun_lim + DryRun.is_write_to_file = args.is_write + + pp_calculator(expert_input) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_simulator.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_simulator.py new file mode 100644 index 0000000000000000000000000000000000000000..675f8d7178627b5ed9771bfa8f1dc456e8e31353 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_simulator.py @@ -0,0 +1,926 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pp simulator""" +from __future__ import annotations +from collections.abc import Iterable +import copy +import sys +import time +from dataclasses import dataclass, field +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle, Polygon +from matplotlib import colors +from matplotlib.transforms import ScaledTranslation +sys.setrecursionlimit(8192) + +def format_2d_inputs(a, raw, col): + if isinstance(a, (int, float)): + return np.broadcast_to(a, (raw, col)) + if isinstance(a, (list, tuple)): + if all(isinstance(item, (list, tuple)) for item in a): + return np.array(a) + if all(isinstance(item, (int, float)) for item in a): + return np.array([a]) + raise ValueError(f"Unsupported inputs: {a}") + else: + raise ValueError(f"Unsupported inputs: {a}") + +def apply_color(l:list, c:list[str]): + for i in range(len(l)): + l[i] = f'{l[i]:.4f}' if isinstance(l[i], float) else l[i] + l[i] = f"\033[{c[i]}m{l[i]}\033[0m" + return l + +def apply_format(l:list): + s = f'{l[0]:^22}' + symbol = ['=', '+', '+', '+', '+', '+'] + for i in range(len(l) - 1): + s = f'{s}{symbol[i]}{l[i + 1]:^22}' + return s + +def color_mix(c1, c2, w1=0.5, w2=0.5): + rgb = (np.array(colors.to_rgba(c1, 1)) * w1 + np.array(colors.to_rgba(c2, 1)) * w2) / (w1 + w2) + return colors.to_rgba(rgb) + + +class CausalError(Exception): + def __init__(self, msg, blocks:list[list[MicroBlockSim]]=None, loop:list[BlockSim]=[]) -> None: + self.msg = msg + self.canvas = PlotMgr(num_plots=1, figsize=(12, 6)) + self.canvas.draw_loop(blocks, loop, 0, False, False, True) + self.canvas.ax[0].set_title("Block pipeline dependency") + super().__init__() + print(f"{self.canvas.msg}") + + def __str__(self): + plt.show() + return f"{self.msg}" + + +class CausalCommError(Exception): + def __init__(self, msg, blocks:list[list[MicroBlockSim]]=None, loop:list[BlockSim]=[]) -> None: + self.msg = msg + self.canvas = PlotMgr(num_plots=1, figsize=(12, 6)) + self.canvas.draw_comm_loop(blocks, loop, 0) + self.canvas.ax[0].set_title("Block comm pipeline dependency") + super().__init__() + print(f"{self.canvas.msg}") + + def __str__(self): + plt.show() + return f"{self.msg}" + +def dfs_builder(comm=False): + def decorator(func): + def wrapper(*args, **kwargs): + self = args[0] + pre, left = (self.depend_pre, self.depend_left) if comm else (self.pre, self.left) + if self.finish: + return + if pre is None or left is None: + raise NotImplementedError + if self.in_queue: + raise ValueError + self.in_queue = True + res = func(*args, **kwargs) + self.finish = True + self.in_queue = False + return res + return wrapper + return decorator + +def timer(func:function): + def wrapper(*args, **kwargs): + T0 = time.time() + res = func(*args, **kwargs) + T1 = time.time() - T0 + print(f"function `{func.__name__}` time used: {T1:.4f} s", flush=True) + return res + return wrapper + + +class PlotMgr: + def __init__(self, num_plots=2, ax_type='block', subplot_args=None, *args, **kwargs): + self.fig = plt.figure(figsize=kwargs.get('figsize', (12, 8))) + self.fig.subplots_adjust(wspace=0, hspace=0.4) + ax_type = ax_type if isinstance(ax_type, (list, tuple)) else [ax_type] * num_plots + self.ax = [] + for i in range(num_plots): + if subplot_args is None: + self.ax.append(self.fig.add_subplot(num_plots * 100 + 10 + i + 1)) + elif isinstance(subplot_args, Iterable) and len(subplot_args) >= num_plots: + self.ax.append(self.fig.add_subplot(subplot_args[i])) + else: + raise ValueError(f"Unsupported subplot_args format: {subplot_args}") + + def _set_block_ax(self, ax:plt.Axes, pp:int) -> plt.Axes: + ax.set_title("Pipeline Flow Timeline") + ax.set_yticks(range(pp), [f"stage {p}" for p in range(pp)]) + for tick in ax.get_yticklabels(): + tick.set_verticalalignment('top') + tick.set_transform(tick.get_transform() + ScaledTranslation(0, 0.05-1/pp, self.fig.dpi_scale_trans)) + tick.set_fontsize(12) + ax.set_ylim(0, pp) + ax.invert_yaxis() + + def _get_block_indices(self, blocks:list[list[MicroBlockSim]], mode='compact', equal_wide=False): + if mode not in ['compact', 'joint', 'timeline']: + raise ValueError(f"Get unsupported draw mode: {mode}") + if mode == 'timeline' and not blocks[-1][-1].finish: + raise ValueError(f"Block building should be finished before drawing timeline") + block_index = [] + for p in range(len(blocks)): + inds = [] + for block in blocks[p]: + if mode == 'compact': + if block._type == 'c': + inds.append(1 if equal_wide else block.time) + else: + inds.append(0) + elif mode == 'joint': + if block._type == 'c': + inds.append(1 if equal_wide else block.time) + else: + inds.append(block.time) + else: + inds.append(1) + inds.insert(0, 0) + inds = np.cumsum(inds) + block_index.append(inds) + return block_index + + def draw_block(self, block_index:list[list[float]], blocks:list[list[MicroBlockSim]], ax_index:int = 0, equal_wide=False, width=1, phase=False): + for p in range(len(blocks)): + for b, block in enumerate(blocks[p]): + if block._type == 'c': + block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide, width=width, phase=phase) + return self + + def draw_comm(self, block_index:list[list[float]], blocks:list[list[MicroBlockSim]], ax_index:int = 0, equal_wide=False, mode='compact'): + for p in range(len(blocks)): + for b, block in enumerate(blocks[p]): + if block._type == 'c' and mode == 'compact': + if block.send_block: + block.send_block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide) + if block.rec_block: + block.rec_block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide) + elif block._type in ['s', 'r'] and mode in ['joint', 'timeline']: + block.draw(self.ax[ax_index], index=block_index[p][b], equal_wide=equal_wide, mode=mode) + return self + + def draw_connect(self, block_index:list[list[float]], blocks:list[list[MicroBlockSim]], ax_index:int = 0, equal_wide=False, mode='compact'): + for p in range(len(blocks)): + for b, block in enumerate(blocks[p]): + if block._type == 'c' and mode == 'compact' and block.send_block: + dual_p = block.send_block.dual._stage + dual_ind = blocks[dual_p].index(block.send_block.dual.host) + block.send_block.draw_comm(self.ax[ax_index], index_from=block_index[p][b], index_to=block_index[dual_p][dual_ind], equal_wide=equal_wide, mode=mode) + elif block._type == 's' and mode in ['joint', 'timeline']: + dual_p = block.dual._stage + dual_ind = blocks[dual_p].index(block.dual) + block.draw_comm(self.ax[ax_index], index_from=block_index[p][b], index_to=block_index[dual_p][dual_ind], equal_wide=equal_wide, mode=mode) + return self + + def draw(self, blocks:list[list[MicroBlockSim]], ax_index:int = 0, comm=False, connect=False, equal_wide=False, mode='compact', phase=False) -> PlotMgr: + pp = len(blocks) + block_index = self._get_block_indices(blocks, mode=mode, equal_wide=equal_wide) + width = max(np.max(block_index[p]) for p in range(pp)) if blocks[0][-1].end is None else max(blocks[p][-1].end for p in range(pp)) + self.draw_block(block_index, blocks, ax_index, equal_wide, width, phase=phase) + if comm: + self.draw_comm(block_index, blocks, ax_index, equal_wide, mode) + if connect: + self.draw_connect(block_index, blocks, ax_index, equal_wide, mode) + self._set_block_ax(self.ax[ax_index], pp) + self.ax[ax_index].set_xlim(0, width) + return self + + def draw_loop(self, blocks:list[list[MicroBlockSim]], loop:list[BlockSim], ax_index:int = 0, comm=False, connect=False, equal_wide=False) -> PlotMgr: + self.draw(blocks, ax_index, comm, connect, equal_wide, phase=True) + block_index = self._get_block_indices(blocks, equal_wide=equal_wide) + msg = 'dependency loop: ' + for b in range(len(loop) - 1): + p = loop[b]._stage + ind = blocks[p].index(loop[b]) + x1, y1, dx1, _ = loop[b].loc_size(block_index[p][ind], equal_wide) + p = loop[b + 1]._stage + ind = blocks[p].index(loop[b + 1]) + x2, y2, dx2, _ = loop[b + 1].loc_size(block_index[p][ind], equal_wide) + msg = f'{msg} {loop[b].color_label} -> ' + self.ax[ax_index].annotate(None, xy=(x1 + dx1 / 2, y1), xytext=(x2 + dx2 / 2, y2), + arrowprops=dict(fc='white', ec='r', arrowstyle='simple', shrinkA=5, shrinkB=5, connectionstyle="arc3,rad=-0.1")) + self.msg = f'{msg} {loop[len(loop) - 1].color_label}' + return self + + def draw_comm_loop(self, lines:list[list[BlockSim]], loop:list[BlockSim], ax_index:int = 0) -> PlotMgr: + self.draw(lines, ax_index, True, True, True, 'joint', phase=True) + block_index = self._get_block_indices(lines, mode='joint', equal_wide=True) + msg = 'dependency loop: ' + for b in range(len(loop) - 1): + p = loop[b]._stage + ind = lines[p].index(loop[b]) + x1, y1, dx1, _ = loop[b].loc_size(block_index[p][ind], True, 'joint') + p = loop[b + 1]._stage + ind = lines[p].index(loop[b + 1]) + x2, y2, dx2, _ = loop[b + 1].loc_size(block_index[p][ind], True, 'joint') + msg = f'{msg} {loop[b].color_label} -> ' + self.ax[ax_index].annotate(None, xy=(x1 + abs(dx1) / 2, y1), xytext=(x2 + abs(dx2) / 2, y2), size=10, + arrowprops=dict(fc='white', ec='r', arrowstyle='simple', shrinkA=3, shrinkB=3, connectionstyle="arc3,rad=-0.1", lw=0.8)) + self.msg = f'{msg} {loop[len(loop) - 1].color_label}' + return self + + def draw_mem(self, block_mem_list:list[np.ndarray], ax_index:int = 0) -> PlotMgr: + for p in range(len(block_mem_list)): + self.ax[ax_index].plot((block_mem_list[p].T)[0], (block_mem_list[p].T)[1], label=f"stage-{p}") + self.ax[ax_index].set_title("Block Memory Timeline") + self.ax[ax_index].set_xlim(0, max(np.max((block_mem.T)[0]) for block_mem in block_mem_list)) + + def draw_info(self, bubble_info:dict, mem_info:list): + info_list = [f'{k} bubble: {v:.4f}' for k, v in bubble_info.items()] + self.fig.text(0.5, 0.5, ', '.join(info_list), ha='center', va='center', + fontdict={'fontsize':13, 'weight':'medium'}, color='C3') + info_list = [f"{v:.2f}" for v in mem_info] + self.fig.text(0.5, 0.05, f"peak memory: {', '.join(info_list)}", ha='center', va='center', + fontdict={'fontsize':13, 'weight':'medium'}, color='C0') + + def show(self): + self.fig.legend(bbox_to_anchor=(0.22, 0.45)) + plt.show() + plt.savefig("figure.pdf") # 默认存储在当前目录 + + +@dataclass +class BlockSim: + _stage: int # p + _state: str # s + _id: int # m + _virtual: int # v + time: float + _type: str + start: float = None + end: float = None + pre: BlockSim = field(repr=False, default=None) + left: BlockSim = field(repr=False, default=None) + right: MicroBlockSim = field(repr=False, default=None) + depend_pre: BlockSim = field(repr=False, default=None) + depend_left: BlockSim = field(repr=False, default=None) + finish = False + in_queue = False + _flag = False + _color = '0;38' + father: BlockSim = field(repr=False, default=None) + @property + def label(self) -> tuple: + return (self._type, self._state, self._id, self._virtual, self._stage) + @property + def color_label(self) -> str: + return f"\033[{self._color}m{self.label}\033[0m" + @property + def repr(self) -> str: + raise NotImplementedError + def draw(self, ax:plt.Axes, *args, **kwargs): + raise NotImplementedError + @dfs_builder(False) + def build_without_comm(self) -> None: + r"""Build pipeline timeline without comm blocks and dependency.""" + self.pre.build_without_comm() + self.left.build_without_comm() + self.start = max(self.pre.end, self.left.end) + self.end = self.start + self.time + @dfs_builder(True) + def build_with_comm(self) -> None: + r"""Build pipeline timeline with comm blocks and dependency.""" + self.depend_pre.build_with_comm() + self.depend_left.build_with_comm() + self.start = max(self.depend_pre.end, self.depend_left.end) + self.end = self.start + self.time + def reset_time_recursive(self) -> None: + raise NotImplementedError + def reset_time(self) -> None: + self.start = None + self.end = None + self.finish = False + def loc_size(self, x:int = 0, equal_wide=False, mode='compact') -> tuple: + x = x if self.start is None else self.start + dx = 1 if equal_wide else self.time + return x, self._stage + 0.5, dx, 1 + def loop(self, comm=False) -> list[BlockSim]: + if self._flag and not self.in_queue: + return [] + l = [] + if self.in_queue: + loop = [self] + block = self.father + while block.father and block is not self: + block = block.father + loop.append(block) + return loop + self._flag = True + self.in_queue = True + depends = [self.depend_pre, self.depend_left] if comm else [self.pre, self.left] + for dep in depends: + if dep: + dep.father = self + l.extend(dep.loop(comm=comm)) + dep.father = None + self.in_queue = False + return l + + def comm_loop(self) -> list[BlockSim]: + return self.loop(True) + +@dataclass +class HeadBlockSim(BlockSim): + _stage: int # p + _type: str = 'h' + _id: int = field(repr=False, init=False) + _state: str = field(repr=False, init=False) + _virtual: int = field(repr=False, init=False) + time: float = 0. + start: float = 0. + end: float = 0. + finish = True + @property + def label(self) -> tuple: + return (self._type, self._stage) + @property + def repr(self) -> str: + s_list = [] + block = self + while block: + s_list.append(block.__repr__()) + block = block.right + return '\n'.join(s_list) + def draw(self, ax, *args, **kwargs): + return + def build_without_comm(self): + return + def build_with_comm(self): + return + def reset_time_recursive(self): + return + +@dataclass +class MicroBlockSim(BlockSim): + _type: str = 'c' + mem: float = 0. + phase: str = None + send_block: SendBlockSim = field(repr=False, default=None) + rec_block: RecBlockSim = field(repr=False, default=None) + def __post_init__(self): + self._color = '1;34' if self._state == 'f' else '1;33' + + def draw(self, ax:plt.Axes, *args, **kwargs) -> None: + x, y, dx, dy = self.loc_size(kwargs.get('index', 0), kwargs.get('equal_wide', False)) + color = (167/255, 184/255, 231/255) if self._state == 'f' else (255/255, 213/255, 143/255) + mix_color = (240/255, 255/255, 245/255) if self._state == 'f' else (255/255, 240/255, 255/255) + color = color_mix(mix_color, color, w1=self._virtual / 3) + if self.phase == 'warmup' and kwargs.get('phase', False): + edgecolor = 'lightblue' + elif self.phase == 'cooldown' and kwargs.get('phase', False): + edgecolor = 'orange' + else: + edgecolor = 'black' + rect = Rectangle((x, y - dy / 2), dx, dy, facecolor=color, edgecolor=edgecolor, linewidth=0.4) + if dx > 0.008 * kwargs.get('width', 0): + ax.text(rect.xy[0] + dx / 2, rect.xy[1] + dy / 2, str(self._id), ha='center', va='center', color='black', fontdict={'fontsize':9}) + ax.add_patch(rect) + + def reset_time_recursive(self) -> None: + if self.finish: + self.pre.reset_time_recursive() + self.left.reset_time_recursive() + self.reset_time() + +@dataclass +class CommBlockSim(BlockSim): + host: MicroBlockSim = field(repr=False, default=None) + dual: CommBlockSim = field(repr=False, default=None) + def joint_loc(self, index:int = 0, equal_wide=False) -> tuple: + raise NotImplementedError + def get_triangle(self, x, y, dx, dy) -> tuple: + raise NotImplementedError + def draw(self, ax:plt.Axes, *args, **kwargs) -> None: + color = (167/255, 184/255, 231/255) if self._state == 'f' else (255/255, 213/255, 143/255) + mix_color = (240/255, 255/255, 255/255) if self._state == 'f' else (255/255, 240/255, 255/255) + color = color_mix(mix_color, color, w1=1.2*self._virtual/3) + index, equal_wide, mode = (kwargs.get('index', 0), kwargs.get('equal_wide', False), kwargs.get('mode', 'compact')) + x, y, dx, dy = self.loc_size(index, equal_wide, mode) + xy = self.get_triangle(x, y, dx, dy) + tri = Polygon(xy, closed=True, facecolor=color, edgecolor='black', linewidth=0.4) + ax.add_patch(tri) + +@dataclass +class SendBlockSim(CommBlockSim): + _type: str = 's' + _color = '35' + def loc_size(self, index:int = 0, equal_wide=False, mode='compact') -> tuple: + host_x, _, host_dx, _ = self.host.loc_size(index, equal_wide) + x, y, _, _ = super().loc_size(index, equal_wide) + _dx = self.time + _dy = min(np.sqrt(self.time) * 0.6, 0.6) + if mode == 'compact': + x = host_x + host_dx - _dx + return x, y, _dx, _dy + def get_triangle(self, x, y, dx, dy) -> tuple: + return [[x, y - dy / 2], [x, y + dy / 2], [x + dx, y]] + + def draw_comm(self, ax:plt.Axes, *args, **kwargs) -> None: + index_from, index_to = (kwargs.get('index_from', 0), kwargs.get('index_to', 0)) + equal_wide, mode = (kwargs.get('equal_wide', False), kwargs.get('mode', 'compact')) + x, y, dx, _ = self.loc_size(index_from, equal_wide, mode) + x_, y_, dx_, _ = self.dual.loc_size(index_to, equal_wide, mode) + ax.annotate(None, xy=(x_ - dx_/2, y_), xytext=(x + dx/2, y), arrowprops=dict(ec='grey', arrowstyle='->', shrinkA=2, shrinkB=2)) + + @dfs_builder(True) + def build_with_comm(self) -> None: + r"""Build pipeline timeline with comm blocks and dependency.""" + self.dual.depend_left.build_with_comm() + self.depend_left.build_with_comm() + self.start = max(self.depend_left.end, self.dual.depend_left.end) + self.end = self.start + self.time + + def loop(self, comm=False) -> list[BlockSim]: + if comm: + return self.comm_loop() + return super().loop(comm) + + def comm_loop(self) -> list[BlockSim]: + if self._flag and not self.in_queue: + return [] + l = [] + if self.in_queue: + loop = [self] + block = self.father + while block.father and block is not self: + block = block.father + loop.append(block) + return loop + self._flag = True + self.in_queue = True + depends = [self.dual.depend_left, self.depend_left] + for dep in depends: + if dep: + dep.father = self + l.extend(dep.comm_loop()) + dep.father = None + self.in_queue = False + return l + +@dataclass +class RecBlockSim(CommBlockSim): + _type: str = 'r' + _color = '32' + def loc_size(self, index:int = 0, equal_wide=False, mode='compact') -> tuple: + host_x, _, _, _ = self.host.loc_size(index, equal_wide) + x, y, _, _ = super().loc_size(index, equal_wide) + _dx = self.time + _dy = min(np.sqrt(self.time) * 0.6, 0.6) + if mode == 'compact': + x = host_x + return x, y, -_dx, -_dy + def get_triangle(self, x, y, dx, dy) -> tuple: + return [[x, y], [x - dx, y + dy / 2], [x - dx, y - dy / 2]] + + @dfs_builder(True) + def build_with_comm(self) -> None: + r"""Build pipeline timeline with comm blocks and dependency.""" + self.dual.build_with_comm() + self.depend_left.build_with_comm() + self.start = max(self.depend_left.end, self.dual.start) + self.end = self.start + self.time + +class PipelineBuild: + @staticmethod + def _inter_merge(a: list[MicroBlockSim], b: list[MicroBlockSim], delta: int = 0) -> list[MicroBlockSim]: + res = [] + if delta >= 0: + res.extend(a[:delta]) + a = a[delta:] + else: + res.extend(b[:-delta]) + b = b[-delta:] + stable_count = 0 + while len(a): + block = a.pop(0) + block.phase = 'stable' + res.append(block) + stable_count += 1 + if len(b): + block = b.pop(0) + block.phase = 'stable' + res.append(block) + stable_count += 1 + else: + break + if stable_count: + res[-1].phase = 'cooldown' + if len(a): + res.extend(a) + elif len(b): + res.extend(b) + return res + + @staticmethod + def _build_chain(line: list[MicroBlockSim], p: int) -> list[BlockSim]: + head = HeadBlockSim(p) + left = head + for item in line: + left.right = item + item.left = left + left = item + if p == 0: + head.right.pre = head + return line + + @staticmethod + def build_1f1b(pp, micro_num, p, forward_time, backward_time, block_mem) -> list[BlockSim]: + for_line = [MicroBlockSim(p, 'f', i, 0, forward_time, mem=block_mem) for i in range(micro_num)] + back_line = [MicroBlockSim(p, 'b', i, 0, backward_time, mem=block_mem) for i in range(micro_num)] + line = PipelineBuild._inter_merge(for_line, back_line, pp - p - 1) + return PipelineBuild._build_chain(line, p) + + @staticmethod + def build_virtualpipeline(pp, micro_num, vp, p, forward_time, backward_time, block_mem) -> list[BlockSim]: + for_line = [] + back_line = [] + r = micro_num % pp + for inter in range(micro_num // pp): + for i in range(vp): + if inter == 0: + for_line.extend([MicroBlockSim(p, 'f', m, i, forward_time[i], mem=block_mem[i], phase='warmup') for m in range(r)]) + back_line.extend([MicroBlockSim(p, 'b', m, i, backward_time[i], mem=block_mem[i], phase='cooldown') for m in range(r)]) + for_line.extend([MicroBlockSim(p, 'f', r + m + inter * pp, i, forward_time[i], mem=block_mem[i], phase='warmup') for m in range(pp)]) + back_line.extend([MicroBlockSim(p, 'b', r + m + inter * pp, i, backward_time[i], mem=block_mem[i], phase='cooldown') for m in range(pp)]) + else: + line = PipelineBuild._inter_merge(for_line, back_line, (vp + 1) * pp - 2 * p - 2 + r * (vp - 1)) + return PipelineBuild._build_chain(line, p) + + @staticmethod + def build_virtualpipeline2(pp, micro_num, vp, p, forward_time, backward_time, block_mem): + for_line = [] + back_line = [] + r = micro_num % pp + for inter in range(micro_num // pp): + for i in range(vp): + if inter == 0: + for_line.extend([MicroBlockSim(p, 'f', m, i, forward_time[i], mem=block_mem[i]) for m in range(r)]) + back_line.extend([MicroBlockSim(p, 'b', m, i, backward_time[i], mem=block_mem[i]) for m in range(r)]) + for_line.extend([MicroBlockSim(p, 'f', r + m + inter * pp, i, forward_time[i], mem=block_mem[i]) for m in range(pp)]) + back_line.extend([MicroBlockSim(p, 'b', r + m + inter * pp, i, backward_time[i], mem=block_mem[i]) for m in range(pp)]) + line = PipelineBuild._inter_merge(for_line, back_line, vp * pp - p - 1) + return PipelineBuild._build_chain(line, p) + + +class PipelineSimulator: + def __init__(self, block_time:list, micro_num:int, comm_time:float = 0.1, layer_recompute=False, block_mem=1, backward_ratio=2.): + self.init(block_time, micro_num, comm_time, layer_recompute, block_mem, backward_ratio) + + def init(self, block_time, micro_num, comm_time, layer_recompute, block_mem, backward_ratio, *args, **kwargs): + self.micro_num = micro_num + self.pp, self.vp = self._base_init(block_time) + self.block_num = 2 * self.vp * self.micro_num + self.comm_time = comm_time + self._input_format(block_time, layer_recompute, block_mem, backward_ratio) + self._statistic_init() + self._comm = True + self.adjust_func_list = [self.swap_send_rec] + # Construct pipeline blocks + if self.vp == 1: + self.blocks = [PipelineBuild.build_1f1b(self.pp, self.micro_num, p, + self.block_time[0, p], + self.backward_time[0, p], + self.block_mem[0, p]) for p in range(self.pp)] + else: + self.blocks = [PipelineBuild.build_virtualpipeline(self.pp, self.micro_num, self.vp, p, + self.block_time[:, p], + self.backward_time[:, p], + self.block_mem[:, p]) for p in range(self.pp)] + if self.micro_num >= self.pp: + self.adjust_func_list = [self.vpp_send_delay, self.residue_delay] + self.adjust_func_list + self._build_block() # create connection among compute blocks + self._build_comm_block() # create comm blocks for each compute block + + def run(self, comm=True, print_info=True) -> PipelineSimulator: + self._comm = comm + self._check_loop() + if comm: + self.lines = self._create_lines(*self.adjust_func_list) + self._check_comm_loop() + for b in range(self.block_num): + for p in range(self.pp): + self.blocks[p][b].build_with_comm() + self.lines[0][-1].build_with_comm() + else: + for p in range(self.pp): + for block in self.blocks[p]: + block.build_without_comm() + self._statistic_info() + if print_info: + self.print_info() + return self + + def show(self, comm=True, connect=None) -> PipelineSimulator: + self.canvas = PlotMgr(2, ['block', 'memory']) + if self._comm: + connect = True if connect is None else connect + self.canvas.draw(self.lines, 0, comm, connect, False, 'timeline') + else: + connect = False if connect is None else connect + self.canvas.draw(self.blocks, 0, comm, connect, False, 'timeline') + self.canvas.draw_mem(self.states['block_mem_list'], 1) + self.canvas.draw_info(self.bubbles, self.peak_memory) + self.canvas.show() + return self + + def print_info(self): + print('\033[1;37m' + '—'*13, f'pp:{self.pp:>2}, vp:{self.vp:>2}, micro:{self.micro_num:>3} ' + '—'*12 + '\033[0m') + print('-'*20, ' bubble ', '-'*20) + print(apply_format(apply_color(list(self.bubbles.keys()), ['1;33', '1;32', '1;31', '1;35', '1;36']))) + print(apply_format(apply_color(list(self.bubbles.values()), ['1;33', '1;32', '1;31', '1;35', '1;36']))) + print('-'*20, ' memory ', '-'*20) + print(f"peak memory: {', '.join([f'{v:.2f}' for v in self.peak_memory])}") + return self + + def _base_init(self, block_time) -> tuple: + if isinstance(block_time, (list, tuple)): + if all(isinstance(item, (list, tuple)) for item in block_time): + vp = len(block_time) + pp = len(block_time[0]) + elif all(isinstance(item, (int, float)) for item in block_time): + vp = 1 + pp = len(block_time) + else: + raise ValueError(f"Unsupported input format block_time: {block_time}") + else: + raise ValueError(f"Unsupported input format block_time: {block_time}") + if self.micro_num < pp: + raise ValueError(f" `micro_num`({self.micro_num}) should equal or larger than `pp`({pp})") + return pp, vp + + def _input_format(self, block_time, layer_recompute, block_mem, backward_ratio) -> None: + self.block_time = format_2d_inputs(block_time, self.vp, self.pp) + if isinstance(layer_recompute, bool): + self.layer_recompute = self.block_time if layer_recompute else format_2d_inputs(0, self.vp, self.pp) + else: + self.layer_recompute = format_2d_inputs(layer_recompute, self.vp, self.pp) + if isinstance(block_mem, (int, float)): + self.block_mem = self.block_time * block_mem + else: + self.block_mem = format_2d_inputs(block_mem, self.vp, self.pp) + self.backward_ratio = format_2d_inputs(backward_ratio, self.vp, self.pp) + + def _statistic_init(self) -> None: + self.forward_time = self.block_time + self.backward_time = np.flip(self.block_time * self.backward_ratio + self.layer_recompute, axis=0) + self.states = {'last_time': np.zeros(self.pp), + 'warmup_time': np.zeros(self.pp), + 'cooldown_time': np.zeros(self.pp), + 'stable_free_time': (np.zeros((self.vp, self.pp)), np.zeros((self.vp, self.pp))), + 'block_mem_list': [np.array([[0, 0]]) for _ in range(self.pp)]} + self.model_compute_time = (np.sum(self.forward_time) + np.sum(self.forward_time * self.backward_ratio)) * self.micro_num + self.hardware_compute_time = (np.sum(self.forward_time) + np.sum(self.backward_time)) * self.micro_num + self.bubbles = {'real': 0, + 'ideal': (self.pp - 1) / self.vp / self.micro_num, + 'imba': 0, + 'comm': 0} + if np.sum(self.layer_recompute) > 1e-5: + self.bubbles['recompute'] = self.hardware_compute_time / self.model_compute_time - 1 + p, v, m = self.pp, self.vp, self.micro_num + if self.vp == 1: + if self.pp == 2: + self.bubbles['comm'] = 4*m + elif self.pp %2 == 0: + self.bubbles['comm'] = 4*p*m + 4*p**2 - 14*p + else: + self.bubbles['comm'] = 4*p*m + 4*p**2 - 12*p + elif self.pp <= 5: + comm_coef_list = [[4, -2, 0], [6, -2, -6], [4, 0, 12], [6, -2, 40]] + self.bubbles['comm'] = np.dot(np.array([p*v*m, m*p, 1]), comm_coef_list[self.pp - 2]) + elif self.pp % 2 == 0: + self.bubbles['comm'] = 4*p*v*m + 4*p**2 - 13*p + else: + self.bubbles['comm'] = 6*p*v*m - 2*v*p**2 + 4*v*p - 2*p*m + 6*p**2 -16*p + + self.bubbles['comm'] *= self.comm_time / self.model_compute_time + + def _statistic_info(self) -> None: + for p in range(self.pp): + blocks = self.lines[p] if self._comm else self.blocks[p] + current_mem = 0 + for block in blocks: + if block._type == 'c' and block._state == 'f': + current_mem += block.mem + elif block._type == 'c' and block._state == 'b': + if not self._comm or not block.rec_block: + current_mem -= block.mem + elif block._type == 'r' and block.host._state == 'b': + current_mem -= block.host.mem + block = block.host + else: + continue + self.states['block_mem_list'][p] = np.append(self.states['block_mem_list'][p], np.array([[block.end, current_mem]]), axis=0) + self.states['block_mem_list'][p] = np.append(self.states['block_mem_list'][p], np.array([[blocks[-1].end, current_mem]]), axis=0) + self.peak_memory = [np.max((self.states['block_mem_list'][p].T)[1]) for p in range(self.pp)] + self.end_time = max(np.max((self.states['block_mem_list'][p].T)[0]) for p in range(self.pp)) + self.bubbles['real'] = (self.pp * self.end_time - self.model_compute_time) / self.model_compute_time + self.bubbles['imba'] = self.bubbles['real'] - self.bubbles['ideal'] + 1e-10 + if not self._comm: + self.bubbles.pop('comm') + else: + self.bubbles['imba'] -= self.bubbles['comm'] + if self.bubbles.get('recompute'): + self.bubbles['imba'] -= self.bubbles['recompute'] + + def _get_pre_label(self, label:tuple) -> tuple: + t, s, m, v, p = label + if (s, v, p) == ('f', 0, 0): + return ('h', p) + if (s, p) == ('f', 0): + return (t, s, m, v - 1, self.pp - 1) + if (s, p) == ('b', self.pp - 1): + if v == 0: + return (t, 'f', m, self.vp - 1, p) + return (t, s, m, v - 1, 0) + if s == 'f': + return (t, s, m, v, p - 1) + if s == 'b': + return (t, s, m, v, p + 1) + raise ValueError(f"Illegal label: {label}") + + def _build_block(self) -> None: + r"""Build `pre` relation for computation blocks.""" + books = {self.blocks[0][0].pre.label: self.blocks[0][0].pre} + for p in range(self.pp): + for item in self.blocks[p]: + books[item.label] = item + for p in range(self.pp): + block = self.blocks[p][0] + while block: + pre_label = self._get_pre_label(block.label) + block.pre = books[pre_label] + block = block.right + + def _build_comm_block(self) -> None: + r"""Build `send_block` and `rec_block` relation among a computation block and two comm blocks.""" + for p in range(self.pp): + block = self.blocks[p][0] + while block: + pre = block.pre + if pre._stage != block._stage: + block.rec_block = RecBlockSim(p, block._state, block._id, block._virtual, self.comm_time) + pre.send_block = SendBlockSim(pre._stage, pre._state, pre._id, pre._virtual, self.comm_time) + block.rec_block.host = block + block.rec_block.dual = pre.send_block + pre.send_block.host = pre + pre.send_block.dual = block.rec_block + block.depend_pre = block.rec_block + block.rec_block.depend_pre = pre.send_block + pre.send_block.depend_pre = pre + else: + block.depend_pre = pre + block = block.right + + def _check_loop(self) -> None: + loop = self.blocks[0][-1].loop() + if loop: + raise CausalError('Block dependency exist loops!', self.blocks, loop) + for p in range(self.pp): + for block in self.blocks[p]: + block._flag = False + + def _check_comm_loop(self) -> None: + loop = self.lines[0][-1].comm_loop() + if loop: + raise CausalCommError('Block comm dependency exist loops!', self.lines, loop) + for p in range(self.pp): + for block in self.lines[p]: + block._flag = False + + def _create_lines(self, *adjust_func) -> list[list[BlockSim]]: + lines = [copy.copy(self.blocks[p]) for p in range(self.pp)] + for p in range(self.pp): + for b in range(self.block_num): + block = self.blocks[p][b] + pre = block.pre + if block.rec_block: + lines[p].insert(lines[p].index(block), block.rec_block) + if pre._type == 'h': + lines[pre._stage].insert(0, pre.send_block) + else: + lines[pre._stage].insert(lines[pre._stage].index(pre) + 1, pre.send_block) + for func in adjust_func: + lines = func(lines) + for p in range(self.pp): + for b, block in enumerate(lines[p]): + if b == 0: + block.depend_left = block.left if block.left else block.host.left + else: + block.depend_left = lines[p][b - 1] + return lines + + def _get_block_phase(self, p:int, b:int) -> str: + r = self.micro_num % self.pp + if b < (self.vp + 1) * self.pp - 2 * p - 2 + r: + return 'warmup' + if b > self.block_num - (self.vp + 1) * self.pp + 2 * p: + return 'cooldown' + return 'stable' + + def _send_block_delay(self, lines, p:int, b:int, distance:int) -> None: + i_send = lines[p].index(self.blocks[p][b].send_block) + send_block = lines[p].pop(i_send) + i_new = lines[p].index(self.blocks[p][b + distance]) + 1 + lines[p].insert(i_new, send_block) + + def swap_send_rec(self, lines) -> list[list[BlockSim]]: + for p in range(self.pp): + for b, block in enumerate(self.blocks[p]): + if b >= len(self.blocks[p]) - 1: + continue + i_b = lines[p].index(block) + i_bn = lines[p].index(self.blocks[p][b + 1]) + if i_bn - i_b == 3: + if p % 2 == 0 and lines[p][i_b+1]._type == 'r' and lines[p][i_b+2]._type == 's': + lines[p][i_b+1], lines[p][i_b+2] = lines[p][i_b+2], lines[p][i_b+1] + if p % 2 == 1 and lines[p][i_b+1]._type == 's' and lines[p][i_b+2]._type == 'r': + if block.phase == 'warmup' and self.blocks[p][b + 1].phase == 'cooldown': + continue + lines[p][i_b+1], lines[p][i_b+2] = lines[p][i_b+2], lines[p][i_b+1] + if lines[p][i_b+1].dual._stage == lines[p][i_b+2].dual._stage: + pd = lines[p][i_b+1].dual._stage + j_b1 = lines[pd].index(lines[p][i_b+1].dual) + j_b2 = lines[pd].index(lines[p][i_b+2].dual) + if j_b1 > j_b2: + lines[p][i_b+1], lines[p][i_b+2] = lines[p][i_b+2], lines[p][i_b+1] + self.diff_4_swap(i_b, i_bn, lines, p) + return lines + + def diff_4_swap(self, i_b, i_bn, lines, p): + if i_bn - i_b == 4: + if lines[p][i_b + 1].dual._stage == lines[p][i_b + 2].dual._stage and lines[p][i_b + 2].dual._stage == \ + lines[p][i_b + 3].dual._stage: + if lines[p][i_b + 1]._type == 's' and lines[p][i_b + 2]._type == 's' and lines[p][i_b + 3]._type == 'r': + lines[p][i_b + 1], lines[p][i_b + 2] = lines[p][i_b + 2], lines[p][i_b + 1] + + def vpp_send_delay(self, lines) -> list[list[BlockSim]]: + if self.micro_num % self.pp != 0: + return lines + for p in range(self.pp): + for b, block in enumerate(self.blocks[p]): + if block.send_block is not None and block.phase == 'stable': + self._send_block_delay(lines, p, b, 1) + return lines + + def residue_delay(self, lines) -> list[list[BlockSim]]: + r = self.micro_num % self.pp + if not r: + return lines + for p in range(self.pp): + for b, block in enumerate(self.blocks[p]): + if block.send_block is None: + continue + if p == self.pp - 1 and block._id < self.pp + r and block._state == 'f': + self._send_block_delay(lines, -1, b, r + max(0, block._id - self.pp + 1)) + elif p == 0 and block._id < self.pp + r and block._state == 'b': + if self.micro_num // self.pp == 1: + self._send_block_delay(lines, 0, b, r) + else: + self._send_block_delay(lines, 0, b, r + self.pp) + elif block.phase == 'stable': + self._send_block_delay(lines, p, b, 1) + return lines + +def test_comm_imba_zero(): + for pp in range(2, 10): + for m in range(pp, 3*pp+1): # check vp=1 imba + sim = PipelineSimulator(np.ones(pp).tolist(), m, comm_time=0.1).run(print_info=False) + if sim.bubbles['imba'] > 0.001: + sim.print_info() + for vp in range(2, 6): # check vp>1 imba + for m in range(pp, 4*pp+1, pp): + sim = PipelineSimulator(np.ones((vp, pp)).tolist(), m, comm_time=0.1).run(print_info=False) + if sim.bubbles['imba'] > 0.001: + sim.print_info() + +''' +当前缺陷: + 1. 开comm与vpp时, pp要求整除micro, 否则可能成环; +''' + +if __name__ == '__main__': + layer_dis = [[2*0.6, 1.6, 1, 3, 2, 3, 4, 2], [6, 6, 6, 4, 6, 5, 5, 5+1.5]] + recompute = [[2*0.6, 2, 1, 3, 2, 3, 4, 2], [6, 6, 6, 4, 6, 4, 5, 0]] + PipelineSimulator(layer_dis, 128, 0.07, recompute).run(False).show(False, False) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_util.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5c10871fcf16caedc907b37f4b072d42d63db237 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/pp_util.py @@ -0,0 +1,521 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""pp utility""" +import numpy as np +import time +import re +import sys +import subprocess +import argparse +import chardet +import highspy +import yaml +import ast +import os +import shlex + +from fast_tuner.utils.logger import logger + +pipeline_output_file = 'pipeline_output' +dryrun_yaml_dir = 'dryrun_yaml' +dryrun_shell_dir = 'dryrun_shell' +dryrun_error = 'Dryrun failed, please check the mindspore environment!' + +def update_yaml_value(yaml_file, key, value): + with open(yaml_file, 'r', encoding='utf-8') as file: + yaml_data = yaml.safe_load(file) + if key in yaml_data: + print(f"find the {key}, update the context") + else: + print(f"can not find {key}") + yaml_data[key] = value + with open(yaml_file, 'w', encoding='utf-8') as file: + yaml.dump(yaml_data, file, default_flow_style=False) + + +def highs_solve_mps(mps_file, solution_file, origin_model, time_limit): + highs_solver = highspy.Highs() + highs_solver.readModel(mps_file) + if time_limit != sys.maxsize: + highs_solver.setOptionValue('time_limit', time_limit) + # highs_solver.setOptionValue('threads', 8) + + highs_solver.run() + + info = highs_solver.getInfo() + origin_model.solution_status = highs_solver.solutionStatusToString(info.primal_solution_status) + logger.info(f'Solution status = {origin_model.solution_status}') + origin_model.model_status = highs_solver.modelStatusToString(highs_solver.getModelStatus()) + logger.info(f'Model status = {origin_model.model_status}') + origin_model.run_time = highs_solver.getRunTime() + logger.info(f'Run time = {origin_model.run_time}') + origin_model.gap = info.mip_gap + logger.info(f'Gap = {origin_model.gap * 100:.2f}%') + logger.info(f'Optimal objective = {info.objective_function_value}') + origin_model.min_time = info.objective_function_value + if origin_model.model_status == 'Time limit reached' and origin_model.gap > 0: + logger.info(f'Dual bound = {info.mip_dual_bound}') + if origin_model.solution_status == 'None': + if origin_model.model_status == 'Infeasible': + logger.error(f'{mps_file} is Infeasible, Please check the memory limit!') + elif origin_model.model_status == 'Time limit reached': + logger.error(f'{mps_file} is not finished!, Please check the time limit!') + else: + logger.error(f'{mps_file} is no solution, model_status = {origin_model.model_status}!') + highs_solver.writeSolution(solution_file, 0) + logger.info(f'Writing the solution to {solution_file}') + + +def qiuqi_solver_mps(solver_file, model_file, solution_file_name, origin_model): + command0 = ['chmod', '+x', solver_file] + subprocess.run(command0) + command1 = [solver_file, '-Help'] + subprocess.run(command1, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + command2 = [solver_file, f'-SolutionFile={solution_file_name}', model_file] + subprocess.run(command2) + logger.info(f'Writing the solution to {solution_file_name}') + with open(solution_file_name, 'r', encoding='utf-8') as file: + content = file.read() + result_content = re.search(r'# Objective: (\d+\.\d+)', content) + objective_function_value = result_content.group(1) + logger.info(f'Optimal objective = {objective_function_value}') + origin_model.min_time = objective_function_value + + +def write_config_to_yaml(recompute_config, offset, yaml_file): + with open(yaml_file, 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + data['recompute_config'] = recompute_config + data['model']['model_config']['offset'] = offset + with open(yaml_file, 'w', encoding='utf-8') as file: + yaml.dump(data, file, default_flow_style=False, indent=4) + + +def write_config_to_shell(offset, shell_file): + configs, unparses = parse_shell(shell_file) + layer_num = configs.get('NUM_LAYERS') // configs.get('PP') + layer_list = flatten_to_str(offset, layer_num) + configs['NUM_LAYERS_LIST'] = layer_list + configs_to_shell(shell_file, configs, unparses) + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def str2list(v): + if isinstance(v, list): + return v + if isinstance(ast.literal_eval(v), list): + return ast.literal_eval(v) + else: + raise argparse.ArgumentTypeError('List value expected.') + + +def str2int(v): + if isinstance(v, int): + return v + if isinstance(int(v), int): + return int(v) + else: + raise argparse.ArgumentTypeError('Int value expected.') + + +def str2dict(v): + if isinstance(v, dict): + return v + if isinstance(ast.literal_eval(v), dict): + return ast.literal_eval(v) + else: + raise argparse.ArgumentTypeError('Dict value expected.') + + + +def build_new_config_yaml(args): + with open(args.yaml, 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + if args.offset: + data['model']['model_config']['offset'] = args.offset + logger.info(f'change offset to {args.offset}') + recompute_config = build_recompute_config(args.is_select_recompute, args.is_recompute, + args.select_recompute_layers, args.recompute_layers) + if args.is_select_recompute is None: + logger.info(f'is_select_recompute is None') + else: + data['recompute_config']['select_recompute'] = recompute_config['select_recompute'] + data['recompute_config']['select_comm_recompute'] = recompute_config['select_comm_recompute'] + logger.info(f'change is_select_recompute to {args.is_select_recompute}') + logger.info(f'change select_recompute_layers to {args.select_recompute_layers}') + if args.is_recompute is None: + logger.info(f'is_recompute is None') + else: + data['recompute_config']['recompute'] = recompute_config['recompute'] + logger.info(f'change is_recompute to {args.is_recompute}') + logger.info(f'change recompute_layers to {args.recompute_layers}') + file_name, file_ext = os.path.splitext(os.path.basename(args.yaml)) + yaml_path = os.path.dirname(os.path.abspath(args.yaml)) + timestamp = time.strftime("%Y%m%d%H%M%S") + new_yaml = os.path.join(yaml_path, f'{file_name}_{timestamp}.yaml') + with open(new_yaml, 'w', encoding='utf-8') as file: + yaml.dump(data, file, default_flow_style=False, indent=4) + logger.info(f'build new yaml {new_yaml}') + return new_yaml + + +def build_new_config_shell(args): + configs, unparses = parse_shell(args.shell) + layer_num = configs.get('NUM_LAYERS') // configs.get('PP') + layer_list = flatten_to_str(args.offset, layer_num) + configs['NUM_LAYERS_LIST'] = layer_list + logger.info(f'change offset to {args.offset}') + file_name, file_ext = os.path.splitext(os.path.basename(args.shell)) + shell_path = os.path.dirname(os.path.abspath(args.shell)) + timestamp = time.strftime("%Y%m%d%H%M%S") + new_shell = os.path.join(shell_path, f'{file_name}_{timestamp}.sh') + configs_to_shell(new_shell, configs, unparses) + logger.info(f'build new shell {new_shell}') + return new_shell + + +def build_recompute_config(is_select_recompute, is_recompute, select_recompute_layers, recompute_layers): + recompute_config = { + 'parallel_optimizer_comm_recompute': False, + 'mp_comm_recompute': True, + 'recompute_slice_activation': True + } + + if is_select_recompute: + # 选择对应层的算子进行重计算 + recompute_config['select_recompute'] = {} + recompute_config['select_recompute'][r'feed_forward\.null'] = select_recompute_layers + recompute_config['select_recompute'][r'feed_forward\.w1\.activation\.silu'] = select_recompute_layers + recompute_config['select_recompute'][r'feed_forward\.w1\.reshape'] = select_recompute_layers + recompute_config['select_recompute'][r'feed_forward\.w2\.reshape'] = select_recompute_layers + recompute_config['select_recompute'][r'add'] = select_recompute_layers + recompute_config['select_recompute'][r'cast_up'] = select_recompute_layers + # 选择对应层的算子进行通信重计算 + recompute_config['select_comm_recompute'] = {} + recompute_config['select_comm_recompute'][r'.*\.norm'] = select_recompute_layers + recompute_config['select_comm_recompute'][r'attention\.wq\.reshape'] = select_recompute_layers + recompute_config['select_comm_recompute'][r'attention\.wk\.reshape'] = select_recompute_layers + recompute_config['select_comm_recompute'][r'feed_forward\.w1\.reshape'] = select_recompute_layers + recompute_config['select_comm_recompute'][r'feed_forward\.w3\.reshape'] = select_recompute_layers + elif is_select_recompute is False: + recompute_config['select_recompute'] = False + recompute_config['select_comm_recompute'] = False + + if is_recompute: + recompute_config['recompute'] = recompute_layers + elif is_recompute is False: + recompute_config['recompute'] = False + + return recompute_config + + +def construct_distribution(num_micro, num_stage): + parts, remainder = num_micro // num_stage, num_micro % num_stage + distribution = [] + for part in range(parts): + if part == 0: + distribution.append(num_stage + remainder) + else: + distribution.append(num_stage) + return distribution + + +def sort_micro(parts, num_vpp, num_stage, distribution, low_mem, seq_split, is_f_then_b: bool = False): + forward = [] + backward = [] + final_orders = [] + for part in range(parts): + for vpp in range(num_vpp): + for micro_id in range(distribution[part]): + for split in range(seq_split): + forward.append((part, vpp, 'f', micro_id, split)) + for vpp in range(num_vpp - 1, -1, -1): + for micro_id in range(distribution[part]): + for split in range(seq_split - 1, -1, -1): + backward.append((part, vpp, 'b', micro_id, split)) + # f-then-b的调度规则,待启用 + if is_f_then_b: + for stage in range(num_stage): + stage_order = [] + for part in range(parts): + for micro_id in range(distribution[part]): + stage_order.append((part, 0, 'f', micro_id, 0)) + for micro_id in range(distribution[part]): + stage_order.append((part, 0, 'b', micro_id, 0)) + final_orders.append(stage_order) + return final_orders + + for stage in range(num_stage): + if low_mem: + warmup = min(((num_vpp - 1) * distribution[0] + (num_stage - stage - 1)) * seq_split, len(forward)) + else: + warmup = min(((num_vpp - 1) * distribution[0] + (num_stage - stage - 1) * 2) * seq_split, len(forward)) + # 最后一个stage,第一个micro前向做完之后才能做后向 + if stage == num_stage - 1: + warmup = warmup + seq_split - 1 + stage_order = [] + stage_order += forward[: warmup] + for i in range(warmup, len(forward)): + stage_order.append(forward[i]) + stage_order.append(backward[i - warmup]) + stage_order += backward[len(forward) - warmup:] + final_orders.append(stage_order) + return final_orders + + +def get_init_input_peak_mem(): + return [] + + +def get_state_input_peak_mem(): + return [] + + +def get_ranks_stages(yaml_file): + with open(yaml_file, 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + pipeline_stage = data['parallel_config']['pipeline_stage'] + model_parallel = data['parallel_config']['model_parallel'] + data_parallel = data['parallel_config']['data_parallel'] + rank_size = pipeline_stage * model_parallel * data_parallel + return rank_size, pipeline_stage + + +def get_shell_ranks_stages(shell_file): + configs, unparses = parse_shell(shell_file) + pipeline_stage = configs.get('PP') + model_parallel = configs.get('TP') + data_parallel = configs.get('DP') + rank_size = pipeline_stage * model_parallel * data_parallel + return rank_size, pipeline_stage + + +def get_layers_distribution(offset, num_layers, num_stages, num_vpp): + layers_stage_vpp = [[0 for _ in range(num_vpp)] for _ in range(num_stages)] + for stage in range(num_stages): + for vpp in range(num_vpp): + layers_stage_vpp[stage][vpp] = offset[vpp][stage] + num_layers // (num_stages * num_vpp) + return layers_stage_vpp + + +def find_most_times_stage(layer_num_of_stage): + # key 为各个stage的layer数,value 为layer数对应的stage编号List + frequency_dict = {} + layer_num = layer_num_of_stage[0] + max_time = 0 + for i in range(len(layer_num_of_stage)): + if layer_num_of_stage[i] not in frequency_dict: + frequency_dict[layer_num_of_stage[i]] = [] + frequency_dict[layer_num_of_stage[i]].append(i) + if len(frequency_dict[layer_num_of_stage[i]]) > max_time: + max_time = len((frequency_dict[layer_num_of_stage[i]])) + layer_num = layer_num_of_stage[i] + return frequency_dict[layer_num] + + +def extract_peak_memory(file_path): + with open(file_path, 'rb') as file: + raw_data = file.read() + result = chardet.detect(raw_data) + encoding = result['encoding'] + with open(file_path, 'r', encoding=encoding) as file: + content = file.read() + result_content = re.search(r'Used peak memory usage \(without fragments\):\s*(\d+)M', content) + if result_content: + return int(result_content.group(1)) + else: + raise ValueError(dryrun_error) + + +def extract_actual_peak_memory(file_path): + with open(file_path, 'rb') as file: + raw_data = file.read() + result = chardet.detect(raw_data) + encoding = result['encoding'] + with open(file_path, 'r', encoding=encoding) as file: + content = file.read() + result_content = re.search( + r'Actual peak memory usage \(with fragments\):\s*(\d+)M', content) + if result_content: + return int(result_content.group(1)) + else: + raise ValueError(dryrun_error) + + +def get_peak_batch(micro_batch_num, num_stage, num_vpp, low_mem, offset, num_layers): + distribution = construct_distribution(micro_batch_num, num_stage) + final_orders = sort_micro(micro_batch_num // num_stage, num_vpp, num_stage, distribution, low_mem, 1) + layers_stage_vpp = get_layers_distribution(offset, num_layers, num_stage, num_vpp) + # 求峰值内存时的激活份数 + peaks = [0] * num_stage + for stage in range(num_stage): + cur_mem = 0 + for micro in final_orders[stage]: + if micro[2] == 'f': + cur_mem += layers_stage_vpp[stage][micro[1]] + else: + cur_mem -= layers_stage_vpp[stage][micro[1]] + peaks[stage] = max(peaks[stage], cur_mem) + return peaks + + +def build_coe_array(peak_num_act, peak_num_select_recom, peak_num_recompute, x, stage_range): + num_stage = len(peak_num_act) + layer_dis = [sum(x[vpp][stage] for vpp in range(1)) for stage in range(num_stage)] + coe_a = np.empty((stage_range[-1] - stage_range[0], 5), float) + for stage in range(stage_range[0], stage_range[-1]): + coe_a[stage - stage_range[0]][0] = 1 + coe_a[stage - stage_range[0]][1] = peak_num_act[stage] + coe_a[stage - stage_range[0]][2] = layer_dis[stage] + coe_a[stage - stage_range[0]][3] = peak_num_recompute[stage] + coe_a[stage - stage_range[0]][4] = peak_num_select_recom[stage] + return coe_a + + +def bulid_yaml(old_yaml_file, recompute_config, offset, num_layers, num_vpp, num_stage, dense_layers, micro): + with open(old_yaml_file, 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + data['recompute_config'] = recompute_config + data['model']['model_config']['offset'] = offset + if 'mtp_depth' in data['model']['model_config']: + mtp_depth = data['model']['model_config']['mtp_depth'] + num_layers -= mtp_depth + data['model']['model_config']['num_layers'] = num_layers + if 'moe_config' in data: + if 'first_k_dense_replace' in data['moe_config']: + data['moe_config']['first_k_dense_replace'] = dense_layers + data['parallel_config']['pipeline_stage'] = num_stage + data['parallel_config']['micro_batch_num'] = micro + if 'pp_interleave_num' in data['model']['model_config']: + data['model']['model_config']['pp_interleave_num'] = num_vpp + pipeline_output = os.path.join(os.getcwd(), pipeline_output_file) + if not os.path.exists(pipeline_output): + os.mkdir(pipeline_output) + new_file = os.path.join(pipeline_output, dryrun_yaml_dir) + if not os.path.exists(new_file): + os.mkdir(new_file) + timestamp = time.strftime("%Y%m%d%H%M%S") + new_yaml_name = os.path.join(new_file, f'{timestamp}.yaml') + with open(new_yaml_name, 'w', encoding='utf-8') as file: + yaml.dump(data, file, default_flow_style=False, indent=4) + return new_yaml_name + + +def bulid_shell(old_shell_file, offset, num_layers, num_vpp, num_stage, dense_layers, micro): + configs, unparses = parse_shell(old_shell_file) + layer_num = num_layers//num_stage + layer_list = flatten_to_str(offset, layer_num) + configs['NUM_LAYERS_LIST'] = layer_list + # 内存不够,重计算全开 + if 'RECOMPUTE_NUM_LAYERS' in configs: + configs['RECOMPUTE_NUM_LAYERS'] = max(layer_list) + if 'FIRST_K_DENSE_REPLACE' in configs: + configs['FIRST_K_DENSE_REPLACE'] = dense_layers + configs['PP'] = num_stage + configs['VPP'] = num_vpp + configs['MBS'] = 1 + configs['GBS'] = micro * configs.get('MBS') * configs.get('DP') + configs['NUM_LAYERS'] = num_layers + pipeline_output = os.path.join(os.getcwd(), pipeline_output_file) + if not os.path.exists(pipeline_output): + os.mkdir(pipeline_output) + new_file = os.path.join(pipeline_output, dryrun_shell_dir) + if not os.path.exists(new_file): + os.mkdir(new_file) + timestamp = time.strftime("%Y%m%d%H%M%S") + new_shell_name = os.path.join(new_file, f'{timestamp}.sh') + configs_to_shell(new_shell_name, configs, unparses) + return new_shell_name + + +def parse_shell(shell_file): + with open(shell_file, 'r', encoding='utf-8') as f: + content = f.read() + lexer = shlex.shlex(content, posix=True) + lexer.whitespace_split = True + lexer.whitespace = '\n' + lexer.escape = '' + configs = {} + unparses = '' + for token in lexer: + if '=' in token: + key, value = token.split('=', 1) + key = key.strip() + try: + value = int(value) + except ValueError: + pass + configs[key] = value + else: + unparses += token + '\n' + return configs, unparses + + +def parse_shell_config(config_value): + parts = config_value.split('--') + paras = {} + for part in parts: + part_split = part.strip().split(maxsplit=1) + if part_split: + key = part_split[0].strip() + value = part_split[1].strip(' \n\\') if len(part_split)>1 else '' + try: + value = int(value) + except ValueError: + pass + paras[key] = value + return paras + + +def change_shell_config(configs_dict, config, para, value): + config_value = configs_dict[config] + parse_config_value = parse_shell_config(config_value) + parse_config_value[para] = value + content = '\n' + for key, value in parse_config_value.items(): + content += f' --{key} {value} \\\n' + configs_dict[config] = content + + +def configs_to_shell(shell_name, configs, unparses): + with open(shell_name, 'w', encoding='utf-8') as f: + for key, value in configs.items(): + if isinstance(value, int) or value[0] == '$': + f.write(f'{key}={value}\n') + else: + f.write(f'{key}="{value}"\n') + f.write(unparses) + + +def flatten_to_str(lst, num=0): + items = [] + for item in lst: + if isinstance(item, list): + items.append(flatten_to_str(item, num)) + elif isinstance(item, int): + items.append(str(item + num)) + else: + item.append(str(item)) + return ','.join(items) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/result_csv.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/result_csv.py new file mode 100644 index 0000000000000000000000000000000000000000..1555002400d3c65f03a18b8b3a5194864a4d3337 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/result_csv.py @@ -0,0 +1,159 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""generate result to csv""" + +import os +import time +import csv +import yaml + +from fast_tuner.utils.logger import logger +from fast_tuner.pipeline_conductor import pp_util +from fast_tuner.pipeline_conductor.dryrun import DryRun, dryrun_config_error + +class ResultCsv: + """ + generate result csv + """ + def __init__(self, output_path, output_file='pipeline_output', name='test_result'): + self.header = ['test', 'layers', 'micro', 'dp', 'tp', 'pp', 'ep', 'vp', 'op/fsdp', 'dense:moe', '反向:正向', + '重计算增加比率', 'mtp+head', 'moe时长', 'low_mem', '目标值', 'cost', 'x', 'offset', + 'ra', '内存信息', '内存上限(GB)', '求解器', 'GAP', 'solution_status', 'model_status', + '求解耗时/s', 'dryrun_check'] + self.name = name + timestamp = time.strftime("%Y%m%d%H%M%S") + csv_dir = os.path.join(os.path.abspath(output_path), output_file) + if not os.path.exists(csv_dir): + os.mkdir(csv_dir) + self.path = os.path.join(csv_dir, f'{self.name}_{timestamp}.csv') + self.create_csv_file() + + def create_csv_file(self): + """ + create csv file + """ + with open(self.path, 'w', encoding='utf-8-sig') as file: + header = ','.join(self.header) + '\n' + file.write(header) + logger.info (f'Successfully created {self.path}') + + def config_to_csv(self, candidate, low_mem, solver_name): + new_row = ['']*len(self.header) + h = self.header + new_row[h.index('test')] = candidate.config_path + new_row[h.index('求解器')] = solver_name + new_row[h.index('low_mem')] = low_mem + new_row[h.index('dense:moe')] = candidate.profiling_info.dmratio + new_row[h.index('反向:正向')] = candidate.profiling_info.bfratio + new_row[h.index('重计算增加比率')] = candidate.profiling_info.re_grow_ration + new_row[h.index('mtp+head')] = candidate.profiling_info.hratio + new_row[h.index('moe时长')] = candidate.profiling_info.moe_fw + if DryRun.config_file_type == 0: + self.yaml_to_row(candidate.config_path, new_row) + elif DryRun.config_file_type == 1: + self.shell_to_row(candidate.config_path, new_row) + elif DryRun.config_file_type == 2: + self.toml_to_row(candidate, new_row) + else: + raise TypeError(dryrun_config_error) + with open(self.path, 'a', newline='', encoding='utf-8-sig') as file: + writer = csv.writer(file) + writer.writerows([new_row]) + + def yaml_to_row(self, yaml_file, row): + """ + trans yaml info to row of csv + """ + with open(yaml_file, 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + h = self.header + row[h.index('layers')] = data['model']['model_config']['num_layers'] + if 'mtp_depth' in data['model']['model_config']: + row[h.index('layers')] += data['model']['model_config']['mtp_depth'] + row[h.index('micro')] = data['parallel_config']['micro_batch_num'] + row[h.index('dp')] = data['parallel_config']['data_parallel'] + row[h.index('tp')] = data['parallel_config']['model_parallel'] + row[h.index('pp')] = data['parallel_config']['pipeline_stage'] + if 'expert_parallel' in data['parallel_config']: + row[h.index('ep')] = data['parallel_config']['expert_parallel'] + else: + row[h.index('ep')] = 1 + if 'pp_interleave_num' in data['model']['model_config']: + row[h.index('vp')] = data['model']['model_config']['pp_interleave_num'] + else: + row[h.index('vp')] = 1 + + def shell_to_row(self, shell_file, row): + h = self.header + configs, _ = pp_util.parse_shell(shell_file) + row[h.index('layers')] = configs.get('NUM_LAYERS') + row[h.index('micro')] = configs.get('GBS') // (configs.get('MBS') * configs.get('DP')) + row[h.index('dp')] = configs.get('DP') + row[h.index('tp')] = configs.get('TP') + row[h.index('pp')] = configs.get('PP') + row[h.index('ep')] = configs.get('EP', 1) + row[h.index('vp')] = configs.get('VPP', 1) + + def toml_to_row(self, candidate, row): + """ + trans toml info to row of csv + """ + model_args = candidate.model_args + h = self.header + row[h.index('layers')] = model_args.num_layers + row[h.index('micro')] = model_args.mbn + row[h.index('dp')] = model_args.dp + row[h.index('tp')] = model_args.tp + row[h.index('pp')] = model_args.pp + row[h.index('ep')] = model_args.ep + row[h.index('vp')] = 1 + row[h.index('op/fsdp')] = model_args.dp_shard_degree + + def result_to_csv(self, solution): + """ + fill solution result to csv + """ + row_cost =[] + row_no_cost = [] + with open(self.path, 'r', newline='', encoding='utf-8-sig') as file: + reader = csv.reader(file) + h = self.header + for row in reader: + if row[h.index('test')] == str(solution.init_config.config_file): + row[h.index('内存信息')] = solution.init_config.memory.get_mem() + row[h.index('内存上限(GB)')] = solution.init_config.memory.mem_lim / 1024 + row[h.index('求解耗时/s')] = solution.run_time + row[h.index('solution_status')] = solution.solution_status + row[h.index('model_status')] = solution.model_status + if solution.solution_status != 'None': + row[h.index('目标值')] = solution.object_value + row[h.index('x')] = solution.layer_dis.tolist() + row[h.index('offset')] = solution.offset.tolist() + row[h.index('ra')] = solution.ra_dis.tolist() + row[h.index('GAP')] = solution.gap + row[h.index('cost')] = solution.object_value * float(row[h.index('moe时长')]) + row[h.index('dryrun_check')] = solution.check_peak_mem + if row[h.index('cost')]: + row_cost.append(row) + else: + row_no_cost.append(row) + sorted_rows = sorted(row_cost[1:], key=lambda x: float(x[self.header.index('cost')])) + sorted_rows = [self.header] + sorted_rows + row_no_cost + with open(self.path, 'w', newline='', encoding='utf-8-sig') as file: + writer = csv.writer(file) + writer.writerows(sorted_rows) + +if __name__ == '__main__': + csv = ResultCsv('./output') diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/solution.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/solution.py new file mode 100644 index 0000000000000000000000000000000000000000..e1f79440dcb24abbd2960b3ecc10d979c25c01ea --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/solution.py @@ -0,0 +1,256 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +This Module consists a couple of utility functions +for solution. Including parser, peak memory calculation +and validity check +''' + +import itertools +import re +import numpy as np +from fast_tuner.pipeline_conductor.start_service import InitConfig +from fast_tuner.pipeline_conductor.math_model import Model +from fast_tuner.pipeline_conductor import micro +from fast_tuner.pipeline_conductor.start_service import ExpertInput +from fast_tuner.utils.logger import logger + +class Solution: + ''' + parse and record various properties of + solutions, including a few simple analysis + ''' + x_type1 = [int] + x_type2 = [int] + layer_dis = [int] + layer1_dis_stage = [int] + layer2_dis_stage = [int] + offset = [int] + indicator_type1 = [int] + indicator_type2 = [int] + rs_type1 = [int] + rs_type2 = [int] + rs_dis = [int] + ra_type1 = [int] + ra_type2 = [int] + ra_dis = [int] + forward_s = [float] + backward_s = [float] + peak_num = micro.PeakNum + check_peak_mem = [] + object_value = float + gap = float + solution_status = str + model_status = str + run_time = float + sol_file = '' + + def __init__(self, init_config: InitConfig): + self.init_config = init_config + self.x_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.x_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.offset = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.indicator_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.indicator_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.rs_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.rs_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.ra_type1 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.ra_type2 = np.empty((self.init_config.pp_interleave_num, self.init_config.pipeline_stage), int) + self.forward_s = np.empty((self.init_config.pipeline_stage, self.init_config.pp_interleave_num, + self.init_config.parts, self.init_config.micro_batch_num, + self.init_config.seq_splits), dtype=float) + self.backward_s = np.empty((self.init_config.pipeline_stage, self.init_config.pp_interleave_num, + self.init_config.parts, self.init_config.micro_batch_num, + self.init_config.seq_splits), dtype=float) + + def set_solution(self, solved_model: Model, is_origin_solver, sol_file): + ''' + parse the raw MIP solution into pipeline + parallelism layer distributions + ''' + self.sol_file = sol_file + if is_origin_solver: + for stage in range(self.init_config.pipeline_stage): + for vpp in range(self.init_config.pp_interleave_num): + self.x_type1[vpp][stage] = solved_model.x_type1[stage][vpp].solution_value() + self.x_type2[vpp][stage] = solved_model.x_type2[stage][vpp].solution_value() + self.indicator_type1[vpp][stage] = solved_model.indicator_type1[stage][vpp].solution_value() + self.indicator_type2[vpp][stage] = solved_model.indicator_type2[stage][vpp].solution_value() + self.rs_type1[vpp][stage] = solved_model.rs_type1[stage][vpp].solution_value() + self.rs_type2[vpp][stage] = solved_model.rs_type2[stage][vpp].solution_value() + self.ra_type1[vpp][stage] = solved_model.ra_type1[stage][vpp].solution_value() + self.ra_type2[vpp][stage] = solved_model.ra_type2[stage][vpp].solution_value() + for part in range(self.init_config.parts): + for micro_ind, seq in itertools.product(range(solved_model.distribution[part]), + range(self.init_config.seq_splits)): + self.forward_s[stage][vpp][part][micro_ind][seq] = ( + solved_model.forward_s[stage][vpp][part][micro_ind][seq].solution_value()) + self.backward_s[stage][vpp][part][micro_ind][seq] = ( + solved_model.backward_s[stage][vpp][part][micro_ind][seq].solution_value()) + else: + with open(sol_file, 'r', encoding='utf-8') as file: + for line in file: + content1 = re.search(r'^x_type1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content2 = re.search(r'^x_type2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content3 = re.search(r'^indicator1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content4 = re.search(r'^indicator2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content5 = re.search(r'^rs_type1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content6 = re.search(r'^rs_type2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content7 = re.search(r'^ra_type1_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content8 = re.search(r'^ra_type2_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content9 = re.search( + r'^forward_s_(\d+)_(\d+)_(\d+)_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + content10 = re.search( + r'^backward_s(\d+)_(\d+)_(\d+)_(\d+)_(\d+)\s+([-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?)', + line) + if content1: + self.x_type1[int(content1.group(2))][int(content1.group(1))] = round(float(content1.group(3))) + if content2: + self.x_type2[int(content2.group(2))][int(content2.group(1))] = round(float(content2.group(3))) + if content3: + self.indicator_type1[int(content3.group(2))][int(content3.group(1))] = round( + float(content3.group(3))) + if content4: + self.indicator_type2[int(content4.group(2))][int(content4.group(1))] = round( + float(content4.group(3))) + if content5: + self.rs_type1[int(content5.group(2))][int(content5.group(1))] = round(float(content5.group(3))) + if content6: + self.rs_type2[int(content6.group(2))][int(content6.group(1))] = round(float(content6.group(3))) + if content7: + self.ra_type1[int(content7.group(2))][int(content7.group(1))] = round(float(content7.group(3))) + if content8: + self.ra_type2[int(content8.group(2))][int(content8.group(1))] = round(float(content8.group(3))) + + if content9: + self.forward_s[int(content9.group(1))][int(content9.group(2))][int(content9.group(3))][ + int(content9.group(4))][int(content9.group(5))] = float(content9.group(6)) + if content10: + self.backward_s[int(content10.group(1))][int(content10.group(2))][int(content10.group(3))][ + int(content10.group(4))][int(content10.group(5))] = float(content10.group(6)) + self.object_value = solved_model.min_time + self.gap = solved_model.gap + self.model_status = solved_model.model_status + self.solution_status = solved_model.solution_status + self.run_time = solved_model.run_time + if self.solution_status != 'None': + self.cal_peak_mem(solved_model) + self.set_total_dis() + self.check_time_list() + + def cal_peak_mem(self, solved_model: Model): + self.peak_num = micro.PeakNum(solved_model.sort_micro) + self.peak_num.set_peak_act_recompute_num(self.x_type2, self.rs_type2, self.ra_type2, + self.x_type1, self.rs_type1, self.ra_type1, + solved_model.memory) + + def set_total_dis(self): + self.layer_dis = self.x_type1 + self.x_type2 + self.offset = (self.layer_dis - (self.init_config.num_layers_type1 + self.init_config.num_layers_type2) // + (self.init_config.pp_interleave_num * self.init_config.pipeline_stage)) + self.rs_dis = self.rs_type1 + self.rs_type2 + self.ra_dis = self.ra_type1 + self.ra_type2 + self.layer1_dis_stage = [sum(self.x_type1[v][s] for v in range(self.init_config.pp_interleave_num)) for s in + range(self.init_config.pipeline_stage)] + self.layer2_dis_stage = [sum(self.x_type2[v][s] for v in range(self.init_config.pp_interleave_num)) for s in + range(self.init_config.pipeline_stage)] + + def check_time_list(self): + ''' + inner functions for testing whether + the starting times align with exec order + ''' + s_time = [[] for _ in range(self.init_config.pipeline_stage)] + for stage in range(self.init_config.pipeline_stage): + for i in range(len(self.peak_num.sort_micro.final_orders[stage])): + micro_batch = self.peak_num.sort_micro.final_orders[stage][i] + if micro_batch.state == 'f': + s_time[stage].append(self.forward_s[stage][micro_batch.vpp][micro_batch.part][micro_batch.micro_id] + [micro_batch.split]) + if micro_batch.state == 'b': + s_time[stage].append(self.backward_s[stage][micro_batch.vpp][micro_batch.part][micro_batch.micro_id] + [micro_batch.split]) + # 各stage的micro0batch start time单调递增 + for stage in range(self.init_config.pipeline_stage): + for i in range(len(self.peak_num.sort_micro.final_orders[stage]) - 1): + if s_time[stage][i] > s_time[stage][i + 1]: + raise ValueError('Time sequence error!') + + def solution_print(self): + '''print the information of the solution''' + logger.info('layer distribution: ') + logger.info(self.layer_dis.tolist()) + logger.info('layer of type1 distribution: ') + logger.info(self.x_type1.tolist()) + logger.info('the indicator of layer1: ') + logger.info(self.indicator_type1.tolist()) + + logger.info('layer of type2 distribution: ') + logger.info(self.x_type2.tolist()) + logger.info('the indicator of layer2: ') + logger.info(self.indicator_type2.tolist()) + logger.info('the offset: ') + logger.info(self.offset.tolist()) + + logger.info('rs distribution: ') + logger.info(self.rs_dis.tolist()) + logger.info('layer of type1 rs distribution: ') + logger.info(self.rs_type1.tolist()) + logger.info('layer of type2 rs distribution: ') + logger.info(self.rs_type2.tolist()) + + logger.info('ra distribution: ') + logger.info(self.ra_dis.tolist()) + logger.info('layer of type1 ra distribution: ') + logger.info(self.ra_type1.tolist()) + logger.info('layer of type2 ra distribution: ') + logger.info(self.ra_type2.tolist()) + + logger.info('the peak memory:') + logger.info(list(self.peak_num.max_mem.values())) + logger.info('the number of micro batch when peak memory') + logger.info(list(self.peak_num.micro_num_of_max_mem.values())) + + logger.info('the number of activations of the type of layer1') + logger.info(list(self.peak_num.peak_num_act_type1.values())) + logger.info('the number of select recomputes of the type of layer1') + logger.info(list(self.peak_num.peak_num_select_recom_type1.values())) + logger.info('the number of full recomputes of the type of layer1') + logger.info(list(self.peak_num.peak_num_recompute_type1.values())) + + logger.info('the number of activations of the type of layer2') + logger.info(list(self.peak_num.peak_num_act_type2.values())) + logger.info('the number of select recomputes of the type of layer2') + logger.info(list(self.peak_num.peak_num_select_recom_type2.values())) + logger.info('the number of full recomputes of the type of layer2') + logger.info(list(self.peak_num.peak_num_recompute_type2.values())) + +def extract_solution_file(yaml_file, sol_file): + expert_input = ExpertInput(yaml_file, '') + expert_input.is_dryrun = False + init_config = InitConfig(expert_input) + origin_model = Model(init_config) + file_solution = Solution(init_config) + file_solution.set_solution(origin_model, False, sol_file) + file_solution.solution_print() diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/start_service.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/start_service.py new file mode 100644 index 0000000000000000000000000000000000000000..8772d29cadc9323bd1693e30515243a908a8e833 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_conductor/start_service.py @@ -0,0 +1,487 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""start service""" +import os +import re +import numpy as np +import yaml + +from fast_tuner.utils.logger import logger +from fast_tuner.pipeline_conductor import dryrun, pp_util +from fast_tuner.pipeline_conductor import micro +from fast_tuner.pipeline_conductor.memory import Memory + +pipeline_output_file = 'pipeline_output' +init_dryrun_file = 'init_dryrun' +double_check_dryrun_filename = 'double_check_dryrun' +HIGHS_NAME = 'HIGHS' +default_low_mem = False +default_time_limit = 1e10 +# 0:deepseek模型;1:boss模型 +model_class = 0 + + +# 专家输入:专家可根据环境变化更改 +class ExpertInput: + config_file = '' + ms_adapter_file = '' + model_args = None + + solver_name = HIGHS_NAME + llm_class = model_class + time_limit = int(default_time_limit) + + is_dryrun = True + is_double_check = False + fit_level = 0 + + low_mem = default_low_mem + layer_ratio = 0.3 + backward_ratio = 2 + srRatio = 0.33 + is_support_ra_with_rs = False + is_select_recomp = False + is_full_recomp = True + + is_head_loss_input = True + head_loss = 1.493 + recompute_ratio = 0.246 + + output_file = pipeline_output_file + output_file_dir = os.path.join(os.getcwd(), output_file) + double_check_dryrun_filename = double_check_dryrun_filename + + def __init__(self, config_file, ms_adapter_file): + if not os.path.exists(self.output_file_dir): + os.mkdir(self.output_file_dir) + self.config_file = config_file + self.ms_adapter_file = ms_adapter_file + + +class InitDryrun: + layers_stage_vpp = [] + init_offset = [] + recompute_config = {} + parts = int + distribution = [] + x_type1 = [] + x_type2 = [] + rs_type1 = [] + rs_type2 = [] + ra_type1 = [] + ra_type2 = [] + dense_layers = None + + def __init__(self, data_parallel, model_parallel, expert_input: ExpertInput): + if expert_input.llm_class == 0: + self.init_stages = 16 + self.init_layers = self.init_stages + 2 + self.init_micro = 32 + self.parts = self.init_micro // self.init_stages + self.distribution = pp_util.construct_distribution(self.init_micro, self.init_stages) + self.init_vpp = 1 + self.dense_layers = 9 + self.deepseek_build_offset() + self.layers_stage_vpp = pp_util.get_layers_distribution(self.init_offset, self.init_layers, + self.init_stages, + self.init_vpp) + self.recompute_config = self.construct_rec_config_for_deepseek() + elif expert_input.llm_class == 1: + self.init_stages = 13 + self.init_layers = 72 + self.init_micro = self.init_stages * 3 + self.parts = self.init_micro // self.init_stages + self.distribution = pp_util.construct_distribution(self.init_micro, self.init_stages) + self.init_vpp = 1 + self.boss_build_offset() + self.layers_stage_vpp = pp_util.get_layers_distribution(self.init_offset, self.init_layers, + self.init_stages, + self.init_vpp) + self.x_type1 = np.zeros((self.init_vpp, self.init_stages), int) + self.x_type2 = np.zeros((self.init_vpp, self.init_stages), int) + for stage in range(0, 7): + self.x_type1[0][stage] = self.layers_stage_vpp[stage][0] + for stage in range(7, 13): + self.x_type2[0][stage] = self.layers_stage_vpp[stage][0] + self.recompute_config = self.construct_rec_config_for_boss() + else: + raise ValueError(f'can not support the model class number of {expert_input.llm_class}!') + + self.rank_size = data_parallel * model_parallel * self.init_stages + + self.config_file = expert_input.config_file + self.ms_adapter_file = expert_input.ms_adapter_file + self.dryrun_output = init_dryrun_file + + def deepseek_build_offset(self): + self.init_offset = np.zeros((self.init_vpp, self.init_stages), int).tolist() + self.init_offset[0][7] = 1 + self.init_offset[0][14] = 1 + + def boss_build_offset(self): + self.init_offset = [[1, 2, -1, 1, 0, 2, -4, 2, -1, 1, 1, 2, 1]] + + def construct_rec_config_for_boss(self): + is_select_recompute = True + is_re_comp = True + self.rs_type1 = np.zeros((self.init_vpp, self.init_stages), int) + self.rs_type2 = np.zeros((self.init_vpp, self.init_stages), int) + self.ra_type1 = np.zeros((self.init_vpp, self.init_stages), int) + self.ra_type2 = np.zeros((self.init_vpp, self.init_stages), int) + + select_recompute_layers = [[0, 0, 0, 0, 1, 4, 0, 1, 1, 0, 2, 4, 0]] + recompute_layers = [[6, 4, 2, 3, 0, 0, 0, 4, 2, 3, 0, 0, 6]] + for stage in range(0, 7): + self.rs_type1[0][stage] = select_recompute_layers[0][stage] + self.ra_type1[0][stage] = recompute_layers[0][stage] + for stage in range(7, 13): + self.rs_type2[0][stage] = select_recompute_layers[0][stage] + self.ra_type2[0][stage] = recompute_layers[0][stage] + + recompute_config = pp_util.build_recompute_config(is_select_recompute, is_re_comp, select_recompute_layers, + recompute_layers) + return recompute_config + + def construct_rec_config_for_deepseek(self): + is_select_recompute = True + is_re_comp = True + select_recompute_layers = np.zeros((self.init_vpp, self.init_stages), int).tolist() + select_recompute_layers[0][3] = self.layers_stage_vpp[3][0] + select_recompute_layers[0][4] = self.layers_stage_vpp[4][0] + select_recompute_layers[0][10] = self.layers_stage_vpp[10][0] + select_recompute_layers[0][11] = self.layers_stage_vpp[11][0] + + recompute_layers = np.zeros((self.init_vpp, self.init_stages), int).tolist() + recompute_layers[0][0] = self.layers_stage_vpp[0][0] + recompute_layers[0][1] = self.layers_stage_vpp[1][0] + recompute_layers[0][2] = self.layers_stage_vpp[2][0] + recompute_layers[0][7] = self.layers_stage_vpp[7][0] + recompute_layers[0][8] = self.layers_stage_vpp[8][0] + recompute_layers[0][9] = self.layers_stage_vpp[9][0] + recompute_layers[0][14] = self.layers_stage_vpp[14][0] + + recompute_layers[0][15] = self.layers_stage_vpp[15][0] + + recompute_config = pp_util.build_recompute_config(is_select_recompute, is_re_comp, select_recompute_layers, + recompute_layers) + return recompute_config + + def init_dryrun(self): + dry_run = dryrun.DryRun(self.config_file, self.ms_adapter_file, self.dryrun_output) + dry_run.start_dryrun(self.recompute_config, self.init_offset, self.init_layers, self.init_vpp, self.init_stages, + self.rank_size, self.dense_layers, self.init_micro) + peak_mem = [dry_run.extract_memory_info(self.init_stages), dry_run.extract_memory_info_act(self.init_stages)] + return peak_mem + + +class InitConfig: + pipeline_stage = int + micro_batch_num = int + model_parallel = int + data_parallel = int + expert_parallel = int + rank_size = int + num_layers_type1 = int + num_layers_type2 = int + pp_interleave_num = int(1) + seq_length = int + hidden_size = int + intermediate_size = int + vocab_size = int + mem_lim = float + seq_splits = int(1) + parts = int + mps_sol_filename = '' + + def __init__(self, expert_input: ExpertInput): + self.expert_input = expert_input + self.config_file = expert_input.config_file + self.ms_adapter_file = expert_input.ms_adapter_file + if dryrun.DryRun.config_file_type == 0: + self.get_yaml_config() + elif dryrun.DryRun.config_file_type == 1: + self.get_shell_config() + elif dryrun.DryRun.config_file_type == 2: + self.get_toml_config() + else: + raise TypeError(dryrun.dryrun_config_error) + self.rank_size = self.pipeline_stage * self.model_parallel * self.data_parallel + self.parts = self.micro_batch_num // self.pipeline_stage + if self.parts == 0: + raise ValueError(f'stage = {self.pipeline_stage} is greater than micro batch = {self.micro_batch_num}!, ' + f'please check the config file!') + self.mps_sol_filename = (f'layers{self.num_layers_type1}_{self.num_layers_type2}_micro{self.micro_batch_num}_' + f'dp{self.data_parallel}' + f'_tp{self.model_parallel}_pp{self.pipeline_stage}_vp{self.pp_interleave_num}' + f'_ep{self.expert_parallel}') + + if self.pp_interleave_num <= 1: + self.expert_input.low_mem = True + + self.memory = Memory(self.mem_lim) + self.set_memory(expert_input.is_dryrun) + + def set_memory(self, is_dryrun): + memory_dir = os.path.join(self.expert_input.output_file_dir, 'memory_info') + filename = (f'layers{self.num_layers_type1}_{self.num_layers_type2}_micro{self.micro_batch_num}_dp{self.data_parallel}' + f'_tp{self.model_parallel}_pp{self.pipeline_stage}_vp1' + f'_ep{self.expert_parallel}.txt') + memory_file_name = os.path.join(memory_dir, filename) + if is_dryrun: + self.mem_calculator_by_dryrun() + if not os.path.exists(memory_dir): + os.mkdir(memory_dir) + self.memory.write_memory_to_file(memory_file_name) + elif os.path.exists(memory_file_name): + logger.info(f'mem_lim = {self.mem_lim}') + with open(memory_file_name, 'r', encoding='utf-8') as file: + lines = file.readlines() + for line in lines: + key, value = line.split('=') + if hasattr(self.memory, key): + setattr(self.memory, key, float(value)) + else: + logger.error(f'NameError: {key}') + self.memory.mem_lim_stage0 = self.memory.mem_lim - self.memory.static_mem0 + self.memory.mem_lim_others = self.memory.mem_lim - self.memory.static_mem + self.memory.mem_lim_last = self.memory.mem_lim - self.memory.lm_head_mem + logger.info(f'Using the {filename} memory for vpp{self.pp_interleave_num}') + else: + logger.info(f'There is no memory file: {memory_file_name}! Using the default memory!') + self.memory.print_mem() + + def get_yaml_config(self): + with open(self.config_file, 'r', encoding='utf-8') as file: + data = yaml.safe_load(file) + self.pipeline_stage = data['parallel_config']['pipeline_stage'] + self.micro_batch_num = data['parallel_config']['micro_batch_num'] + self.model_parallel = data['parallel_config']['model_parallel'] + self.data_parallel = data['parallel_config']['data_parallel'] + if 'expert_parallel' in data['parallel_config']: + self.expert_parallel = data['parallel_config']['expert_parallel'] + else: + self.expert_parallel = 1 + self.num_layers_type2 = data['model']['model_config']['num_layers'] + if 'mtp_depth' in data['model']['model_config']: + self.num_layers_type2 += data['model']['model_config']['mtp_depth'] + if 'pp_interleave_num' in data['model']['model_config']: + self.pp_interleave_num = data['model']['model_config']['pp_interleave_num'] + else: + self.pp_interleave_num = 1 + self.num_layers_type1 = 0 + if 'moe_config' in data: + if 'first_k_dense_replace' in data['moe_config']: + self.num_layers_type1 = data['moe_config']['first_k_dense_replace'] + if self.expert_input.llm_class == 1: + self.num_layers_type1 = 36 + self.num_layers_type2 -= self.num_layers_type1 + if not self.expert_input.is_head_loss_input: + self.seq_length = data['model']['model_config']['seq_length'] + self.hidden_size = data['model']['model_config']['hidden_size'] + self.intermediate_size = data['model']['model_config']['intermediate_size'] + self.vocab_size = data['model']['model_config']['vocab_size'] + self.mem_lim = int(re.search(r'(\d+)GB', data['context']['max_device_memory']).group(1)) * 1024.0 + + + def get_shell_config(self): + configs, unparses = pp_util.parse_shell(self.config_file) + self.pipeline_stage = configs.get('PP') + self.micro_batch_num = configs.get('GBS') // configs.get('MBS') // configs.get('DP') + self.model_parallel = configs.get('TP') + self.data_parallel = configs.get('DP') + self.expert_parallel = configs.get('EP', 1) + self.pp_interleave_num = configs.get('VPP', 1) + self.num_layers_type1 = configs.get('FIRST_K_DENSE_REPLACE', 0) + if self.expert_input.llm_class == 1: + self.num_layers_type1 = 36 + self.num_layers_type2 = configs.get('NUM_LAYERS') - self.num_layers_type1 + if not self.expert_input.is_head_loss_input: + self.seq_length = configs.get('SEQ_LEN') + self.hidden_size = configs.get('HIDDEN_SIZE') + self.intermediate_size = configs.get('FFN_HIDDEN_SIZE') + if 'VOCAB_SIZE' in configs: + self.vocab_size = configs.get('VOCAB_SIZE') + self.mem_lim = configs.get('MAX_DEVICE_MEMORY', 58) * 1024.0 + + def get_toml_config(self): + model_args = self.expert_input.model_args + self.pipeline_stage = model_args.pp + self.micro_batch_num = model_args.mbn + self.model_parallel = model_args.tp + + self.data_parallel = model_args.dp + self.expert_parallel = model_args.ep + self.pp_interleave_num = 1 + + self.num_layers_type1 = model_args.first_k_dense_replace + self.num_layers_type2 = model_args.num_layers - self.num_layers_type1 + + self.seq_length = model_args.seq_len + self.hidden_size = model_args.hidden_size + self.intermediate_size = model_args.ffn_hidden_size + + self.vocab_size = model_args.padded_vocab_size + self.mem_lim = 58 * 1024.0 + + + def mem_calculator_by_dryrun(self): + self.set_cur_memory_info() + self.memory.mem_lim_stage0 = self.mem_lim - self.memory.static_mem0 + self.memory.mem_lim_others = self.mem_lim - self.memory.static_mem + self.memory.mem_lim_last = self.mem_lim - self.memory.lm_head_mem + + def set_cur_memory_info(self): + init_dryrun = InitDryrun(self.data_parallel, self.model_parallel, self.expert_input) + is_input_mem = False + if not is_input_mem: + peak_mem = init_dryrun.init_dryrun() + logger.info(f'The first initial dryrun: peak_mem={peak_mem}') + else: + logger.info('Using input peak memory for act_layer memory!') + peak_mem = pp_util.get_init_input_peak_mem() + if any(mem == 0 for mem in peak_mem[1]): + raise ValueError(pp_util.dryrun_error) + if self.expert_input.llm_class == 0: + self.update_deepseek_memory(peak_mem, init_dryrun) + elif self.expert_input.llm_class == 1: + self.update_boss_memory(peak_mem, init_dryrun) + else: + raise ValueError(f'can not support the model class number of {self.expert_input.llm_class}!') + + def update_deepseek_memory(self, peak_mem, init_dryrun: InitDryrun): + peaks = pp_util.get_peak_batch(init_dryrun.init_micro, init_dryrun.init_stages, init_dryrun.init_vpp, + True, init_dryrun.init_offset, init_dryrun.init_layers) + logger.info(f'peaks={peaks}') + self.memory.select_mem12 = (peak_mem[0][3] - peak_mem[0][4]) / (peaks[3] - peaks[4]) + if self.memory.select_mem12 < 0: + self.memory.select_mem12 = 0 + self.memory.select_mem0 = self.memory.select_mem12 + self.memory.select_mem = (peak_mem[0][10] - peak_mem[0][11]) / (peaks[10] - peaks[11]) + if self.memory.select_mem < 0: + self.memory.select_mem = 0 + self.memory.re_comp_mem12 = (peak_mem[0][1] - peak_mem[0][2]) / (peaks[1] - peaks[2]) + if self.memory.re_comp_mem12 < 0: + self.memory.re_comp_mem12 = 0 + self.memory.re_comp_mem0 = self.memory.re_comp_mem12 + self.memory.re_comp_mem = (peak_mem[0][8] - peak_mem[0][9]) / (peaks[8] - peaks[9]) + if self.memory.re_comp_mem < 0: + self.memory.re_comp_mem = 0 + self.memory.act_mem12 = (peak_mem[0][5] - peak_mem[0][6]) / (peaks[5] - peaks[6]) + if self.memory.act_mem12 < 0: + self.memory.act_mem12 = 0 + self.memory.act_mem0 = self.memory.act_mem12 + self.memory.act_mem = (peak_mem[0][12] - peak_mem[0][13]) / (peaks[12] - peaks[13]) + # 更正re_comp_mem + if dryrun.DryRun.config_file_type == 1: + self.memory.re_comp_mem = self.memory.act_mem + + self.memory.layer_mem012 = (((peak_mem[0][7] - peak_mem[0][1]) - self.memory.re_comp_mem12 * (peaks[7] - peaks[1])) + / (init_dryrun.layers_stage_vpp[7][0] - init_dryrun.layers_stage_vpp[1][0])) + self.memory.layer_mem = (((peak_mem[0][14] - peak_mem[0][9]) - self.memory.re_comp_mem * (peaks[14] - peaks[9])) + / (init_dryrun.layers_stage_vpp[14][0] - init_dryrun.layers_stage_vpp[9][0])) + + layers_stage = [sum(init_dryrun.layers_stage_vpp[stage][vpp] for vpp in range(init_dryrun.init_vpp)) + for stage in range(init_dryrun.init_stages)] + if init_dryrun.layers_stage_vpp[0][0] >= 3: + self.memory.static_mem0 = (peak_mem[1][0] - self.memory.layer_mem012 * 3 - self.memory.layer_mem * + (layers_stage[0] - 3) - self.memory.re_comp_mem12 * peaks[0] * 3 / + layers_stage[0] - self.memory.re_comp_mem * peaks[0] * (layers_stage[0] - 3) / + layers_stage[0]) + else: + self.memory.static_mem0 = (peak_mem[1][0] - self.memory.layer_mem012 * init_dryrun.layers_stage_vpp[0][0] - + self.memory.layer_mem * (layers_stage[0] - init_dryrun.layers_stage_vpp[0][0]) - + self.memory.re_comp_mem12 * peaks[0] * init_dryrun.layers_stage_vpp[0][0] / + layers_stage[0] - self.memory.re_comp_mem * peaks[0] * + (layers_stage[0] - init_dryrun.layers_stage_vpp[0][0]) / layers_stage[0]) + self.memory.static_mem = (peak_mem[1][-2] - self.memory.layer_mem * layers_stage[-2] + - self.memory.re_comp_mem * peaks[-2]) + self.memory.lm_head_mem = (peak_mem[1][-1] - self.memory.layer_mem * layers_stage[-1] + - self.memory.re_comp_mem * peaks[-1]) + + def update_boss_memory(self, peak_mem, init_dryrun: InitDryrun): + sort_micro = micro.SortMicro(init_dryrun.parts, init_dryrun.init_vpp, init_dryrun.init_stages, + init_dryrun.distribution, True, 1) + peaks = micro.PeakNum(sort_micro) + init_mem = Memory(self.mem_lim) + peaks.set_peak_act_recompute_num(init_dryrun.x_type2, init_dryrun.rs_type2, init_dryrun.ra_type2, + init_dryrun.x_type1, init_dryrun.rs_type1, init_dryrun.ra_type1, init_mem) + stage_range = [1, 6] + coe_a1 = pp_util.build_coe_array(peaks.peak_num_act_type1, peaks.peak_num_select_recom_type1, + peaks.peak_num_recompute_type1, init_dryrun.x_type1, stage_range) + mem_b1 = [peak_mem[0][stage] for stage in range(stage_range[0], stage_range[-1])] + mem_result, res, rank, s = np.linalg.lstsq(coe_a1, mem_b1, rcond=None) + mem_result = np.round(mem_result, decimals=1) + logger.info('the type of layer1:') + logger.info(f'the residual = {res}') + logger.info(f'the normal of residual = {np.linalg.norm(res)}') + logger.info(f'the rank = {rank}') + static1 = mem_result[0] + self.memory.act_mem12 = mem_result[1] + self.memory.act_mem0 = mem_result[1] + self.memory.layer_mem012 = mem_result[2] + self.memory.re_comp_mem12 = mem_result[3] + self.memory.re_comp_mem0 = mem_result[3] + self.memory.select_mem12 = mem_result[4] + self.memory.select_mem0 = mem_result[4] + + stage_range = [7, 12] + coe_a2 = pp_util.build_coe_array(peaks.peak_num_act_type2, peaks.peak_num_select_recom_type2, + peaks.peak_num_recompute_type2, init_dryrun.x_type2, stage_range) + mem_b2 = [peak_mem[0][stage] for stage in range(stage_range[0], stage_range[-1])] + mem_result, res, rank, s = np.linalg.lstsq(coe_a2, mem_b2, rcond=None) + mem_result = np.round(mem_result, decimals=1) + logger.info('the type of layer2:') + logger.info(f'the residual = {res}') + logger.info(f'the normal of residual = {np.linalg.norm(res)}') + logger.info(f'the rank = {rank}') + + static2 = mem_result[0] + self.memory.act_mem = mem_result[1] + self.memory.layer_mem = mem_result[2] + self.memory.re_comp_mem = mem_result[3] + self.memory.select_mem = mem_result[4] + + + static_mem0 = (peak_mem[1][0] - self.memory.layer_mem012 * init_dryrun.x_type1[0][0] - + self.memory.layer_mem * init_dryrun.x_type2[0][0] - + self.memory.act_mem12 * peaks.peak_num_act_type1[0] - + self.memory.act_mem * peaks.peak_num_act_type2[0] - + self.memory.re_comp_mem12 * peaks.peak_num_recompute_type1[0] - + self.memory.re_comp_mem * peaks.peak_num_recompute_type2[0] - + self.memory.select_mem12 * peaks.peak_num_select_recom_type1[0] - + self.memory.select_mem * peaks.peak_num_select_recom_type2[0]) + logger.info(f'static1 = {static1}') + logger.info(f'static2 = {static2}') + static_mem = ((static2 + static1) / 2) + lm_head_mem = (peak_mem[1][-1] - self.memory.layer_mem012 * init_dryrun.x_type1[0][-1] - + self.memory.layer_mem * init_dryrun.x_type2[0][-1] - + self.memory.act_mem12 * peaks.peak_num_act_type1[init_dryrun.init_stages - 1] - + self.memory.act_mem * peaks.peak_num_act_type2[init_dryrun.init_stages - 1] - + self.memory.re_comp_mem12 * peaks.peak_num_recompute_type1[init_dryrun.init_stages - 1] - + self.memory.re_comp_mem * peaks.peak_num_recompute_type2[init_dryrun.init_stages - 1] - + self.memory.select_mem12 * peaks.peak_num_select_recom_type1[init_dryrun.init_stages - 1] + - self.memory.select_mem * + peaks.peak_num_select_recom_type2[init_dryrun.init_stages - 1]) + self.memory.static_mem0 = np.round(static_mem0, decimals=1) + self.memory.static_mem = np.round(static_mem, decimals=1) + self.memory.lm_head_mem = np.round(lm_head_mem, decimals=1) + + +if __name__ == '__main__': + expert_input = ExpertInput(yaml_file='C:\\working\\768_4k.yaml', + mind_former_file='~/mindformer.py') + expert_input.is_dryrun = False + model_input = InitConfig(expert_input) + print(model_input) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_tool.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_tool.py new file mode 100644 index 0000000000000000000000000000000000000000..602748901e4958576fbebd3dfbab884f003144b4 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/pipeline_tool.py @@ -0,0 +1,65 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""entry file for pipeline""" + +import argparse + +from fast_tuner.utils.common import check_dryrun_parallel_number +from fast_tuner.utils.ppc_input import ParallelInput +from fast_tuner.utils.logger import logger +from fast_tuner.pipeline_conductor.pipeline_parallel import pipeline_proc +from fast_tuner.pipeline_conductor import pp_util + + +if __name__ == '__main__': + logger.info('start to run pipeline tool') + # 用户输入profiling结果,候选配置等信息,流水线工具给出配置cost排序 + parser = argparse.ArgumentParser(description='Run taylor pipeline_search_tool with user input parameters') + parser.add_argument('--files_dir', type=str, default='./output/dryrun_yaml/', + help='Path to the YAML or SHELL file directory') + parser.add_argument('--yaml_path', type=str, default=None, + help='Path of training config (.yaml)') + parser.add_argument('--shell_path', type=str, default=None, + help="Path of training config (.sh)") + parser.add_argument('--mindformers_dir', type=str, default=None, + help='Directory of mindformers') + parser.add_argument('--mindspeed_path', type=str, default=None, + help="Absolute path of posttrain_gpt (.py)") + parser.add_argument('--profile_data_dir', type=str, default='./profile_data/', + help='Directory of profile data') + parser.add_argument('--solver_name', type=str, default='HIGHS', + help='Name of the solver') + parser.add_argument('--parser_result', type=str, default=None, + help='Profiling parser result file') + parser.add_argument('--nd_path', type=str, default='./config/nd_result.csv', + help='Path to nd result file') + parser.add_argument('--env_json', type=str, required=True, + default='./config/boss_env_config.json', help="Path of environment config (.json)") + parser.add_argument('--register_path', type=str,default='research/jiutian', + help="Path of register") + parser.add_argument('--parallel_num', type=pp_util.str2int, default=16, + help="The number of dryrun at once") + parser.add_argument('--dryrun', type=pp_util.str2bool, default=True, + help="Is auto dryrun") + parser.add_argument('--check', type=pp_util.str2bool, default=True, + help="Is double check") + parser.add_argument('--output_path', type=str, + default='./output/', + help='Directory of output info') + + args = parser.parse_args() + check_dryrun_parallel_number(args.parallel_num) + pipeline_input = ParallelInput(args) + pipeline_proc(pipeline_input) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/__init__.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/common.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b9970957f0baba03147ad68be0e04ef459dbf7bc --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/common.py @@ -0,0 +1,402 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""common use functions""" +import os +import json + +from pathlib import Path +import shutil +import yaml +import toml +import numpy as np + +from fast_tuner.pipeline_conductor.pp_util import configs_to_shell, parse_shell +from fast_tuner.utils.logger import logger + + +def cal_model_layers_num(input_args): + if input_args.model.model_config.mtp_depth is None: + input_args.model.model_config.mtp_depth = 0 + return input_args.model.model_config.num_layers + input_args.model.model_config.mtp_depth + +GENERAL_TOML = 'DP{}_TP{}_PP{}_EP{}_CP{}_FSDP{}_pretrain.toml' +MOE_PATTERN_YAML = 'DP{}_TP{}_PP{}_EP{}_pretrain.yaml' +LLAMA_PATTERN_YAML = 'DP{}_TP{}_PP{}_pretrain.yaml' +MOE_PATTERN_SHELL = 'DP{}_TP{}_PP{}_EP{}_pretrain.sh' +LLAMA_PATTERN_SHELL = 'DP{}_TP{}_PP{}_pretrain.sh' +MODULE_PATTERN_REG_YAML = r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)_pretrain.yaml' +MOE_PATTERN_REG_SHELL = r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)_pretrain.sh' +LLAMA_PATTERN_REG_SHELL = r'DP(\d+)_TP(\d+)_PP(\d+)_pretrain.sh' + +def initial_offset(pipeline_stage, num_layers): + ''' + set offset in a memory friendly way such that + we can estimate the minimal memory requirement + ''' + + if pipeline_stage == 1 or num_layers % pipeline_stage == 0: + offset = 0 + return offset + + pp_interleave_num = 1 + offset = np.zeros((pp_interleave_num, pipeline_stage), dtype=int).tolist() + remainder = num_layers % (pp_interleave_num * pipeline_stage) + for vpp in range(pp_interleave_num): + offset[0][0] += 1 + remainder-=1 + for stage in range(pipeline_stage): + if remainder == 0: + break + offset[vpp][pipeline_stage - stage - 2] += 1 + remainder -= 1 + return offset[0] + +def offset_for_dualpipe(pipeline_stage, num_layers): + ''' + set offset in a memory friendly way such that + we can estimate the minimal memory requirement + ''' + pp_interleave_num = 2 + offset = np.zeros((pp_interleave_num, pipeline_stage), dtype=int).tolist() + new_layers = num_layers - 2 + remainder = new_layers % pipeline_stage + base = new_layers // pipeline_stage + origin_base = new_layers // pipeline_stage + + for stage in range(pipeline_stage): + # 获取当前stage的层数 + if remainder == 0: + cur_layer = base + else: + cur_layer = base + 1 + remainder -= 1 + # 给vpp分配层 + vpp1_layer = cur_layer // pp_interleave_num + if pipeline_stage - stage - 1 == 0: + offset[0][pipeline_stage - stage - 1] = vpp1_layer + 2 - origin_base + else: + offset[0][pipeline_stage - stage - 1] = vpp1_layer - origin_base + vpp2_layer = cur_layer - vpp1_layer + offset[1][stage] = vpp2_layer - origin_base + return offset + +def cal_world_size(input_args): + '''compute the worldsize from config''' + world_size = input_args.dp * input_args.tp * input_args.pp * input_args.cp + return world_size + +def generate_dryrun_yaml(destination_file, config, para): + """ + :param para: 用户输入参数 + :param destination_file: 修改并行参数的配置yaml文件 + :param config: [dp, tp, pp, ep] + :return: + """ + # 复制YAML文件 + shutil.copy2(para.YAML_PATH, destination_file) + + # 读取复制后的YAML文件, 修改config_para字段的值 + with open(destination_file, 'r', encoding='utf-8') as file: + yaml_data = yaml.safe_load(file) + yaml_data['parallel_config']['data_parallel'] = config[0] + yaml_data['parallel_config']['model_parallel'] = config[1] + yaml_data['parallel_config']['pipeline_stage'] = config[2] + yaml_data['parallel_config']['expert_parallel'] = config[3] + yaml_data['parallel_config']['micro_batch_num'] = para.GBS // config[0] + yaml_data['model']['model_config']['offset'] = config[4] + if yaml_data['parallel'].get('dataset_strategy') is not None: + strategy_size = len(yaml_data['parallel']['dataset_strategy']) + yaml_data['parallel']['dataset_strategy'] = [[config[0], 1] for _ in range(strategy_size)] + # 重计算设置为true + yaml_data['recompute_config']['recompute'] = True + + # todo: 适配tnd + yaml_data['train_dataset']['data_loader']['dataset_dir'] = para.DATASET + + + # 将修改后的数据写回YAML文件 + with open(destination_file, 'w', encoding='utf-8') as file: + yaml.dump(yaml_data, file, default_flow_style=False, allow_unicode=True) + + logger.info(f"The dryrun YAML file copied and modified, new file is: {destination_file}") + +def generate_dryrun_shell(destination_file, config, para): + """ + + :param para: 用户输入参数 + :param destination_file: 修改并行参数的配置shell文件 + :param config: [dp, tp, pp, ep, offset] or [dp, tp, pp] + :return: + """ + configs, unparses = parse_shell(para.SHELL_PATH) + configs['DP'] = config[0] + configs['TP'] = config[1] + configs['PP'] = config[2] + configs_to_shell(destination_file, configs, unparses) + logger.info(f'The dryrun SHELL file copied and modified, new file is {destination_file}') + +def generate_profile_yaml(destination_file, config, para): + """ + :param para: 用户输入参数 + + :param destination_file: + :param config: [dp, tp, pp, ep, offset, num_layers] + :return: + """ + # 复制YAML文件 + shutil.copy2(para.YAML_PATH, destination_file) + + # 读取复制后的YAML文件, 修改config_para字段的值 + with open(destination_file, 'r', encoding='utf-8') as file: + yaml_data = yaml.safe_load(file) + yaml_data['parallel_config']['data_parallel'] = config[0] + yaml_data['parallel_config']['model_parallel'] = config[1] + yaml_data['parallel_config']['pipeline_stage'] = config[2] + yaml_data['parallel_config']['expert_parallel'] = config[3] + yaml_data['parallel_config']['micro_batch_num'] = para.GBS // config[0] + yaml_data['model']['model_config']['offset'] = config[4] + yaml_data['recompute_config']['recompute'] = config[5] + yaml_data['model']['model_config']['num_layers'] = config[6] + + yaml_data['profile'] = True + yaml_data['profile_output'] = os.path.join(Path(destination_file).parent, "output") + yaml_data['init_start_profile'] = True + yaml_data['profile_communication'] = True + yaml_data['profile_memory'] = True + yaml_data['op_time'] = True + yaml_data['profile_level'] = 1 + yaml_data['profile_start_step'] = 4 + yaml_data['profile_stop_step'] = 6 + + # 将修改后的数据写回YAML文件 + with open(destination_file, 'w', encoding='utf-8') as file: + yaml.dump(yaml_data, file, default_flow_style=False, allow_unicode=True) + + logger.info(f"The profile YAML file copied and modified, config_para is: {config}") + +def generate_profile_shell(destination_file, config, para): + """ + :param para: 用户输入参数 + + :param destination_file: + :param config: [dp, tp, pp, ep, offset, num_layers] + :return: + """ + configs, unparses = parse_shell(para.SHELL_PATH) + configs['DP'] = config[0] + configs['TP'] = config[1] + configs['PP'] = config[2] + if '_EP' in destination_file: + configs['EP'] = config[3] + configs['FIRST_K_DENSE_RAPLACE'] = 3 + else: + configs['EP'] = 0 # 或者这里不需要写入EP + # TODO 应该需要读NPUS_PER_NODE,当前硬编码为8 + profile_need_rank_num = configs['DP'] * configs['TP'] * configs['PP'] + if profile_need_rank_num < para.RANK_NUM: + configs['NNODES'] = profile_need_rank_num // 8 + configs['NPUS_PER_NODE'] = 8 + if configs['NNODES'] == 0: + configs['NNODES'] = 1 + configs['NPUS_PER_NODE'] = profile_need_rank_num + else: + configs['NNODES'] = para.RANK_NUM // 8 + configs['NPUS_PER_NODE'] = 8 + configs['NUM_LAYERS'] = config[-1] + configs['MINDSPEED_PATH'] = para.MINDSPEED_PATH + # 生成要分析的进程编号列表 + step = profile_need_rank_num // config[2] + profile_ranks = ' '.join(map(str, range(0, profile_need_rank_num, step))) + if '_EP' in destination_file: + output_dir = os.path.join(para.OUTPUT_PATH, "profile_result", + f"DP{config[0]}_TP{config[1]}_PP{config[2]}_EP{config[3]}_profile") + else: + output_dir = os.path.join(para.OUTPUT_PATH, "profile_result", + f"DP{config[0]}_TP{config[1]}_PP{config[2]}_profile") + # 使用f-string构建参数字符串,提高可读性 + profile_args = ( + "--profile " + "--profile-step-start 5 " + "--profile-step-end 7 " + "--profile-level level1 " + "--profile-with-stack " + "--profile-with-cpu " + f"--profile-ranks {profile_ranks} " # 确保参数之间有空格 + f"--profile-save-path {output_dir}" + ) + # todo: 需要适配原有的args內容 + if configs['EP'] == 0: + mem_args = ( + "--use-distributed-optimizer " + "--recompute-method block " + "--recompute-granularity full " + "--recompute-num-layers 1 " + "--num-layer-list 4,3 " + ) + else: + mem_args = ( + "--use-distributed-optimizer " + "--recompute-method block " + "--recompute-granularity full " + "--recompute-num-layers 1 " + "--num-layer-list 4,3 " + "--first-k-dense-replace 3 " + ) + + # 存入配置字典 + configs['PROFILE_ARGS'] = profile_args + configs['MEM_ARGS'] = mem_args + configs_to_shell(destination_file, configs, unparses) + insert_profile_args_final(destination_file) + logger.info(f'The profile SHELL file copied and modified, new file is {destination_file}') + return output_dir + +def generate_profile_toml(destination_file, config, para): + '''generate toml file for profile''' + def read_toml(file_path): + """读取 TOML 文件并返回字典""" + with open(file_path, 'r', encoding='utf-8') as f: + return toml.load(f) + + def write_toml(data, file_path): + """将字典写入 TOML 文件""" + with open(file_path, 'w', encoding='utf-8') as f: + toml.dump(data, f) + + dp, tp, pp, ep, cp, op = config[:6] + template_data = read_toml(para.TOML_PATH) + # [profiling] + template_data['profiling']['enable_profiling'] = True + template_data['profiling']['profile_freq'] = 20 + output_trace_dir = os.path.join(os.path.abspath(para.OUTPUT_PATH), "profile_result", + f"DP{dp}_TP{tp}_PP{pp}_EP{ep}_CP{cp}_OP{op}_trace") + template_data['profiling']['save_traces_folder'] = output_trace_dir + # [model] + template_data['model']['flavor'] = 'tune' + # [training] + template_data['training']['steps'] = 20 + # [parallelism] + template_data['parallelism']['data_parallel_replicate_degree'] = dp // op + template_data['parallelism']['data_parallel_shard_degree'] = op + template_data['parallelism']['tensor_parallel_degree'] = tp + template_data['parallelism']['context_parallel_degree'] = cp + template_data['parallelism']['expert_parallel_degree'] = ep + template_data['parallelism']['pipeline_parallel_degree'] = pp + write_toml(template_data, destination_file) + return output_trace_dir + + + +def insert_profile_args_final(shell_file_path, profile_args="$PROFILE_ARGS \\"): + '''modify shell files for profile''' + with open(shell_file_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + # 标记是否在 torchrun 命令块内 + in_torchrun_block = False + modified_lines = [] + + for line in lines: + stripped = line.strip() + + # 检测命令块开始 + if "$MEM_ARGS" in stripped: + in_torchrun_block = True + modified_lines.append(line) + modified_lines.append(f" {profile_args}\n") + continue + + # 检测命令块结束(最后一个参数行,以 \ 结尾) + if in_torchrun_block and stripped.endswith("\\"): + modified_lines.append(line) + continue + + # 命令块结束后的行 + if in_torchrun_block: + in_torchrun_block = False + + # 普通行 + modified_lines.append(line) + + # 写入修改后的内容 + with open(shell_file_path, 'w', encoding='utf-8') as f: + f.writelines(modified_lines) + + logger.info(f"insert PROFILE_ARGS to {shell_file_path}") + + +# 定义一个映射,将yaml_task和对应的生成函数关联起来 +TASK_FUNCTION_MAP = { + "dryrun_yaml": generate_dryrun_yaml, + "dryrun_shell": generate_dryrun_shell, + "profile_yaml": generate_profile_yaml, + "profile_shell": generate_profile_shell, + "profile_toml": generate_profile_toml, +} + +def generate_files(candidate_configs, des_file_directory, file_task, para, input_args): + '''generate config files w.r.t your training library''' + # 如果目录不存在,则创建它 + if not os.path.exists(des_file_directory): + os.makedirs(des_file_directory) + + # 检查file_task是否有效 + if file_task not in TASK_FUNCTION_MAP: + logger.error(f"Invalid file_task value: {file_task}. Please use one of {list(TASK_FUNCTION_MAP.keys())}.") + return None + + generate_function = TASK_FUNCTION_MAP[file_task] + if para.SHELL_PATH: + pattern = LLAMA_PATTERN_SHELL if input_args.expert_num is None else MOE_PATTERN_SHELL + elif para.YAML_PATH: + pattern = LLAMA_PATTERN_YAML if input_args.expert_num is None else MOE_PATTERN_YAML + else: + pattern = GENERAL_TOML + + file_path_list = [] + output_dir_list = [] + for config in candidate_configs: + # 生成输出文件路径,包括文件名 + destination_file = (des_file_directory + pattern).format(*config) + file_path_list.append(destination_file) + # 只有file_task=profile_shell时才会返回output_dir + output_dir = generate_function(destination_file, config, para) + output_dir_list.append(output_dir) + return file_path_list, output_dir_list + + +def is_dualpipe_open(input_args): + if input_args.mf_args is None: + return False + parallel_cfg = input_args.mf_args.parallel + use_zero_bubble_v = ( + hasattr(parallel_cfg, 'pipeline_config') and + hasattr(parallel_cfg.pipeline_config, 'pipeline_scheduler') and + parallel_cfg.pipeline_config.pipeline_scheduler == 'zero_bubble_v' + ) + return use_zero_bubble_v + +def check_dryrun_parallel_number(parallel_num): + if parallel_num > 16: + raise Exception(f"The parallel number {parallel_num} is too large.") + +def parse_args_from_json(args): + with open(args.config, 'r', encoding='utf-8') as f: + data = json.load(f) + for key, value in data.items(): + if key not in vars(args): + logger.error(f"The json contains invalid parameters: {key}") + else: + setattr(args, key, value) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/dryrun_manage.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/dryrun_manage.py new file mode 100644 index 0000000000000000000000000000000000000000..b45a4528ef3a767388fab69103edf1dd00f32ebc --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/dryrun_manage.py @@ -0,0 +1,135 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dryrun operation""" +import os +import subprocess +import re +import json +from multiprocessing import Pool +from fast_tuner.utils.common import cal_world_size, is_dualpipe_open, \ + LLAMA_PATTERN_REG_SHELL, MOE_PATTERN_REG_SHELL, MODULE_PATTERN_REG_YAML +from fast_tuner.utils.logger import logger + + +def env_update(rank_size, env_variable_json): + # 设置环境变量 + with open(env_variable_json, 'r', encoding='utf-8') as f: + env_vars = json.load(f) + env_vars['RANK_SIZE'] = str(rank_size) + os.environ.update(env_vars) + +def create_target_dir(file, target_directory): + output_dir = os.path.splitext(file)[0] + target_dir = os.path.join(target_directory, f"dryrun_{output_dir}") + return target_dir + +def execute_command(para, config_file_path, target_dir, rank_id, tp): + '''execute dryrun command''' + if para.YAML_PATH: + command = ['python', os.path.join(para.MINDFORMERS_DIR, 'run_mindformer.py'), '--config', config_file_path, + '--register_path', para.REGISTER_PATH] + else: + command = ['bash', config_file_path] + try: + # 构建日志文件路径 + os.environ['RANK_ID'] = str(rank_id) + os.environ['MODEL_PARALLEL'] = str(tp) + os.makedirs(target_dir, exist_ok=True) + log_file_path = os.path.join(target_dir, 'dryrun_new.log') + with open(log_file_path, 'w', encoding='utf-8') as log_file: + logger.info(f"Command: {' '.join(command)} rank {rank_id} start run") + subprocess.run(command, stdout=log_file, stderr=subprocess.STDOUT, check=False) + except Exception as e: + logger.error(f"The command execution failed.: {e}") + +def calculate_rank_id(input_args, match, layer_num, rank_size): + if is_dualpipe_open(input_args): + return rank_size - 1 + pp = int(match.group(3)) + x = pp - layer_num % pp + rank_id = x * rank_size // pp + if rank_id == rank_size: + rank_id -= 1 + return rank_id + +def read_dryrun_info(root_dir): + '''read the peak mem from dryrun log''' + logger.info(f"Reading dryrun info from {root_dir}") + result_list = [] + # 预编译正则表达式,提高匹配性能 + pattern = re.compile(r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)') + + for entry in os.scandir(root_dir): + if not entry.is_dir(): + continue + + log_file_path = os.path.join(entry.path, 'dryrun_new.log') + if not os.path.isfile(log_file_path): + continue + + with open(log_file_path, 'r', encoding='utf-8') as log_file: + for line in log_file: + if "Used peak memory usage (without fragments):" in line: + peak_value = line.split(': ')[-1].strip() + peak_value = int(peak_value.rstrip('M')) + break + else: + # 如果没有找到 peak 值,跳过当前目录 + continue + + match = pattern.search(entry.name) + if match: + dp_num, tp_num, pp_num, ep_num = map(int, match.groups()) + result_list.append([dp_num, tp_num, pp_num, ep_num, peak_value]) + + logger.info(f"read dryrun info done, result size {len(result_list)} format [dp, tp, pp, ep, peak_value]") + logger.info('\n'.join(str(item) for item in result_list)) + return result_list + +def get_file_pattern(root_dir, input_args, para): + logger.info(f"Reading file pattern from {root_dir}") + if para.SHELL_PATH: + if input_args.expert_num is None: + pattern = LLAMA_PATTERN_REG_SHELL + else: + pattern = MOE_PATTERN_REG_SHELL + else: + pattern = MODULE_PATTERN_REG_YAML + return pattern + +def launch_dryrun(input_args, dryrun_file_dir, dryrun_data_dir, para): + '''dryrun launcher''' + logger.info('start dryrun') + rank_size = cal_world_size(input_args) + env_update(rank_size, para.ENV_JSON) + tasks = [] + layer_num = input_args.num_layers + + # 遍历指定目录下的所有文件 + for root, _, files in os.walk(dryrun_file_dir): + for file in files: + pattern = get_file_pattern(root, input_args, para) + match = re.match(pattern, file) + if not match: + continue + rank_id = calculate_rank_id(input_args, match, layer_num, rank_size) + tp = int(match.group(2)) + target_dir = create_target_dir(file, dryrun_data_dir) + yaml_file_path = os.path.join(root, file) + tasks.append((para, yaml_file_path, target_dir, rank_id, tp)) + + with Pool(processes=para.DRYRUN_LIM) as pool: + pool.starmap(execute_command, tasks) + logger.info("all dryrun done") diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_config.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_config.py new file mode 100644 index 0000000000000000000000000000000000000000..38f2751809b09d91b044bf51bf557073bd87d7b0 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_config.py @@ -0,0 +1,87 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" Transformer-Config dict parse module """ + +import os +import yaml + +class InputConfig(dict): + """ + A class for configuration that inherits from Python's dict class. + Can parse configuration parameters from yaml files + + Args: + args (Any): Extensible parameter list, a yaml configuration file path + kwargs (Any): Extensible parameter dictionary, a yaml configuration file path + + Returns: + An instance of the class. + + Examples: + >>> cfg = InputConfig('./test.yaml') + >>> cfg.a + """ + + def __init__(self, *args): + super().__init__() + cfg_dict = {} + + for arg in args: + if isinstance(arg, str): + if arg.endswith('yaml') or arg.endswith('yml'): + raw_dict = InputConfig._file2dict(arg) + cfg_dict.update(raw_dict) + + InputConfig._dict2config(self, cfg_dict) + + def __getattr__(self, key): + if key not in self: + return None + return self[key] + + @staticmethod + def _file2dict(filename=None): + """Convert config file to dictionary. + + Args: + filename (str) : config file. + """ + if filename is None: + raise NameError(f'This {format(filename)} cannot be empty.') + + filepath = os.path.realpath(filename) + with open(filepath, encoding='utf-8') as fp: + # 文件指针重置到文件开头 + fp.seek(0) + cfg_dict = yaml.safe_load(fp) + + return cfg_dict + + @staticmethod + def _dict2config(config, dic): + """Convert dictionary to config. + + Args: + config : Config object + dic (dict) : dictionary + """ + if isinstance(dic, dict): + for key, value in dic.items(): + if isinstance(value, dict): + sub_config = InputConfig() + dict.__setitem__(config, key, sub_config) + InputConfig._dict2config(sub_config, value) + else: + config[key] = dic[key] diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_param.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_param.py new file mode 100644 index 0000000000000000000000000000000000000000..25261961b4ef8731e775d6ff3f1edb93b6cf6254 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/input_param.py @@ -0,0 +1,60 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""inner para class""" +from fast_tuner.utils.logger import logger + + +class InputParam: + '''All information that we need to record''' + PARAM_MAPPING = { + "ENV_JSON" : "env_json", + "DATASET" : "dataset", + "YAML_PATH" : "yaml_path", + "SHELL_PATH" : "shell_path", + "TOML_PATH" : "toml_path", + "MINDFORMERS_DIR" : "mindformers_dir", + "MINDSPEED_PATH" : "mindspeed_path", + "TORCHTITAN_PATH": "torchtitan_path", + "DRYRUN_DATA_DIR" : "dryrun_data_dir", + "PROFILE_DATA_DIR" : "profile_data_dir", + "DRYRUN_LIM" : "dryrun_lim", + "RANK_NUM" : "rank_num", + "SOLVER_NAME" : "solver_name", + "MAX_EXPERT_PARALLEL": "max_expert_parallel", + "OUTPUT_PATH": "output_path", + "GBS": "gbs", + "SELECT_RECOMPUTE": "select_recompute", + "ALG_PHASE": "alg_phase", + "PARSER_RESULT": "parser_result", + "DRYRUN": "dryrun", + "CHECK": "check", + "CONFIG": "config", + "NPUS_PER_NODE": "npus_per_node", + "NNODES": "nnodes", + "STRATEGY": "strategy", + } + + def __init__(self, args): + self.args = args + + def __getattr__(self, name): + if name in self.PARAM_MAPPING: + param_name = self.PARAM_MAPPING[name] + return getattr(self.args, param_name) + raise AttributeError(f"'InputParam' object has no attribute '{name}'") + + def print_params(self): + for key, value in self.PARAM_MAPPING.items(): + logger.info(f"{key}: {value} = {getattr(self, key)}") diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/logger.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd4716beedb4c868444680c5f52568230e1dc0e --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/logger.py @@ -0,0 +1,40 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""logger file""" +import logging + +DEFAULT_STDOUT_FORMAT = '%(levelname)s %(asctime)s %(filename)s:%(lineno)d - %(message)s' +FORMATTER = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +OUTPUT_LEVEL_NUM = logging.WARNING +logging.addLevelName(OUTPUT_LEVEL_NUM, "OUTPUT") + + +def setup_logger(name: str, level: int = logging.DEBUG): + """setup a logger""" + ch = logging.StreamHandler() + ch.setLevel(level) + ch.setFormatter(FORMATTER) + + def output(self, message, *args): + self.warning(message, *args) + + logging.Logger.output = output + ppb_logger = logging.getLogger(name) + ppb_logger.setLevel(level) + ppb_logger.addHandler(ch) + + return ppb_logger + +logger = setup_logger('symphony', level=logging.INFO) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/ppc_input.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/ppc_input.py new file mode 100644 index 0000000000000000000000000000000000000000..832e0f582da8e07c281dd33807b35c5a649ae869 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/ppc_input.py @@ -0,0 +1,139 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""interface between nd and pp """ +import os +import re +import csv +from typing import List +from pathlib import Path +from fast_tuner.utils.logger import logger +from fast_tuner.utils.common import GENERAL_TOML +from fast_tuner.utils.profiling.profile_info import ProfileInfo +from fast_tuner.pipeline_conductor.start_service import ExpertInput +from fast_tuner.pipeline_conductor.dryrun import DryRun, dryrun_config_error +from fast_tuner.ndsearch.para_for_nd_search import ParaForNd + +class ParallelConfig: + def __init__(self, gbs, config): + self.dp = config[0] + self.tp = config[1] + self.pp = config[2] + self.vp = 1 + self.ep = config[3] + self.micro = gbs / self.dp + +class PipelineInputConfig: + def __init__(self, profiling_info: ProfileInfo, config_path, model_args = None): + self.profiling_info = profiling_info + self.config_path = config_path + self.model_args = model_args + +class ParallelInput: + def __init__(self, para, profile_file_dir=None): + self.candidate_configs: List[PipelineInputConfig] = [] + + self.init_configs_info(para, profile_file_dir) + + self.is_lowmem = int(os.getenv('ENABLE_LESS_MEM_VPP', 0)) + args = para.args + self.solver_name = args.solver_name + self.env_config_json = args.env_json + self.dryrun_lim = args.dryrun_lim + self.dryrun = args.dryrun + self.check = args.check + self.output_path = args.output_path + if DryRun.config_file_type == 0: + self.ms_adapter_file = args.mindformers_dir + elif DryRun.config_file_type == 1: + self.ms_adapter_file = args.mindspeed_path + elif DryRun.config_file_type == 2: + self.ms_adapter_file = args.torchtitan_path + else: + raise TypeError(dryrun_config_error) + + @staticmethod + def parse_results_by_csv(csv_file): + result_dict = {} + with open(csv_file, mode='r', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + # key = "_".join([row['dp'], row['tp'], row['pp'], row['ep']]) + key_columns = list(row.keys())[:-5] # 获取前N-5列的列名 + key = "_".join([row[col] for col in key_columns]) + value = [float(row['dmratio']), float(row['bfratio']), float(row['re_grow_ration']), + float(row['hratio']), float(row['moe_fw'])] + result_dict[key] = value + return result_dict + + def get_model_files_dir(self, args, profile_file_dir=None): + if profile_file_dir is not None: + files_dir = Path(profile_file_dir) + elif args.files_dir is not None: + files_dir = Path(args.files_dir) + else: + raise RuntimeError('Must specify either files_dir or profile_file_dir') + + if args.yaml_path: + model_files = files_dir.glob('*.yaml') + elif args.shell_path: + model_files = files_dir.glob('*.sh') + elif args.toml_path: + model_files = files_dir.glob('*.toml') + else: + raise Exception("No yaml_path or shell_path specified") + return model_files + + def get_args_info(self, para, config_file): + if config_file.name.endswith('.toml'): + para.TOML_PATH = config_file + model_args = ParaForNd(para) + return model_args + else: + return None + + def init_configs_info(self, para, profile_file_dir=None): + args = para.args + model_files = self.get_model_files_dir(args, profile_file_dir) + csv_result = {} + # 若用户直接输入profiling解析信息---csv文件,则从文件中读入 + if args.parser_result is not None: + csv_result = self.parse_results_by_csv(args.parser_result) + + if args.yaml_path : + pattern = re.compile(r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)') + DryRun.config_file_type = 0 + elif args.shell_path: + pattern = re.compile(r'DP(\d+)_TP(\d+)_PP(\d+)_EP(\d+)') + DryRun.config_file_type = 1 + elif args.toml_path: + pattern = re.compile(GENERAL_TOML.replace('{}', '(\d+)')) + DryRun.config_file_type = 2 + + for config_file in model_files: + match = pattern.search(config_file.name) + if match: + # dp_num, tp_num, pp_num, ep_num = map(int, match.groups()) + # config = [dp_num, tp_num, pp_num, ep_num] + config = [int(x) for x in match.groups()] + config_str = "_".join(map(str, config)) + if config_str in csv_result: + profile_list = csv_result[config_str] + else: + profile_list = [] + # todo: profiling解析的输入有待确认,例如rank + profile_info = ProfileInfo(args.profile_data_dir, profile_list) + input_args = self.get_args_info(para, config_file) + pipeline_input_config = PipelineInputConfig(profile_info, config_file, input_args) + self.candidate_configs.append(pipeline_input_config) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_info.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_info.py new file mode 100644 index 0000000000000000000000000000000000000000..3dd0f246b34ba9a16a02c129eabb0eddc9826327 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_info.py @@ -0,0 +1,65 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""profile parser result interface""" +import csv +from fast_tuner.utils.logger import logger + +from fast_tuner.utils.logger import logger + + +class ProfileInfo: + def __init__(self, profiling_path, input_data): + # 若input_data为空,则需要解析profiling结果 + if not input_data: + logger.info('did not use the input data!') + input_data = self.parse_profiling_result(profiling_path) + self.dmratio = input_data[0] + self.bfratio = input_data[1] + self.re_grow_ration = input_data[2] + self.hratio = input_data[3] + self.moe_fw = input_data[4] + logger.info(f'{input_data}') + + @staticmethod + def generate_csv(): + # 定义 CSV 文件的列名 + headers = ['dp', 'tp', 'pp', 'ep', 'vp', 'dmratio', 'bfratio', 'hratio', 'moe_bw', 're_grow_ration'] + + # 示例数据,这里可以根据实际需求修改或添加更多行数据 + data = [ + [128, 1, 8, 16, 1, 0.1, 0.2, 0.3, 100, 0.34], + [128, 1, 8, 8, 1, 0.4, 0.5, 0.6, 200, 0.24] + ] + + # 定义要保存的 CSV 文件路径 + csv_file_path = './config/profiling_result.csv' + + # 打开文件并写入数据 + with open(csv_file_path, mode='w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + + # 写入列名 + writer.writerow(headers) + + # 写入数据行 + for row in data: + writer.writerow(row) + + logger.info(f"CSV 文件已生成,路径为: {csv_file_path}") + + def parse_profiling_result(self, profiling_path): + # todo: 待填充 + profiling_data = [0.1, 0.2, 0.3, 100, 2] + return profiling_data diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_launch.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_launch.py new file mode 100644 index 0000000000000000000000000000000000000000..437a580285ad3ca39209c4a2a041ee9b6687b006 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_launch.py @@ -0,0 +1,73 @@ +"""launch profile""" +from fast_tuner.utils.common import LLAMA_PATTERN_REG_SHELL, MOE_PATTERN_REG_SHELL, GENERAL_TOML +from fast_tuner.utils.logger import logger +import subprocess +import os +import re +import time +from pathlib import Path +class ProfileLaunch: + # 自动执行profile + def __init__(self, profile_configs, para): + self.profile_configs = profile_configs + self.para = para + + def profile_launch(self, profile_file_dir): + # 遍历指定目录下的所有文件 + for root, _, files in os.walk(profile_file_dir): + for file in files: + # if '_EP' in file: + # pattern = MOE_PATTERN_REG_SHELL + # else: + # pattern = LLAMA_PATTERN_REG_SHELL + # match = re.match(pattern, file) + # if not match: + # continue + profile_file_path = os.path.join(root, file) + if file.endswith(".toml"): + self.run_torchtitan(profile_file_path) + elif file.endswith(".sh"): + self.run_shell(profile_file_path) + else: + logger.error(f"file type: {file} is not supported.") + + def run_shell(self, profile_file_dir): + cmd = ["bash", profile_file_dir] + logger.info(f"profile command: {cmd}") + + process = subprocess.run( + cmd, + preexec_fn=os.setpgrp, + check=False, + ) + # 为避免profile子进程未结束生成profile文件失败,增加sleep + time.sleep(60) + return_code = process.returncode + logger.info("Last job returns %d.", return_code) + + + def run_torchtitan(self, profile_file_toml): + run_file = self.para.TORCHTITAN_PATH + regex_pattern = GENERAL_TOML.replace('{}', '(\d+)') + match = re.match(regex_pattern, Path(profile_file_toml).name) + if match: + dp = int(match.group(1)) + tp = int(match.group(2)) + pp = int(match.group(3)) + cp = int(match.group(5)) + world_size = dp * tp * pp * cp + else: + raise ValueError(f"Invalid profile_file_toml: {profile_file_toml}") + cmd = f'CONFIG_FILE={profile_file_toml} NGPU={world_size} {run_file}' + logger.info(f"run_torchtitan profile command: {cmd}") + + process = subprocess.run( + cmd, + preexec_fn=os.setpgrp, + shell=True, + check=False, + ) + # 为避免profile子进程未结束生成profile文件失败,增加sleep + time.sleep(20) + return_code = process.returncode + logger.info("Last job returns %d.", return_code) diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_parser.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9de91989fb850b8dcf9c7f360c28667dadbbdc2c --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_parser.py @@ -0,0 +1,681 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""profile parser""" +import json +import os +import statistics +import csv +from collections import defaultdict +from fast_tuner.pipeline_conductor.pp_util import parse_shell +from fast_tuner.utils.logger import logger +from fast_tuner.ndsearch.para_for_nd_search import ParaForNd +from fast_tuner.ndsearch.memory_model import compute_weight_and_optimizer_memory +from pathlib import Path + +encoding = 'utf-8' + +def find_file_by_name(directory, filename): + for root, _, files in os.walk(directory): + if filename in files: + return os.path.join(root, filename) + logger.error(f'No such file {filename} in directory {directory}') + return None + +def median_mean(durs): + if len(durs) == 0: + logger.info(f"get durs no values") + return 0 + median_durs = durs[len(durs)//4 : len(durs) - len(durs)//4] + return round(statistics.mean(median_durs)/1000, 3) + +class ProfileParser: + + ''' + The recommend config for different pp is: + pp2: [6,5] recomp [6,0], reduce laysers when OOM, layers at least [5,4] + if still OOM, do another profile wrt [3,1] no recomp w/ dense_replace = 1 + pp4: [3,3,3,2] recomp [3,3,0,0] + pp8: [3,1,1,1,1,2,2,2] recomp [3,1,1,1,1,2,0,2] (or try pp4 with layers [3,2,2,2] w/ recomp [3,2,0,2]) + a typical input is a string and a tuple of numbers. 2-tuple for pp2 and 4-tuple for higher pp, + which consists the rank for each stages. + ''' + + profile_data = None + config = '' + # recomp_bws, moe_fws1, moe_fws2, moe_bws, dense_fws, dense_bws, totals = [], [], [], [], [], [], [] + # moe_fw1, moe_fw2, moe_bw, recomp_bw, dense_fw, dense_bw, head_ratio, num_last_stage_layers, mtp = 0 + dense_fws, dense_bws, recomp_dense_bws, moe_fws, moe_bws, recomp_moe_bws, totals = [], [], [], [], [], [], [] + (dense_fw, dense_bw, recomp_dense_bw, moe_fw, moe_bw, + recomp_moe_bw, total, head_ratio, num_last_stage_layers, mtp) = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + att_tid, grad_att_tid = 2, 2 + + def __init__(self, input_args, para): + self.para = para + self.mbn = input_args.mbn + self.layer_num = input_args.num_layers + self.num_layer_list = input_args.num_layer_list + self.recompute_num_layers = input_args.recompute_num_layers + self.dense_num = input_args.first_k_dense_replace + self.step = input_args.profile_steps + self.model_name = input_args.model_name + + def parse_batch_profile_result(self, profile_configs_results): + # step1. 解析profile文件 + # ep填1, dmratrion:0.33, bfratio=dense_bw/dense_fw, re_grow_ration = 1,hratio=0.68 + # [dp, tp, pp, dmratio, bfratio, re_grow_ration, hratio, moe_fw] + profile_result = [] + dense_flag = False + #Todo 这里para.YAML_PATH的实现待修改 + if self.para.YAML_PATH: + for result in profile_configs_results: + profile_dir = result[-1] + pp = result[2] + self.layer_num = result[-2] + if len(result) > 6: + config = result[:4] + else: + config = result[:3] + config += [0.33, 0, 1, 0.68, 0] + self.config_anal_func(profile_dir, pp, config) + self.refresh() + profile_result.append(config) + elif self.para.SHELL_PATH: + for result in profile_configs_results: + profile_dir = result[-1] + pp = result[2] + update_input_args, _ = parse_shell(result[-2]) + self.mbn = update_input_args.get('GBS') // update_input_args.get('MBS') // update_input_args.get('DP') + self.layer_num = update_input_args.get('NUM_LAYERS') + if 'FIRST_K_DENSE_RAPLACE' in update_input_args: + self.dense_num = update_input_args.get('FIRST_K_DENSE_RAPLACE') + else: + self.dense_num = 3 + logger.info(f"mbn: {self.mbn}, layer_num: {self.layer_num}.") + if len(result) > 8: + config = result[:4] + else: + config = result[:3] + config += [0.33, 0, 1, 0.68, 0] + self.config_anal_func(profile_dir, pp, config) + self.refresh() + profile_result.append(config) + elif self.para.TOML_PATH: + for result in profile_configs_results: + profile_dir = result[-1] + pp = result[2] + config = result[:-2] + config += [0.33, 1, 1, 0.68, 16] + self.config_anal_func(profile_dir, pp, config) + self.refresh() + profile_result.append(config) + + # step2. 生成csv文件 + # dense模型的result长度是8,moe模型的result长度是9 + if len(profile_result[0]) == 8: + dense_flag = True + self.write_result_to_csv(profile_result, dense_flag) + + # 把profile结果写入csv + def write_result_to_csv(self, profile_result, is_llama=False): + if self.para.YAML_PATH or self.para.SHELL_PATH: + if is_llama: + headers = ['dp', 'tp', 'pp', 'dmratio', 'bfratio', 're_grow_ration', 'hratio', 'moe_fw'] + else: + headers = ['dp', 'tp', 'pp', 'ep', 'dmratio', 'bfratio', 're_grow_ration', 'hratio', 'moe_fw'] + else: + headers = ['dp', 'tp', 'pp', 'ep', 'cp', 'op', 'dmratio', 'bfratio', 're_grow_ration', 'hratio', 'moe_fw'] + # 写入 CSV 文件 + try: + csv_path = os.path.join(os.path.abspath(self.para.OUTPUT_PATH), 'profile_parser_result.csv') + with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + # 写入表头 + writer.writerow(headers) + # 写入数据 + writer.writerows(profile_result) + logger.info(f"CSV file {csv_path} generate succ.") + self.para.args.parser_result = csv_path + except Exception as e: + print(f"write CSV file fail: {e}") + + def load_profile(self, file, rk): + global encoding + path = os.path.abspath('.\\profile_info') + '\\' + str(file) + path = find_file_by_name(path + '\\rank_' + str(rk), 'trace_view.json') + with open(path, 'r', encoding=encoding) as f: + self.profile_data = json.load(f) + + def load_profile_by_dir(self, rank_dir): + global encoding + path = find_file_by_name(rank_dir, 'trace_view.json') + with open(path, 'r', encoding=encoding) as f: + self.profile_data = json.load(f) + + def load_profile_by_kernel(self, file): + global encoding + path = find_file_by_name(file, 'kernel_details.csv') + if path is None: + logger.error(f"can not find kernel details file {file}") + return False + logger.info(f'path of kernel_details is {path}') + with open(path, 'r', encoding=encoding) as f: + reader = csv.DictReader(f) + self.profile_data = list(reader) + return True + + def load_profile_no_recompute(self, file, rk): + global encoding + path = os.path.abspath('.\\profile_info') + '\\' + str(file) + path = find_file_by_name(path + '\\rank_' + str(rk) + '_norecomp' , 'trace_view.json') + with open(path, 'r', encoding=encoding) as f: + self.profile_data = json.load(f) + + def refresh(self): + self.profile_data = [] + (self.dense_fws, self.dense_bws, self.recomp_dense_bws, self.moe_fws, self.moe_bws, + self.recomp_moe_bws, self.totals) = [ + [] for _ in range(7)] + (self.dense_fw, self.dense_bw, self.recomp_dense_bw, self.moe_fw, self.moe_bw, self.recomp_moe_bw, + self.total, self.head_ratio, self.num_last_stage_layers, self.mtp) = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + + def release(self): + self.profile_data = [] + + def set_tid(self, a, g): + self.att_tid, self.grad_att_tid = a, g + + def extract_atts(self): + atts, grad_atts = [], [] + for line in self.profile_data: + name = line['name'] + if line['tid'] == self.att_tid and name.endswith('FlashAttentionScore_FlashAttentionScore'): + atts.append(float(line['ts'])) + if line['tid'] == self.grad_att_tid and name.endswith('FlashAttentionScoreGrad_FlashAttentionScoreGrad'): + grad_atts.append(float(line['ts'])) + logger.info(f'num of atts is {len(atts)}, num of grad_atts is {len(grad_atts)}') + return atts, grad_atts + + def extract_atts_by_kernel(self): + atts, grad_atts = [], [] + try: + with open(self.profile_data, 'r', encoding='utf-8') as file: + reader = csv.DictReader(file) + for row in reader: + name_value = row.get('Name', '') + if isinstance(name_value, str): + if name_value.endswith('FlashAttentionScore_FlashAttentionScore'): + ts_value = float(row.get('Start Time(us)')) + atts.append(ts_value) + elif name_value.endswith('FlashAttentionScoreGrad_FlashAttentionScoreGrad'): + ts_value = float(row.get('Start Time(us)')) + grad_atts.append(ts_value) + logger.info(f'extract by kernel_details.csv num of atts is {len(atts)}, num of grad_atts') + return atts, grad_atts + except Exception as e: + logger.error(f'extract by kernel_details.csv fail: {e}') + return [], [] + + def extract_atts_by_loaded_data(self): + atts, grad_atts = [], [] + try: + for row in self.profile_data: + if row['Type'] == 'FlashAttentionScore': + atts.append(float(row['Start Time(us)'])) + if row['Type'] == 'FlashAttentionScoreGrad': + grad_atts.append(float(row['Start Time(us)'])) + logger.info(f'num of atts is {len(atts)}, num of grad_atts is {len(grad_atts)}') + return atts, grad_atts + except Exception as e: + logger.error(f'extract by loaded_data fail: {e}') + return [], [] + + def extract_atts_for_titan(self): + atts, grad_atts = [], [] + try: + for row in self.profile_data: + if row['Name'] == 'aclnnFlashAttentionScore_FlashAttentionScore_FlashAttentionScore': + atts.append(float(row['Start Time(us)'])) + if row['Name'] == 'aclnnFlashAttentionScoreGrad_FlashAttentionScoreGrad_FlashAttentionScoreGrad': + grad_atts.append(float(row['Start Time(us)'])) + logger.info(f'num of atts is {len(atts)}, num of grad_atts is {len(grad_atts)}') + return atts, grad_atts + except Exception as e: + logger.error(f'extract by loaded_data fail: {e}') + return [], [] + + #todo: 待修改 + def stage_anal(self, pp, stage, isRecompute): #assuming we profiled 2 steps w/ micro = 32 + atts, grad_atts = self.extract_atts_by_kernel() + layers, xlayers = len(grad_atts) // 64, 2 * len(grad_atts) // 64 + warm_up = (pp-1-stage) * layers + atts = atts[warm_up : 32*xlayers - warm_up] + atts[32*xlayers + warm_up : 32*xlayers + 32*xlayers-warm_up] + grad_atts = (grad_atts[warm_up : 32*layers - warm_up] + + grad_atts[32*layers + warm_up : 32*layers + 32*layers-warm_up]) + att_chunks = [atts[xlayers*i:xlayers*(i+1)] for i in range(len(atts)//xlayers-1)] + grad_chunks = [grad_atts[layers*i:layers*(i+1)] for i in range(len(grad_atts)//layers-1)] + if pp == 2: + self.pp2_analysis(att_chunks, grad_chunks, layers, pp, stage) + else: + self.pp_not2_analysis(att_chunks, grad_chunks, layers, pp, stage) + + def pp_not2_analysis(self, att_chunks, grad_chunks, layers, pp, stage): + if stage == 0: + for chunk in att_chunks: + for i in range(layers - 1): + if i <= 2: self.dense_fws.append(chunk[i + 1] - chunk[i]) + for chunk in grad_chunks: + for i in range(layers - 3, layers - 1): self.dense_bws.append(chunk[i + 1] - chunk[i]) + elif stage == pp - 3: + for chunk in att_chunks: + for i in range(layers - 1): + self.moe_fws1.append(chunk[i + 1] - chunk[i]) + for chunk in grad_chunks: + for i in range(layers - 3, layers - 1): self.moe_bws.append(chunk[i + 1] - chunk[i]) + elif stage == pp - 2: + for chunk in att_chunks: + for i in range(layers - 1): self.moe_fws2.append(chunk[i + 1] - chunk[i]) + for chunk in grad_chunks: + for i in range(layers - 3, layers - 1): self.moe_bws.append(chunk[i + 1] - chunk[i]) + else: + for i in range(len(att_chunks) - 1): self.totals.append(att_chunks[i + 1][0] - att_chunks[i][0]) + self.num_last_stage_layers = layers + + def pp2_analysis(self, att_chunks, grad_chunks, layers, pp, stage): + if stage == 0: + for chunk in att_chunks: + for i in range(layers - 1): + if i <= 2: + self.dense_fws.append(chunk[i + 1] - chunk[i]) + else: + self.moe_fws1.append(chunk[i + 1] - chunk[i]) + for chunk in grad_chunks: + for i in range(layers - 3, layers - 1): self.dense_bws.append(chunk[i + 1] - chunk[i]) + for i in range(layers - 4): self.recomp_bws.append(chunk[i + 1] - chunk[i]) + for i in range(len(att_chunks) - 1): self.totals.append(att_chunks[i + 1][0] - att_chunks[i][0]) + if stage == pp - 1: + for chunk in att_chunks: + for i in range(layers - 2): self.moe_fws2.append(chunk[i + 1] - chunk[i]) + for chunk in grad_chunks: + for i in range(2, layers - 1): self.moe_bws.append(chunk[i + 1] - chunk[i]) + for i in range(len(att_chunks) - 1): self.totals.append(att_chunks[i + 1][0] - att_chunks[i][0]) + self.num_last_stage_layers = layers + + def stage_anal_for_pp1(self): + atts, grad_atts = self.extract_atts_for_titan() + layer_num = 10 # 与titan的__init__.py中的 + try: + import torchtitan.protocols.train_spec as train_spec_module + model_flavor = 'tune' + train_spec = train_spec_module.get_train_spec(self.model_name) + model_args = train_spec.model_args[model_flavor] + layer_num = model_args.n_layers + except Exception as e: + print(f'Error is: {e}, and do not get layer_num for profile parser.') + + atts_per_step = len(atts) // self.step + grad_atts_per_step = len(grad_atts) // self.step + att_chunks = [atts[atts_per_step * i:atts_per_step * (i + 1)] for i in range(self.step)] + grad_chunks = [grad_atts[grad_atts_per_step * i:grad_atts_per_step * (i + 1)] for i in range(self.step)] + + for chunk in att_chunks: + for i in range(layer_num - 1): + self.moe_fws.append(chunk[i + 1] - chunk[i]) + + for chunk in grad_chunks: + for i in range(layer_num - 1): + self.moe_bws.append(chunk[i + 1] - chunk[i]) + + def stage_anal_for_pp2(self, pp, stage): + atts, grad_atts = self.extract_atts_by_loaded_data() + layer_num = self.num_layer_list[stage] + warm_up = (pp - 1 - stage) * layer_num + atts_per_step = len(atts) // self.step + grad_atts_per_step = len(grad_atts) // self.step + + att_chunks = [atts[atts_per_step * i:atts_per_step * (i + 1)] for i in range(self.step)] + grad_chunks = [grad_atts[grad_atts_per_step * i:grad_atts_per_step * (i + 1)] for i in range(self.step)] + + if pp == 2: + if stage == 0: + for chunk in att_chunks: + for i in range(self.mbn - 1): + idx = warm_up + i * (layer_num + self.recompute_num_layers) + for j in range(self.dense_num): + self.dense_fws.append(chunk[idx + j + 1] - chunk[idx + j]) + + for chunk in grad_chunks: + for i in range(self.mbn - 2): + idx = warm_up + i * layer_num + self.dense_bws.append(chunk[idx + 2] - chunk[idx + 1]) + self.recomp_dense_bws.append(chunk[idx + 3] - chunk[idx + 2]) + + if stage == 1: + for chunk in att_chunks: + for i in range(self.mbn): + idx = warm_up + i * (layer_num + self.recompute_num_layers) + for j in range(layer_num - 1): + self.moe_fws.append(chunk[idx + j + 1] - chunk[idx + j]) + if i > 1: + idx_last = warm_up + (i - 1) * (layer_num + self.recompute_num_layers) + self.totals.append(chunk[idx] - chunk[idx_last]) + + for chunk in grad_chunks: + for i in range(self.mbn): + idx = warm_up + i * layer_num + self.moe_bws.append(chunk[idx + 1] - chunk[idx]) + self.recomp_moe_bws.append(chunk[idx + 2] - chunk[idx + 1]) + self.num_last_stage_layers = layer_num + + def config_anal(self, config, ranks): + self.config = config + pp = len(ranks) + for i in range(pp): + self.load_profile(config, ranks[i]) + self.stage_anal(pp, i) + self.release() + + def config_anal_func(self, profile_dir, pp, profile_result): + self.config = profile_dir + self.process_folders(profile_dir, pp, profile_result) + + def process_folders(self, profile_result_dir, pp, profile_result): + """ + 遍历根目录下的所有文件夹,并使用 parser 函数解析每个文件夹中的内容 + + 参数: + profile_result_dir (str): 根目录路径 + """ + root_path = Path(profile_result_dir) + if not root_path.is_dir(): + logger.info(f"profile_result_dir:{profile_result_dir} not exist") + return + + # 获取所有子文件夹--每个rank的文件夹 + folders = [f for f in root_path.iterdir() if f.is_dir()] + folders.sort() + i = 0 + for rank_folder in folders: + logger.info(f"parsing: {rank_folder}") + # Todo para.YAML_PATH的分支待修改 + if self.para.YAML_PATH: + self.load_profile_by_dir(rank_folder) + self.stage_anal(pp, i, True) + else: + find_file = self.load_profile_by_kernel(rank_folder) + if find_file == False: + continue + if self.para.SHELL_PATH: + self.stage_anal_for_pp2(pp, i) + elif self.para.TOML_PATH: + self.stage_anal_for_pp1() + self.release() + i += 1 + self.refined_data() + self.fill_result(profile_result) + + def data_display(self): + logger.info(f'moe_fws1 are {self.moe_fws1}') + logger.info(f'moe_fws2 are {self.moe_fws2}') + logger.info(f'moe_bws is {self.moe_bws}') + logger.info(f'recomp_bws is {self.recomp_bws}') + logger.info(f'dense_fw is {self.dense_fw}') + logger.info(f'dense_bw is {self.dense_bws}') + logger.info(f'totals is {self.totals}') + + def refined_data(self): + if self.para.YAML_PATH or self.para.SHELL_PATH: + if len(self.dense_fws) == 0 or len(self.dense_bws) == 0: + logger.info(f"dense_fws len {len(self.dense_fws)} " + f"or dense_bws {len(self.dense_fws)} is empty") + return + self.dense_fw = median_mean(self.dense_fws) + self.dense_bw = median_mean(self.dense_bws) + self.recomp_dense_bw = median_mean(self.recomp_dense_bws) + self.moe_fw = median_mean(self.moe_fws) + self.moe_bw = median_mean(self.moe_bws) + self.recomp_moe_bw = median_mean(self.recomp_moe_bws) + self.total = median_mean(self.totals) + self.mtp = (self.total - self.num_last_stage_layers * self.moe_fw - + (self.num_last_stage_layers - self.recompute_num_layers) * self.moe_bw - self.recomp_moe_bw) + logger.info( + f'config is {self.config}, dense forward is {self.dense_fw}, ' + f'dense backward is {self.dense_bw}, recompute dense backward is {self.recomp_dense_bw}, ' + f'moe forward is {self.moe_fw}, moe backward is {self.moe_bw}, ' + f'recompute moe backward is {self.recomp_moe_bw}, lmhead and MTP is {self.mtp}') + elif self.para.TOML_PATH: + if len(self.moe_fws) == 0 or len(self.moe_bws) == 0: + logger.info(f"fws len {len(self.moe_fws)} " + f"or bws {len(self.moe_bws)} is empty") + return + self.moe_fw = median_mean(self.moe_fws) + self.moe_bw = median_mean(self.moe_bws) + logger.info( + f'config is {self.config}, forward is {self.moe_fw}, backward is {self.moe_bw}') + + def fill_result(self, profile_result): + if self.para.YAML_PATH or self.para.SHELL_PATH: + if len(self.dense_fws) == 0 or len(self.dense_bws) == 0: + logger.info(f"config {self.config} parse profile result is empty ") + return + profile_result[-1] = self.moe_fw + profile_result[-2] = self.mtp / self.moe_fw + profile_result[-3] = (self.recomp_moe_bw - self.moe_bw) / self.moe_fw + profile_result[-4] = self.moe_bw / self.moe_fw + profile_result[-5] = self.dense_fw / self.moe_fw + elif self.para.TOML_PATH: + if len(self.moe_fws) == 0 or len(self.moe_bws) == 0: + logger.info(f"config {self.config} parse profile result is empty ") + return + profile_result[-1] = self.moe_fw + profile_result[-2] = 1.5 + profile_result[-3] = 1.0 + profile_result[-4] = self.moe_bw / self.moe_fw + profile_result[-5] = 0.3 + +class MemoryInfo: + select_mem0 = 76 + select_mem12 = 76 + select_mem = 77 + re_comp_mem0 = 3 + re_comp_mem12 = 3 + re_comp_mem = 5 + act_mem0 = 79 + act_mem12 = 79 + act_mem = 79 + layer_mem012 = 349 + layer_mem = 340 + static_mem0 = 734 + static_mem = 116 + lm_head_mem = 690 + + def write_to_txt(self, memory_file_name): + with open(memory_file_name, 'w', encoding='utf-8') as file: + file.write(f'select_mem0={self.select_mem0}\n') + file.write(f'select_mem12={self.select_mem12}\n') + file.write(f'select_mem={self.select_mem}\n') + file.write(f're_comp_mem0={self.re_comp_mem0}\n') + file.write(f're_comp_mem12={self.re_comp_mem12}\n') + file.write(f're_comp_mem={self.re_comp_mem}\n') + file.write(f'act_mem0={self.act_mem0}\n') + file.write(f'act_mem12={self.act_mem12}\n') + file.write(f'act_mem={self.act_mem}\n') + file.write(f'layer_mem012={self.layer_mem012}\n') + file.write(f'layer_mem={self.layer_mem}\n') + file.write(f'static_mem0={self.static_mem0}\n') + file.write(f'static_mem={self.static_mem}\n') + file.write(f'lm_head_mem={self.lm_head_mem}\n') + logger.info(f'Write memory info to {memory_file_name}') + +class ProfileMemParser: + pipeline_output_file = 'pipeline_output' + memory_info_dir = 'memory_info' + def __init__(self, input_args, para): + self.profile_mem_data = None + self.input_args = input_args + self.para = para + self.peak_memory_usage = 0.0 # MB + self.static_memory_usage = 0.0 # MB + + def write_mem_info_to_txt(self, config, mem_info): + ''' + + Args: + config: [dp, tp, pp, ep, offset, recompute, num_layers] + + Returns: + + ''' + memory_dir = os.path.join(os.getcwd(), self.pipeline_output_file, self.memory_info_dir) + os.makedirs(memory_dir, exist_ok=True) + dense_layer_num = self.input_args.first_k_dense_replace + moe_layer_num = self.input_args.num_layers - dense_layer_num + + dp, tp, pp, ep = config[:4] + mbn = self.input_args.gbs // self.input_args.mbs // dp + filename = ( + f'layers{dense_layer_num}_{moe_layer_num}_micro{mbn}_dp{dp}' + f'_tp{tp}_pp{pp}_vp1' + f'_ep{ep}.txt') + memory_file_name = os.path.join(memory_dir, filename) + mem_info.write_to_txt(memory_file_name) + + def set_static_mem(self, mem_info, profile_file): + self.para.TOML_PATH = profile_file + input_args = ParaForNd(self.para) + dense_size, moe_size, vocab_size = compute_weight_and_optimizer_memory(input_args) + mem_info.layer_mem012 = dense_size + mem_info.layer_mem = moe_size + mem_info.static_mem0 = vocab_size + mem_info.static_mem = 0 + mem_info.lm_head_mem = 1.5 * moe_size + return input_args + + def cal_dynamic_mem(self, mem_info, cur_input_args): + dynamic_mem_total = self.peak_memory_usage - self.static_memory_usage + if cur_input_args.pp == 1: + dynamic_mem_layer = dynamic_mem_total // cur_input_args.num_layers + mem_info.act_mem0 = dynamic_mem_layer + mem_info.act_mem12 = dynamic_mem_layer + mem_info.act_mem = dynamic_mem_layer + + + + def manage_mem_infos(self, need_cumlate=True): + sorted_times = sorted(self.real_memory_changes.keys()) + cumulative_mem = 0 + times = [] + memory_usage = [] + for time in sorted_times: + if need_cumlate: + cumulative_mem += self.real_memory_changes[time] + times.append(time) + memory_usage.append(cumulative_mem) + else: + times.append(time) + memory_usage.append(self.real_memory_changes[time]) + max_index = memory_usage.index(max(memory_usage)) + peak_time = times[max_index] + peak_memory_usage = memory_usage[max_index] + logger.info(f'Peak memory usage: {peak_memory_usage}, peak time: {peak_time}') + self.peak_memory_usage = peak_memory_usage / 1024 + + def parse_memory_block(self, filename): + real_memory_changes = defaultdict(int) + static_memory = 0 + flag = False + with open(filename, mode='r', encoding='utf-8') as file: + reader = csv.DictReader(file) + # real memory + for row in reader: + alloc_size = row['Size(KB)'] + if not alloc_size: + logger.debug(f'Size(KB) is none') + continue + size = float(alloc_size) + + alloc_time = row['Allocation Time(us)'] + if not alloc_time: + logger.debug(f'Allocation Time is none size is {size}') + continue + start = int(alloc_time.strip().split('.')[0]) + + release_time = row['Active Release Time(us)'] + if not release_time: + logger.debug(f'Active Release Time is none, size is {size}') + continue + end = int(release_time.strip().split('.')[0]) + + real_memory_changes[start] += size # 在start时增加内存 + real_memory_changes[end - 1] += 0 # for plt + real_memory_changes[end] -= size # 在end时减少内存 + + if flag == False: + static_memory = int(row['Allocation Total Allocated(MB)'].strip().split('.')[0]) + + flag = True + real_memory_changes[start] += (static_memory * 1024) + + print(f'static mem is {static_memory} M') + self.static_memory_usage = static_memory + self.real_memory_changes = real_memory_changes + self.manage_mem_infos() + + def parse_memory_file(self, profile_dir): + # step1. parse operator_memory.csv + root_path = Path(profile_dir) + if not root_path.is_dir(): + logger.info(f"profile_result_dir:{profile_dir} not exist") + return + folders = [f for f in root_path.iterdir() if f.is_dir()] + folders.sort() + for rank_folder in folders: + logger.info(f"parsing: {rank_folder}") + path = find_file_by_name(rank_folder, 'operator_memory.csv') + if not path: + logger.error(f"No operator memory csv file in profile_dir {profile_dir} rank_dir {rank_folder}") + continue + self.parse_memory_block(path) + + def parse_mem(self, profile_dir, profile_file): + ''' + + Args: + profile_dir: profile结果的路径 + + Returns: + + ''' + mem_info = MemoryInfo() + self.parse_memory_file(profile_dir) + cur_input_args = self.set_static_mem(mem_info, profile_file) + self.cal_dynamic_mem(mem_info, cur_input_args) + # file mem_info + return mem_info + + def mem_parser(self, profile_configs): + ''' + + Args: + profile_configs: [dp, tp, pp, ep, offset, recompute, num_layers, profile_result_dir] + + Returns: + + ''' + for config in profile_configs: + profile_dir = config[-1] + profile_file = config[-2] + mem_info = self.parse_mem(profile_dir, profile_file) + self.write_mem_info_to_txt(config, mem_info) + return diff --git a/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_prepare.py b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..18caaeadb1f80cb1708f6b9780b7472822a25b41 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/fast_tuner/utils/profiling/profile_prepare.py @@ -0,0 +1,262 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""profile prepare work""" +import os +from fast_tuner.utils.logger import logger +from fast_tuner.utils.common import generate_files, cal_world_size + + +def find_file_by_name(directory, filename): + for root, dirs, files in os.walk(directory): + if filename in files: + return os.path.join(root, filename) + +def decide_pp_dp(config, available_devices): + """ + + :param config: [config, available_devices] + :param available_devices: 可用来做profiling的卡数 + :return: [dp, tp, pp, ep] + """ + dp, tp, origin_pp, ep, cp, op = config[:6] + world_size = dp * tp * origin_pp * cp + device_multiple = world_size //available_devices + min_pp = 2 # pp最小取2 + + # 检查设备数量是足够的, dp * tp >= ep + if max(min_pp * tp * cp, min_pp * ep) > available_devices: + print(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: 2, ep: {ep}') + + # 裁剪pp + if device_multiple >= origin_pp // min_pp: + pp = 2 + else: + if device_multiple % 3 == 0: + if origin_pp % 3 == 0: + pp = origin_pp // 3 + else: + pp = origin_pp // (device_multiple // 3) + else: + pp = origin_pp // device_multiple + + # 调整dp以满足设备数量限制 + cur_multiple = pp * dp * tp * cp // available_devices + if cur_multiple > 1: + dp //= cur_multiple + return [dp, tp, pp, ep, cp, op] + +def decide_dp(config, available_devices): + """ + :param config: [config, available_devices] + :param available_devices: 可用来做profiling的卡数 + :return: [dp, tp, pp, ep] + """ + dp, tp, origin_pp, ep = config[:4] + pp = 2 # pp取2 + + # 检查设备数量是足够的, dp * tp >= ep + if max(pp * tp, pp * ep) > available_devices: + print(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: {pp}, ep: {ep}') + return None + + # 调整dp以满足设备数量限制 + cur_multiple = pp * dp * tp // available_devices + if cur_multiple > 1: + dp //= cur_multiple + if dp % ep != 0: + print( + f'dp_size ({dp}) is not divisible by ep_size({ep}), for config: dp: {dp}, tp: {tp}, pp: {pp}, ep: {ep}') + return None + return [dp, tp, pp, ep] + +def decide_dp_for_titan(config, available_devices): + dp, tp, pp, ep, cp, op = config[:6] + world_size = dp * tp * pp * cp + pp = 1 + if max(pp * tp, pp * ep) > available_devices: + logger.info(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: {pp}, ep: {ep}, cp: {cp}, fsdp: {op}') + return None + if world_size <= 8: + if world_size != available_devices: + logger.info(f'Currently, we assume world size {world_size}' + f'equals the number of available devices {available_devices} within a single node') + return None + else: + return config[:6] + # 调整dp以满足设备数量限制 + cur_multiple = pp * dp * cp * tp // available_devices + if cur_multiple > 1: + if op // cur_multiple == 0: + logger.info(f'prune cause of fsdp: ddp: {dp//op}, tp: {tp}, pp: {pp}, ep: {ep}, cp: {cp}, fsdp: {op}') + return None + op //= cur_multiple + dp //= cur_multiple + return [dp, tp, pp, ep, cp, op] + +def decide_pp_dp_llama(config, available_devices): + """ + + :param config: [config, available_devices] + :param available_devices: 可用来做profiling的卡数 + :return: [dp, tp, pp] + """ + dp, tp, origin_pp = config[:3] + world_size = dp * tp * origin_pp + device_multiple = world_size // available_devices + min_pp = 1 # pp最小取2 + + # 检查设备数量是足够的, dp * tp >= ep + if min_pp * tp > available_devices: + print(f'not enough devices for config: dp: {dp}, tp: {tp}, pp: 2') + return None + + # 裁剪pp + if device_multiple >= origin_pp // min_pp: + pp = min_pp + else: + if device_multiple % 3 == 0: + if origin_pp % 3 == 0: + pp = origin_pp // 3 + else: + pp = origin_pp // (device_multiple // 3) + else: + pp = origin_pp // device_multiple + + # 调整dp以满足设备数量限制 + cur_multiple = pp * dp * tp // available_devices + if cur_multiple > 1: + dp //= cur_multiple + return [dp, tp, pp] + +def trans_config_satisfy_rank_num(mem_prune_space, para, input_args): + """ + + :param mem_prune_space: [dp,tp,pp,ep,cp,op] + :param para: user config para for tool + :param input_args: model args and train args from config file + :return: list[[dp, tp, pp, ep, cp, op, offset, num_layers]] + """ + profile_configs = [] + rank_num = min(para.RANK_NUM, input_args.world_size) + for config in mem_prune_space: + trans_config = budget_profile_config_generator(config, para, rank_num) + if trans_config is not None and trans_config not in profile_configs: + profile_configs.append(trans_config) + logger.info(f"profile configs len: {len(profile_configs)} by {rank_num} devices") + # todo: 待添加注释,说明筛选逻辑 + if len(profile_configs) > 10: + # pp <= 16 + profile_configs = [cand for cand in profile_configs if cand[2] <= 16] + # 按tp , cp, 越小越好, ep从大到小排 + profile_configs = sorted(profile_configs, key=lambda x: (x[1], x[4], -x[3])) + profile_configs = profile_configs[:10] + print(*(config for config in profile_configs), sep='\n') + return profile_configs + +def budget_profile_config_generator(config, para, available_devices): + """ + 根据可用卡数,将config转换成当前卡资源可满足的并行配置做profile,先裁剪pp, pp最小为2,然后再缩减dp + 一次profiling同时获取不开重计算和完全重计算的信息 + + :param config: [dp, tp, pp, ep, cp, op, evaluate_peak_mem] + :param para: user config para for tool + :param available_devices: + :return: [dp, tp, pp, ep, offset, recompute, num_layers], min(pp)=2, 层数(num_layers+MTP) + """ + if para.YAML_PATH: + if len(config) < 4: + basic_config = decide_pp_dp_llama(config, available_devices) + else: + basic_config = decide_pp_dp(config, available_devices) + pp = basic_config[2] + if len(config) < 4: + offset = 0 + recompute = 0 + num_layers = 32 + else: + if pp == 2: + offset = [1, 0] + recompute = [6, 0] + num_layers = 11 + elif pp == 4: + offset = [1, 1, 1, 0] + recompute = [3, 3, 0, 0] + num_layers = 11 + elif pp == 8: + offset = [2, 0, 0, 0, 0, 1, 1, 1] + recompute = [3, 1, 1, 1, 1, 2, 0, 2] + num_layers = 13 + else: + print(f'pp {pp} not supported') + return None + + config_num_layers = num_layers if para.SHELL_PATH else (num_layers - 1) + basic_config.extend([offset, recompute, config_num_layers]) + return basic_config + elif para.SHELL_PATH: + if len(config) < 4: + # TODO llama系模型待适配 + basic_config = decide_pp_dp_llama(config, available_devices) + else: + basic_config = decide_dp(config, available_devices) + if basic_config is None: + return None + num_layer_list = [4, 3] + recompute_num_layers = 1 + num_layers = 7 + basic_config.extend([num_layer_list, recompute_num_layers, num_layers]) + return basic_config + elif para.TOML_PATH: + basic_config = decide_dp_for_titan(config, available_devices) + return basic_config + else: + print('Error: No config file path!') + return None + +def taylor_pp_adaptor(profile_info): + layer_ratio = (profile_info['dense_fw']+profile_info['dense_bw'])/(profile_info['moe_fw']+profile_info['moe_bw']) + backward_ratio = (profile_info['moe_bw']/profile_info['moe_fw']) + return layer_ratio, backward_ratio + +def sapp_adaptor(profile_info): + body_dense = (profile_info['dense_fw']+profile_info['dense_bw'])/3 + body_moe = (profile_info['moe_fw']+profile_info['moe_bw'])/3 + tail = profile_info['head']/3 + return profile_info['embed'], body_dense, body_moe, tail + +def profile_prepare(mem_prune_space, para, input_args): + """ + prepare for profiling + + :param para: user input args for tool + :param mem_prune_space: [dp, tp, pp, ep, cp, op, evaluate_peak_mem] or [dp, tp, pp] + :param input_args: model args and train args from config file + :return: ordered [dp, tp, pp, ep, cost] + """ + # 配置转换可在现有卡资源下profile的配置 + profile_configs = trans_config_satisfy_rank_num(mem_prune_space, para, input_args) + profile_dir = 'profile_yaml' if para.YAML_PATH else 'profile_shell' if para.SHELL_PATH else 'profile_toml' + profile_file_dir = os.path.abspath(para.OUTPUT_PATH) + os.sep + profile_dir + os.sep + file_task = "profile_yaml" if para.YAML_PATH else "profile_shell" if para.SHELL_PATH else 'profile_toml' + file_path_list, output_dir_list = generate_files(profile_configs, profile_file_dir, file_task, para, input_args) + + result = [] + for config, shell_file_path, profile_dir in zip(profile_configs, file_path_list, output_dir_list): + # 创建子列表的副本并添加元素(避免修改原列表) + new_config = config.copy() + new_config.append(shell_file_path) + new_config.append(profile_dir) + result.append(new_config) + return result, profile_file_dir diff --git a/hyper_parallel/auto_parallel/fast-tuner/pyproject.toml b/hyper_parallel/auto_parallel/fast-tuner/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..b8fff1c0f345f2f8bc7962e3c630cc59c8f033ab --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/pyproject.toml @@ -0,0 +1,71 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "fast-tuner" +version = "0.0.1dev0" +description = "Fast Algorithm Supported by Theories. AI Framework hyper-parameter tuner." +readme = "" +requires-python = ">=3.9" +authors = [ + {name = "Taylor Lab"} +] +maintainers = [ + {name = "Taylor Lab"} +] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: Other/Proprietary License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dependencies = [ + "pyyaml", + "numpy<2.0.0", + "ortools", + "matplotlib", + "chardet", + "highspy", + "toml", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "black>=22.0", + "flake8>=4.0", + "mypy>=0.900", +] + +[project.urls] +Homepage = "http://example.com" +Repository = "http://example.com" +Issues = "http://example.com" + +[project.scripts] +fast-tuner-parallel = "fast_tuner.parallel_tool:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["fast_tuner*"] + +[tool.black] +line-length = 120 +target-version = ['py39'] + +[tool.mypy] +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] diff --git a/hyper_parallel/auto_parallel/fast-tuner/test/units/nd_search_unit_test.py b/hyper_parallel/auto_parallel/fast-tuner/test/units/nd_search_unit_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dbee4a279c11aed266cb2deba75fc6ebd5fbaa8c --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/test/units/nd_search_unit_test.py @@ -0,0 +1,109 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import unittest + +from fast_tuner.ndsearch.build_initial_spaces import build_initial_spaces +from fast_tuner.ndsearch.expert_filter_configs import expert_filter_configs +from fast_tuner.ndsearch.memory_model import grey_box_memory_prune +from fast_tuner.utils.input_config import InputConfig +from fast_tuner.utils.profiling.profile_parser import ProfileParser + + +# ND搜索单元测试用例 +class NdTestCase(unittest.TestCase): + # 测试根据dryrun的内存信息对候选配置剪枝功能 + def test_memory_prune(self): + # input construct + input_args = InputConfig('../../config/pretrain_deepseek3_671b.yaml') + dryrun_info_init = [ + [32, 2, 16, 8, 142786], + [128, 2, 4, 8, 188944], + [16, 2, 32, 4, 153480], + [16, 8, 8, 8, 96950], + [64, 8, 2, 8, 825586], + [64, 2, 8, 4, 221281], + [8, 4, 32, 8, 81943], + [64, 4, 4, 4, 283005], + [64, 1, 16, 8, 241679], + [32, 8, 4, 4, 270048], + [8, 4, 32, 4, 98967], + [256, 2, 2, 4, 1656206], + [64, 4, 4, 8, 162045], + [8, 8, 16, 8, 68632], + [4, 8, 32, 8, 54686], + [128, 1, 8, 4, 299846], + [8, 8, 16, 4, 101784], + [32, 8, 4, 8, 148595], + [256, 1, 4, 4, 363705], + [32, 4, 8, 8, 116591], + [4, 8, 32, 4, 71710], + [32, 4, 8, 4, 181999], + [32, 2, 16, 4, 175938], + [16, 2, 32, 8, 136456], + [64, 8, 2, 4, 1635569], + [64, 2, 8, 8, 155874], + [128, 1, 8, 8, 234458], + [16, 4, 16, 8, 93350], + [256, 2, 2, 8, 847944], + [128, 4, 2, 8, 832465], + [16, 8, 8, 4, 162358], + [32, 1, 32, 8, 245502], + [512, 1, 2, 8, 927251], + [64, 1, 16, 4, 274811], + [16, 4, 16, 4, 126502], + [128, 2, 4, 4, 309904], + [32, 1, 32, 4, 262506], + [256, 1, 4, 8, 242746], + [128, 4, 2, 4, 1642448], + [512, 1, 2, 4, 1687161] + ] + test_ep = [4, 8] + max_expert_parallel = 64 + # execute prune + memory_aware_configs = grey_box_memory_prune(input_args, dryrun_info_init, test_ep, max_expert_parallel) + print(f"Dryrun Search space len: {len(memory_aware_configs)}, format: [dp, tp, pp, ep, evaluate_peak_mem]") + # output check + assert len(memory_aware_configs) == 10 + for sub_array in memory_aware_configs: + assert len(sub_array) == 5 + assert sub_array[3] <= 64 + + + def test_initial_expert_space(self): + input_args = InputConfig('../../config/pretrain_deepseek3_671b.yaml') + initial_configs = build_initial_spaces(input_args) + # 生成搜索空间test + assert len(initial_configs) == 576030 + + # 专家剪枝test + expert_prune_search_space = expert_filter_configs(initial_configs, input_args, 1024) + assert len(expert_prune_search_space) == 8753 + + # dryrun yaml生成test + + # parser test + def test_profile_parser(self): + test = ProfileParser() + test.config_anal('dp8tp4pp4ep32', [0, 32, 64, 127]) + test.refined_data() + test.refresh() + test.config_anal('dp8tp4pp2ep32', [0, 127]) + test.refined_data() + +if __name__ == "__main__": + suite = unittest.TestLoader().loadTestsFromTestCase(NdTestCase) + runner = unittest.TextTestRunner(verbosity=2) + runner.run(suite) diff --git a/hyper_parallel/auto_parallel/fast-tuner/test/units/pipeline_alg_unit_test.py b/hyper_parallel/auto_parallel/fast-tuner/test/units/pipeline_alg_unit_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0b9489c8a7bf987ecb563dbf0d1852aa5b4207 --- /dev/null +++ b/hyper_parallel/auto_parallel/fast-tuner/test/units/pipeline_alg_unit_test.py @@ -0,0 +1,248 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import sys +import os +from pathlib import Path +sys.path.append(os.getcwd()) +import unittest + +from fast_tuner.pipeline_conductor import pipeline_parallel +from fast_tuner.pipeline_conductor.start_service import InitConfig, ExpertInput + +#流水线负载均衡算法单元测试用例 +class PipeTestCase(unittest.TestCase): + + def get_yaml_path(self, filename): + yaml_file_path = Path(__file__).resolve().parents[2] / 'config' / filename + return yaml_file_path + + def test_micro64_dp16_tp4_pp16_vp1_ep8(self): + yaml_file = self.get_yaml_path('test_62_16_1_64_1_0.yaml') + mind_former_file = '' + expert_input = ExpertInput(yaml_file, mind_former_file) + expert_input.solver_name = 'HIGHS' + expert_input.is_dryrun = False + expert_input.backward_ratio = 1.839 + expert_input.layer_ratio = 0.397 + expert_input.head_loss = 1.5 + expert_input.recompute_ratio = 1 + model_input = InitConfig(expert_input) + model_input.memory.select_mem0 = 400.0 + model_input.memory.select_mem12 = 400.0 + model_input.memory.select_mem = 765.0 + model_input.memory.re_comp_mem0 = 30.0 + model_input.memory.re_comp_mem12 = 30.0 + model_input.memory.re_comp_mem = 30.0 + model_input.memory.act_mem0 = 479.0 + model_input.memory.act_mem12 = 479.0 + model_input.memory.act_mem = 765.0 + model_input.memory.layer_mem012 = 866.0 + model_input.memory.layer_mem = 10691.0 + model_input.memory.static_mem0 = 3383.0 + model_input.memory.static_mem = 2382.0 + model_input.memory.lm_head_mem = 8675.0 + model_input.memory.mem_lim_stage0 = 51913.0 + model_input.memory.mem_lim_others = 52914.0 + model_input.memory.mem_lim_last = 46621.0 + model_input.memory.print_mem() + cur_solution = pipeline_parallel.solve_problem(model_input) + object_value = 1132.66 + self.assertLessEqual(cur_solution.object_value, object_value + 0.01) + self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01) + + def test768_4k_gp(self): + yaml_file = self.get_yaml_path('768die4k_gp.yaml') + mind_former_file = '~/mindformers/run_mindformer.py' + expert_input = ExpertInput(yaml_file, mind_former_file) + expert_input.solver_name = 'HIGHS' + expert_input.is_dryrun = False + expert_input.is_double_check = False + expert_input.time_limit = 1 * 60 + expert_input.layer_ratio = 0.71 + expert_input.backward_ratio = 1.627 + expert_input.recompute_ratio = 0.246 + expert_input.head_loss = 1.493 + model_input = InitConfig(expert_input) + model_input.memory.select_mem0 = 416.0 + model_input.memory.select_mem12 = 416.0 + model_input.memory.select_mem = 674.0 + model_input.memory.re_comp_mem0 = 46.0 + model_input.memory.re_comp_mem12 = 46.0 + model_input.memory.re_comp_mem = 158.0 + model_input.memory.act_mem0 = 494.0 + model_input.memory.act_mem12 = 494.0 + model_input.memory.act_mem = 673.0 + model_input.memory.layer_mem012 = 1010.0 + model_input.memory.layer_mem = 2644.0 + model_input.memory.static_mem0 = 4243.0 + model_input.memory.static_mem = 2279.0 + model_input.memory.lm_head_mem = 8468.0 + model_input.memory.mem_lim_stage0 = 53101.0 + model_input.memory.mem_lim_others = 55065.0 + model_input.memory.mem_lim_last = 48876.0 + model_input.memory.print_mem() + cur_solution = pipeline_parallel.solve_problem(model_input) + object_value = 3603.39 + self.assertLessEqual(cur_solution.object_value, object_value + 0.01) + self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01) + + def test512_8k(self): + yaml_file = self.get_yaml_path('layers62_micro960_dp4_tp8_pp16_vp1_ep32.yaml') + mind_former_file = '' + expert_input = ExpertInput(yaml_file, mind_former_file) + expert_input.solver_name = 'HIGHS' + expert_input.is_dryrun = False + expert_input.is_double_check = False + expert_input.time_limit = 3 * 60 + expert_input.layer_ratio = 0.71 + expert_input.backward_ratio = 1.627 + expert_input.recompute_ratio = 0.246 + expert_input.head_loss = 1.493 + model_input = InitConfig(expert_input) + model_input.memory.select_mem0 = 480.0 + model_input.memory.select_mem12 = 480.0 + model_input.memory.select_mem = 737.0 + model_input.memory.re_comp_mem0 = 78.0 + model_input.memory.re_comp_mem12 = 78.0 + model_input.memory.re_comp_mem = 190.0 + model_input.memory.act_mem0 = 614.0 + model_input.memory.act_mem12 = 614.0 + model_input.memory.act_mem = 738.0 + model_input.memory.layer_mem012 = 130.0 + model_input.memory.layer_mem = 7230.0 + model_input.memory.static_mem0 = 5789.0 + model_input.memory.static_mem = 3210.0 + model_input.memory.lm_head_mem = 7939.0 + model_input.memory.mem_lim_stage0 = 51555.0 + model_input.memory.mem_lim_others = 54134.0 + model_input.memory.mem_lim_last = 49405.0 + model_input.memory.print_mem() + cur_solution = pipeline_parallel.solve_problem(model_input) + object_value = 10943.91 + self.assertLessEqual(cur_solution.object_value, object_value + 0.01) + self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01) + + def test512_8k_no_swap(self): + yaml_file = self.get_yaml_path('512_8k_no_swap.yaml') + mind_former_file = '~/mindformers/run_mindformer.py' + expert_input = ExpertInput(yaml_file, mind_former_file) + expert_input.solver_name = 'HIGHS' + expert_input.is_dryrun = False + expert_input.is_double_check = False + expert_input.time_limit = 5 * 60 + expert_input.layer_ratio = 0.71 + expert_input.backward_ratio = 1.627 + expert_input.recompute_ratio = 0.246 + expert_input.head_loss = 1.493 + model_input = InitConfig(expert_input) + model_input.memory.select_mem0 = 832.0 + model_input.memory.select_mem12 = 832.0 + model_input.memory.select_mem = 1347.0 + model_input.memory.re_comp_mem0 = 92.0 + model_input.memory.re_comp_mem12 = 92.0 + model_input.memory.re_comp_mem = 316.0 + model_input.memory.act_mem0 = 988.0 + model_input.memory.act_mem12 = 988.0 + model_input.memory.act_mem = 1346.0 + model_input.memory.layer_mem012 = 1109.0 + model_input.memory.layer_mem = 3670.0 + model_input.memory.static_mem0 = 5618.0 + model_input.memory.static_mem = 2667.0 + model_input.memory.lm_head_mem = 12405.0 + model_input.memory.mem_lim_stage0 = 49678.0 + model_input.memory.mem_lim_others = 52629.0 + model_input.memory.mem_lim_last = 42891.0 + model_input.memory.print_mem() + cur_solution = pipeline_parallel.solve_problem(model_input) + object_value = 5661.55 + self.assertLessEqual(cur_solution.object_value, object_value + 0.01) + self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01) + + def test512_4k(self): + yaml_file = self.get_yaml_path('layers62_micro60_dp32_tp4_pp4_vp1_ep128.yaml') + mind_former_file = '~/mindformers/run_mindformer.py' + expert_input = ExpertInput(yaml_file, mind_former_file) + expert_input.solver_name = 'HIGHS' + expert_input.is_dryrun = False + expert_input.is_double_check = False + expert_input.time_limit = 3 * 60 + expert_input.backward_ratio = 1.627 + expert_input.layer_ratio = 0.71 + expert_input.head_loss = 1.5 + expert_input.recompute_ratio = 0.71 + model_input = InitConfig(expert_input) + model_input.memory.select_mem0 = 415.0 + model_input.memory.select_mem12 = 415.0 + model_input.memory.select_mem = 672.0 + model_input.memory.re_comp_mem0 = 30.0 + model_input.memory.re_comp_mem12 = 30.0 + model_input.memory.re_comp_mem = 142.0 + model_input.memory.act_mem0 = 493.0 + model_input.memory.act_mem12 = 493.0 + model_input.memory.act_mem = 673.0 + model_input.memory.layer_mem012 = 869.0 + model_input.memory.layer_mem = 2111.0 + model_input.memory.static_mem0 = 6704.0 + model_input.memory.static_mem = 2377.0 + model_input.memory.lm_head_mem = 7986.0 + model_input.memory.mem_lim_stage0 = 48592.0 + model_input.memory.mem_lim_others = 52919.0 + model_input.memory.mem_lim_last = 47310.0 + model_input.memory.print_mem() + cur_solution = pipeline_parallel.solve_problem(model_input) + object_value = 5868.49 + self.assertLessEqual(cur_solution.object_value, object_value + 0.01) + self.assertGreaterEqual(cur_solution.object_value, object_value - 0.01) + + # 此用例运行时间较长,大概40分钟 + def test512_8k_swap(self): + yaml_file = self.get_yaml_path('512_8k_swap.yaml') + mind_former_file = '~/mindformers/run_mindformer.py' + expert_input = ExpertInput(yaml_file, mind_former_file) + expert_input.solver_name = 'HIGHS' + expert_input.is_dryrun = False + expert_input.is_double_check = False + expert_input.time_limit = 1 * 60 + expert_input.layer_ratio = 0.71 + expert_input.backward_ratio = 1.627 + expert_input.recompute_ratio = 0.246 + expert_input.head_loss = 1.493 + model_input = InitConfig(expert_input) + model_input.memory.select_mem0 = 832.0 + model_input.memory.select_mem12 = 832.0 + model_input.memory.select_mem = 1347.0 + model_input.memory.re_comp_mem0 = 92.0 + model_input.memory.re_comp_mem12 = 92.0 + model_input.memory.re_comp_mem = 316.0 + model_input.memory.act_mem0 = 988.0 + model_input.memory.act_mem12 = 988.0 + model_input.memory.act_mem = 1346.0 + model_input.memory.layer_mem012 = 1109.0 + model_input.memory.layer_mem = 3670.0 + model_input.memory.static_mem0 = 5438.0 + model_input.memory.static_mem = 2664.0 + model_input.memory.lm_head_mem = 12402.0 + model_input.memory.mem_lim_stage0 = 49858.0 + model_input.memory.mem_lim_others = 52632.0 + model_input.memory.mem_lim_last = 42894.0 + model_input.memory.print_mem() + cur_solution = pipeline_parallel.solve_problem(model_input) + object_value = 5562.66 + self.assertLessEqual(cur_solution.object_value, object_value + 20) + self.assertGreaterEqual(cur_solution.object_value, object_value - 20) + +if __name__ == '__main__': + unittest.main() diff --git a/setup.py b/setup.py index 49290d95acf748bf35c052955a73e5b66734c7e6..112170c3ae3cec1f01247fdac3cfb015b54598a6 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ def get_description(): os_info = get_platform() cpu_info = platform.machine().strip() - return 'hyper_parallel platform: %s, cpu: %s' % (os_info, cpu_info) + return f'hyper_parallel platform: {os_info}, cpu: {cpu_info}' def get_install_requires(): @@ -69,7 +69,7 @@ def get_install_requires(): Returns: list, list of dependent packages. """ - with open('requirements.txt') as file: + with open('requirements.txt', encoding='utf-8') as file: return file.read().strip().splitlines() @@ -151,7 +151,9 @@ if __name__ == '__main__': long_description=get_readme_content(), long_description_content_type="text/markdown", test_suite="tests", - packages=find_packages(exclude=["*tests*"]), + packages=find_packages(exclude=["*tests*", + "hyper_parallel.auto_parallel.fast-tuner", + "hyper_parallel.auto_parallel.fast-tuner.*"]), platforms=[get_platform()], include_package_data=True, package_data=package_data,