From 7c9a27d14e62d6e924675da7c767d79dc2c798b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=8F=E7=90=A2=E8=89=BA?= Date: Tue, 30 Dec 2025 10:45:45 +0800 Subject: [PATCH] mint update --- .../test_cross_entropy_loss/__init__.py | 0 .../test_cross_entropy_loss/data_gen_utils.py | 87 +++++++++++ .../run_cross_entropy_loss.py | 98 +++++++++++++ .../test_cross_entropy_loss.py | 136 ++++++++++++++++++ 4 files changed, 321 insertions(+) create mode 100644 tests/st/test_ut/test_pynative/test_cross_entropy_loss/__init__.py create mode 100644 tests/st/test_ut/test_pynative/test_cross_entropy_loss/data_gen_utils.py create mode 100644 tests/st/test_ut/test_pynative/test_cross_entropy_loss/run_cross_entropy_loss.py create mode 100644 tests/st/test_ut/test_pynative/test_cross_entropy_loss/test_cross_entropy_loss.py diff --git a/tests/st/test_ut/test_pynative/test_cross_entropy_loss/__init__.py b/tests/st/test_ut/test_pynative/test_cross_entropy_loss/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/st/test_ut/test_pynative/test_cross_entropy_loss/data_gen_utils.py b/tests/st/test_ut/test_pynative/test_cross_entropy_loss/data_gen_utils.py new file mode 100644 index 000000000..0ae79fbcf --- /dev/null +++ b/tests/st/test_ut/test_pynative/test_cross_entropy_loss/data_gen_utils.py @@ -0,0 +1,87 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Data generation utilities for pynative CrossEntropyLoss tests with random data""" +import numpy as np + +import mindspore as ms + +from mindformers.parallel_core.training_graph.loss_func import CrossEntropyLoss +from mindformers.parallel_core.transformer_config import default_transformer_config + +def get_init_params(batch_size, seq_length, vocab_size): + """ + Generates random initial parameters (inputs) for VocabParallelCrossEntropy. + """ + np.random.seed(42) + + logits_shape = (batch_size * seq_length, vocab_size) + logits = 0.01 * np.random.randn(*logits_shape).astype(np.float32) + + target_shape = (batch_size * seq_length,) + target = np.random.randint(0, vocab_size, size=target_shape).astype(np.int32) + input_mask = np.random.randint(0, 2, size=target_shape).astype(np.float32) + + if np.sum(input_mask) == 0 and input_mask.size > 0: + input_mask[0] = 1.0 + + return { + "logits": logits, + "target": target, + "input_mask": input_mask, + } + + +def get_static_output(logits, target, input_mask): + """get GRAPH_MODE CELoss result""" + ms.set_context(mode=0) + config = default_transformer_config + config.calculate_per_token_loss = True + net = CrossEntropyLoss(config) + grad_fn = ms.value_and_grad(net, grad_position=0) + result, grad = grad_fn(logits, target, input_mask) + numerator, denominator = result + loss = numerator / denominator + return { + "numerator": numerator, + "denominator": denominator, + "loss": loss, + "grad": grad + } + + +def get_cpu_output(logits, target, input_mask): + """get cpu (numpy) CELoss result""" + # forward + logit_max = np.max(logits, 1, keepdims=True) + logit_sub = logits - logit_max + logit_exp = np.exp(logit_sub) + exp_sum = np.sum(logit_exp, -1, keepdims=True) + log_exp_sum = np.log(exp_sum) + logit_neg_logsoftmax = log_exp_sum - logit_sub + loss_reduce = logit_neg_logsoftmax[np.arange(logits.shape[0]), target] + numerator = (loss_reduce * input_mask).sum() + denominator = input_mask.sum() + 1.e-8 + loss = numerator / denominator + # backward + dout_reduce = input_mask / input_mask.sum() + logits_softmax = logit_exp / exp_sum + logits_softmax[np.arange(logits.shape[0]), target] -= 1 + grad = logits_softmax * dout_reduce.reshape(-1, 1) + return { + "numerator": numerator, + "denominator": denominator, + "loss": loss, + "grad": grad + } diff --git a/tests/st/test_ut/test_pynative/test_cross_entropy_loss/run_cross_entropy_loss.py b/tests/st/test_ut/test_pynative/test_cross_entropy_loss/run_cross_entropy_loss.py new file mode 100644 index 000000000..bd2a216a1 --- /dev/null +++ b/tests/st/test_ut/test_pynative/test_cross_entropy_loss/run_cross_entropy_loss.py @@ -0,0 +1,98 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Run CrossEntropyLoss accuracy test with configurable parameters via args""" +import os +import argparse +from pathlib import Path +import numpy as np + +from data_gen_utils import get_init_params, get_cpu_output, get_static_output + +import mindspore as ms + +from mindformers.pynative.loss import CrossEntropyLoss + + +SCRIPT_DIR = Path(__file__).parent.resolve() + + +class CrossEntropyLossRunner: + """Class to manage CrossEntropyLoss model and data""" + + def __init__(self, args_from_parser): + self.args = args_from_parser + self.calculate_per_token_loss = self.args.calculate_per_token_loss + + self.vocab_size = self.args.vocab_size + self.batch_size = self.args.batch_size + self.seq_length = self.args.seq_length + + logits, target, input_mask = get_init_params(self.batch_size, self.seq_length, self.vocab_size).values() + self.output_cpu = get_cpu_output(logits, target, input_mask) + + self.logits = ms.Tensor(logits, dtype=ms.float32) + self.target = ms.Tensor( + target.reshape((self.batch_size, self.seq_length)).reshape(-1), + dtype=ms.int32, + ) + self.input_mask = ms.Tensor( + input_mask.reshape((self.batch_size, self.seq_length)).reshape(-1), + dtype=ms.int32, + ) + + def run(self): + """Run the model with given inputs""" + ms.set_context(mode=1) + net = CrossEntropyLoss( + calculate_per_token_loss=self.calculate_per_token_loss, + ) + + grad_fn = ms.value_and_grad(net, grad_position=0) + result, grad = grad_fn(self.logits, self.target, self.input_mask) + + output_pynative = {} + if not self.calculate_per_token_loss: + output_pynative["loss"] = result + else: + numerator, denominator = result + output_pynative["numerator"] = numerator + output_pynative["denominator"] = denominator + output_pynative["grad"] = grad + output_pynative = {k: v.asnumpy().astype(np.float32) for k, v in output_pynative.items() if v is not None} + output_static = get_static_output(self.logits, self.target, self.input_mask) + output_path = self.args.output_path + np.savez(os.path.join(output_path, "output_pynative_loss.npz"), **output_pynative) + np.savez(os.path.join(output_path, "output_static_loss.npz"), **output_static) + np.savez(os.path.join(output_path, "output_cpu_loss.npz"), **self.output_cpu) + +def main(): + parser = argparse.ArgumentParser(description="Run CrossEntropyLoss test") + parser.add_argument("--vocab_size", type=int, default=1024) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--seq_length", type=int, default=8) + parser.add_argument("--output_path", type=str, default="./") + parser.add_argument("--calculate_per_token_loss", type=lambda x: x.lower() == "true", default="false") + + args = parser.parse_args() + + ms.context.set_context(deterministic="ON") + ms.set_seed(42) + + runner = CrossEntropyLossRunner(args) + runner.run() + + +if __name__ == "__main__": + main() diff --git a/tests/st/test_ut/test_pynative/test_cross_entropy_loss/test_cross_entropy_loss.py b/tests/st/test_ut/test_pynative/test_cross_entropy_loss/test_cross_entropy_loss.py new file mode 100644 index 000000000..4f6fedc41 --- /dev/null +++ b/tests/st/test_ut/test_pynative/test_cross_entropy_loss/test_cross_entropy_loss.py @@ -0,0 +1,136 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test CrossEntropyLoss with various configurations""" + +from pathlib import Path +import subprocess +import pytest +import numpy as np +from tests.utils.double_benchmark import DoubleBenchmarkStandard, DoubleBenchmarkComparator + + +TOTAL_VOCAB_SIZE = 1024 +BATCH_SIZE = 4 +SEQ_LENGTH = 8 + +SINGLE_CARD_TEST_PARAM = "model_args, data_keys" +SINGLE_CARD_TEST_CASES = [ + # Test Case 1: Single Card, calculate_per_token=False + ( + {"calculate_per_token_loss": False}, + {"loss": "loss", "grad": "grad"}, + ), + # Test Case 2: Single Card, calculate_per_token=True + ( + {"calculate_per_token_loss": True}, + {"numerator": "numerator", "denominator": "denominator"}, + ), +] + +def build_msrun_command_list( + run_script_path, + vocab_size, + batch_size, + seq_length, + calculate_per_token_loss, + output_path_param, + ): + """Build the msrun command with the specified parameters for VocabParallelCrossEntropy.""" + cmd_list = [ + "python", + str(run_script_path), + f"--vocab_size={vocab_size}", + f"--batch_size={batch_size}", + f"--seq_length={seq_length}", + f"--calculate_per_token_loss={str(calculate_per_token_loss).lower()}", + f"--output_path={output_path_param}", + ] + print(f"Equivalent shell command for debugging (approximate): {' '.join(cmd_list)}") + return cmd_list + + +class TestCrossEntropyLoss: + """Test class for CrossEntropyLoss with different configurations""" + + OUTPUT_PYNATIVE_FILENAME = "output_pynative_loss.npz" + OUTPUT_STATIC_FILENAME = "output_static_loss.npz" + OUTPUT_CPU_FILENAME = "output_cpu_loss.npz" + + def setup_method(self): + """Setup method to prepare test environment""" + self.sh_path = Path(__file__).parent.resolve() + self.run_script_path = self.sh_path / "run_cross_entropy_loss.py" + + def check_acc(self, output_pynative_dict, output_static_dict, output_cpu_dict, data_keys): + """ + Compare output using DoubleBenchmarkComparator. + """ + standard = DoubleBenchmarkStandard(dtype="float32") + + for key, data_key in data_keys.items(): + assert key in output_pynative_dict, f"Key '{key}' not found in MindSpore output." + assert data_key in output_cpu_dict, f"Golden data key '{data_key}' not found." + npu_data = output_pynative_dict.get(key) + golden_data = output_static_dict.get(data_key) + gpu_data = output_cpu_dict.get(data_key) + + DoubleBenchmarkComparator.check_pass_or_not( + npu_data=npu_data, gpu_data=gpu_data, golden_data=golden_data, standard=standard + ) + + def run_test( + self, + model_args, + data_keys, + tmp_path, + ): + """Helper function to run test and check results""" + output_file_path = tmp_path + + cmd_list = build_msrun_command_list( + run_script_path=self.run_script_path, + vocab_size=TOTAL_VOCAB_SIZE, + batch_size=BATCH_SIZE, + seq_length=SEQ_LENGTH, + calculate_per_token_loss=model_args["calculate_per_token_loss"], + output_path_param=output_file_path, + ) + + result = subprocess.run(cmd_list, shell=False, capture_output=True, text=True, check=False) + + assert result.returncode == 0, ( + f"Test script failed with non-zero exit code: " + f"{result.returncode}.\nStdout:\n{result.stdout}\nStderr:\n{result.stderr}" + ) + assert output_file_path.exists(), f"Output file {output_file_path} was not created." + output_pynative_dict = np.load(tmp_path / self.OUTPUT_PYNATIVE_FILENAME) + output_static_dict = np.load(tmp_path / self.OUTPUT_STATIC_FILENAME) + output_cpu_dict = np.load(tmp_path / self.OUTPUT_CPU_FILENAME) + self.check_acc(output_pynative_dict, output_static_dict, output_cpu_dict, data_keys) + + @pytest.mark.level0 + @pytest.mark.platform_arm_ascend910b_training + @pytest.mark.env_onecard + @pytest.mark.parametrize( + SINGLE_CARD_TEST_PARAM, + SINGLE_CARD_TEST_CASES + ) + def test_single_card_cases(self, model_args, data_keys, tmp_path): + """Test single card with various configurations.""" + self.run_test( + model_args=model_args, + data_keys=data_keys, + tmp_path=tmp_path, + ) -- Gitee