Count Good Triplets in an Array

Problem Statement

You are given two 0-indexed arrays nums1 and nums2, both of length n, consisting of a permutation of integers from 0 to n - 1.

A good triplet is a set of 3 distinct values which are present in increasing order by position both in nums1 and nums2. In other words, if we consider pos1v as the index of value v in nums1 and pos2v as the index of value v in nums2, then a good triplet will be a set (x, y, z) where:

Return the total number of good triplets.

Example 1:

Input: nums1 = [2,0,1,3], nums2 = [0,1,2,3]
Output: 1
Explanation: 
There are 4 triplets (x,y,z) such that pos1x < pos1y < pos1z. 
They are (2,0,1), (2,0,3), (2,1,3), and (0,1,3). 
Out of those triplets, only the triplet (0,1,3) satisfies 
pos2x < pos2y < pos2z. Hence, there is only 1 good triplet.
      

Example 2:

Input: nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]
Output: 4
Explanation: The 4 good triplets are (4,0,3), (4,0,2), 
(4,1,3), and (4,1,2).
      

Constraints:

Approach 1: Brute Force

Explanation: Check all possible triplets and verify if they maintain increasing order in both arrays.

Time Complexity: O(n³)

Space Complexity: O(n) for position mapping

function goodTriplets(nums1, nums2):
    n = length of nums1
    pos1 = map value to index in nums1
    pos2 = map value to index in nums2
    
    count = 0
    for x = 0 to n-1:
        for y = 0 to n-1:
            if pos1[x] >= pos1[y] or pos2[x] >= pos2[y]:
                continue
            for z = 0 to n-1:
                if pos1[y] < pos1[z] and pos2[y] < pos2[z]:
                    count++
    
    return count
      

šŸ’” Think: Triple nested loop checks all combinations but too slow for constraints.

Approach 2: Binary Indexed Tree (BIT) - Optimal

Explanation: For each element as middle element, count:
• Elements to its left in nums1 that also appear before it in nums2
• Elements to its right in nums1 that also appear after it in nums2
Use BIT to efficiently count smaller/larger elements.

Time Complexity: O(n log n)

Space Complexity: O(n)

function goodTriplets(nums1, nums2):
    n = length of nums1
    pos2 = map nums2[i] to index i
    
    // Transform nums1 based on positions in nums2
    arr = [pos2[nums1[i]] for i in range(n)]
    
    // Count smaller elements to the left for each position
    left = array of size n
    bit1 = BinaryIndexedTree(n)
    for i = 0 to n-1:
        left[i] = bit1.query(arr[i])
        bit1.update(arr[i] + 1, 1)
    
    // Count larger elements to the right for each position
    right = array of size n
    bit2 = BinaryIndexedTree(n)
    for i = n-1 down to 0:
        right[i] = bit2.query(n) - bit2.query(arr[i] + 1)
        bit2.update(arr[i] + 1, 1)
    
    // Count good triplets with i as middle element
    count = 0
    for i = 0 to n-1:
        count += left[i] * right[i]
    
    return count

class BinaryIndexedTree:
    function __init__(size):
        tree = array of size+1 initialized to 0
    
    function update(index, delta):
        while index <= size:
            tree[index] += delta
            index += index & (-index)
    
    function query(index):
        sum = 0
        while index > 0:
            sum += tree[index]
            index -= index & (-index)
        return sum
      

šŸ’” Think: Transform to position array, use BIT to count valid left/right elements, multiply counts.

Approach 3: Segment Tree with Coordinate Compression

Explanation: Similar to BIT approach but using segment tree for range queries. Map positions in nums2, then count inversions efficiently.

Time Complexity: O(n log n)

Space Complexity: O(n)

// C++: count good triplets using segment trees (or BIT). Here is a clean segment tree implementation
#include <vector>
using namespace std;

struct SegTree {
  int n; vector<int> t;
  SegTree(int _n=0){ init(_n); }
  void init(int _n){ n=_n; t.assign(4*n,0); }

  void update(int pos,int val,int idx,int l,int r){
    if(l==r){ t[idx]+=val; return; }
    int mid=(l+r)/2;
    if(pos <= mid) update(pos,val,2*idx+1,l,mid);
    else update(pos,val,2*idx+2,mid+1,r);
    t[idx]=t[2*idx+1]+t[2*idx+2];
  }

  int query(int ql,int qr,int idx,int l,int r){
    if(ql > r || qr < l) return 0;
    if(ql <= l && r <= qr) return t[idx];
    int mid=(l+r)/2;
    return query(ql,qr,2*idx+1,l,mid) + query(ql,qr,2*idx+2,mid+1,r);
  }
};

class Solution {
public:
  long long goodTriplets(vector<int>& nums1, vector<int>& nums2){
    int n = nums1.size();
    vector<int> pos2(n);
    for(int i=0;i<n;i++) pos2[nums2[i]] = i;

    // map nums1 values to positions in nums2
    vector<int> arr(n);
    for(int i=0;i<n;i++) arr[i] = pos2[nums1[i]];

    SegTree seg(n);
    vector<long long> left(n,0), right(n,0);

    // count elements smaller than arr[i] to the left
    for(int i=0;i<n;i++){
      if(arr[i]-1 >= 0) left[i] = seg.query(0, arr[i]-1, 0, 0, n-1);
      seg.update(arr[i], 1, 0, 0, n-1);
    }

    // reset segment tree
    seg.init(n);
    for(int i=n-1;i>=0;i--){
      if(arr[i]+1 <= n-1) right[i] = seg.query(arr[i]+1, n-1, 0, 0, n-1);
      seg.update(arr[i], 1, 0, 0, n-1);
    }

    long long ans = 0;
    for(int i=0;i<n;i++) ans += left[i] * right[i];
    return ans;
  }
};

šŸ’” Think: Segment tree provides same functionality as BIT with more intuitive range operations.

Approach 4: Merge Sort Based Counting

Explanation: Use merge sort to count inversions efficiently. Transform problem to counting specific ordered pairs using modified merge sort.

Time Complexity: O(n log n)

Space Complexity: O(n)

function goodTriplets(nums1, nums2):
    n = length of nums1
    pos2 = {nums2[i]: i for i in range(n)}
    arr = [pos2[nums1[i]] for i in range(n)]
    
    left = array of size n  // count of valid left elements
    right = array of size n // count of valid right elements
    
    // Count smaller elements on left using merge sort
    for i = 0 to n-1:
        left[i] = countSmaller(arr[0:i], arr[i])
    
    // Count larger elements on right using merge sort
    for i = 0 to n-1:
        right[i] = countLarger(arr[i+1:n], arr[i])
    
    count = 0
    for i = 0 to n-1:
        count += left[i] * right[i]
    
    return count

function countSmaller(arr, target):
    // Count elements < target in arr
    
function countLarger(arr, target):
    // Count elements > target in arr
      

šŸ’” Think: Merge sort efficiently counts ordered elements; multiply valid pairs for triplet count.

Solve on LeetCode