How to track down data usage in LLVM IR

Hi, I ran into a data track problem when trying to change/wrap the functions use this data as an argument. This is a LTO optimization pass and I am suffering from how to sufficiently track down all usages/modifications of the target data(in fact, it is a pointer).

Problem Abstract:
Now I want to replace all specific functions using/modifying a specific data pointer p(can be local or passed by argument in one function), how to track down all related data and specific functions?

Sample:
Let’s take a look of following sample code

void funcA(float * a)
{
        //do some fp32 calculation on array a
}

void funcB(float * b)
{
        funcA(b);
}

void funcC(float * c)
{
        //do some mixed-precision calculation on array c
}

int main()
{
        float * p = (float*)malloc(sizeof(float)*100);
        //init the data in array p
        init_func(p);
        //do some calculating in p
        funcA(p);
        float * q = p;
        //do some calculating in q(which is p actually)
        funcA(q);
        //a wrapper of funcA
        funcB(p);
}

The optimization pass we are using is to replace funcA into mixed-precision(like fp16) funcC.

Now we assume that we have some precision-error message about array c in funcC and now we are dealing with the funcA in funcB(assuming no inline), determining whether we should replace it with funcC according to the precision-loss threshold we want.

So, let’s say we now facing such a optimized code:

void funcA(float * a)
{
        //do some fp32 calculation on array a
}

void funcB(float * b)
{
        funcA(b);
}

void funcC(float * c)
{
        //do some mixed-precision calculation on array c
}

int main()
{
        float * p = (float*)malloc(sizeof(float)*100);
        //init the data in array p
        init_func(p);
        //do some calculating in p
        funcC(p);
        float * q = p;
        //do some calculating in q(which is p actually)
        funcC(q);
        //a wrapper of funcA
        funcB(p);
}

If I found out that the precision-loss is too big to accept, I want to revert all replaced funcA (funcC now) back to the original called function(a.k.a. funcA(p) and funcA(q)). So how can I recognize all functions/instructions that modify the data via the pointer p?

Any discussion is highly appreciated, I think this is a pretty difficult problem which needs a lot of consideration. Thank you.

Through all these days, I implemented a DFS method to find all related data. But now I ran into a problem that I cannot correctly recognize related data / functions. What I need to do is to find those functions which produces the input for the current function. In this sample, the 7th, 10th, 14th argument are the input, the 14th argument is also the output of GemmEx.
Sample:

; ModuleID = '/mnt/data/home/mzw/workspace/test_space/llvm_test/gemm_pass2_test/dataflow_test/short_test.hip'
source_filename = "/mnt/data/home/mzw/workspace/test_space/llvm_test/gemm_pass2_test/dataflow_test/short_test.hip"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%"class.std::ios_base::Init" = type { i8 }

@_ZStL8__ioinit = internal global %"class.std::ios_base::Init" zeroinitializer, align 1
@__dso_handle = external hidden global i8
@llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* @_GLOBAL__sub_I_short_test.hip, i8* null }]

declare dso_local void @_ZNSt8ios_base4InitC1Ev(%"class.std::ios_base::Init"* nonnull dereferenceable(1)) unnamed_addr #0

; Function Attrs: nounwind
declare dso_local void @_ZNSt8ios_base4InitD1Ev(%"class.std::ios_base::Init"* nonnull dereferenceable(1)) unnamed_addr #1

; Function Attrs: nofree nounwind
declare dso_local i32 @__cxa_atexit(void (i8*)*, i8*, i8*) local_unnamed_addr #2

; Function Attrs: norecurse uwtable mustprogress
define dso_local i32 @main() local_unnamed_addr #3 {
  %1 = alloca float*, align 8
  %2 = alloca float*, align 8
  %3 = alloca float*, align 8
  %4 = alloca i8*, align 8
  %5 = alloca float, align 4
  %6 = alloca float, align 4
  %7 = bitcast float** %1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %7) #6
  %8 = bitcast float** %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %8) #6
  %9 = bitcast float** %3 to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %9) #6
  %10 = bitcast float** %1 to i8**
  %11 = call i32 @hipMalloc(i8** nonnull %10, i64 4)
  %12 = bitcast float** %2 to i8**
  %13 = call i32 @hipMalloc(i8** nonnull %12, i64 4)
  %14 = bitcast float** %3 to i8**
  %15 = call i32 @hipMalloc(i8** nonnull %14, i64 4)
  %16 = bitcast i8** %4 to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* nonnull %16) #6
  store i8* null, i8** %4, align 8, !tbaa !2
  %17 = call i32 @hipblasCreate(i8** nonnull %4)
  %18 = bitcast float* %5 to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %18) #6
  %19 = bitcast float* %6 to i8*
  call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %19) #6
  %20 = load i8*, i8** %4, align 8, !tbaa !2
  %21 = load i8*, i8** %10, align 8, !tbaa !2
  %22 = load i8*, i8** %12, align 8, !tbaa !2
  %23 = load i8*, i8** %14, align 8, !tbaa !2
  %24 = call i32 @hipblasGemmEx(i8* %20, i32 111, i32 111, i32 1, i32 1, i32 1, i8* nonnull %18, i8* %21, i32 151, i32 1, i8* %22, i32 151, i32 1, i8* nonnull %19, i8* %23, i32 151, i32 1, i32 151, i32 160)
  %25 = load i8*, i8** %4, align 8, !tbaa !2
  %26 = load i8*, i8** %12, align 8, !tbaa !2
  %27 = load i8*, i8** %10, align 8, !tbaa !2
  %28 = call i32 @hipblasGemmEx(i8* %25, i32 111, i32 111, i32 1, i32 1, i32 1, i8* nonnull %18, i8* %26, i32 151, i32 1, i8* %26, i32 151, i32 1, i8* nonnull %19, i8* %27, i32 151, i32 1, i32 151, i32 160)
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %19) #6
  call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %18) #6
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %16) #6
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %9) #6
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %8) #6
  call void @llvm.lifetime.end.p0i8(i64 8, i8* nonnull %7) #6
  ret i32 0
}

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #4

declare dso_local i32 @hipMalloc(i8**, i64) local_unnamed_addr #0

declare dso_local i32 @hipblasCreate(i8**) local_unnamed_addr #0

declare dso_local i32 @hipblasGemmEx(i8*, i32, i32, i32, i32, i32, i8*, i8*, i32, i32, i8*, i32, i32, i8*, i8*, i32, i32, i32, i32) local_unnamed_addr #0

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #4

; Function Attrs: uwtable
define internal amdgpu_kernel void @_GLOBAL__sub_I_short_test.hip() #5 section ".text.startup" {
  tail call void @_ZNSt8ios_base4InitC1Ev(%"class.std::ios_base::Init"* nonnull dereferenceable(1) @_ZStL8__ioinit)
  %1 = tail call i32 @__cxa_atexit(void (i8*)* bitcast (void (%"class.std::ios_base::Init"*)* @_ZNSt8ios_base4InitD1Ev to void (i8*)*), i8* getelementptr inbounds (%"class.std::ios_base::Init", %"class.std::ios_base::Init"* @_ZStL8__ioinit, i64 0, i32 0), i8* nonnull @__dso_handle) #6
  ret void
}

attributes #0 = { "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #1 = { nounwind "frame-pointer"="none" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #2 = { nofree nounwind }
attributes #3 = { norecurse uwtable mustprogress "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #4 = { argmemonly nofree nosync nounwind willreturn }
attributes #5 = { uwtable "device-init" "frame-pointer"="none" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" }
attributes #6 = { nounwind }

!llvm.module.flags = !{!0}
!llvm.ident = !{!1}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{!"clang version 13.0.0 (https://github.com/RadeonOpenCompute/llvm-project roc-4.3.0 21295 f2943f684437d2c1143a56e418d29fc6b3314072)"}
!2 = !{!3, !3, i64 0}
!3 = !{!"any pointer", !4, i64 0}
!4 = !{!"omnipotent char", !5, i64 0}
!5 = !{!"Simple C++ TBAA"}

My Core DFS code is contained below:

struct dataflow : public ModulePass{
        static char ID;

        dataflow();
        ~dataflow(){};

        bool runOnModule(Module & M) override;
        void dfs(Value *,Function *, int,std::unordered_map<Value*,bool>&);
        void jump_out_of_parent(Value *, Function *, int,std::unordered_map<Value*,bool>&);
        template<typename T>bool is_in_list(std::vector<T>,T);

        int gemmex_id;
        std::unordered_map<int,Instruction*> gemm_id_call_inst_map;
        std::unordered_map<Instruction*,int> gemm_call_inst_int_map;
        std::unordered_map<int,std::vector<int>> related_gemm_id_map;
        //below one is for general call inst, above are for gemm
        std::vector<CallInst*> call_inst_list;
        std::string modify_matrix_mem_func_name;

    };

    dataflow::dataflow() : ModulePass(ID){
        gemmex_id = 0;
        modify_matrix_mem_func_name = "dududu";
    }

    template<typename T>
    bool dataflow::is_in_list(std::vector<T> l, T target)
    {
        for(T e: l)
        {
            if(target == e) return true;
        }
        return false;
    }

    //NOTE: The related_gemm_id_map we get here contains all related gemm without consindering 
    //1)whether is runned acutally      2) whether it's runned before it(we only consider the before gemm every time we optimize)
    void dataflow::dfs(Value * called_arg, Function * caller_func, int target_id, std::unordered_map<Value*,bool> & dfsed_value_map)
    {
        //TO.DO.: How to avoid repeating the same instruction?                      //DONE
        if(dfsed_value_map.find(called_arg) != dfsed_value_map.end() &&  dfsed_value_map[called_arg]) return;
        else
            dfsed_value_map[called_arg] = true;
        errs()<<"Current target arg is "<<*called_arg<<"\n";
        //"user" means that this argument is used as argument/operand somewhere
        for(auto user = called_arg->user_begin(), user_end = called_arg->user_end();
            user != user_end; user++)           //User means that this arg is used as operand in these instructions
        {
            //If this arg is just right as the argument of GemmEx
            if(Instruction * inst = dyn_cast<Instruction>(*user))
            {
                errs()<<*inst<<"\n";
                //errs()<<"Facing instruction of "<<*inst<<"\n";
                if(isa<CallInst>(inst))
                {
                    CallInst * call_inst = dyn_cast<CallInst>(inst);
                    Function * called_func = call_inst->getCalledFunction();
                    if(called_func && called_func->getName() == "hipblasGemmEx")
                    {
                        int cur_id = gemm_call_inst_int_map[inst];
                        errs()<<"We found the call_inst of GemmEx: "<<*call_inst<<"\n";
                        //we only care about those gemmex that accept this arg as output
                        //NOTE: This makes us wont add gemmex itself into its dependence list
                        if(called_arg==call_inst->getOperand(14))
                        {
                            //TO.DO.: Avoid searching the same gemm                        //DONE
                            if(is_in_list<int>(related_gemm_id_map[target_id],cur_id))
                            {
                                //do nothing
                                //errs()<<"We occur the same GemmEx with id "<<cur_id<<"\n";
                            }
                            else
                            {
                                errs()<<"The "<<target_id<<"th GemmEx depends on "<<cur_id<<"th GemmEx\n";
                                related_gemm_id_map[target_id].push_back(cur_id);
                                //dfs(call_inst->getOperand(7),caller_func, cur_id, dfsed_value_map);
                                //dfs(call_inst->getOperand(10),caller_func, cur_id, dfsed_value_map);
                                //dfs(call_inst->getOperand(14),caller_func, cur_id, dfsed_value_map);
                            }
                        }
                        else
                            continue;

                    }
                    else if(called_func && called_func->getName() == modify_matrix_mem_func_name)
                    {
                        //TO.DO.: when we run into something like ReadFile() that can modify the Matrix to be a new data matrix
                        //what should we do?
                        
                    }
                    else
                    {
                        //we dont care about other functions
                        //10-26:But if we met a function containing the GemmEx, we wont dig in. 
                        //Oppositely, it will start from the contained GemmEx and jump out to find this GemmEx
                        //QUES.: But in this way, we cannot find the dependency from called func to current GemmEx
                        //like {testfunc(AAC),GemmEx(ABC)} we cannot know the GemmEx depends on the one in testfunc
                        //we only can know testfunc depends on GemmEx's C
                        continue;
                    }
                }
                else
                {
                    //we assume we only have load/store in this branch
                    //errs()<<"Now we have met the load/store inst\n";
                    Value * ret_v = dynamic_cast<Value*>(inst);
                    errs()<<*ret_v<<"\n";
                    dfs(ret_v,caller_func, target_id, dfsed_value_map);
                }
            }
        }
        //"use" means that this argument is def/not as an operand somewhere.
        for(auto use = called_arg->use_begin(), use_end = called_arg->use_end(); use != use_end; use++)
        {
            if(Instruction * inst = dyn_cast<Instruction>(*use))
            {
                for(auto i = 0; i < inst->getNumOperands(); i++)
                {
                    Value * related_v = inst->getOperand(i);
                    if(related_v == called_arg) continue;
                    else dfs(related_v,caller_func, target_id, dfsed_value_map);
                }
            }
        }

        //For those whose argument is passed through the arguments of parent functions
        //In fact, we should check whether it is in parent's arguments whenever we are handling a new Value
        //So that we can jump out of parent function, get the all coresponding call_inst of parent function
        //and dfs on the coresponding passed arguments of call_inst
        //TO.DO.: 
        jump_out_of_parent(called_arg,caller_func,target_id,dfsed_value_map);
    }

    void dataflow::jump_out_of_parent(Value * target_arg, Function * parent_func, int target_id, std::unordered_map<Value*,bool>& dfsed_value_map)
    {
        //errs()<<parent_func->getName()<<"\n";
        /*
        if(parent_func->getName().str() == "main") return;
        if(parent_func->getName().str() == "_Z8testfuncPv18hipblasOperation_tS0_iiiS_S_17hipblasDatatype_tiS_S1_iS_S_S1_iS1_17hipblasGemmAlgo_t")
        {
            errs()<<"jumping out of testfunc\n";
            errs()<<"It has total "<<parent_func->arg_size()<<" arguments\n";
        }
        */
        //NOTE: We cannot use getNumOperands to get the argument list size of a function def
        for(size_t i = 0; i < parent_func->arg_size(); i++)
        {
            Value * arg = parent_func->getArg(i);
            //errs()<<"The function "<<parent_func->getName().str()<<" "<<i<<"th argument is "<<*arg<<"\n";
            if(target_arg == arg)
            {
                //TO.DO.: Loop over call_inst_list and locate all corresponding passed-in argument, dfs on them
                for(int j = 0; j < call_inst_list.size(); j++)
                {
                    Function * called_func = call_inst_list[j]->getCalledFunction();
                    if(called_func && called_func == parent_func)
                    {
                        //errs()<<"We now jump out of test func\n";
                        CallInst * call_inst = call_inst_list[j];
                        Value * target_passed_arg = call_inst->getArgOperand(i);
                        Function * new_parent_func = call_inst->getParent()->getParent();
                        dfs(target_passed_arg,new_parent_func,target_id, dfsed_value_map);
                    }
                }
            }
        }
    }

    bool dataflow::runOnModule(Module &M)
    {
        //TO.DO.: In official version, we should only care about functions outside the tool_library
        //NOTE: We only collect all id of call_inst and gemmex here. Because in dfs, no any in-order ensured
        for(auto func = M.getFunctionList().begin(), end_func = M.getFunctionList().end();
            func != end_func; func++)
        {
            //errs()<<"Now we are facing declare of function "<<func->getName()<<"\n";
            for(auto bb = func->begin(); bb != func->end(); bb++)
            {
                for(auto inst = bb->begin(); inst != bb->end(); inst++)
                {
                    if(CallInst * call_inst = dyn_cast<CallInst>(inst))
                    {
                        call_inst_list.push_back(call_inst);
                        Function * called_func = call_inst->getCalledFunction();
                        if(called_func && called_func->getName() == "hipblasGemmEx")
                        {
                            errs()<<"We get the "<<++gemmex_id<<"th called GemmEx function in "<<*call_inst<<"\n";
                            gemm_id_call_inst_map[gemmex_id]=call_inst;
                            gemm_call_inst_int_map[call_inst] = gemmex_id;
                        }
                    }
                }
            }
        }

        errs()<<"Now we finish collecting all call_inst(including gemmex)\n";

        //Now we are dealing with related-gemm 
        gemmex_id = 0;
        for(auto func = M.getFunctionList().begin(), end_func = M.getFunctionList().end();
            func != end_func; func++)
        {
            //errs()<<"Now we are facing declare of function "<<func->getName()<<"\n";
            for(auto bb = func->begin(); bb != func->end(); bb++)
            {
                for(auto inst = bb->begin(); inst != bb->end(); inst++)
                {
                    //only focus on GemmEx
                    if(CallInst * call_inst = dyn_cast<CallInst>(inst))
                    {
                        Function * called_func = call_inst->getCalledFunction();
                        if(called_func && called_func->getName() == "hipblasGemmEx")
                        {
                            gemmex_id++;
                            errs()<<"Now we use "<<gemmex_id<<"th GemmEx as target GemmEx\n";
                            //get MatrixA argument defined before.
                            Value * Matrix1_argv = call_inst->getArgOperand(7);
                            //For those who has a load / related operation previous
                            Function * caller_func = dyn_cast<Function>(func);
                            std::unordered_map<Value*,bool> dfsed_value_map;
                            dfs(Matrix1_argv,caller_func,gemmex_id,dfsed_value_map);

                            Value * Matrix2_argv = call_inst->getArgOperand(10);
                            dfsed_value_map.clear();
                            dfs(Matrix2_argv,caller_func,gemmex_id,dfsed_value_map);

                            Value * Matrix3_argv = call_inst->getArgOperand(14);
                            dfsed_value_map.clear();
                            dfs(Matrix3_argv,caller_func,gemmex_id,dfsed_value_map);

                        }
                    }
                }
            }
        }

        //running above, we are able to get the whole data-related gemm
        //filter those gemm are not actually runned or not runned before target-gemm
        for(auto it = related_gemm_id_map.begin(); it != related_gemm_id_map.end(); it++)
        {
            std::cout<<"The related GemmEx id of "<<it->first<<" contains: ";
            for(auto id : it->second) std::cout<<id<<" ";
            std::cout<<std::endl;
        }


        return false;
    }

For this sample, we don’t need the function jump_out_of_parent, and In the end, I got the following output:

The related GemmEx id of 2 contains: 1 2 
The related GemmEx id of 1 contains: 2 1 

Noticing that %24 = call i32 @hipblasGemmEx(i8* %20, i32 111, i32 111, i32 1, i32 1, i32 1, i8* nonnull %18, i8* %21, i32 151, i32 1, i8* %22, i32 151, i32 1, i8* nonnull %19, i8* %23, i32 151, i32 1, i32 151, i32 160) is GemmEx 1,
and %28 = call i32 @hipblasGemmEx(i8* %25, i32 111, i32 111, i32 1, i32 1, i32 1, i8* nonnull %18, i8* %26, i32 151, i32 1, i8* %26, i32 151, i32 1, i8* nonnull %19, i8* %27, i32 151, i32 1, i32 151, i32 160) is GemmEx 2.

The related cpp code is like

hipblasGemmEx( blas_handle, HIPBLAS_OP_N, HIPBLAS_OP_N, 
                                    m, n, k, &alpha, 
                                    A, HIPBLAS_R_32F, k,
                                    B, HIPBLAS_R_32F, n,
                                    &beta, C, HIPBLAS_R_32F, m, 
                                    HIPBLAS_R_32F, HIPBLAS_GEMM_DEFAULT);
    hipblasGemmEx( blas_handle, HIPBLAS_OP_N, HIPBLAS_OP_N, 
                                    m, n, k, &alpha, 
                                    B, HIPBLAS_R_32F, k,
                                    B, HIPBLAS_R_32F, n,
                                    &beta, A, HIPBLAS_R_32F, m, 
                                    HIPBLAS_R_32F, HIPBLAS_GEMM_DEFAULT);

So the ideal output should be

The related GemmEx id of 2 contains: 2 
The related GemmEx id of 1 contains: 1 2