sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

栏目: 数据库 · 发布时间: 5年前

# ohlc_clustering.py

import copy
import datetime
import pymysql

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# from matplotlib.finance import candlestick_ohlc
import matplotlib.dates as mdates
from matplotlib.dates import (
    DateFormatter, WeekdayLocator, DayLocator, MONDAY
)
import mpl_finance as mpf
import numpy as np
import pandas as pd
import pandas_datareader.data as web
from sklearn.cluster import KMeans

def get_open_normalised_prices():
    """
    Obtains a pandas DataFrame containing open normalised prices
    for high, low and close for a particular equities symbol
    from Yahoo Finance. That is, it creates High/Open, Low/Open
    and Close/Open columns.
    """
    # df = web.DataReader(symbol, "yahoo", start, end)

    connect = pymysql.connect(
        host='127.0.0.1',
        db='blog',
        user='root',
        passwd='123456',
        charset='utf8',
        use_unicode=True
    )
    select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01'  order by date asc"
    df = pd.read_sql(select_sql_300, con=connect)

    df["H/O"] = df["High"]/df["Open"]
    df["L/O"] = df["Low"]/df["Open"]
    df["C/O"] = df["Close"]/df["Open"]
    df.drop(
        [
            "Open", "High", "Low",
            "Close", "Date"
        ],
        axis=1, inplace=True
    )
    return df

def plot_candlesticks(data):
    """
    Plot a candlestick chart of the prices,
    appropriately formatted for dates
    """
    # Copy and reset the index of the dataframe
    # to only use a subset of the data for plotting
    df = copy.deepcopy(data)
    # df = df[df.index >= since]
    df.reset_index(inplace=True)
    df['date_fmt'] = df['Date'].apply(
        lambda date: mdates.date2num(date.to_pydatetime())
    )

    # Set the axis formatting correctly for dates
    # with Mondays highlighted as a "major" tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter('%b %d')
    fig, ax = plt.subplots(figsize=(16,4))
    fig.subplots_adjust(bottom=0.2)
    # ax.xaxis.set_major_locator(mondays)
    # ax.xaxis.set_minor_locator(alldays)
    # ax.xaxis.set_major_formatter(weekFormatter)

    # Plot the candlestick OHLC chart using black for
    # up days and red for down days
    csticks = mpf.candlestick_ohlc(
        ax, df[
            ['date_fmt', 'Open', 'High', 'Low', 'Close']
        ].values, width=0.6,
        colorup='r', colordown='green'
    )
    # ax.set_axis_bgcolor((1,1,0.9))
    ax.xaxis_date()
    # plt.setp(
    #     plt.gca().get_xticklabels(),
    #     rotation=45, horizontalalignment='right'
    # )
    plt.show()


def plot_cluster(data):
    df = copy.deepcopy(data)
    # df = df[df.index >= since]
    df.reset_index(inplace=True)
    df['date_fmt'] = df['Date'].apply(
        lambda date: mdates.date2num(date.to_pydatetime())
    )

    # Set the axis formatting correctly for dates
    # with Mondays highlighted as a "major" tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter('%b %d')
    fig, ax = plt.subplots(figsize=(16, 4))
    fig.subplots_adjust(bottom=0.2)
    # ax.xaxis.set_major_locator(mondays)
    # ax.xaxis.set_minor_locator(alldays)
    # ax.xaxis.set_major_formatter(weekFormatter)

    df0 = df.loc[df["Cluster"] == 0]
    df1 = df.loc[df["Cluster"] == 1]
    df2 = df.loc[df["Cluster"] == 2]
    df3 = df.loc[df["Cluster"] == 3]

    size = 1.2
    ax.scatter(df0['date_fmt'], df0['Close'], s=size, c='y',marker='o',label="Small Rise")
    ax.scatter(df1['date_fmt'], df1['Close'], s=size, c='g', marker='o', label="Big Down")
    ax.scatter(df2['date_fmt'], df2['Close'], s=size, c='r', marker='o', label="Big Rise")
    ax.scatter(df3['date_fmt'], df3['Close'], s=size, c='b', marker='o', label="Small Down")

    ax.xaxis_date()
    plt.xlabel('Date')
    plt.ylabel('Close')
    plt.legend(loc='upper right')

    # plt.setp(
    #     plt.gca().get_xticklabels(),
    #     rotation=45, horizontalalignment='right'
    # )
    plt.show()

def plot_3d_normalised_candles(data):
    """
    Plot a 3D scatterchart of the open-normalised bars
    highlighting the separate clusters by colour
    """
    fig = plt.figure(figsize=(12, 9))
    ax = Axes3D(fig, elev=21, azim=-136)
    ax.scatter(
        data["H/O"], data["L/O"], data["C/O"],
        c=labels.astype(np.float)
    )
    ax.set_xlabel('High/Open')
    ax.set_ylabel('Low/Open')
    ax.set_zlabel('Close/Open')
    plt.show()

def plot_cluster_ordered_candles(data):
    """
    Plot a candlestick chart ordered by cluster membership
    with the dotted blue line representing each cluster
    boundary.
    """
    # Set the format for the axis to account for dates
    # correctly, particularly Monday as a major tick
    mondays = WeekdayLocator(MONDAY)
    alldays = DayLocator()
    weekFormatter = DateFormatter("")
    fig, ax = plt.subplots(figsize=(16,4))
    ax.xaxis.set_major_locator(mondays)
    ax.xaxis.set_minor_locator(alldays)
    ax.xaxis.set_major_formatter(weekFormatter)

    # Sort the data by the cluster values and obtain
    # a separate DataFrame listing the index values at
    # which the cluster boundaries change
    df = copy.deepcopy(data)
    df.sort_values(by="Cluster", inplace=True)
    df.reset_index(inplace=True)
    df["clust_index"] = df.index
    df["clust_change"] = df["Cluster"].diff()
    change_indices = df[df["clust_change"] != 0]

    # Plot the OHLC chart with cluster-ordered "candles"
    csticks = mpf.candlestick_ohlc(
        ax, df[
            ["clust_index", 'Open', 'High', 'Low', 'Close']
        ].values, width=0.6,
        colorup='#000000', colordown='#ff0000'
    )
    # ax.set_axis_bgcolor((1,1,0.9))

    # Add each of the cluster boundaries as a blue dotted line
    for row in change_indices.iterrows():
        plt.axvline(
            row[1]["clust_index"],
            linestyle="dashed", c="blue"
        )
    plt.xlim(0, len(df))
    plt.setp(
        plt.gca().get_xticklabels(),
        rotation=45, horizontalalignment='right'
    )
    plt.show()

def create_follow_cluster_matrix(data):
    """
    Creates a k x k matrix, where k is the number of clusters
    that shows when cluster j follows cluster i.
    """
    data["ClusterTomorrow"] = data["Cluster"].shift(-1)
    data.dropna(inplace=True)
    data["ClusterTomorrow"] = data["ClusterTomorrow"].apply(int)
    hs300["ClusterMatrix"] = list(zip(data["Cluster"], data["ClusterTomorrow"]))
    cmvc = data["ClusterMatrix"].value_counts()
    clust_mat = np.zeros( (k, k) )
    for row in cmvc.iteritems():
        clust_mat[row[0]] = row[1]*100.0/len(data)
    print("Cluster Follow-on Matrix:")
    print(clust_mat)


if __name__ == "__main__":
    # Obtain S&P500 pricing data from Yahoo Finance

    connect = pymysql.connect(
        host='127.0.0.1',
        db='blog',
        user='root',
        passwd='123456',
        charset='utf8',
        use_unicode=True
    )
    select_sql_300 = "select date as Date,open as Open,high as High,low as Low,adj_close as Close from `tmp_stock` where code ='399300' and date >= '2004-6-01'  order by date asc"
    hs300 = pd.read_sql(select_sql_300, con=connect)


    # # Plot last year of price "candles"
    plot_candlesticks(hs300)

    # Carry out K-Means clustering with four clusters on the
    # three-dimensional data H/O, L/O and C/O
    hs300_norm = get_open_normalised_prices()
    k = 4
    km = KMeans(n_clusters=k, random_state=42)
    km.fit(hs300_norm)
    labels = km.labels_
    hs300_norm["Cluster"] = labels
    hs300["Cluster"] = labels


    #
    # # Plot the 3D normalised candles using H/O, L/O, C/O
    plot_3d_normalised_candles(hs300_norm)


    # Create and output the cluster follow-on matrix
    create_follow_cluster_matrix(hs300)

    plot_cluster(hs300)

sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

sklearn kMeans 分类实战,对沪深300的每日涨跌进行分类

http://www.waitingfy.com/archives/5039

参考:

https://zhuanlan.zhihu.com/p/43872533

https://www.quantstart.com/articles/k-means-clustering-of-daily-ohlc-bar-data

Post Views: 0

5039


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

计算群体智能基础

计算群体智能基础

恩格尔伯里特 / 谭营 / 2009-10 / 69.00元

《计算群体智能基础》全面系统地介绍了计算群体智能中的粒子群优化(PSO)和蚁群优化(ACO)的基本概念、基本模型、理论分析及其应用。在简要介绍基本优化理论和总结各类优化问题之后,重点介绍了社会网络结构如何在个体间交换信息以及个体聚集行为如何形成一个功能强大的有机体。在概述了进化计算后,重点论述了粒子群优化和蚁群优化的基本模型及其各种变体,给出了分析粒子群优化模型的一种通用方法,证明了基于蚂蚁行为实......一起来看看 《计算群体智能基础》 这本书的介绍吧!

JS 压缩/解压工具
JS 压缩/解压工具

在线压缩/解压 JS 代码

UNIX 时间戳转换
UNIX 时间戳转换

UNIX 时间戳转换

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试