import { useAssayContext } from '@resistapp/client/contexts/assay-context';
import { useSampleDataContext } from '@resistapp/client/contexts/sample-data-context';
import { useOverviewContext } from '@resistapp/client/contexts/use-overview-context/use-overview-context';
import {
  getBeforeOrAfterAbundances,
  OverviewDatum,
} from '@resistapp/client/data-utils/plot-data/build-overview-line-data';
import {
  hasAbundanceByAssay,
  targetAbundancesByAssay,
} from '@resistapp/client/data-utils/plot-data/process-overview-line-datum';
import { L2TargetOrAssay } from '@resistapp/common/assays';
import { LOQ_COPY_NUMBER_IN_L } from '@resistapp/common/statistics/fold-change';
import { PartialAbundance } from '@resistapp/common/statistics/resistance-index';
import { ProcessMode } from '@resistapp/common/types';
import { Group } from '@visx/group';
import { scaleLinear } from '@visx/scale';
import { AreaClosed, LinePath } from '@visx/shape';
import { curveLinear } from '@vx/curve';
import { EventType } from '@vx/event/lib/types';
import { chain, get, isNil, reduce, some, sortBy, values } from 'lodash';
import { Fragment, useCallback, useMemo, useState } from 'react';
import { theme } from '../../../shared/theme';

export type OnHover = (event: EventType) => void;

type OverviewDatumWithPrecalculatedValue = OverviewDatum & {
  precalculatedValue: number | undefined;
  isLOQ?: boolean;
};

interface DataPoint {
  datum: OverviewDatum;
  x: number;
  y: number;
  value: number;
  isLOQ?: boolean;
  isInterpolated?: boolean;
}

interface PointPair {
  before: DataPoint | null;
  after: DataPoint | null;
}

interface AreaPoint {
  x: number;
  y0: number;
  y1: number;
  isReduction: boolean;
}

interface Props {
  data: OverviewDatumWithPrecalculatedValue[];
  timeScale: (value: Date | number) => number | undefined;
  valueScale: (value: number) => number | undefined;
  selected: boolean;
  mouseMoveHandler: (data: OverviewDatumWithPrecalculatedValue[], event: React.MouseEvent<SVGElement>) => void;
  mouseClickHandler: (data: OverviewDatumWithPrecalculatedValue, event?: React.MouseEvent) => void;
  siteSelected: boolean;
  processMode: ProcessMode;
}

export function GeneTrendLines(props: Props) {
  const { data, timeScale, valueScale, selected, mouseMoveHandler, mouseClickHandler, siteSelected, processMode } =
    props;

  const { selectedMonth } = useOverviewContext();
  const { queryFilters } = useSampleDataContext();
  const { getGroup } = useAssayContext();
  const [isHoveringCircle, setIsHoveringCircle] = useState<{ index: number; type: 'before' | 'after' } | null>(null);

  // Check if we're in reduction mode where we show both before and after lines
  const isChangeChart = processMode === ProcessMode.DURING;

  // Get selected targets from the query filters
  const selectedTargets: L2TargetOrAssay[] = useMemo(() => {
    return queryFilters.filters.selectedTargets;
  }, [queryFilters.filters.selectedTargets]);

  // Filter abundances based on selected targets
  const filterAbundancesByTargets = useCallback(
    (abundances: PartialAbundance[] | undefined) => {
      if (!abundances?.length) return [];
      return chain(selectedTargets)
        .filter(target => hasAbundanceByAssay(abundances, target, getGroup))
        .flatMap(target => {
          const grouped = targetAbundancesByAssay(abundances, target, getGroup);
          return chain(grouped).values().flatten().value();
        })
        .uniqBy(abundance => abundance.assay)
        .value();
    },
    [getGroup, selectedTargets],
  );

  // Helper function to process points
  const processPoints = useCallback(
    (mode: ProcessMode) => {
      return chain(data)
        .map(d => {
          const allAbundances = getBeforeOrAfterAbundances(d, mode);
          const filteredAbundances = filterAbundancesByTargets(allAbundances);
          if (!filteredAbundances.length) return null;
          const hasDetectedValues = some(filteredAbundances, a => !isNil(a.copiesPerL));
          const avgCopiesPerL = hasDetectedValues
            ? chain(filteredAbundances)
                .map(a => a.copiesPerL)
                .filter(val => !isNil(val))
                .mean()
                .value()
            : LOQ_COPY_NUMBER_IN_L;
          return {
            datum: d,
            x: timeScale(new Date(d.date)) ?? 0,
            y: valueScale(avgCopiesPerL) ?? 0,
            value: avgCopiesPerL,
            isLOQ: !hasDetectedValues,
          };
        })
        .filter(point => point !== null)
        .map(point => point as DataPoint)
        .value();
    },
    [data, timeScale, valueScale, filterAbundancesByTargets],
  );

  const beforePoints = useMemo(
    () => (isChangeChart || processMode === ProcessMode.BEFORE ? processPoints(ProcessMode.BEFORE) : []),
    [processPoints, isChangeChart, processMode],
  );

  const afterPoints = useMemo(
    () => (isChangeChart || processMode === ProcessMode.AFTER ? processPoints(ProcessMode.AFTER) : []),
    [processPoints, isChangeChart, processMode],
  );

  // NOTE INTERPOLATION IS DISABLED FOR NOW, SINCE WE CAN ASSUME overivew data to have before and after points for all data
  // HOWEVER, IT'S KEPT FOR FUTURE REFERNCE IN CASE WE'LL NEED TO SUPPORT MISALIGNED BEFORE & AFTER TIMEPOINTS, OR SITUATIONS WHERE
  // BEFORE OR AFTER POINTS ARE MISSING FOR SOME DATES

  // Helper function to interpolate between two points
  // const interpolatePoint = useCallback(
  //   (earlierPoint: DataPoint, laterPoint: DataPoint, currentDate: Date): DataPoint => {
  //     const currentTimestamp = currentDate.getTime();
  //     const earlierTimestamp = ensureDate(earlierPoint.datum.date).getTime();
  //     const laterTimestamp = ensureDate(laterPoint.datum.date).getTime();
  //     const ratio = (currentTimestamp - earlierTimestamp) / (laterTimestamp - earlierTimestamp);
  //     const interpolatedY = earlierPoint.y + ratio * (laterPoint.y - earlierPoint.y);
  //     const interpolatedValue = earlierPoint.value + ratio * (laterPoint.value - earlierPoint.value);
  //     return {
  //       ...earlierPoint,
  //       y: interpolatedY,
  //       value: interpolatedValue,
  //       isInterpolated: true,
  //     };
  //   },
  //   [],
  // );

  // Find matching points between before and after for shading
  const combinedPoints = useMemo(() => {
    if (!isChangeChart) return [];

    // Add before points to the record
    const withBeforePoints = reduce(
      beforePoints,
      (acc, point) => {
        const date = point.datum.date;
        acc[date] = { before: point, after: null };
        return acc;
      },
      {} as Record<string, PointPair>,
    );

    // Add after points to the record
    const withAllPoints = afterPoints.reduce((acc, point) => {
      const date = point.datum.date;
      if (!get(acc, date, undefined)) {
        acc[date] = { before: null, after: point };
      } else {
        acc[date] = { ...acc[date], after: point };
      }
      return acc;
    }, withBeforePoints);

    // const beforeDates = beforePoints.map(point => point.datum.date);
    // const afterDates = afterPoints.map(point => point.datum.date);
    // const allDates = uniq([...beforeDates, ...afterDates]);
    // const sortedDates = sortBy(allDates, dateStr => new Date(dateStr).getTime());
    const processedPointsByDate = values(withAllPoints);

    // Keep all pairs where we have both before and after points
    return processedPointsByDate.filter(pair => pair.before && pair.after) as Array<{
      before: DataPoint;
      after: DataPoint;
    }>;
  }, [beforePoints, afterPoints, isChangeChart]);

  // Helper to check for line intersection
  const createIntersectionPoint = useCallback(
    (
      pair: { before: DataPoint; after: DataPoint },
      nextPair: { before: DataPoint; after: DataPoint },
    ): AreaPoint | null => {
      const beforeY = Number(pair.before.y);
      const afterY = Number(pair.after.y);
      const nextBeforeY = Number(nextPair.before.y);
      const nextAfterY = Number(nextPair.after.y);
      const x = pair.before.x;
      const nextX = nextPair.before.x;
      const beforeDiff = beforeY - afterY;
      const nextDiff = nextBeforeY - nextAfterY;
      // Lines intersect if the difference changes sign
      if ((beforeDiff > 0 && nextDiff < 0) || (beforeDiff < 0 && nextDiff > 0)) {
        // Calculate t where the lines intersect
        const t = (afterY - beforeY) / (nextBeforeY - beforeY - (nextAfterY - afterY));
        // Ensure t is between 0 and 1
        if (t > 0 && t < 1) {
          const intersectionX = x + t * (nextX - x);
          const intersectionY = beforeY + t * (nextBeforeY - beforeY);
          return {
            x: intersectionX,
            y0: intersectionY,
            y1: intersectionY,
            isReduction: false, // At intersection, reduction = increase = false
          };
        }
      }
      return null;
    },
    [],
  );

  // Create area fill data
  const areaData = useMemo(() => {
    if (!isChangeChart || combinedPoints.length < 2) return [];

    const sortedPoints = sortBy(combinedPoints, pair => pair.before.x);
    // Build area data with intersections
    return sortedPoints.flatMap((pair, i) => {
      const beforeY = Number(pair.before.y);
      const afterY = Number(pair.after.y);
      const x = pair.before.x;
      const result: AreaPoint[] = [
        {
          x,
          y0: beforeY,
          y1: afterY,
          isReduction: beforeY >= afterY,
        },
      ];
      // Check for intersection with next point
      if (i < sortedPoints.length - 1) {
        const nextPair = sortedPoints[i + 1];
        const intersection = createIntersectionPoint(pair, nextPair);
        if (intersection) {
          result.push(intersection);
        }
      }
      return result;
    });
  }, [combinedPoints, createIntersectionPoint, isChangeChart]);

  // Define transparent colors for the areas
  const transparentBlue = 'rgba(86, 133, 219, 0.2)'; // Based on blue500
  const transparentRed = 'rgba(225, 83, 147, 0.2)'; // Based on red500

  // Create an identity scale for the area charts
  const identityScale = scaleLinear({
    domain: [0, 1],
    range: [0, 1],
  });

  // Helper to determine if a point is an LOQ point
  const isPointLOQ = useCallback(
    (x: number) => some(afterPoints, p => Math.abs(p.x - x) < 0.001 && p.isLOQ),
    [afterPoints],
  );

  // Split the data into two separate arrays for better rendering
  const increaseData = useMemo(
    () =>
      areaData.filter(d => {
        const isAfterLOQ = isPointLOQ(d.x);
        return (d.y0 > d.y1 && !isAfterLOQ) || (!d.isReduction && d.y0 === d.y1 && !isAfterLOQ);
      }),
    [areaData, isPointLOQ],
  );
  const reductionData = useMemo(
    () =>
      areaData.filter(d => {
        const isAfterLOQ = isPointLOQ(d.x);
        return d.y0 < d.y1 || isAfterLOQ || (!d.isReduction && d.y0 === d.y1);
      }),
    [areaData, isPointLOQ],
  );

  return (
    <Group className="gene-trend-lines">
      {/* Reduction area (blue) - where raw > treated - only in reduction mode */}
      {isChangeChart && reductionData.length > 1 && (
        <AreaClosed
          data={reductionData}
          x={d => d.x}
          y0={d => d.y0}
          y1={d => d.y1}
          yScale={identityScale}
          curve={curveLinear}
          fill={transparentBlue}
          opacity={selected ? 1 : 0.8}
        />
      )}

      {/* Increase area (red) - where treated > raw - only in reduction mode */}
      {isChangeChart && increaseData.length > 1 && (
        <AreaClosed
          data={increaseData}
          x={d => d.x}
          y0={d => d.y0}
          y1={d => d.y1}
          yScale={identityScale}
          curve={curveLinear}
          fill={transparentRed}
          opacity={selected ? 1 : 0.8}
        />
      )}

      {/* Raw abundance line (solid) */}
      {beforePoints.length > 0 && (
        <LinePath
          curve={curveLinear}
          data={beforePoints}
          x={(d: DataPoint) => d.x}
          y={(d: DataPoint) => Number(d.y)}
          stroke={selected ? theme.colors.neutral900 : theme.colors.neutral400}
          strokeWidth={3}
          strokeOpacity={1}
          strokeDasharray={isChangeChart ? '5,5' : undefined}
          style={{ pointerEvents: 'visibleStroke' }}
          onMouseMove={event => {
            mouseMoveHandler(data, event);
          }}
        />
      )}

      {/* Treated abundance line (dashed) */}
      {afterPoints.length > 0 && (
        <LinePath
          curve={curveLinear}
          data={afterPoints}
          x={(d: DataPoint) => d.x}
          y={(d: DataPoint) => Number(d.y)}
          stroke={selected ? theme.colors.neutral900 : theme.colors.neutral400}
          strokeWidth={3}
          strokeOpacity={1}
          style={{ pointerEvents: 'visibleStroke' }}
          onMouseMove={event => {
            mouseMoveHandler(data, event);
          }}
        />
      )}

      {/* Raw abundance circles */}
      {beforePoints.map((point, j) => {
        const isLOQPoint = point.isLOQ;
        const isHovering = isHoveringCircle && isHoveringCircle.index === j && isHoveringCircle.type === 'before';
        const pointRadius =
          isHovering && siteSelected
            ? 10
            : getRadius(point.datum, selected, selectedMonth, selectedMonth === null && j === data.length - 1);

        return (
          <Fragment key={`before-${j}`}>
            <circle
              r={pointRadius}
              cx={point.x}
              cy={point.y}
              fill={isLOQPoint ? 'white' : selected ? theme.colors.neutral900 : theme.colors.neutral400}
              stroke={selected ? theme.colors.neutral900 : theme.colors.neutral400}
              strokeWidth={2}
              strokeDasharray={isLOQPoint ? '2,2' : '1,1'}
              onMouseMove={event => {
                mouseMoveHandler(data, event);
              }}
              onClick={event => {
                // Convert OverviewDatum to OverviewDatumWithPrecalculatedValue
                mouseClickHandler(
                  {
                    ...point.datum,
                    precalculatedValue: undefined,
                  },
                  event,
                );
              }}
              onMouseOver={() => {
                siteSelected && setIsHoveringCircle({ index: j, type: 'before' });
              }}
              onMouseOut={() => {
                siteSelected && setIsHoveringCircle(null);
              }}
              style={{ cursor: 'pointer' }}
            />
          </Fragment>
        );
      })}

      {/* Treated abundance circles */}
      {afterPoints.map((point, j) => {
        const isLOQPoint = point.isLOQ;
        const isHovering = isHoveringCircle && isHoveringCircle.index === j && isHoveringCircle.type === 'after';
        const pointRadius =
          isHovering && siteSelected
            ? 10
            : getRadius(point.datum, selected, selectedMonth, selectedMonth === null && j === data.length - 1);

        return (
          <Fragment key={`after-${j}`}>
            <circle
              r={pointRadius}
              cx={point.x}
              cy={point.y}
              fill={isLOQPoint ? 'white' : selected ? theme.colors.neutral900 : theme.colors.neutral400}
              stroke={selected ? theme.colors.neutral900 : theme.colors.neutral400}
              strokeWidth={isLOQPoint ? 2 : 0}
              onMouseMove={event => {
                mouseMoveHandler(data, event);
              }}
              onClick={event => {
                // Convert OverviewDatum to OverviewDatumWithPrecalculatedValue
                mouseClickHandler(
                  {
                    ...point.datum,
                    precalculatedValue: undefined,
                  },
                  event,
                );
              }}
              onMouseOver={() => {
                siteSelected && setIsHoveringCircle({ index: j, type: 'after' });
              }}
              onMouseOut={() => {
                siteSelected && setIsHoveringCircle(null);
              }}
              style={{ cursor: 'pointer' }}
            />
          </Fragment>
        );
      })}
    </Group>
  );
}

function getRadius(data: OverviewDatum, selected: boolean, selectedMonth: Date | null, isLastPoint: boolean) {
  const isSelectedMonth = selectedMonth?.toISOString() === data.date || isLastPoint;
  const radiusMultiplier = isSelectedMonth ? 1.6 : 1;

  return selected ? 5 * radiusMultiplier : 4 * radiusMultiplier;
}

// NOTE INTERPOLATION IS DISABLED FOR NOW, SINCE WE CAN ASSUME overivew data to have before and after points for all data
// HOWEVER, IT'S KEPT FOR FUTURE REFERNCE IN CASE WE'LL NEED TO SUPPORT MISALIGNED BEFORE & AFTER TIMEPOINTS, OR SITUATIONS WHERE
// BEFORE OR AFTER POINTS ARE MISSING FOR SOME DATES

// /**
//  * Interpolates missing points for all dates in a non-mutating way
//  *
//  * @param pointsByDate Initial map of date strings to point pairs
//  * @param sortedDates Sorted array of all dates
//  * @param interpolatePoint Function to interpolate between two points
//  * @returns A new map with interpolated points
//  */
// function interpolateAllPoints(
//   pointsByDate: Record<string, PointPair>,
//   sortedDates: string[],
//   interpolatePoint: (earlierPoint: DataPoint, laterPoint: DataPoint, currentDate: Date) => DataPoint,
// ): PointPair[] {
//   const retMap = sortedDates.reduce(
//     (accMap, currentDate) => {
//       const newPair = processDatePoints(accMap, currentDate, sortedDates, interpolatePoint);
//       if (newPair) {
//         accMap.set(currentDate, newPair);
//       }
//       return accMap;
//     },
//     new Map(entries(pointsByDate)),
//   );
//   return Array.from(retMap.values());
// }

// /**
//  * Processes a single date and interpolates missing points if needed
//  *
//  * @param pointsByDate Map of date strings to point pairs
//  * @param currentDate The date to process
//  * @param sortedDates Sorted array of all dates
//  * @param interpolatePoint Function to interpolate between two points
//  * @returns A new point pair for the current date or null if no change
//  */
// function processDatePoints(
//   pointsByDate: Map<string, PointPair>,
//   currentDate: string,
//   sortedDates: string[],
//   interpolatePoint: (earlierPoint: DataPoint, laterPoint: DataPoint, currentDate: Date) => DataPoint,
// ): PointPair | null {
//   const currentPair = pointsByDate.get(currentDate);
//   if (!currentPair) return null;

//   // If we have both points, no changes needed
//   if (currentPair.before && currentPair.after) return null;

//   // Find missing points using a reusable function
//   const before =
//     !currentPair.before && currentPair.after
//       ? findPoint(pointsByDate, currentDate, sortedDates, 'before', interpolatePoint)
//       : currentPair.before;

//   const after =
//     currentPair.before && !currentPair.after
//       ? findPoint(pointsByDate, currentDate, sortedDates, 'after', interpolatePoint)
//       : currentPair.after;

//   // Only return a new pair if something changed
//   if (before !== currentPair.before || after !== currentPair.after) {
//     return { before, after };
//   }

//   return null;
// }

// /**
//  * Finds a missing point by either interpolating between two points or using a neighboring point
//  *
//  * @param pointsByDate Map of date strings to point pairs
//  * @param currentDate The date to process
//  * @param sortedDates Sorted array of all dates
//  * @param pointType Which type of point to find ('before' or 'after')
//  * @param interpolatePoint Function to interpolate between two points
//  * @returns The found or interpolated DataPoint, or null if not possible
//  */
// function findPoint(
//   pointsByDate: Map<string, PointPair>,
//   currentDate: string,
//   sortedDates: string[],
//   pointType: keyof PointPair, // 'before' or 'after'
//   interpolatePoint: (earlierPoint: DataPoint, laterPoint: DataPoint, currentDate: Date) => DataPoint,
// ): DataPoint | null {
//   const currentTime = new Date(currentDate).getTime();

//   // Find earlier and later dates
//   const earlierDates = sortedDates.filter(d => ensureDate(d).getTime() < currentTime);
//   const laterDates = sortedDates.filter(d => ensureDate(d).getTime() > currentTime);

//   // Find earlier and later points of the requested type
//   const earlierWithPoint = find(earlierDates.reverse(), d => {
//     const pair = pointsByDate.get(d);
//     return pair && pair[pointType];
//   });

//   const laterWithPoint = find(laterDates, d => {
//     const pair = pointsByDate.get(d);
//     return pair && pair[pointType];
//   });

//   // Handle different cases based on available points
//   if (earlierWithPoint && laterWithPoint && isString(earlierWithPoint) && isString(laterWithPoint)) {
//     // Interpolate between two points
//     const earlierPair = pointsByDate.get(earlierWithPoint);
//     const laterPair = pointsByDate.get(laterWithPoint);
//     if (earlierPair?.[pointType] && laterPair?.[pointType]) {
//       return interpolatePoint(earlierPair[pointType], laterPair[pointType], new Date(currentDate));
//     }
//   } else if (earlierWithPoint && isString(earlierWithPoint)) {
//     // Use earlier point
//     const earlierPair = pointsByDate.get(earlierWithPoint);
//     if (earlierPair?.[pointType]) {
//       return { ...earlierPair[pointType], isInterpolated: true };
//     }
//   } else if (laterWithPoint && isString(laterWithPoint)) {
//     // Use later point
//     const laterPair = pointsByDate.get(laterWithPoint);
//     if (laterPair?.[pointType]) {
//       return { ...laterPair[pointType], isInterpolated: true };
//     }
//   }
//   return null;
// }
