import React, { useEffect, useState, useRef } from "react";
import { graphql } from "gatsby";
import {
  Row,
  Col,
  Divider,
  Card,
  Collapse,
  Tag,
  Statistic,
  Radio,
  Tooltip,
  Checkbox,
} from "antd";
import {
  GithubOutlined,
  Html5Outlined,
  GoogleOutlined,
  ChromeOutlined,
} from "@ant-design/icons";
import {
  EditOutlined,
  EllipsisOutlined,
  SettingOutlined,
  TrophyOutlined,
} from "@ant-design/icons";
import BlogPostChrome from "../../components/BlogPostChrome";
import Img from "gatsby-image";
import { CaretRightOutlined, CalendarOutlined } from "@ant-design/icons";
import Layout from "../../components/Layouts";
import { DataSet, Network } from "vis-network/standalone";
import cardData from "./cards.json";
import "./Cifar10.css";

const { Panel } = Collapse;

export const frontmatter = {
  title: `Build CIFAR-10 Classifier using various frameworks`,
  written: `2020-11-14`,
  updated: `2020-12-20`,
  layoutType: `post`,
  contentType: "blog",
  path: `/cifar10-notebooks/`,
  category: `Deep learning`,
  image: "./poster.png",
  description: `Build CIFAR10 classifier(s) using different deep learning frameworks and hardware accelerators`,
};

const colorsMap = [
  { name: "tensorflow", color: "#ff6f00" },
  { name: "pytorch", color: "#ee4c2c" },
  { name: "pytorch lightning", color: "#792fe4", opacity: 0.7 },
];

function RenderTreeChart(props) {
  const domNode = useRef(null);
  const network = useRef(null);

  function addNode(record, column, returnData) {
    const source = record[column];
    // add nodes
    if (!returnData["nodes"].find((item) => item.label === source)) {
      if (source && source.length > 0) {
        let color = "#97C2FC";
        let opacity = 1.0;
        const found = colorsMap.find(
          (item) => item.name === source.toLowerCase()
        );
        if (found) {
          color = found.color;
          opacity = found.opacity || opacity;
        }

        returnData["nodes"].push({ id: source, label: source, color, opacity });
      }
    }

    return returnData;
  }

  function addEdges(record, column, scanFields, returnData) {
    const fieldValue = record[column];
    scanFields
      .filter((item) => item !== column)
      .map((field) => {
        return cardData
          .filter((item) => item[column] === fieldValue)
          .map((item) => {
            let color = "blue";
            const found = Object.keys(colorsMap).find(
              (item) => item === fieldValue.toLowerCase()
            );
            if (found) {
              color = found.color;
            }
            return { from: fieldValue, to: item[field], color: { color } };
          })
          .reduce((_, currentEdge) => {
            if (currentEdge.from !== currentEdge.to) {
              if (
                !returnData["edges"].find(
                  (edgeItem) =>
                    edgeItem.from === currentEdge.from &&
                    edgeItem.to === currentEdge.to
                ) &&
                !returnData["edges"].find(
                  (edgeItem) =>
                    edgeItem.from === currentEdge.to &&
                    edgeItem.to === currentEdge.from
                )
              ) {
                returnData["edges"].push(currentEdge);
              }
            }
          }, {});
      });
    return returnData;
  }

  // Things to track when creating graph
  // card_type, framework, accelerator_type, accelerator, experiment_framework
  const cardGraph = cardData.reduce(
    (acc, current) => {
      // const { card_type, framework, accelerator_type, accelerator, experiment_framework } = current;
      // if (!acc["nodes"].includes(card_type)) {
      //   acc["nodes"].push({id: card_type, label: card_type })
      // }
      const fields = [
        "title",
        "framework",
        "accelerator_type",
        "accelerator",
        "experiment_framework",
      ];
      for (let i = 0; i < fields.length; i++) {
        const field = fields[i];
        acc = addNode(current, field, acc);
        const { id } = current;
        const fieldValue = current[field];

        const edges = cardData
          .filter((item) => item[field] === fieldValue)
          .map((item) => {
            const toValue =
              i + 1 < fields.length ? item[fields[i + 1]] : fieldValue;
            // let color = "blue";
            // const found = Object.keys(colorsMap).find(item => item === toValue.toLowerCase());
            // if (found) {
            //   color = found.color;
            // }
            const found = colorsMap.find(
              (item) =>
                item.name === fieldValue.toLowerCase() ||
                item.name === toValue.toLowerCase()
            );
            let color = { inherit: "from" };
            if (found) {
              color = found.color;
            }

            return {
              from: fieldValue,
              to: toValue,
              // color: { inherit: "from"}
              color: color,
            };
          })
          .reduce((_, currentEdge) => {
            if (currentEdge.from !== currentEdge.to) {
              if (
                !acc["edges"].find(
                  (edgeItem) =>
                    edgeItem.from === currentEdge.from &&
                    edgeItem.to === currentEdge.to
                )
              ) {
                acc["edges"].push(currentEdge);
              }
            }
          }, {});
        // acc = addEdges(current, field, fields, acc);
        console.log(acc);
      }
      // acc = addNode(card_type, acc);
      // acc = addNode(framework, acc);
      // acc = addNode(accelerator_type, acc);
      // acc = addNode(accelerator, acc);
      // acc = addNode(experiment_framework, acc);
      return acc;
    },
    { nodes: [], edges: [] }
  );

  // create an array with nodes
  const nodes = new DataSet([
    { id: 1, label: "Node 1" },
    { id: 2, label: "Node 2" },
    { id: 3, label: "Node 3" },
    { id: 4, label: "Node 4" },
    { id: 5, label: "Node 5" },
  ]);

  // create an array with edges
  const edges = new DataSet([
    { from: 1, to: 3 },
    { from: 1, to: 2 },
    { from: 2, to: 4 },
    { from: 2, to: 5 },
    { from: 3, to: 3 },
  ]);

  const data = {
    nodes: nodes,
    edges: edges,
  };

  const options = {};

  // useEffect(() => {
  //   network.current = new Network(domNode.current, data, options);
  // }, [domNode, network, data, options])
  useEffect(() => {
    network.current = new Network(domNode.current, cardGraph, options);
  }, [domNode, network, cardGraph, options]);

  return <div className="network" ref={domNode}></div>;
}

function LinksPanel(props) {
  return (
    <Row gutter={[16, 24]}>
      <Col span={10}>
        <a href="#">🌍 HTML</a>
      </Col>
      <Col span={4}> | </Col>
      <Col span={10}>
        <a href="#">
          <GithubOutlined />
          &nbsp;GitHub
        </a>
      </Col>
    </Row>
  );
}

function NotebookCard(props) {
  const {
    title,
    card_type = "",
    framework,
    accelerator = "",
    accelerator_type = "",
    model,
    train_time,
    num_epochs = 50,
    train_acc,
    test_acc,
    links = [],
    experiment_framework = "",
    updated_on,
    grouped = false,
    maxTestAcc = 100,
  } = props;

  function create_tags() {
    const tags = [<Tag color="green">{accelerator_type.toUpperCase()}</Tag>];
    if (test_acc !== 0 && test_acc >= maxTestAcc) {
      tags.push(
        <Tag color="purple">
          <TrophyOutlined />
          ACCURACY
        </Tag>
      );
    }
    if (card_type && card_type === "Error") {
      tags.push(<Tag color="red">ERROR</Tag>);
    }
    if (card_type && card_type === "Planned") {
      tags.push(
        <Tag color="blue">
          {" "}
          <CalendarOutlined /> PLANNED
        </Tag>
      );
    }
    if (grouped === false) {
      tags.push(<Tag color="blue">{framework.toUpperCase()}</Tag>);
    }
    tags.push(<Tag color="orange">{experiment_framework.toUpperCase()}</Tag>);
    // is it the best
    return <div>{tags}</div>;
  }

  function typeIcon(type) {
    switch (type) {
      case "github": {
        return (
          <Tooltip title="GitHub">
            <GithubOutlined />
          </Tooltip>
        );
      }
      case "colab": {
        return (
          <Tooltip title="Google Colab">
            <GoogleOutlined />
          </Tooltip>
        );
      }
      default: {
        return (
          <Tooltip title="HTML">
            <ChromeOutlined />
          </Tooltip>
        );
      }
    }
  }

  function create_links() {
    return links.map((item) => {
      const { type, url } = item;
      return (
        <>
          <a href={url}>{typeIcon(type)}</a>
          <Divider type="vertical" />
        </>
      );
    });
  }

  function getTitle() {
    if (card_type && card_type === "Error") {
      return `${title} ⚠️`;
    } else {
      return title;
    }
  }

  return (
    <Card
      className="card-item"
      title={getTitle()}
      bordered={true}
      style={{ width: 300 }}
      extra={create_links()}
      // actions={[
      //   <a href="#">HTML</a>,
      //   <a href="#">GitHub</a>,
      // ]}
    >
      <p>
        <Statistic title={model} value={accelerator} />
      </p>
      <Row gutter={[16, 24]}>
        <Col span={12}>
          <Statistic title="Train Acc" value={train_acc} />
        </Col>
        <Col span={12}>
          <Statistic title="Test Acc" value={test_acc} />
        </Col>
      </Row>
      <Row gutter={[16, 24]}>
        <Col span={12}>
          <Statistic title="Epochs" value={num_epochs} />
        </Col>
        <Col span={12}>
          <Statistic title="Training Time" value={train_time} />
        </Col>
      </Row>
      <Row>
        <Col span={24}>
          <p style={{ fontSize: ".6em" }}>Last Updated: {updated_on}</p>
        </Col>
      </Row>
      <Row>
        <Col span={24}>{create_tags()}</Col>
      </Row>
    </Card>
  );
}

function GroupbyFramework(props) {
  const items = cardData.reduce((acc, item) => {
    const { framework } = item;
    if (!acc.includes(framework)) {
      acc.push(framework);
    }
    return acc;
  }, []);
  return (
    <div>
      {items &&
        items.map((item, idx) => {
          const key = `c-${idx}`;
          const activeKey = `c-0`;
          let maxTestAcc = cardData.reduce((acc, current) => {
            if (current.framework === item && current.test_acc > acc) {
              acc = current.test_acc;
            }
            return acc;
          }, 0);
          return (
            <Collapse
              defaultActiveKey={[activeKey]}
              bordered={false}
              expandIcon={({ isActive }) => (
                <CaretRightOutlined rotate={isActive ? 90 : 0} />
              )}
              style={{
                background: "white",
                border: "0px",
                marginBottom: "24px",
                borderRadius: "2px",
              }}
            >
              <Panel header={item.toUpperCase()} key={key}>
                <div className="cards-container">
                  {cardData
                    .filter((record) => record.framework === item)
                    .map((row) => {
                      return (
                        <NotebookCard
                          {...row}
                          grouped
                          maxTestAcc={maxTestAcc}
                        />
                      );
                    })}
                </div>
              </Panel>
            </Collapse>
          );
        })}
    </div>
  );
}

function ShowAllCards(props) {
  let maxTestAcc = cardData.reduce((acc, current) => {
    if (current.test_acc > acc) {
      acc = current.test_acc;
    }
    return acc;
  }, 0);

  return (
    <div className="cards-container">
      {cardData.map((item) => {
        return <NotebookCard {...item} maxTestAcc={maxTestAcc} />;
      })}
    </div>
  );
}

class Cifar10Classifier extends React.Component {
  constructor(props) {
    super(props);
    this.state = {
      mode: "top",
      radioValue: 2,
    };
  }

  onRadioChange = (e) => {
    console.log("radio checked", e.target.value);
    this.setState({
      radioValue: e.target.value,
    });
  };

  onShowGraph = (e) => {
    this.setState({
      showGraph: e.target.checked,
    });
  };

  handleModeChange = (e) => {
    const mode = e.target.value;
    this.setState({ mode });
  };

  render() {
    const { mode, showGraph, radioValue } = this.state;
    return (
      <Layout data={this.props.data} location={this.props.location}>
        <BlogPostChrome {...this.props.data.javascriptFrontmatter}>
        <h1 style={{ textAlign: "center"}}>
        CIFAR-10 CLASSIFIER USING VARIOUS FRAMEWORKS
        </h1>
        <p
          className="header-subtitle"
          style={{ marginTop: 20, marginBottom: 10 }}
        >          
          14 Nov, 2020
        </p>
          <article>
            <p>
              In this post, I share my collection of notebooks demonstrating how
              to build a{" "}
              <a href="https://www.cs.toronto.edu/~kriz/cifar.html">CIFAR-10</a>{" "}
              classifier in various deep learning frameworks. The objectives are
              to show how to:
            </p>
            <ul>
              <li>create a simple classifier using CNNs</li>
              <li>track experiments</li>
              <li>use hyperparameter tuning frameworks</li>
            </ul>
            <p>
              I started working on this post to update my old notebooks on{" "}
              <a href="/native_colab">training a classifier on GPU vs TPU</a>.
              In 2018, When I first wrote about GPU vs TPU, I wanted to find out
              the complexity involved in converting the code to switch from one
              accelerator (GPU) to another (TPUs). And whether it provided any
              benefit out of the box. A lot has changed since then.
            </p>
            <p>
              <strong>TensorFlow 2.0</strong> has made things a lot simpler. The
              eager mode is gentle on my brain, the Keras API, as always, is fun
              to work with. The introduction of <code>tf.data</code> API makes
              the construction of input pipelines easy. The features such as
              Autotune, cache, and prefetch take care of optimizing the
              pipeline. The <code>tf.distribute.Strategy</code> makes it simpler
              to switch between the accelerators (GPU, TPU).
            </p>
            <p>
              This time around I decided to cover <strong>PyTorch</strong>,{" "}
              <strong>PyTorch Lightning</strong>, and <strong>JAX</strong> as
              well. While I do have some experience working with PyTorch and
              Lightning, JAX is mainly there because I wanted a reason to make
              something in JAX 😀.
            </p>
            <p>
              Each card gives you some information about the notebook, training
              time, train and test accuracy, etc. I would advise you not to pay
              too much attention to the accuracy metrics because there is a
              slight difference in some notebooks' augmentation pipeline. Also,
              It is not my intention to perform any comparison between the
              frameworks. They all work great and may have pros and cons.
            </p>
            <h4>Update: Nov 3rd, 2020</h4>
            <p>
              My primary workstation, the one with GTX 1080TI in the cards
              below, is dead. I cannot continue with the following planned
              notebooks for now:
              <ul>
                <li>Optuna: PyTorch & PyTorch Lightning on GTX 1080TI</li>
                <li>Ray Tune: PyTorch Lightning on GTX 1080TI</li>
                <li>JAX * on GTX 1080TI</li>
              </ul>
            </p>
          </article>
          <Divider />
          {/* {/* <Checkbox onChange={this.onShowGraph}>Show Graph</Checkbox>            */}

          <Radio.Group
            onChange={this.onRadioChange}
            value={this.state.radioValue}
          >
            <Radio.Button value={1}>Show All Cards</Radio.Button>
            <Radio.Button value={2}>Group by Framework</Radio.Button>
            <Radio.Button value={3}>Graph View</Radio.Button>
          </Radio.Group>

          {radioValue === 1 && <ShowAllCards />}
          <Divider />

          {radioValue === 2 && <GroupbyFramework />}
          {radioValue === 3 && <RenderTreeChart />}
        </BlogPostChrome>
      </Layout>
    );
  }
}

export default Cifar10Classifier;

const styles = {};

styles.row = {
  display: `flex`,
  flexWrap: `wrap`,
  margin: `8px -4px 1rem`,
};

export const pageQuery = graphql`
  query cifar10classifierQuery($slug: String!) {
    javascriptFrontmatter(fields: { slug: { eq: $slug } }) {
      ...JSBlogPost_data
    }
  }
`;
