Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/silu_and_mul.hpp"
#include "ops/swiglu.hpp"
14 changes: 14 additions & 0 deletions include/infinicore/ops/silu_and_mul.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_CLASS(SiluAndMul, Tensor, Tensor);

Tensor silu_and_mul(Tensor x);
void silu_and_mul_(Tensor out, Tensor x);

} // namespace infinicore::op
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "infiniop/ops/rope.h"
#include "infiniop/ops/sigmoid.h"
#include "infiniop/ops/silu.h"
#include "infiniop/ops/silu_and_mul.h"
#include "infiniop/ops/softmax.h"
#include "infiniop/ops/softplus.h"
#include "infiniop/ops/sub.h"
Expand Down
29 changes: 29 additions & 0 deletions include/infiniop/ops/silu_and_mul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef __INFINIOP_SILU_AND_MUL_API_H__
#define __INFINIOP_SILU_AND_MUL_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopSiluAndMulDescriptor_t;

__C __export infiniStatus_t infiniopCreateSiluAndMulDescriptor(
infiniopHandle_t handle,
infiniopSiluAndMulDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input);

__C __export infiniStatus_t infiniopGetSiluAndMulWorkspaceSize(
infiniopSiluAndMulDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopSiluAndMul(
infiniopSiluAndMulDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *output,
const void *input,
void *stream);

__C __export infiniStatus_t infiniopDestroySiluAndMulDescriptor(
infiniopSiluAndMulDescriptor_t desc);

#endif // __INFINIOP_SILU_AND_MUL_API_H__
2 changes: 2 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .rms_norm import rms_norm
from .rope import RopeAlgo, rope
from .silu import silu
from .silu_and_mul import silu_and_mul
from .swiglu import swiglu

__all__ = [
Expand All @@ -17,4 +18,5 @@
"embedding",
"rope",
"RopeAlgo",
"silu_and_mul",
]
17 changes: 17 additions & 0 deletions python/infinicore/nn/functional/silu_and_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def silu_and_mul(input: Tensor, out=None) -> Tensor:
r"""Apply the SiLU and Mul (SwiGLU) function.
Formula: output = SiLU(input_gate) * input_up
Input shape: [..., 2*d], Output shape: [..., d]
"""

if out is None:
return Tensor(_infinicore.silu_and_mul(input._underlying))

_infinicore.silu_and_mul_(out._underlying, input._underlying)

return out
35 changes: 35 additions & 0 deletions src/infinicore/ops/silu_and_mul/silu_and_mul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#include "infinicore/ops/silu_and_mul.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(SiluAndMul);

SiluAndMul::SiluAndMul(Tensor out, Tensor x) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, x);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), out, x);
}

void SiluAndMul::execute(Tensor out, Tensor x) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(SiluAndMul, out, x);
}

Tensor silu_and_mul(Tensor x) {
Shape shape = x->shape();
size_t ndim = x->ndim();

if (shape[ndim - 1] % 2 != 0) {
throw std::runtime_error("SiluAndMul input last dim must be even.");
}
shape[ndim - 1] /= 2;

auto out = Tensor::empty(shape, x->dtype(), x->device());
silu_and_mul_(out, x);
return out;
}

void silu_and_mul_(Tensor out, Tensor x) {
SiluAndMul::execute(out, x);
}

} // namespace infinicore::op
50 changes: 50 additions & 0 deletions src/infinicore/ops/silu_and_mul/silu_and_mul_infiniop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "../infiniop_impl.hpp"
#include "infinicore/ops/silu_and_mul.hpp"

namespace infinicore::op::silu_and_mul_impl::infiniop {

INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, SiluAndMul, 100);

struct PlannedMeta {
std::shared_ptr<Descriptor> descriptor;
graph::GraphTensor workspace, output, input;
};

void *plan(Tensor output, Tensor input) {
size_t seed = hash_combine(output, input);

INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE(
Descriptor, descriptor, SiluAndMul,
seed, output->desc(), input->desc());

INFINIOP_WORKSPACE_TENSOR(workspace, SiluAndMul, descriptor);

auto planned = new PlannedMeta{
descriptor,
graph::GraphTensor(workspace),
graph::GraphTensor(output),
graph::GraphTensor(input)};

return planned;
}

void run(void *planned_meta) {
auto planned = reinterpret_cast<PlannedMeta *>(planned_meta);

INFINICORE_CHECK_ERROR(infiniopSiluAndMul(
planned->descriptor->desc,
planned->workspace->data(),
planned->workspace->numel(),
planned->output->data(),
planned->input->data(),
context::getStream()));
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(SiluAndMul, &plan, &run, &cleanup);

} // namespace infinicore::op::silu_and_mul_impl::infiniop
2 changes: 2 additions & 0 deletions src/infinicore/pybind11/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/silu_and_mul.hpp"
#include "ops/swiglu.hpp"

namespace py = pybind11;
Expand All @@ -42,6 +43,7 @@ inline void bind(py::module &m) {
bind_swiglu(m);
bind_rope(m);
bind_embedding(m);
bind_silu_and_mul(m);
}

} // namespace infinicore::ops
29 changes: 29 additions & 0 deletions src/infinicore/pybind11/ops/silu_and_mul.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <pybind11/pybind11.h>

#include "infinicore/ops/silu_and_mul.hpp"

namespace py = pybind11;

namespace infinicore::ops {

inline void bind_silu_and_mul(py::module &m) {
m.def("silu_and_mul",
&op::silu_and_mul,
py::arg("input"),
R"doc(
SiLU and Mul (SwiGLU) activation function.
Input should be [..., 2*d], output will be [..., d].
)doc");

m.def("silu_and_mul_",
&op::silu_and_mul_,
py::arg("output"),
py::arg("input"),
R"doc(
In-place or destination-specified SiLU and Mul (SwiGLU) activation function.
)doc");
}

} // namespace infinicore::ops
54 changes: 54 additions & 0 deletions src/infiniop/ops/silu_and_mul/info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#ifndef __SILU_AND_MUL_INFO_H__
#define __SILU_AND_MUL_INFO_H__

#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>

namespace op::silu_and_mul {

class SiluAndMulInfo {
SiluAndMulInfo() = default;

public:
infiniDtype_t dtype;
size_t batch_size;
size_t out_hidden_dim;

static utils::Result<SiluAndMulInfo> create(infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) {
auto dtype = y_desc->dtype();

auto x_shape = x_desc->shape();
auto y_shape = y_desc->shape();
auto ndim = x_desc->ndim();

if (ndim != y_desc->ndim()) {
return INFINI_STATUS_BAD_PARAM;
}

if (x_shape[ndim - 1] != 2 * y_shape[ndim - 1]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

size_t batch = 1;
for (int i = 0; i < (int)ndim - 1; ++i) {
if (x_shape[i] != y_shape[i]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
batch *= y_shape[i];
}

return utils::Result<SiluAndMulInfo>(SiluAndMulInfo{
dtype,
batch,
y_shape[ndim - 1]});
}

private:
SiluAndMulInfo(infiniDtype_t dtype, size_t batch, size_t hidden)
: dtype(dtype), batch_size(batch), out_hidden_dim(hidden) {}
};

} // namespace op::silu_and_mul

#endif // __SILU_AND_MUL_INFO_H__
8 changes: 8 additions & 0 deletions src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SILU_ADN_MUL_MOORE_API_H__
#define __SILU_ADN_MUL_MOORE_API_H__

#include "../silu_and_mul.h"

DESCRIPTOR(moore)

#endif // __SILU_ADN_MUL_MOORE_API_H__
Loading