import { Flex } from '@chakra-ui/react';
import { useAssayContext } from '@resistapp/client/contexts/assay-context';
import { useSampleDataContext } from '@resistapp/client/contexts/sample-data-context';
import { useOverviewContext } from '@resistapp/client/contexts/use-overview-context/use-overview-context';
import { OverviewDatum } from '@resistapp/client/data-utils/plot-data/build-overview-line-data';
import { getGenesAndCopyNumbers } from '@resistapp/client/data-utils/plot-data/process-overview-line-datum';
import { useContainerWidth } from '@resistapp/client/hooks/use-container-width';
import { getMetricColor } from '@resistapp/client/utils/metric-utils';
import { L1Target, L1Targets, L2TargetOrAssay, L2Targets, l2TargetsByL1 } from '@resistapp/common/assays';
import { MinorTarget } from '@resistapp/common/assays-temp-96-gene-minor-targets';
import { friendlyCopyNumber, friendlyL2Target } from '@resistapp/common/friendly';
import { LOQ_COPY_NUMBER_IN_L } from '@resistapp/common/statistics/fold-change';
import { ChartUnit, MetricMode, ProcessMode } from '@resistapp/common/types';
import { AxisBottom, AxisLeft } from '@visx/axis';
import { GridColumns } from '@visx/grid';
import { Group } from '@visx/group';
import { scaleBand, scaleLog } from '@visx/scale';
import { Line } from '@visx/shape';
import { useTooltipInPortal } from '@visx/tooltip';
import { chain, get } from 'lodash';
import { RefObject, useMemo, useState } from 'react';
import { getGroupColor, getGroupTextColor } from '../../../shared/palettes';
import { theme } from '../../../shared/theme';
import { PlotTooltip, usePlotTooltip } from '../../../tooltips/plot-tooltip';
import { graphMargins } from '../consts';
import { friendlyGeneName } from '../site-details-utils';
import { getBarDatum } from './get-bar-datum';
import { ReductionBar } from './reduction-bar';
import { L2TargetTooltipContent } from './reduction-tooltip';
import { SingleBarDatum, TargetType } from './types';

interface Props {
  containerRef: RefObject<HTMLDivElement>;
  height?: number;
  width?: number;
  selectedSiteDatum: OverviewDatum;
  groupingTarget: TargetType;
}

const margins = { vertical: 40, horizontal: 140 };
const yMarginForxAxisLegend = 5;
export const textPadding = 8;

const antibioticWidths: Record<string, number> = {
  MLSB: 34,
  Aminoglycoside: 105,
  MDR: 26,
  'Beta-Lactam': 83,
  Quinolone: 65,
  Tetracycline: 79,
  Vancomycin: 77,
  Phenicol: 55,
  Sulfonamide: 81,
};

export function TreatedAndReducedBarGraph({
  containerRef,
  height = 300,
  width = 0,
  selectedSiteDatum,
  groupingTarget,
}: Props) {
  const { metricMode, activeChartUnit, activeOverviewConfiguration } = useOverviewContext();
  const { getGroup, allGeneGroups, allAssays } = useAssayContext();
  const geneNameByAssay = useMemo(
    () =>
      chain(allAssays)
        .keyBy(a => a.assay)
        .mapValues(a => friendlyGeneName(a.assay, a.gene))
        .value(),
    [allAssays],
  );
  const { queryFilters } = useSampleDataContext();
  const containerWidth = useContainerWidth(containerRef);
  const [hoveredTarget, setHoveredTarget] = useState<L2TargetOrAssay | null>(null);

  const isArg = (groupingTarget as L1Target) === L1Targets.ARG;

  const tooltipStuff = useTooltipInPortal({
    scroll: false,
    detectBounds: true,
  });
  const { tooltipData, tooltipProps } = usePlotTooltip<SingleBarDatum>(tooltipStuff);
  const tooltipRef = tooltipStuff.containerRef;

  if (!containerWidth) {
    return null;
  }

  const effectiveWidth = width || containerWidth - graphMargins;

  // Get data for both before and after samples
  const beforeGenesAndCopyNumbers = getGenesAndCopyNumbers(
    selectedSiteDatum,
    groupingTarget,
    metricMode,
    ProcessMode.BEFORE,
    getGroup,
    true,
  );
  const afterGenesAndCopyNumbers = getGenesAndCopyNumbers(
    selectedSiteDatum,
    groupingTarget,
    metricMode,
    ProcessMode.AFTER,
    getGroup,
    true,
  );

  // Get the list of targets to show based on the grouping target
  const targetList = get(
    l2TargetsByL1,
    groupingTarget,
    chain(beforeGenesAndCopyNumbers)
      .map(g => g.assay)
      .uniq()
      .value(),
  );

  // Calculate totals for each target
  const barData = chain(targetList)
    .map(target =>
      getBarDatum(
        groupingTarget,
        target,
        beforeGenesAndCopyNumbers,
        afterGenesAndCopyNumbers,
        LOQ_COPY_NUMBER_IN_L,
        selectedSiteDatum,
      ),
    )
    .filter(datum => !!datum.beforeTotal || !!datum.afterTotal)
    .sort((a, b) => (b.afterTotal || 0) - (a.afterTotal || 0))
    .value();

  if (barData.length === 0 || activeChartUnit !== ChartUnit.COPIES_PER_L) {
    return (
      <Flex justifyContent="center" alignItems="center" height="100%" width="100%">
        Data not available for site or sample type
      </Flex>
    );
  }

  const xMax = effectiveWidth - margins.horizontal;
  const yMax = height - margins.vertical;
  const treatedBarColor = getMetricColor(-3.5, MetricMode.REDUCTION, activeChartUnit);
  const increaseBarColor = getMetricColor(3, MetricMode.REDUCTION, activeChartUnit);
  const reducedBarColor = getMetricColor(-1, MetricMode.REDUCTION, activeChartUnit);
  const lodColor = theme.colors.neutral700;

  const maxTotal = Math.max(...barData.map(d => Math.max(d.beforeTotal || 0, d.afterTotal || 0)));

  const yScale = scaleBand<L2Targets | string>({
    range: [0, yMax],
    round: true,
    domain: barData.map(d => d.barIdentifier),
    padding: 0.4,
  });

  const minValue = LOQ_COPY_NUMBER_IN_L / 10;
  const xScale = scaleLog<number>({
    range: [0, xMax],
    round: true,
    domain: [minValue, (maxTotal || LOQ_COPY_NUMBER_IN_L) * 18],
    base: 10,
  });

  const lodX = xScale(LOQ_COPY_NUMBER_IN_L);

  return (
    <>
      <PlotTooltip {...tooltipProps} backgroundColor={theme.colors.blue50}>
        {tooltipData && (
          <L2TargetTooltipContent
            tooltipData={tooltipData}
            activeOverviewConfiguration={activeOverviewConfiguration}
            allGeneGroups={allGeneGroups}
          />
        )}
      </PlotTooltip>
      <svg width={effectiveWidth} height={height + yMarginForxAxisLegend} overflow="visible" ref={tooltipRef}>
        <Group top={0} left={margins.horizontal}>
          <GridColumns
            left={0}
            height={height - 34}
            scale={xScale}
            stroke={theme.colors.neutral300}
            tickValues={xScale.ticks(6).filter(value => Math.log10(value) % 1 === 0)}
          />
          {/* LOD line */}
          <Line
            from={{ x: lodX, y: 0 }}
            to={{ x: lodX, y: height - 34 }}
            stroke={lodColor}
            strokeOpacity={1}
            strokeWidth={2}
            strokeDasharray="2, 5"
          />
          <text
            x={lodX}
            y={yMax + 45}
            fontSize={12}
            fill={lodColor}
            fontWeight={theme.fontWeight.bold}
            textAnchor="middle"
          >
            (LOQ)
          </text>
          <AxisBottom
            scale={xScale}
            top={yMax}
            tickFormat={n => friendlyCopyNumber(n.valueOf())}
            tickValues={
              // Use the same tick values as the grid
              xScale.ticks(6).filter(value => Math.log10(value) % 1 === 0)
            }
            hideTicks
            hideAxisLine
            tickLabelProps={{
              fontSize: 14,
              color: theme.colors.neutral700,
              fontWeight: theme.fontWeight.bold,
            }}
          />
          {barData.map((data: SingleBarDatum) => {
            const barHeight = yScale.bandwidth();
            const barY = yScale(data.barIdentifier) || 0;
            // Handle zero values by using minValue
            const beforeWidth = xScale(data.beforeTotal || minValue);
            const afterWidth = xScale(data.afterTotal || minValue);

            return (
              <ReductionBar
                key={data.barIdentifier}
                datum={data}
                barY={barY}
                barHeight={barHeight}
                beforeWidth={data.isDecrease ? beforeWidth : afterWidth}
                afterWidth={data.isDecrease ? afterWidth : beforeWidth}
                baseBarColor={data.isDecrease ? treatedBarColor : increaseBarColor}
                changeBarColor={data.isDecrease ? reducedBarColor : increaseBarColor}
                tooltipStuff={tooltipStuff}
                hoveredTarget={hoveredTarget}
                setHoveredTarget={setHoveredTarget}
                xScale={xScale}
              />
            );
          })}
          <AxisLeft
            scale={yScale}
            tickFormat={value => {
              if ((groupingTarget as L1Target) === L1Targets.ARG) {
                return friendlyL2Target(value as L2Targets);
              }
              const geneName = get(geneNameByAssay, value, value);
              const abreviated = geneName.replace('resistance', 'res.');
              return abreviated[abreviated.length - 2] === '_' ? abreviated.slice(0, -2) : abreviated;
            }}
            tickLabelProps={{
              fontSize: 14,
              color: theme.colors.neutral700,
              fontWeight: theme.fontWeight.bold,
              cursor: isArg ? 'pointer' : 'default',
              dy: '.33em',
              textAnchor: 'end',
              x: -10,
            }}
            tickComponent={({ x, y, formattedValue }) => {
              if (!formattedValue) return null;

              // Only show interactive labels for L2Targets
              if (!isArg) {
                // Find the original identifier (assay ID) that corresponds to this formatted name
                // We need to go from the displayed label back to the original barIdentifier
                const targetIdentifier = barData.find(d => {
                  const geneName = get(geneNameByAssay, d.barIdentifier, d.barIdentifier);
                  const abbreviated = geneName.replace('resistance', 'res.');
                  const formatted =
                    abbreviated[abbreviated.length - 2] === '_' ? abbreviated.slice(0, -2) : abbreviated;
                  return formatted === formattedValue;
                })?.barIdentifier;

                if (!targetIdentifier) return null;

                const isNonArgSelected =
                  queryFilters.filters.selectedTargets.length === 0 ||
                  queryFilters.filters.selectedTargets.includes(targetIdentifier as L2Targets);
                const isNonArgTheOnlySelected =
                  queryFilters.filters.selectedTargets.length === 1 &&
                  queryFilters.filters.selectedTargets[0] === (targetIdentifier as L2Targets);

                return (
                  <text
                    x={x}
                    y={y}
                    fontSize={14}
                    fontWeight={isNonArgTheOnlySelected ? theme.fontWeight.extraHeavy : theme.fontWeight.bold}
                    fontStyle={groupingTarget === MinorTarget.PATHOGEN ? 'italic' : 'normal'}
                    fill={isNonArgSelected ? theme.colors.neutral700 : theme.colors.neutral500}
                    dy=".33em"
                    textAnchor="end"
                    cursor="pointer"
                    onClick={() => {
                      queryFilters.toggleSingleTarget(targetIdentifier);
                    }}
                  >
                    {formattedValue}
                  </text>
                );
              }

              const l2Target = Object.values(L2Targets).find(target => friendlyL2Target(target) === formattedValue);
              const isSelected =
                !l2Target ||
                queryFilters.filters.selectedTargets.length === 0 ||
                queryFilters.filters.selectedTargets.includes(l2Target);
              const isTheOnlySelected =
                l2Target &&
                queryFilters.filters.selectedTargets.length === 1 &&
                queryFilters.filters.selectedTargets[0] === l2Target;

              const textProps = {
                fontSize: 14,
                fontWeight: theme.fontWeight.bold,
                cursor: 'pointer',
              };

              if (isTheOnlySelected) {
                const groupColor = getGroupColor(l2Target, 'antibiotic', allGeneGroups);
                const groupTextColor = getGroupTextColor(l2Target, 'antibiotic', allGeneGroups);
                const padding = 6;
                const textMetrics = {
                  width: antibioticWidths[formattedValue] || 70, // fallback width if not found
                  height: 18,
                };

                return (
                  <g
                    onClick={() => {
                      queryFilters.toggleSingleTarget(l2Target);
                    }}
                  >
                    <rect
                      x={x - textMetrics.width - padding * 3}
                      y={y - textMetrics.height / 2 - padding / 2}
                      width={textMetrics.width + padding * 3}
                      height={textMetrics.height + padding}
                      fill={groupColor}
                      rx={theme.borders.radius.small}
                    />
                    <text
                      x={x - padding}
                      y={y}
                      fontSize={14}
                      fontWeight={theme.fontWeight.extraHeavy}
                      cursor="pointer"
                      fill={groupTextColor}
                      dy=".33em"
                      textAnchor="end"
                    >
                      {formattedValue}
                    </text>
                  </g>
                );
              }

              return (
                <text
                  x={x}
                  y={y}
                  {...textProps}
                  fill={isSelected ? theme.colors.neutral700 : theme.colors.neutral500}
                  dy=".33em"
                  textAnchor="end"
                  onClick={() => {
                    if (l2Target) {
                      queryFilters.toggleSingleTarget(l2Target);
                    }
                  }}
                >
                  {formattedValue}
                </text>
              );
            }}
            hideTicks
            hideAxisLine
          />
        </Group>
      </svg>
    </>
  );
}
