#include<stdcpp.h>
using namespace std;

struct TreeNode {
    int val;
    TreeNode *left;
    TreeNode *right;
    TreeNode() : val(0), left(nullptr), right(nullptr) {}
    TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}
    TreeNode(int x, TreeNode *left, TreeNode *right) : val(x), left(left), right(right) {}
};

class Solution {
public:
    priority_queue<int, vector<int>, greater<int> > vals;
    int cnt = 0;
    void traverse(TreeNode * node){
        vals.push(node->val);
        cnt++;
        if(node->left!=nullptr)  traverse(node->left);
        if(node->right!=nullptr)  traverse(node->right);
    }
    TreeNode* bstToGst(TreeNode* root) {
        traverse(root);
        int val[110];
        int idx[110];
        int pref[110],sum=0;
        for(int i = 0 ; i < cnt ; i++){
            val[i] = vals.top();
            vals.pop();
            if(i==0)
                pref[i] = val[i];
            else{
                pref[i] = pref[i-1] + val[i];
            }
        }
        sum = pref[cnt-1];
        for(int i = 0 ; i < cnt ; i++){
            idx[val[i]] = sum - pref[i];
        }
        stack<TreeNode*> sta;
        sta.push(root);
        while(sta.empty() == false){
            TreeNode* top_node = sta.top();
            cout<<"top_node is "<<top_node->val<<endl;
            sta.pop();
            if(top_node->left!=nullptr) sta.push(top_node->left);
            if(top_node->right!=nullptr) sta.push(top_node->right);
            top_node->val+=idx[top_node->val];
        }
        cnt = 0;
        std::priority_queue<int, std::vector<int>, std::greater<int>> empty;
        vals.swap(empty);
        return root;
    }
};

void create(int * val, int cnt, TreeNode *cur,int idx){
    int left_idx = (idx + 1) * 2 - 1;
    int right_idx = (idx + 1) * 2 ;
    if(left_idx > cnt){
        cur->left = nullptr;
        cur->right = nullptr;
        return;
    }
    if(right_idx > cnt){
        cur->right = nullptr;
        return;
    }
    if(val[left_idx] == -1){
        cur -> left = nullptr;
    }else{
        TreeNode * left_child = new TreeNode(val[left_idx]);
        cur->left = left_child;
        create(val,cnt,left_child,left_idx);
    }
    if(val[right_idx] == -1){
        cur -> right = nullptr;
    }else{
        TreeNode * right_child = new TreeNode(val[right_idx]);
        cur->right = right_child;
        create(val, cnt, right_child, right_idx);
    }
    return;
}
void traverse(TreeNode* node){
    cout<<node->val<<' ';
    if(node->left != nullptr){
        traverse(node->left);
    }
    if(node->right != nullptr){
        traverse(node->right);
    }
    return;
}
TreeNode* createTree(int * val, int cnt){
    TreeNode * root = new TreeNode(val[0]);
    cout<<"create a tree"<<endl;
    create(val,cnt,root, 0);
    traverse(root);
    cout<<"\ndone"<<endl;
    // cout<<endl;
    return root;
}
int main(){
    Solution sol;

    //create example 1
    const int ex1_length = 15;
    int ex1_data[ex1_length] = {4,1,6,0,2,5,7,-1,-1,-1,3,-1,-1,-1,8};
    TreeNode * ex1 = createTree(ex1_data, ex1_length);
    sol.bstToGst(ex1);
    traverse(ex1);
    cout<<endl;

    const int ex2_length = 3;
    int ex2_data[ex2_length] = {0,-1,1};
    TreeNode * ex2 = createTree(ex2_data,ex2_length);
    sol.bstToGst(ex2);
    traverse(ex2);
    cout<<endl;


    return 0;
}