add partial derivative

This commit is contained in:
2024-08-25 23:45:56 +08:00
parent 90fcee2ff9
commit cc06c34967
6 changed files with 204 additions and 15 deletions

19
mbcp/mp_math/const.py Normal file
View File

@@ -0,0 +1,19 @@
# -*- coding: utf-8 -*-
"""
Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@Time : 2024/8/25 下午9:45
@Author : snowykami
@Email : snowykami@outlook.com
@File : const.py
@Software: PyCharm
"""
import math
PI = math.pi
E = math.e
GOLDEN_RATIO = (1 + math.sqrt(5)) / 2
GAMMA = 0.57721566490153286060651209008240243104215933593992
EPSILON = 1e-8

View File

@@ -8,13 +8,14 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@File : equation.py
@Software: PyCharm
"""
import numpy as np
from .point import Point3
from .mp_math_typing import ONE_VARIABLE_FUNC, TWO_VARIABLES_FUNC, THREE_VARIABLES_FUNC
from mbcp.mp_math.mp_math_typing import OneVarFunc, Var, MultiVarFunc, Number
from mbcp.mp_math.point import Point3
from mbcp.mp_math.const import EPSILON
class CurveEquation:
def __init__(self, x_func: ONE_VARIABLE_FUNC, y_func: ONE_VARIABLE_FUNC, z_func: ONE_VARIABLE_FUNC):
def __init__(self, x_func: OneVarFunc, y_func: OneVarFunc, z_func: OneVarFunc):
"""
曲线方程。
:param x_func:
@@ -25,14 +26,45 @@ class CurveEquation:
self.y_func = y_func
self.z_func = z_func
def __call__(self, *t: float) -> "Point3" | tuple["Point3"]:
def __call__(self, *t: Var) -> Point3 | tuple[Point3, ...]:
"""
计算曲线上的点。
Args:
*t:
Returns:
"""
if len(t) == 1:
return Point3(self.x_func(t[0]), self.y_func(t[0]), self.z_func(t[0]))
else:
# np加速
...
return tuple([Point3(x, y, z) for x, y, z in zip(self.x_func(t), self.y_func(t), self.z_func(t))])
def __str__(self):
return "CurveEquation()"
def get_partial_derivative_func(func: MultiVarFunc, var: int | tuple[int, ...], epsilon: Number = EPSILON) -> MultiVarFunc:
"""
求N元函数偏导函数。
Args:
func: 函数
var: 变量位置,可为整数(一阶偏导)或整数元组(高阶偏导)
epsilon: 偏移量
Returns:
偏导函数
"""
if isinstance(var, int):
def partial_derivative_func(*args: Var) -> Var:
args_list_plus = list(args)
args_list_plus[var] += epsilon
args_list_minus = list(args)
args_list_minus[var] -= epsilon
return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon)
return partial_derivative_func
elif isinstance(var, tuple):
for i in var:
func = get_partial_derivative_func(func, i, epsilon)
return func
else:
raise ValueError("Invalid var type")

View File

@@ -8,11 +8,26 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@File : mp_math_typing.py
@Software: PyCharm
"""
from typing import Callable, Iterable, TypeAlias
from typing import Callable, Iterable, TypeAlias, TypeVar
"""自变量"""
VAR: TypeAlias = float | Iterable[float] # 为后期支持多维矢量化做准备
RealNumber: TypeAlias = int | float
Number: TypeAlias = RealNumber | complex
SingleVar = TypeVar("SingleVar", bound=Number)
ArrayVar = TypeVar("ArrayVar", bound=Iterable[Number])
Var: TypeAlias = SingleVar | ArrayVar
ONE_VARIABLE_FUNC: TypeAlias = Callable[[VAR], float]
TWO_VARIABLES_FUNC: TypeAlias = Callable[[VAR, VAR], float]
THREE_VARIABLES_FUNC: TypeAlias = Callable[[VAR, VAR, VAR], float]
OneSingleVarFunc: TypeAlias = Callable[[SingleVar], SingleVar]
OneArrayFunc: TypeAlias = Callable[[ArrayVar], ArrayVar]
OneVarFunc: TypeAlias = OneSingleVarFunc | OneArrayFunc
TwoSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar], SingleVar]
TwoArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar], ArrayVar]
TwoVarFunc: TypeAlias = TwoSingleVarFunc | TwoArrayFunc
ThreeSingleVarFunc: TypeAlias = Callable[[SingleVar, SingleVar, SingleVar], SingleVar]
ThreeArrayFunc: TypeAlias = Callable[[ArrayVar, ArrayVar, ArrayVar], ArrayVar]
ThreeVarFunc: TypeAlias = ThreeSingleVarFunc | ThreeArrayFunc
MultiSingleVarFunc: TypeAlias = Callable[..., SingleVar]
MultiArrayFunc: TypeAlias = Callable[..., ArrayVar]
MultiVarFunc: TypeAlias = MultiSingleVarFunc | MultiArrayFunc

View File

@@ -8,6 +8,9 @@ Copyright (C) 2020-2024 LiteyukiStudio. All Rights Reserved
@File : utils.py
@Software: PyCharm
"""
from typing import overload
from mbcp.mp_math.mp_math_typing import RealNumber
def clamp(x: float, min_: float, max_: float) -> float:
@@ -22,3 +25,34 @@ def clamp(x: float, min_: float, max_: float) -> float:
限制后的值
"""
return max(min(x, max_), min_)
class Approx(float):
"""
用于近似比较浮点数的类。
"""
epsilon = 0.001
"""全局近似值。"""
def __new__(cls, x: RealNumber):
return super().__new__(cls, x)
def __eq__(self, other):
return abs(self - other) < Approx.epsilon
def __ne__(self, other):
return not self.__eq__(other)
def approx(x: float, y: float = 0.0, epsilon: float = 0.0001) -> bool:
"""
判断两个数是否近似相等。或包装一个实数用于判断是否近似于0。
Args:
x:
y:
epsilon:
Returns:
是否近似相等
"""
return abs(x - y) < epsilon