diff --git a/mindformers/tools/pipeline_balance/layers/llama2_70b_prof_RR_full.json b/mindformers/tools/pipeline_balance/layers/llama2_70b_prof_RR_full.json index 98aa4bb7213329b841f3da76f3827cfb502e07c3..2f38b98dd8be2dcf293dd6d1e150d31248eaf82c 100644 --- a/mindformers/tools/pipeline_balance/layers/llama2_70b_prof_RR_full.json +++ b/mindformers/tools/pipeline_balance/layers/llama2_70b_prof_RR_full.json @@ -1,39 +1,40 @@ { - "name": "llama2_70b_prof", - "pre_defined_layer": { - "LlamaEmbedding": 0, - "LlamaRMSNorm": -1 - }, - "auto_partition_layer": { - "LLamaDecodeLayer": 96 - }, - "layers_description": [ - { - "name": "LlamaEmbedding", - "model_name": "llama2_70b_prof", - "type": "HEAD", - "time": 30, - "nb_layer": 1, - "memory_parameter": 9785 + "name": "llama2_70b_prof", + "pre_defined_layer": { + "LlamaEmbedding": 0, + "LlamaRMSNorm": -1 }, - { - "name": "LLamaDecodeLayer", - "model_name": "llama2_70b_prof", - "type": "BODY", - "time": 180, - "nb_layer": 96, - "memory_activation": 822, - "memory_parameter": 1562, - "memory_recompute": 32, - "memory_select_comm": 498 + "auto_partition_layer": { + "LLamaDecodeLayer": 96 }, - { - "name": "LlamaRMSNorm", - "model_name": "llama2_70b_prof", - "type": "TAIL", - "time": 90, - "nb_layer": 1, - "memory_parameter": 2869 - } - ] -} \ No newline at end of file + "layers_description": [ + { + "name": "LlamaEmbedding", + "model_name": "llama2_70b_prof", + "type": "HEAD", + "time": 30, + "nb_layer": 1, + "memory_parameter": 11000 + }, + { + "name": "LLamaDecodeLayer", + "model_name": "llama2_70b_prof", + "type": "BODY", + "time": 180, + "nb_layer": 96, + "memory_parameter": 1583, + "memory_recompute": 32, + "memory_activation": 826, + "memory_select_rec": 614, + "memory_select_comm": 426 + }, + { + "name": "LlamaRMSNorm", + "model_name": "llama2_70b_prof", + "type": "TAIL", + "time": 90, + "nb_layer": 1, + "memory_parameter": 2473 + } + ] + } diff --git a/mindformers/tools/pipeline_balance/layers/template.json b/mindformers/tools/pipeline_balance/layers/template.json index 46ae22278ca3920d4026641da8a70938c845b1da..9896915381cc0c3d5f852960367ff9cb93d923e7 100644 --- a/mindformers/tools/pipeline_balance/layers/template.json +++ b/mindformers/tools/pipeline_balance/layers/template.json @@ -1,22 +1,22 @@ { - "name": "template", + "name": "model_name", "pre_defined_layer": { - "headLayer": 0, - "tailLayer": -1 + "headLayerName_AtBeginRepresentedByZero": 0, + "tailLayerName_AtEndRepresentedByMinusOne": -1 }, "auto_partition_layer": { - "bodyLayer": 4 + "NumberOfLayers": 4 }, "layers_description": [ { - "name": "headLayer", + "name": "headLayerName", "type": "HEAD", "time": 0, "nb_layer": 1, "memory_parameter": 0 }, { - "name": "bodyLayer", + "name": "bodyLayerName", "type": "BODY", "time": 0, "nb_layer": 1, @@ -25,7 +25,7 @@ "memory_recompute": 0 }, { - "name": "tailLayer", + "name": "tailLayerName", "type": "TAIL", "time": 0, "nb_layer": 1, diff --git a/mindformers/tools/pipeline_balance/main.py b/mindformers/tools/pipeline_balance/main.py index fb0b8c43a9bbea59a4b32b5b489bfab95bfdf93e..976b288d1792cfe7d3f14c45f8864d166e2e47a7 100644 --- a/mindformers/tools/pipeline_balance/main.py +++ b/mindformers/tools/pipeline_balance/main.py @@ -17,7 +17,7 @@ import sys import argparse import json -import mindformers.tools.pipeline_balance.utils.interactive +import mindformers.tools.pipeline_balance.utils.interactive as Interactive from mindformers.tools.pipeline_balance.utils.layer import generate_layers_list from mindformers.tools.pipeline_balance.utils.compute_memory import compute_memories from mindformers.tools.pipeline_balance.sapp.sapp_pipeline import SappPipeline, choose_interleave @@ -49,6 +49,12 @@ if __name__ == "__main__": default=56000, help="Maximum memory available (MB)") + parser.add_argument('-lm', '--less_memory', + type=lambda x: (str(x).lower() in [ + 'true', '1', 'yes']), + default=False, + help="Compute Memory with 'Less Memory interleave' option") + parser.add_argument('-o', '--overlap', type=list, default=[1], @@ -59,7 +65,7 @@ if __name__ == "__main__": type=str, default="Llama_special", help="") - + # Model info parser.add_argument('-t', '--time_limit', type=int, @@ -111,7 +117,7 @@ if __name__ == "__main__": args = parser.parse_args() if len(sys.argv) == 1: - interactive.main() + Interactive.main() sys.exit(0) layer_folder = args.layer_folder @@ -121,6 +127,8 @@ if __name__ == "__main__": number_of_stage = args.stage number_of_micro_batch = args.micro_batch time_limit = args.time_limit + less_memory = args.less_memory + # overlap_coeff = [1.75, 1, 1] overlap_coeff = args.overlap @@ -159,7 +167,8 @@ if __name__ == "__main__": num_of_micro_batch=number_of_micro_batch, max_memory=max_memory, layers=layers, - num_of_interleave=interleave_degree) + num_of_interleave=interleave_degree, + vpp_less_memory=less_memory) pipe.construct_problem(solver="pulp") pipe.solve_problem(time_limit=time_limit) diff --git a/mindformers/tools/pipeline_balance/sapp/sapp_pipeline.py b/mindformers/tools/pipeline_balance/sapp/sapp_pipeline.py index 98b2e08ee9e956cdfc79052fd364a18a286cdd8f..84b92d48c1f82c7962f67252fbd6dba6ae571081 100644 --- a/mindformers/tools/pipeline_balance/sapp/sapp_pipeline.py +++ b/mindformers/tools/pipeline_balance/sapp/sapp_pipeline.py @@ -20,12 +20,15 @@ from mindformers.tools.pipeline_balance.utils.layer import Layer, filter_layer_t import mindformers.tools.pipeline_balance.utils.recompute as Recompute class SappPipeline: - def __init__(self, model_name: str, num_of_stage: int, num_of_micro_batch: int, max_memory: int, layers: list[Layer], num_of_interleave: int = 1): + def __init__(self, model_name: str, num_of_stage: int, num_of_micro_batch: int, + max_memory: int, layers: list[Layer], vpp_less_memory: bool = False, + num_of_interleave: int = 1): self.model_name_ = model_name self.num_of_stage_ = num_of_stage self.num_of_micro_batch_ = num_of_micro_batch self.num_of_interleave_ = num_of_interleave self.max_memory_ = max_memory + self.vpp_less_memory_ = vpp_less_memory self.problem_ = None self.layers_ = layers self.layers_sorted_ = { @@ -40,6 +43,7 @@ class SappPipeline: num_of_micro_batch=self.num_of_micro_batch_, num_of_interleave=self.num_of_interleave_, max_memory=self.max_memory_, + vpp_less_memory=self.vpp_less_memory_, layers=self.layers_, layers_sorted=self.layers_sorted_ ) @@ -57,10 +61,10 @@ class SappPipeline: if solver == "pulp": self.problem_ = self._construct_problem_pulp_() elif solver == "other": - print("No other solver available..., automatically switch to pulp!!!") + print("[WARNING] No other solver available..., automatically switch to pulp!!!") self.problem_ = self._construct_problem_pulp_() else: - print("No other solver available..., automatically switch to pulp!!!") + print("[WARNING] No other solver available..., automatically switch to pulp!!!") self.problem_ = self._construct_problem_pulp_() def solve_problem(self, time_limit=90): @@ -132,23 +136,27 @@ class SappPipeline: recomputes[r][inter][stage] += int( self.problem_.variables_[layer_name][Recompute.Type.FULL][inter][stage].varValue) - print(f"layer-to-stage assignment baseline is {semi_layer_per_stage}") - print("\nTo put in yaml configuration:") + print(f"[INFO] layer-to-stage assignment baseline is {semi_layer_per_stage}") + print("[INFO] \nTo put in yaml configuration:") if self.num_of_interleave_ == 1: offset = flatten(offset) - print(f"\toffset: {offset}") + print(f"[INFO] \toffset: {offset}") for r in Recompute.Type: if self._recompute_considered()[r] and r is not Recompute.Type.NONE: recompute_layers = recomputes[r] if self.num_of_interleave_ == 1: recompute_layers = flatten(recompute_layers) - print(f"\t{Recompute.YamlName[r]}: {recompute_layers}") - print(f"\tpp_interleave_num: {self.num_of_interleave_}") + print(f"[INFO] \t{Recompute.YamlName[r]}: {recompute_layers}") + print(f"[INFO] \tpp_interleave_num: {self.num_of_interleave_}") + def get_naive_memory_activation(self, all_recompute=False, interleave_num=1) -> list[float]: """Give the activation memory per stage for a naive layer assignment without interleave for simulator.""" memory_active = [] for layer in self.layers_sorted_[Layer.Type.BODY]: + for r in Recompute.Type: + if layer.recompute_considered_[r]: + rec = r lyr_per_stg = self.naive_layer_per_stage( layer.nb_layer_, interleave_num) if self.has_some_memory_info(): @@ -158,7 +166,7 @@ class SappPipeline: if all_recompute: memory_active[inter].append( lyr_per_stg[inter][stage] * - layer.memory_activation_rec_[Recompute.Type.FULL]) + layer.memory_activation_rec_[rec]) else: memory_active[inter].append( lyr_per_stg[inter][stage] * @@ -177,7 +185,6 @@ class SappPipeline: for stage in range(self.num_of_stage_): memory_param[inter].append(lyr_per_stg[inter][stage] * layer.memory_parameter_) - for head in self.layers_sorted_[Layer.Type.HEAD]: if head.memory_parameter_ is not None: memory_param[0][0] += head.memory_parameter_ @@ -190,25 +197,32 @@ class SappPipeline: def get_naive_time(self, all_recompute=False, interleave_num=1) -> list[float]: """Give the time per stage for a naive layer assignment without interleave for simulator.""" time = [] + rec = Recompute.Type.FULL for i in range(interleave_num): time.append([]) for s in range(self.num_of_stage_): time[i].append(0) for layer in self.layers_sorted_[Layer.Type.BODY]: + for r in Recompute.Type: + if layer.recompute_considered_[r]: + rec = r lyr_per_stg = self.naive_layer_per_stage( layer.nb_layer_, interleave_num) if not all_recompute: time[i][s] += lyr_per_stg[i][s] * layer.time_ else: time[i][s] += lyr_per_stg[i][s] * ( - layer.forward_time_ + layer.backward_time_rec_[Recompute.Type.FULL]) + layer.forward_time_ + layer.backward_time_rec_[rec]) for head in self.layers_sorted_[Layer.Type.HEAD]: time[0][0] += head.time_ for tail in self.layers_sorted_[Layer.Type.TAIL]: time[interleave_num-1][self.num_of_stage_-1] += tail.time_ + if all_recompute: + print("even partitionning computed with", Recompute.YamlName[rec]) return time + def simulate(self, show=True, file_name=None): """.Use simulator to visualize output.""" if self.has_some_memory_info(): @@ -262,7 +276,7 @@ def choose_interleave(model_name: str, number_of_stage: int, number_of_micro_bat pipe.construct_problem(solver="pulp") pipe.solve_problem() time = pipe.simulate(show=False) - print(f"for interleave {inter}, time = {time}") + print(f"[INFO] for interleave {inter}, time = {time}") if time < best_time: best_time = time best_inter = inter diff --git a/mindformers/tools/pipeline_balance/sapp/sapp_solver.py b/mindformers/tools/pipeline_balance/sapp/sapp_solver.py index a82320a85d53eb9657cdd4f83841a8ab4a570951..13356f06eb87ad534c7bf129dd1205d7716e09f6 100644 --- a/mindformers/tools/pipeline_balance/sapp/sapp_solver.py +++ b/mindformers/tools/pipeline_balance/sapp/sapp_solver.py @@ -24,14 +24,16 @@ import mindformers.tools.pipeline_balance.utils.recompute as Recompute class PulpSolver: - def __init__(self, num_of_stage: int, num_of_interleave: int, num_of_micro_batch: int, max_memory: int, - layers: list[Layer], layers_sorted: dict[Layer.Type, list[Layer]], + def __init__(self, num_of_stage: int, num_of_interleave: int, num_of_micro_batch: int, + max_memory: int, layers: list[Layer], + layers_sorted: dict[Layer.Type, list[Layer]], vpp_less_memory: bool = False, description: str = "Pipeline_execution_time_minimize"): self.num_of_stage_ = num_of_stage self.num_of_interleave_ = num_of_interleave self.num_of_micro_batch_ = num_of_micro_batch self.max_memory_ = max_memory + self.vpp_less_memory_ = vpp_less_memory self.layers_ = layers self.layers_sorted_ = layers_sorted @@ -103,79 +105,82 @@ class PulpSolver: ############################################ # Time Constraint # ############################################ - def add_max_stage_constraint(self, prob, variables, layers_sorted, num_of_stage, num_of_interleave, max_stage_time): - """Constraints on sub-main-part of a stage that it may take (for all stage)""" - # TODO (very) low priority: improve constraint to add for first and last stage - # by adding temp variable that are strictly positive to decide when to add or not the head/tail - def _max_stage_bound_i_(variables, layers_sorted, stage_id, inter_rec_ids, num_of_stage, num_of_interleave): - """Compute bound for stage_id, inter_norecompute_id and inter_recompute_id """ - bound = lpSolver.LpAffineExpression() - for layer in layers_sorted[Layer.Type.BODY]: - for rec in Recompute.Type: - if self.recompute_considered_[rec]: - bound += (variables[layer.name_][rec][inter_rec_ids[rec]][stage_id] - * (layer.forward_time_ + layer.backward_time_rec_[rec])) + def _max_stage_bound_i_fp(self, variables, layers_sorted, stage_id, inter_f): + bound = lpSolver.LpAffineExpression() + for layer in layers_sorted[Layer.Type.BODY]: + for rec in Recompute.Type: + if self.recompute_considered_[rec]: + bound += (variables[layer.name_][rec][inter_f][stage_id] + * layer.forward_time_) + return bound - if stage_id == 0: - put_head = False - for rec in Recompute.Type: - if self.recompute_considered_[rec]: - if inter_rec_ids[rec] == 0: - put_head = True - if put_head: - for head in layers_sorted[Layer.Type.HEAD]: - bound += (variables[head.name_][0] * head.time_) - if stage_id == num_of_stage-1: - put_tail = False - for rec in Recompute.Type: - if self.recompute_considered_[rec]: - if inter_rec_ids[rec] == num_of_interleave-1: - put_tail = True - if put_tail: - for tail in layers_sorted[Layer.Type.TAIL]: - bound += (variables[tail.name_][0] * tail.time_) - return bound - - indexes = Recompute.make_all_indexes( - self.recompute_considered_, num_of_interleave) - recomputes = Recompute.recomputes_from_indexes( - self.recompute_considered_, indexes) + def _max_stage_bound_i_bp(self, variables, layers_sorted, stage_id, inter_b): + bound = lpSolver.LpAffineExpression() + for layer in layers_sorted[Layer.Type.BODY]: + for rec in Recompute.Type: + if self.recompute_considered_[rec]: + bound += (variables[layer.name_][rec][inter_b][stage_id] + * layer.backward_time_rec_[rec]) + return bound + + def _max_stage_bound_head_tail(self, variables, layers_sorted, stage_id, + inter_f, inter_b): + bound = lpSolver.LpAffineExpression() + if stage_id == 0: + if inter_f == 0: + for head in layers_sorted[Layer.Type.HEAD]: + bound += variables[head.name_][0] * head.time_ + if inter_b == 0: + for head in layers_sorted[Layer.Type.HEAD]: + bound += variables[head.name_][0] * head.time_ * 2 + if stage_id == self.num_of_stage_-1: + if inter_f == self.num_of_interleave_-1: + for tail in layers_sorted[Layer.Type.TAIL]: + bound += variables[tail.name_][0] * tail.time_ + if inter_b == self.num_of_interleave_-1: + for tail in layers_sorted[Layer.Type.TAIL]: + bound += variables[tail.name_][0] * tail.time_ * 2 + return bound + + def _total_sum(self, variables, layers_sorted): + bound = lpSolver.LpAffineExpression() + for layer in layers_sorted[Layer.Type.BODY]: + for rec in Recompute.Type: + if self.recompute_considered_[rec]: + for inter in range(self.num_of_interleave_): + for stage in range(self.num_of_stage_): + bound += (variables[layer.name_][rec][inter][stage] + * (layer.forward_time_ + layer.backward_time_rec_[rec])) + return bound + + def add_max_stage_constraint(self, prob, variables, layers_sorted, + num_of_stage, num_of_interleave, max_stage_time): + """Constraints on sub-main-part of a stage that it may take (for all stage)""" for i_stage in range(num_of_stage): - for rec_index in recomputes: - prob += (max_stage_time >= _max_stage_bound_i_( - variables, layers_sorted, i_stage, rec_index, num_of_stage, num_of_interleave)) - - sum_FPi_BPi = lpSolver.LpVariable( - "sum_FPi_BPi", 0, None, lpSolver.LpContinuous) - var_sum_FPi_BPi = lpSolver.LpVariable( - "var_sum_FPi_BPi", 0, None, lpSolver.LpContinuous) + for inter_f in range(self.num_of_interleave_): + for inter_b in range(self.num_of_interleave_): + prob += (max_stage_time >= + self._max_stage_bound_i_fp(variables, layers_sorted, + i_stage, inter_f) + + self._max_stage_bound_i_bp(variables, layers_sorted, + i_stage, inter_b) + + self._max_stage_bound_head_tail(variables, layers_sorted, + i_stage, inter_f, inter_b)) + pipeline_total_time = lpSolver.LpVariable( "pipeline_total_time", 0, None, lpSolver.LpContinuous) - for i_stage in range(num_of_stage): - for i_inter in range(num_of_interleave): - sum_FPi_BPi += _max_stage_bound_i_( - variables, layers_sorted, i_stage, { - r: i_inter for r in Recompute.Type}, - num_of_stage, num_of_interleave) - - prob += var_sum_FPi_BPi >= sum_FPi_BPi - prob += pipeline_total_time >= var_sum_FPi_BPi + max_stage_time * ( - 64 - # self.num_of_micro_batch_ - 2 - ) - - # max_stages = lpSolver.LpVariable("max_stages", 0, None, lpSolver.LpContinuous) - # min_stages = lpSolver.LpVariable("min_stages", 0, None, lpSolver.LpContinuous) - - # for i_stage in range(num_of_stage): - # for i_inter in range(num_of_interleave): - # prob += (max_stages >= _max_stage_bound_i_(variables, layers_sorted, i_stage, - # i_inter, i_inter, i_inter, - # num_of_stage, num_of_interleave)) - # prob += (min_stages <= _max_stage_bound_i_(variables, layers_sorted, i_stage, - # i_inter, i_inter, i_inter, - # num_of_stage, num_of_interleave)); + # Only max time + prob += pipeline_total_time >= max_stage_time + + # Max time & sum time + # var_sum_FPi_BPi = lpSolver.LpVariable( + # "var_sum_FPi_BPi", 0, None, lpSolver.LpContinuous) + + # prob += var_sum_FPi_BPi >= self._total_sum(variables, layers_sorted) + # prob += pipeline_total_time >= var_sum_FPi_BPi + max_stage_time * ( + # self.num_of_micro_batch_ - 2 + # ) return prob @@ -198,7 +203,7 @@ class PulpSolver: bound += variables[tail.name_][0] * tail.memory_parameter_ return bound - def compute_activation_nums(self, num_of_stage: int, num_of_interleave: int) -> list[list[int]]: + def compute_activation_nums(self, num_of_stage: int, num_of_interleave: int, micro_batch: int) -> list[list[int]]: activation_nums = [] if num_of_interleave > 1: for i in range(num_of_interleave): @@ -210,6 +215,30 @@ class PulpSolver: for s in range(num_of_stage): activation_nums[num_of_interleave - 1][s] += min(0, num_of_stage - 2 * s - 1) + for i in range(num_of_interleave): + for s in range(num_of_stage): + activation_nums[i][s] = min(activation_nums[i][s], + micro_batch) + + else: + for i in range(num_of_interleave): + activation_nums.append([]) + for s in range(num_of_stage): + activation_nums[i].append(num_of_stage-s) + + + print("number of activation to count: ", activation_nums) + return activation_nums + + def compute_less_activation_nums(self, num_of_stage: int, num_of_interleave: int, micro_batch: int) -> list[list[int]]: + activation_nums = [] + if num_of_interleave > 1: + for i in range(num_of_interleave): + activation_nums.append([]) + for _ in range(num_of_stage): + activation_nums[i].append(num_of_stage) + for s in range(num_of_stage): + activation_nums[num_of_interleave -1][s] -= s else: for i in range(num_of_interleave): activation_nums.append([]) @@ -256,9 +285,11 @@ class PulpSolver: # bound += abs(inter_a - inter_b) * self.num_of_micro_batch_ return bound - def add_pipeline_memory_constraint(self, prob, variables, layers_sorted, num_of_stage, num_of_interleave, memory_limit): - activation_nums = self.compute_activation_nums( - num_of_stage, num_of_interleave) + def add_pipeline_memory_constraint(self, prob, variables, layers_sorted, num_of_stage, num_of_interleave, micro_batch, memory_limit): + if self.vpp_less_memory_: + activation_nums = self.compute_less_activation_nums(num_of_stage, num_of_interleave, micro_batch) + else: + activation_nums = self.compute_activation_nums(num_of_stage, num_of_interleave, micro_batch) for s in range(num_of_stage): prob += memory_limit >= ( self.stage_param_memory(variables, layers_sorted, s, num_of_stage, num_of_interleave) + @@ -290,7 +321,7 @@ class PulpSolver: if self.has_some_memory_info(): self.add_pipeline_memory_constraint(prob, self.variables_, layers_sorted, num_of_stage, - num_of_interleave, max_memory) + num_of_interleave, num_of_micro_batch, max_memory) return prob @@ -317,22 +348,22 @@ class PulpSolver: layer_name = layer.name_ print("For layer:", layer_name) print("=========") - print(" Forward Prop time: ", layer.time_) - print(" Backward Prop time: ", - layer.backward_time_rec_[Recompute.Type.NONE]) - print(" Back SelectRec time: ", - layer.backward_time_rec_[Recompute.Type.SLCT]) - print(" Back Recompute time: ", - layer.backward_time_rec_[Recompute.Type.FULL]) + print(" Forward Prop time: ", layer.forward_time_) + for rec in Recompute.Type: + if layer.recompute_considered_[rec]: + print(" Backward Prop", Recompute.YamlName[rec],"time: ", + layer.backward_time_rec_[rec]) for inter in range(self.num_of_interleave_): for stage in range(self.num_of_stage_): print( - " Assign", - layer_name, " ", end="\t") + " Assign", layer_name, end=": ") for rec in Recompute.Type: if self.recompute_considered_[rec]: - print(str(self.variables_[layer_name][rec][inter][stage].varValue) - + " " + rec.name, end=" + ") + if rec is Recompute.Type.NONE: + print(str(int(self.variables_[layer_name][rec][inter][stage].varValue)), end=" ") + else: + print("+", str(int(self.variables_[layer_name][rec][inter][stage].varValue)) + + " " + rec.name, end=" ") print(" of chunk " + str(inter) if self.num_of_interleave_ != 1 else "", " to stage " + str(stage)) diff --git a/mindformers/tools/pipeline_balance/utils/compute_memory.py b/mindformers/tools/pipeline_balance/utils/compute_memory.py index 6651018489185965f549d225a43636d79b850bd5..62e1b07db79c2354c6d573f3a5455a02c1118ea0 100644 --- a/mindformers/tools/pipeline_balance/utils/compute_memory.py +++ b/mindformers/tools/pipeline_balance/utils/compute_memory.py @@ -68,8 +68,8 @@ class ComputeMemory: for stage1 in stages: for stage2 in stages: if not stage1.same_global_config(stage2): - print( - "ERROR: Cannot set stagesA, all element don't have the same configuration") + print("[ERROR]", + "Cannot set stagesA, all element don't have the same configuration") self.stagesA_ = [] return self.stagesA_ = stages @@ -78,14 +78,14 @@ class ComputeMemory: for stage1 in stages: for stage2 in stages: if not stage1.same_global_config(stage2): - print( - "ERROR: Cannot set stagesB, all element don't have the same configuration") + print("[ERROR]", + "Cannot set stagesB, all element don't have the same configuration") self.stagesB_ = [] return for stageA in self.stagesB_: if stage1.same_global_config(stageA): - print( - "ERROR: Cannot set stagesB, an element have the same configuration than stagesA") + print("[ERROR]", + "Cannot set stagesB, an element have the same configuration than stagesA") self.stagesB_ = [] return self.stagesB_ = stages @@ -114,12 +114,12 @@ class ComputeMemory: res /= stage1.nb_layer_ return res else: - print( - "ERROR: stage with same characteristic, BUT SAME ID too, cannot compute memory_parameter") + print("[ERROR]", + "stage with same characteristic, BUT SAME ID too, cannot compute memory_parameter") return 0 else: - print( - "ERROR: stage with different characteristic, cannot compute memory_parameter") + print("[ERROR]", + "stage with different characteristic, cannot compute memory_parameter") return 0 def _compute_memory_parameter_(self, multi_run=False) -> float: @@ -152,7 +152,7 @@ class ComputeMemory: elif self._compute_memories_layers_(): return self.memory_parameter_ else: - print("ERROR when computing _compute_memory_parameter_!!!") + print("[ERROR] Issue with _compute_memory_parameter_!!!") return 0 def _compute_memory_activation_(self, rec, multi_run=False) -> float: @@ -160,14 +160,13 @@ class ComputeMemory: return: memory_activation """ # TODO very low priority - print("_compute_memory_activation_") if multi_run or (len(self.stagesA_) < 5 and len(self.stagesB_) < 5): # look at solution 4 stages - print("Not implemented yet!!!") + print("[ERROR] Not implemented yet!!!") elif self._compute_memories_layers_(): return self.memory_activation_rec_[rec] else: - print("ERROR when compute _compute_memory_activation_!!!") + print("[ERROR] Issue with _compute_memory_activation_!!!") return 0 def _compute_memories_layers_(self) -> bool: @@ -176,8 +175,8 @@ class ComputeMemory: stage_num = len(self.stagesA_) if stage_num == used_rec_num + 3: return self._compute_memories_layer_bodies_() - print(f"{stage_num} stages found and ({used_rec_num}) recomputation considered", - " is not coherent. There should be 3 more stages than recomputation considered") + print(f"[ERROR] {stage_num} stages found and ({used_rec_num}) recomputation considered", + "is not coherent. There should be 3 more stages than recomputation considered") return False def _compute_memories_layer_bodies_local_(self, unused_rec: list[Recompute.Type], @@ -186,13 +185,9 @@ class ComputeMemory: Required at least 3 Stages different from first and last stage return (memory_parameter, memory_recompute, memory_activation) """ - # if len(stages) < 5: - # print("ERROR: Not enough stage information to determine memories information"); - # return (0, 0, 0); variable_factor_list = [] constant_memory_list = [] unused_rec.sort(reverse=True) - print(f"unused_rec: {unused_rec}") for stage in stages: if stage.id_ not in [0, self.number_of_stage_-1]: variable_factor_list.append(stage.get_index_memory_var()) @@ -247,7 +242,7 @@ class ComputeMemory: self.memory_parameter_ = memory_parameter_b self.memory_activation_rec_ = memory_recompute_b else: - print("ERROR: _compute_memories_layer_3body_ failed to compute memories") + print("[ERROR] failed to compute memories") return False return True @@ -307,9 +302,7 @@ class ComputeMemory: def compute_memories(layers: list[Layer], memory_folder: str, number_of_stage: int) -> list[Layer]: # TODO priority high - print("compute_memories") filename = "" - print("Not implemented yet!!!") # Put some meta information in a predefine .json file like layers info? with open(memory_folder + filename) as file: pass @@ -366,10 +359,10 @@ if __name__ == "__main__": cm = ComputeMemory(number_of_stage=num_stage, stagesA=stagesA) - print("memory_head =", int(cm.get_memory_head())) - print("memory_parameter =", int(cm.get_memory_parameter())) + print("[INFO] memory_head =", int(cm.get_memory_head())) + print("[INFO] memory_parameter =", int(cm.get_memory_parameter())) for r in Recompute.Type: if cm.recompute_considered_[r]: - print(Recompute.JsonMemoryName[r], "=", int( + print("[INFO]", Recompute.JsonMemoryName[r], "=", int( cm.get_memory_activation(r))) - print("memory_tail =", int(cm.get_memory_tail())) + print("[INFO] memory_tail =", int(cm.get_memory_tail())) diff --git a/mindformers/tools/pipeline_balance/utils/layer.py b/mindformers/tools/pipeline_balance/utils/layer.py index fd10f261cb47622b014d6bd7bb467cd5b6319499..4e46fe80a86ffc87f945422f2ab7979355c4261b 100644 --- a/mindformers/tools/pipeline_balance/utils/layer.py +++ b/mindformers/tools/pipeline_balance/utils/layer.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ +import os import json from enum import Enum @@ -21,8 +22,7 @@ from mindformers.tools.pipeline_balance.utils.computation_analyzer import Comput class Layer: - """Layer Class to describe layer information - + """ Mandatory parameter: name_ (str): name of the layer type_ (LayerType): type of the layer 'HEAD', 'BODY', 'TAIL' @@ -31,19 +31,12 @@ class Layer: Optional (auto-compute) parameter: forward_time_ (float): forward time for the layer (1/3 of time) - backward_time_ (float): backward time for the layer (2/3 of time) - recompute_time_ (float): backward time with recomputation (=forward_time_+backward_time_) - select_rec_time (float): backward time with selective recomputation (=backward_time_+10%) + backward_time_rec_ (dict[Recompute.Type, float]): backward time (2/3 of time) per recomputation + recompute_considered_: dict[Recompute.Type, bool] set recomputations when considered Optional memory parameter (for recompute): memory_parameter_ (float): memory used by the layer (all kind) - memory_activation_ (float): memory used to activate and use the layer (BODY) - memory_recompute_ (float): memory used by the layer if recompute (BODY) - (instead of activation) - memory_recompute_ < memory_activation_ (else recompute never used) - memory_select_rec_ (float): memory used by the layer if selective recomputation used (BODY) - (instead of activation) - memory_recompute_ < memory_activation_ (else recompute never used) + memory_activation_rec_ (dict[Recompute.Type, float]): memory used by activation per recomputation Not manage yet parameter (for multimodal): model_name_ (str): name of the model the layer be part of (for multimodal) @@ -62,29 +55,19 @@ class Layer: recompute_considered_: dict[Recompute.Type, bool] def __init__(self, model_name: str = "misc", name: str = "misc", - ltype: Type = Type.UNKNOWN, nb_layer: int = 0, time: float = 0, + ltype: Type = Type.UNKNOWN, nb_layer: int = 0, time: float = 0.0, backward_time_rec: dict[Recompute.Type, float] = { r: 0 for r in Recompute.Type}, - forward_time: float = None, memory_parameter: float = 0, - memory_activation_rec: dict[Recompute.Type, float] = {r: 0 for r in Recompute.Type}): + forward_time: float = 0.0, memory_parameter: float = 0.0, + memory_activation_rec: dict[Recompute.Type, float] = {r: 0.0 for r in Recompute.Type}): self.name_ = name self.model_name_ = model_name self.type_ = ltype self.nb_layer_ = nb_layer self.time_ = time self.memory_activation_rec_ = memory_activation_rec - # { - # Recompute.Type.NONE: memory_activation, - # Recompute.Type.SLCT: memory_select_rec, - # Recompute.Type.FULL: memory_recompute - # } self.memory_parameter_ = memory_parameter self.backward_time_rec_ = backward_time_rec - # { - # Recompute.Type.NONE: backward_time, - # Recompute.Type.SLCT: select_rec_time, - # Recompute.Type.FULL: recompute_time - # }; self.forward_time_ = forward_time self.recompute_considered_ = self.find_recompute_considered() self.compute_internal_time() @@ -133,12 +116,7 @@ class Layer: def compute_internal_time(self, back_ratio: float = (BackwardDefaultRatio), force_FB: bool = False, force_recompute: bool = False, force_select_rec: bool = False): - """Auto compute internal time if not already present - - forward_time_ = (1-back_ratio) * time_ - backward_time_ = back_ratio * time_ - recompute_time_ = forward_time_ + backward_time_ - """ + """Auto compute internal time if not already present""" if force_FB or self.forward_time_ is None: self.forward_time_ = (1-back_ratio) * self.time_ self.backward_time_ = back_ratio * self.time_ @@ -170,7 +148,7 @@ class Layer: def generate_layers_list(layer_folder: str, model_name: str) -> list[Layer]: """"Parse layer_folder/model_name.json to generate a list of layer""" layers = [] - json_layer = layer_folder + '/' + model_name + '.json' + json_layer = os.path.join(layer_folder, model_name + '.json') with open(json_layer) as json_file: layer_data_json = json.load(json_file) if "layers_description" in layer_data_json: