From 509e12c346233b2d9cf4d3edacb6161a71ba9295 Mon Sep 17 00:00:00 2001 From: lingyunli63 Date: Fri, 9 Apr 2021 21:30:15 +0800 Subject: [PATCH] support BatchMatMul --- src/composite/composite_topi.cc | 76 +++++++++++++++++------- src/composite/optimize/optimize.cc | 3 + src/composite/optimize/rename_matmul.cc | 39 ++++++++++++ src/composite/optimize/rename_matmul.h | 28 +++++++++ src/composite/optimize/reshape_tensor.cc | 34 ++++++----- 5 files changed, 143 insertions(+), 37 deletions(-) create mode 100644 src/composite/optimize/rename_matmul.cc create mode 100644 src/composite/optimize/rename_matmul.h diff --git a/src/composite/composite_topi.cc b/src/composite/composite_topi.cc index 6f954ae1..eb42cfc9 100644 --- a/src/composite/composite_topi.cc +++ b/src/composite/composite_topi.cc @@ -624,7 +624,7 @@ TVM_REGISTER_GLOBAL("BroadcastTo").set_body([](TVMArgs args, TVMRetValue *rv) { } }); -TVM_REGISTER_GLOBAL("BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { CHECK_GE(args.size(), 2); auto inputs = args[0].operator Array(); auto attrs = args[1].operator OpAttr(); @@ -718,7 +718,7 @@ TVM_REGISTER_GLOBAL("BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { }); // only support fractal_zN: [ko mo mi ki] * [no ko ki ni] = [no mo mi ni] -TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) { +TVM_REGISTER_GLOBAL("aicore_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { CHECK_GE(args.size(), 2); auto attrs = args[1].operator OpAttr(); CHECK(attrs.count("transpose_a")); @@ -743,7 +743,7 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) auto left_shape = left_matrix->shape; auto right_shape = right_matrix->shape; CHECK_EQ(left_shape.size(), right_shape.size()); - CHECK_EQ(left_shape.size(), 4); + CHECK_GE(left_shape.size(), 4); auto type_checker = [](const Tensor &input_data, const std::string name, const air::DataType type) { if (input_data->dtype != type) { @@ -757,26 +757,33 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) Array output_shape; Array k; auto compute_mnk = [&output_shape, &k, &left_shape, &right_shape, transpose_a, transpose_b]() { + size_t dim = left_shape.size(); Expr mo, mi, no, ni, ko, ki; if (transpose_a) { - mo = left_shape[0]; - ko = left_shape[1]; - ki = left_shape[2]; - mi = left_shape[3]; + mo = left_shape[dim - 4]; + ko = left_shape[dim - 3]; + ki = left_shape[dim - 2]; + mi = left_shape[dim - 1]; } else { - ko = left_shape[0]; - mo = left_shape[1]; - mi = left_shape[2]; - ki = left_shape[3]; + ko = left_shape[dim - 4]; + mo = left_shape[dim - 3]; + mi = left_shape[dim - 2]; + ki = left_shape[dim - 1]; } if (transpose_b) { - no = right_shape[1]; - ni = right_shape[2]; + no = right_shape[dim - 3]; + ni = right_shape[dim - 2]; } else { - no = right_shape[0]; - ni = right_shape[3]; + no = right_shape[dim - 4]; + ni = right_shape[dim - 1]; } - output_shape = {no, mo, mi, ni}; + for (size_t i = 0; i < dim - 4; ++i) { + output_shape.push_back(left_shape[i]); + } + output_shape.push_back(no); + output_shape.push_back(mo); + output_shape.push_back(mi); + output_shape.push_back(ni); k = {ko, ki}; }; @@ -795,22 +802,47 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) IterVar reduce_ki = air::reduce_axis(Range(0, k[1]), "ki"); Array reduces = {reduce_ko, reduce_ki}; - auto fcompute = [&left_matrix, &right_matrix, &transpose_a, &transpose_b, &reduces, &output_shape, + auto fcompute = [&left_matrix, &right_matrix, &transpose_a, &transpose_b, &reduces, &Mmad](const Array &indices) { - Array left_indice = {reduces[0], indices[1], indices[2], reduces[1]}; - Array right_indice = {indices[0], reduces[0], reduces[1], indices[3]}; + size_t dim = indices.size(); + Array left_indice; + for (size_t i = 0; i < dim - 4; ++i) { + left_indice.push_back(indices[i]); + } if (transpose_a) { - left_indice = {indices[1], reduces[0], reduces[1], indices[2]}; + left_indice.push_back(indices[dim - 3]); + left_indice.push_back(reduces[0]); + left_indice.push_back(reduces[1]); + left_indice.push_back(indices[dim - 2]); + } else { + left_indice.push_back(reduces[0]); + left_indice.push_back(indices[dim - 3]); + left_indice.push_back(indices[dim - 2]); + left_indice.push_back(reduces[1]); + } + + Array right_indice; + for (size_t i = 0; i < dim - 4; ++i) { + right_indice.push_back(indices[i]); } if (transpose_b) { - right_indice = {reduces[0], indices[0], indices[3], reduces[1]}; + right_indice.push_back(reduces[0]); + right_indice.push_back(indices[dim - 4]); + right_indice.push_back(indices[dim - 1]); + right_indice.push_back(reduces[1]); + } else { + right_indice.push_back(indices[dim - 4]); + right_indice.push_back(reduces[0]); + right_indice.push_back(reduces[1]); + right_indice.push_back(indices[dim - 1]); } + Expr res = Mmad(Cast::make(Float(32), left_matrix(left_indice) * right_matrix(right_indice)), reduces); return res; }; // set output name - auto name = "T_matmul_" + left_matrix->op->name + "_" + right_matrix->op->name; + auto name = "T_batchmatmul_" + left_matrix->op->name + "_" + right_matrix->op->name; // set compute attrs auto set_compute_attrs_zN = [&left_matrix, &right_matrix, &inputs, transpose_a, transpose_b, attrs]() { diff --git a/src/composite/optimize/optimize.cc b/src/composite/optimize/optimize.cc index 791c5901..74a499ff 100644 --- a/src/composite/optimize/optimize.cc +++ b/src/composite/optimize/optimize.cc @@ -15,6 +15,7 @@ */ #include "composite/optimize/optimize.h" #include +#include "composite/optimize/rename_matmul.h" #include "composite/optimize/reshape_tensor.h" #include "composite/optimize/elim_transform_op.h" #include "composite/optimize/inplace_assign_mutator.h" @@ -51,6 +52,8 @@ Stmt Optimize(Stmt &s, BuildInfo &info) { if (info.opt.target == "aicore") { pm.RegisterPass(std::make_shared()); } + // rename MatMul to BatchMatMul + pm.RegisterPass(std::make_shared()); s = pm.Run(s); return s; } diff --git a/src/composite/optimize/rename_matmul.cc b/src/composite/optimize/rename_matmul.cc new file mode 100644 index 00000000..af6eb699 --- /dev/null +++ b/src/composite/optimize/rename_matmul.cc @@ -0,0 +1,39 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "composite/optimize/rename_matmul.h" + +namespace akg { +// rename MatMul to BatchMatMul +class RenameMatmulMutator : public IRMutator { + public: + explicit RenameMatmulMutator() {} + ~RenameMatmulMutator() override = default; + + Stmt Mutate_(const Provide *op, const Stmt &s) { + auto call = op->value.as(); + if (call == nullptr || call->name != "MatMul") { + return IRMutator::Mutate_(op, s); + } + return Provide::make(op->func, 0, + Call::make(op->value.type(), "BatchMatMul", call->args, Call::CallType::PureIntrinsic), + op->args); + } +}; + +Stmt RenameMatmul::Run(const Stmt &s) { + return RenameMatmulMutator().Mutate(s); +} +} // namespace akg diff --git a/src/composite/optimize/rename_matmul.h b/src/composite/optimize/rename_matmul.h new file mode 100644 index 00000000..389e06ef --- /dev/null +++ b/src/composite/optimize/rename_matmul.h @@ -0,0 +1,28 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_ +#define COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_ +#include "composite/optimize/optimize.h" + +namespace akg { +class RenameMatmul : public CompositeOptPass { + public: + RenameMatmul() { pass_name_ = __FUNCTION__; } + ~RenameMatmul() = default; + Stmt Run(const Stmt &s) override; +}; +} // namespace akg +#endif // COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_ diff --git a/src/composite/optimize/reshape_tensor.cc b/src/composite/optimize/reshape_tensor.cc index 6788a70c..aae4540c 100644 --- a/src/composite/optimize/reshape_tensor.cc +++ b/src/composite/optimize/reshape_tensor.cc @@ -52,7 +52,8 @@ class ReshapeTensorMutator : public IRMutator { } Stmt Mutate_(const Provide *op, const Stmt &s) { - static std::unordered_set check_list = {"TensorAdd", "Add", "RealDiv", "Mul", "Minimum", "Maximum", "Sub"}; + static std::unordered_set check_list = {"TensorAdd", "Add", "RealDiv", "Mul", + "Minimum", "Maximum", "Sub"}; auto call = op->value.as(); if (call == nullptr || check_list.find(call->name) == check_list.end()) { return IRMutator::Mutate_(op, s); @@ -212,7 +213,7 @@ class ReshapeTensorMutator : public IRMutator { } auto call = op->value.as(); return Provide::make(op->func, 0, Call::make(op->value.type(), call->name, input, Call::CallType::PureIntrinsic), - op->args); + op->args); } Stmt ModifyAttrMap(const AttrStmt *op, const Stmt &stmt, const Map &attr_map) { @@ -252,9 +253,9 @@ class ReshapeTensorMutator : public IRMutator { for (const auto &it : reshape_) { auto arg = Call::make(it.first->dtype, it.first->op->name, it.first->shape, Call::CallType::Halide, it.first->op); - auto reshape_stmt = Provide::make( - it.second->op, 0, Call::make(it.first->dtype, "Reshape", {arg}, Call::CallType::PureIntrinsic), - it.second->shape); + auto reshape_stmt = + Provide::make(it.second->op, 0, Call::make(it.first->dtype, "Reshape", {arg}, Call::CallType::PureIntrinsic), + it.second->shape); Map attrs; attrs.Set("shape", it.second->shape); auto reshape_attr = AttrStmt::make(attrs, "attrs", Expr(1), reshape_stmt); @@ -353,12 +354,11 @@ class ReshapeTensorMutator : public IRMutator { } return std::make_tuple(shape_long, shape_tmp, shape_out); } - }; // When Matmul has DefaultFormat bias, reshape bias to FRACTAL_NZ format // If bias need pad, do pad as -// input_2_reshape(1,1,1,16) = Reshape(input_2(2)):float16:PI +// input_2_reshape(1,1,1,16) = Reshape(input_2(2)):float16:PI class ReshapeMatmul : public ReshapeTensorMutator { public: explicit ReshapeMatmul() {} @@ -383,7 +383,7 @@ class ReshapeMatmul : public ReshapeTensorMutator { } Stmt Mutate_(const Provide *op, const Stmt &s) { - static std::unordered_set check_list = {"MatMul"}; + static std::unordered_set check_list = {"MatMul", "BatchMatMul"}; auto call = op->value.as(); if (call == nullptr || check_list.find(call->name) == check_list.end()) { return IRMutator::Mutate_(op, s); @@ -468,9 +468,9 @@ class ReshapeMatmul : public ReshapeTensorMutator { return orig_shape; } - Array InferShapeToFractalNz(const Array &shape0, const Array &shape1, - const Array &shape_out, const Array &shape_fractal, - const std::string &op_name, const Array &shape_default) override { + Array InferShapeToFractalNz(const Array &shape0, const Array &shape1, const Array &shape_out, + const Array &shape_fractal, const std::string &op_name, + const Array &shape_default) override { auto dims = shape_out.size(); auto batch = dims - 2; Array shape_new; @@ -491,8 +491,8 @@ class ReshapeMatmul : public ReshapeTensorMutator { shape_new.push_back(shape_fractal[shape_fractal.size() - 1]); } } else { - LOG(FATAL) << "[" << op_name << "] " << shape_fractal << " (FRACTAL_NZ) and " << shape_default - << " (DefaultFormat) may need data format transformation for "; + LOG(FATAL) << "[" << op_name << "] " << shape_fractal << " (FRACTAL_NZ) and " << shape_default + << " (DefaultFormat) may need data format transformation for "; } return shape_new; } @@ -512,9 +512,13 @@ class ReshapeMatmul : public ReshapeTensorMutator { std::stack transpose_b; void PadBias(Array &shape_default) { - if (shape_default.size() != 1) { return; } + if (shape_default.size() != 1) { + return; + } auto bias_length = (shape_default[0].as())->value; - if (bias_length % 16 == 0) { return; } + if (bias_length % 16 == 0) { + return; + } int64_t pad_length = (bias_length / 16) * 16 + 16; shape_default.Set(0, Expr(pad_length)); LOG(INFO) << "Pad bias length from " << bias_length << " to " << pad_length; -- Gitee